1use crate::entities::json::{
18 ContextJsonDeserializationError, ContextJsonParser, NullContextSchema,
19};
20use crate::evaluator::{EvaluationError, RestrictedEvaluator};
21use crate::extensions::Extensions;
22use crate::parser::Loc;
23use miette::Diagnostic;
24use serde::{Deserialize, Serialize};
25use smol_str::{SmolStr, ToSmolStr};
26use std::collections::{BTreeMap, HashMap};
27use std::sync::Arc;
28use thiserror::Error;
29
30use super::{
31 BorrowedRestrictedExpr, BoundedDisplay, EntityType, EntityUID, Expr, ExprKind,
32 ExpressionConstructionError, PartialValue, RestrictedExpr, Unknown, Value, ValueKind, Var,
33};
34
35#[derive(Debug, Clone)]
37pub struct Request {
38 pub(crate) principal: EntityUIDEntry,
40
41 pub(crate) action: EntityUIDEntry,
43
44 pub(crate) resource: EntityUIDEntry,
46
47 pub(crate) context: Option<Context>,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
54#[serde(rename_all = "camelCase")]
55pub struct RequestType {
56 pub principal: EntityType,
58 pub action: EntityUID,
60 pub resource: EntityType,
62}
63
64#[derive(Debug, Clone)]
68pub enum EntityUIDEntry {
69 Known {
71 euid: Arc<EntityUID>,
73 loc: Option<Loc>,
75 },
76 Unknown {
78 ty: Option<EntityType>,
80
81 loc: Option<Loc>,
83 },
84}
85
86impl EntityUIDEntry {
87 pub fn evaluate(&self, var: Var) -> PartialValue {
91 match self {
92 EntityUIDEntry::Known { euid, loc } => {
93 Value::new(Arc::unwrap_or_clone(Arc::clone(euid)), loc.clone()).into()
94 }
95 EntityUIDEntry::Unknown { ty: None, loc } => {
96 Expr::unknown(Unknown::new_untyped(var.to_smolstr()))
97 .with_maybe_source_loc(loc.clone())
98 .into()
99 }
100 EntityUIDEntry::Unknown {
101 ty: Some(known_type),
102 loc,
103 } => Expr::unknown(Unknown::new_with_type(
104 var.to_smolstr(),
105 super::Type::Entity {
106 ty: known_type.clone(),
107 },
108 ))
109 .with_maybe_source_loc(loc.clone())
110 .into(),
111 }
112 }
113
114 pub fn known(euid: EntityUID, loc: Option<Loc>) -> Self {
116 Self::Known {
117 euid: Arc::new(euid),
118 loc,
119 }
120 }
121
122 pub fn unknown() -> Self {
124 Self::Unknown {
125 ty: None,
126 loc: None,
127 }
128 }
129
130 pub fn unknown_with_type(ty: EntityType, loc: Option<Loc>) -> Self {
132 Self::Unknown { ty: Some(ty), loc }
133 }
134
135 pub fn uid(&self) -> Option<&EntityUID> {
137 match self {
138 Self::Known { euid, .. } => Some(euid),
139 Self::Unknown { .. } => None,
140 }
141 }
142
143 pub fn get_type(&self) -> Option<&EntityType> {
145 match self {
146 Self::Known { euid, .. } => Some(euid.entity_type()),
147 Self::Unknown { ty, .. } => ty.as_ref(),
148 }
149 }
150}
151
152impl Request {
153 pub fn new<S: RequestSchema>(
158 principal: (EntityUID, Option<Loc>),
159 action: (EntityUID, Option<Loc>),
160 resource: (EntityUID, Option<Loc>),
161 context: Context,
162 schema: Option<&S>,
163 extensions: &Extensions<'_>,
164 ) -> Result<Self, S::Error> {
165 let req = Self {
166 principal: EntityUIDEntry::known(principal.0, principal.1),
167 action: EntityUIDEntry::known(action.0, action.1),
168 resource: EntityUIDEntry::known(resource.0, resource.1),
169 context: Some(context),
170 };
171 if let Some(schema) = schema {
172 schema.validate_request(&req, extensions)?;
173 }
174 Ok(req)
175 }
176
177 pub fn new_with_unknowns<S: RequestSchema>(
183 principal: EntityUIDEntry,
184 action: EntityUIDEntry,
185 resource: EntityUIDEntry,
186 context: Option<Context>,
187 schema: Option<&S>,
188 extensions: &Extensions<'_>,
189 ) -> Result<Self, S::Error> {
190 let req = Self {
191 principal,
192 action,
193 resource,
194 context,
195 };
196 if let Some(schema) = schema {
197 schema.validate_request(&req, extensions)?;
198 }
199 Ok(req)
200 }
201
202 pub fn new_unchecked(
205 principal: EntityUIDEntry,
206 action: EntityUIDEntry,
207 resource: EntityUIDEntry,
208 context: Option<Context>,
209 ) -> Self {
210 Self {
211 principal,
212 action,
213 resource,
214 context,
215 }
216 }
217
218 pub fn principal(&self) -> &EntityUIDEntry {
220 &self.principal
221 }
222
223 pub fn action(&self) -> &EntityUIDEntry {
225 &self.action
226 }
227
228 pub fn resource(&self) -> &EntityUIDEntry {
230 &self.resource
231 }
232
233 pub fn context(&self) -> Option<&Context> {
236 self.context.as_ref()
237 }
238
239 pub fn to_request_type(&self) -> Option<RequestType> {
245 Some(RequestType {
246 principal: self.principal().uid()?.entity_type().clone(),
247 action: self.action().uid()?.clone(),
248 resource: self.resource().uid()?.entity_type().clone(),
249 })
250 }
251}
252
253impl std::fmt::Display for Request {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 let display_euid = |maybe_euid: &EntityUIDEntry| match maybe_euid {
256 EntityUIDEntry::Known { euid, .. } => format!("{euid}"),
257 EntityUIDEntry::Unknown { ty: None, .. } => "unknown".to_string(),
258 EntityUIDEntry::Unknown {
259 ty: Some(known_type),
260 ..
261 } => format!("unknown of type {}", known_type),
262 };
263 write!(
264 f,
265 "request with principal {}, action {}, resource {}, and context {}",
266 display_euid(&self.principal),
267 display_euid(&self.action),
268 display_euid(&self.resource),
269 match &self.context {
270 Some(x) => format!("{x}"),
271 None => "unknown".to_string(),
272 }
273 )
274 }
275}
276
277#[derive(Debug, Clone, PartialEq, Eq)]
279pub enum Context {
280 Value(Arc<BTreeMap<SmolStr, Value>>),
282 RestrictedResidual(Arc<BTreeMap<SmolStr, Expr>>),
287}
288
289impl Context {
290 pub fn empty() -> Self {
292 Self::Value(Arc::new(BTreeMap::new()))
293 }
294
295 fn from_restricted_partial_val_unchecked(
301 value: PartialValue,
302 ) -> Result<Self, ContextCreationError> {
303 match value {
304 PartialValue::Value(v) => {
305 if let ValueKind::Record(attrs) = v.value {
306 Ok(Context::Value(attrs))
307 } else {
308 Err(ContextCreationError::not_a_record(v.into()))
309 }
310 }
311 PartialValue::Residual(e) => {
312 if let ExprKind::Record(attrs) = e.expr_kind() {
313 Ok(Context::RestrictedResidual(attrs.clone()))
320 } else {
321 Err(ContextCreationError::not_a_record(e))
322 }
323 }
324 }
325 }
326
327 pub fn from_expr(
332 expr: BorrowedRestrictedExpr<'_>,
333 extensions: &Extensions<'_>,
334 ) -> Result<Self, ContextCreationError> {
335 match expr.expr_kind() {
336 ExprKind::Record { .. } => {
337 let evaluator = RestrictedEvaluator::new(extensions);
338 let pval = evaluator.partial_interpret(expr)?;
339 #[allow(clippy::expect_used)]
347 Ok(Self::from_restricted_partial_val_unchecked(pval).expect(
348 "`from_restricted_partial_val_unchecked` should succeed when called on a record.",
349 ))
350 }
351 _ => Err(ContextCreationError::not_a_record(expr.to_owned().into())),
352 }
353 }
354
355 pub fn from_pairs(
361 pairs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
362 extensions: &Extensions<'_>,
363 ) -> Result<Self, ContextCreationError> {
364 match RestrictedExpr::record(pairs) {
365 Ok(record) => Self::from_expr(record.as_borrowed(), extensions),
366 Err(ExpressionConstructionError::DuplicateKey(err)) => Err(
367 ExpressionConstructionError::DuplicateKey(err.with_context("in context")).into(),
368 ),
369 }
370 }
371
372 pub fn from_json_str(json: &str) -> Result<Self, ContextJsonDeserializationError> {
379 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
380 .from_json_str(json)
381 }
382
383 pub fn from_json_value(
390 json: serde_json::Value,
391 ) -> Result<Self, ContextJsonDeserializationError> {
392 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
393 .from_json_value(json)
394 }
395
396 pub fn from_json_file(
403 json: impl std::io::Read,
404 ) -> Result<Self, ContextJsonDeserializationError> {
405 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
406 .from_json_file(json)
407 }
408
409 pub fn num_keys(&self) -> usize {
411 match self {
412 Context::Value(record) => record.len(),
413 Context::RestrictedResidual(record) => record.len(),
414 }
415 }
416
417 fn into_pairs(self) -> Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>> {
424 match self {
425 Context::Value(record) => Box::new(
426 Arc::unwrap_or_clone(record)
427 .into_iter()
428 .map(|(k, v)| (k, RestrictedExpr::from(v))),
429 ),
430 Context::RestrictedResidual(record) => Box::new(
431 Arc::unwrap_or_clone(record)
432 .into_iter()
433 .map(|(k, v)| (k, RestrictedExpr::new_unchecked(v))),
436 ),
437 }
438 }
439
440 pub fn substitute(self, mapping: &HashMap<SmolStr, Value>) -> Result<Self, EvaluationError> {
444 match self {
445 Context::RestrictedResidual(residual_context) => {
446 let expr = Expr::record_arc(residual_context).substitute(mapping);
451 let expr = BorrowedRestrictedExpr::new_unchecked(&expr);
452
453 let extns = Extensions::all_available();
454 let eval = RestrictedEvaluator::new(extns);
455 let partial_value = eval.partial_interpret(expr)?;
456
457 #[allow(clippy::expect_used)]
465 Ok(
466 Self::from_restricted_partial_val_unchecked(partial_value).expect(
467 "`from_restricted_partial_val_unchecked` should succeed when called on a record.",
468 ),
469 )
470 }
471 Context::Value(_) => Ok(self),
472 }
473 }
474}
475
476mod iter {
478 use super::*;
479
480 pub struct IntoIter(pub(super) Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>>);
482
483 impl std::fmt::Debug for IntoIter {
484 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
485 write!(f, "IntoIter(<context>)")
486 }
487 }
488
489 impl Iterator for IntoIter {
490 type Item = (SmolStr, RestrictedExpr);
491
492 fn next(&mut self) -> Option<Self::Item> {
493 self.0.next()
494 }
495 }
496}
497
498impl IntoIterator for Context {
499 type Item = (SmolStr, RestrictedExpr);
500 type IntoIter = iter::IntoIter;
501
502 fn into_iter(self) -> Self::IntoIter {
503 iter::IntoIter(self.into_pairs())
504 }
505}
506
507impl From<Context> for RestrictedExpr {
508 fn from(value: Context) -> Self {
509 match value {
510 Context::Value(attrs) => Value::record_arc(attrs, None).into(),
511 Context::RestrictedResidual(attrs) => {
512 RestrictedExpr::new_unchecked(Expr::record_arc(attrs))
516 }
517 }
518 }
519}
520
521impl From<Context> for PartialValue {
522 fn from(ctx: Context) -> PartialValue {
523 match ctx {
524 Context::Value(attrs) => Value::record_arc(attrs, None).into(),
525 Context::RestrictedResidual(attrs) => {
526 PartialValue::Residual(Expr::record_arc(attrs))
531 }
532 }
533 }
534}
535
536impl std::default::Default for Context {
537 fn default() -> Context {
538 Context::empty()
539 }
540}
541
542impl std::fmt::Display for Context {
543 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
544 write!(f, "{}", PartialValue::from(self.clone()))
545 }
546}
547
548impl BoundedDisplay for Context {
549 fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
550 BoundedDisplay::fmt(&PartialValue::from(self.clone()), f, n)
551 }
552}
553
554#[derive(Debug, Diagnostic, Error)]
556pub enum ContextCreationError {
557 #[error(transparent)]
559 #[diagnostic(transparent)]
560 NotARecord(#[from] context_creation_errors::NotARecord),
561 #[error(transparent)]
563 #[diagnostic(transparent)]
564 Evaluation(#[from] EvaluationError),
565 #[error(transparent)]
568 #[diagnostic(transparent)]
569 ExpressionConstruction(#[from] ExpressionConstructionError),
570}
571
572impl ContextCreationError {
573 pub(crate) fn not_a_record(expr: Expr) -> Self {
574 Self::NotARecord(context_creation_errors::NotARecord {
575 expr: Box::new(expr),
576 })
577 }
578}
579
580pub mod context_creation_errors {
582 use super::Expr;
583 use crate::impl_diagnostic_from_method_on_field;
584 use miette::Diagnostic;
585 use thiserror::Error;
586
587 #[derive(Debug, Error)]
593 #[error("expression is not a record: {expr}")]
594 pub struct NotARecord {
595 pub(super) expr: Box<Expr>,
597 }
598
599 impl Diagnostic for NotARecord {
601 impl_diagnostic_from_method_on_field!(expr, source_loc);
602 }
603}
604
605pub trait RequestSchema {
607 type Error: miette::Diagnostic;
609 fn validate_request(
611 &self,
612 request: &Request,
613 extensions: &Extensions<'_>,
614 ) -> Result<(), Self::Error>;
615
616 fn validate_context<'a>(
618 &self,
619 context: &Context,
620 action: &EntityUID,
621 extensions: &Extensions<'a>,
622 ) -> std::result::Result<(), Self::Error>;
623
624 fn validate_scope_variables(
626 &self,
627 principal: Option<&EntityUID>,
628 action: Option<&EntityUID>,
629 resource: Option<&EntityUID>,
630 ) -> std::result::Result<(), Self::Error>;
631}
632
633#[derive(Debug, Clone)]
635pub struct RequestSchemaAllPass;
636impl RequestSchema for RequestSchemaAllPass {
637 type Error = Infallible;
638 fn validate_request(
639 &self,
640 _request: &Request,
641 _extensions: &Extensions<'_>,
642 ) -> Result<(), Self::Error> {
643 Ok(())
644 }
645
646 fn validate_context<'a>(
647 &self,
648 _context: &Context,
649 _action: &EntityUID,
650 _extensions: &Extensions<'a>,
651 ) -> std::result::Result<(), Self::Error> {
652 Ok(())
653 }
654
655 fn validate_scope_variables(
656 &self,
657 _principal: Option<&EntityUID>,
658 _action: Option<&EntityUID>,
659 _resource: Option<&EntityUID>,
660 ) -> std::result::Result<(), Self::Error> {
661 Ok(())
662 }
663}
664
665#[derive(Debug, Diagnostic, Error)]
668#[error(transparent)]
669pub struct Infallible(pub std::convert::Infallible);
670
671#[cfg(test)]
672mod test {
673 use super::*;
674 use cool_asserts::assert_matches;
675
676 #[test]
677 fn test_json_from_str_non_record() {
678 assert_matches!(
679 Context::from_expr(RestrictedExpr::val("1").as_borrowed(), Extensions::none()),
680 Err(ContextCreationError::NotARecord { .. })
681 );
682 assert_matches!(
683 Context::from_json_str("1"),
684 Err(ContextJsonDeserializationError::ContextCreation(
685 ContextCreationError::NotARecord { .. }
686 ))
687 );
688 }
689}