1use crate::entities::json::{
18 ContextJsonDeserializationError, ContextJsonParser, NullContextSchema,
19};
20use crate::evaluator::{EvaluationError, RestrictedEvaluator};
21use crate::extensions::Extensions;
22use crate::parser::Loc;
23use miette::Diagnostic;
24use smol_str::{SmolStr, ToSmolStr};
25use std::collections::{BTreeMap, HashMap};
26use std::sync::Arc;
27use thiserror::Error;
28
29use super::{
30 BorrowedRestrictedExpr, BoundedDisplay, EntityType, EntityUID, Expr, ExprKind,
31 ExpressionConstructionError, PartialValue, RestrictedExpr, Unknown, Value, ValueKind, Var,
32};
33
34#[derive(Debug, Clone)]
36pub struct Request {
37 pub(crate) principal: EntityUIDEntry,
39
40 pub(crate) action: EntityUIDEntry,
42
43 pub(crate) resource: EntityUIDEntry,
45
46 pub(crate) context: Option<Context>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
53#[cfg_attr(
54 feature = "entity-manifest",
55 derive(serde::Serialize, serde::Deserialize)
56)]
57pub struct RequestType {
58 pub principal: EntityType,
60 pub action: EntityUID,
62 pub resource: EntityType,
64}
65
66#[derive(Debug, Clone)]
70pub enum EntityUIDEntry {
71 Known {
73 euid: Arc<EntityUID>,
75 loc: Option<Loc>,
77 },
78 Unknown {
80 ty: Option<EntityType>,
82
83 loc: Option<Loc>,
85 },
86}
87
88impl From<EntityUID> for EntityUIDEntry {
89 fn from(euid: EntityUID) -> Self {
90 Self::Known {
91 euid: Arc::new(euid.clone()),
92 loc: match &euid {
93 EntityUID::EntityUID(euid) => euid.loc(),
94 #[cfg(feature = "tolerant-ast")]
95 EntityUID::Error => None,
96 },
97 }
98 }
99}
100
101impl EntityUIDEntry {
102 pub fn evaluate(&self, var: Var) -> PartialValue {
106 match self {
107 EntityUIDEntry::Known { euid, loc } => {
108 Value::new(Arc::unwrap_or_clone(Arc::clone(euid)), loc.clone()).into()
109 }
110 EntityUIDEntry::Unknown { ty: None, loc } => {
111 Expr::unknown(Unknown::new_untyped(var.to_smolstr()))
112 .with_maybe_source_loc(loc.clone())
113 .into()
114 }
115 EntityUIDEntry::Unknown {
116 ty: Some(known_type),
117 loc,
118 } => Expr::unknown(Unknown::new_with_type(
119 var.to_smolstr(),
120 super::Type::Entity {
121 ty: known_type.clone(),
122 },
123 ))
124 .with_maybe_source_loc(loc.clone())
125 .into(),
126 }
127 }
128
129 pub fn known(euid: EntityUID, loc: Option<Loc>) -> Self {
131 Self::Known {
132 euid: Arc::new(euid),
133 loc,
134 }
135 }
136
137 pub fn unknown() -> Self {
139 Self::Unknown {
140 ty: None,
141 loc: None,
142 }
143 }
144
145 pub fn unknown_with_type(ty: EntityType, loc: Option<Loc>) -> Self {
147 Self::Unknown { ty: Some(ty), loc }
148 }
149
150 pub fn uid(&self) -> Option<&EntityUID> {
152 match self {
153 Self::Known { euid, .. } => Some(euid),
154 Self::Unknown { .. } => None,
155 }
156 }
157
158 pub fn get_type(&self) -> Option<&EntityType> {
160 match self {
161 Self::Known { euid, .. } => Some(euid.entity_type()),
162 Self::Unknown { ty, .. } => ty.as_ref(),
163 }
164 }
165}
166
167impl Request {
168 pub fn new<S: RequestSchema>(
173 principal: (EntityUID, Option<Loc>),
174 action: (EntityUID, Option<Loc>),
175 resource: (EntityUID, Option<Loc>),
176 context: Context,
177 schema: Option<&S>,
178 extensions: &Extensions<'_>,
179 ) -> Result<Self, S::Error> {
180 let req = Self {
181 principal: EntityUIDEntry::known(principal.0, principal.1),
182 action: EntityUIDEntry::known(action.0, action.1),
183 resource: EntityUIDEntry::known(resource.0, resource.1),
184 context: Some(context),
185 };
186 if let Some(schema) = schema {
187 schema.validate_request(&req, extensions)?;
188 }
189 Ok(req)
190 }
191
192 pub fn new_with_unknowns<S: RequestSchema>(
198 principal: EntityUIDEntry,
199 action: EntityUIDEntry,
200 resource: EntityUIDEntry,
201 context: Option<Context>,
202 schema: Option<&S>,
203 extensions: &Extensions<'_>,
204 ) -> Result<Self, S::Error> {
205 let req = Self {
206 principal,
207 action,
208 resource,
209 context,
210 };
211 if let Some(schema) = schema {
212 schema.validate_request(&req, extensions)?;
213 }
214 Ok(req)
215 }
216
217 pub fn new_unchecked(
220 principal: EntityUIDEntry,
221 action: EntityUIDEntry,
222 resource: EntityUIDEntry,
223 context: Option<Context>,
224 ) -> Self {
225 Self {
226 principal,
227 action,
228 resource,
229 context,
230 }
231 }
232
233 pub fn principal(&self) -> &EntityUIDEntry {
235 &self.principal
236 }
237
238 pub fn action(&self) -> &EntityUIDEntry {
240 &self.action
241 }
242
243 pub fn resource(&self) -> &EntityUIDEntry {
245 &self.resource
246 }
247
248 pub fn context(&self) -> Option<&Context> {
251 self.context.as_ref()
252 }
253
254 pub fn to_request_type(&self) -> Option<RequestType> {
260 Some(RequestType {
261 principal: self.principal().uid()?.entity_type().clone(),
262 action: self.action().uid()?.clone(),
263 resource: self.resource().uid()?.entity_type().clone(),
264 })
265 }
266}
267
268impl std::fmt::Display for Request {
269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 let display_euid = |maybe_euid: &EntityUIDEntry| match maybe_euid {
271 EntityUIDEntry::Known { euid, .. } => format!("{euid}"),
272 EntityUIDEntry::Unknown { ty: None, .. } => "unknown".to_string(),
273 EntityUIDEntry::Unknown {
274 ty: Some(known_type),
275 ..
276 } => format!("unknown of type {known_type}"),
277 };
278 write!(
279 f,
280 "request with principal {}, action {}, resource {}, and context {}",
281 display_euid(&self.principal),
282 display_euid(&self.action),
283 display_euid(&self.resource),
284 match &self.context {
285 Some(x) => format!("{x}"),
286 None => "unknown".to_string(),
287 }
288 )
289 }
290}
291
292#[derive(Debug, Clone, PartialEq, Eq)]
294pub enum Context {
295 Value(Arc<BTreeMap<SmolStr, Value>>),
297 RestrictedResidual(Arc<BTreeMap<SmolStr, Expr>>),
302}
303
304impl Context {
305 pub fn empty() -> Self {
307 Self::Value(Arc::new(BTreeMap::new()))
308 }
309
310 fn from_restricted_partial_val_unchecked(
316 value: PartialValue,
317 ) -> Result<Self, ContextCreationError> {
318 match value {
319 PartialValue::Value(v) => {
320 if let ValueKind::Record(attrs) = v.value {
321 Ok(Context::Value(attrs))
322 } else {
323 Err(ContextCreationError::not_a_record(v.into()))
324 }
325 }
326 PartialValue::Residual(e) => {
327 if let ExprKind::Record(attrs) = e.expr_kind() {
328 Ok(Context::RestrictedResidual(attrs.clone()))
335 } else {
336 Err(ContextCreationError::not_a_record(e))
337 }
338 }
339 }
340 }
341
342 pub fn from_expr(
347 expr: BorrowedRestrictedExpr<'_>,
348 extensions: &Extensions<'_>,
349 ) -> Result<Self, ContextCreationError> {
350 match expr.expr_kind() {
351 ExprKind::Record { .. } => {
352 let evaluator = RestrictedEvaluator::new(extensions);
353 let pval = evaluator.partial_interpret(expr)?;
354 #[allow(clippy::expect_used)]
362 Ok(Self::from_restricted_partial_val_unchecked(pval).expect(
363 "`from_restricted_partial_val_unchecked` should succeed when called on a record.",
364 ))
365 }
366 _ => Err(ContextCreationError::not_a_record(expr.to_owned().into())),
367 }
368 }
369
370 pub fn from_pairs(
376 pairs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
377 extensions: &Extensions<'_>,
378 ) -> Result<Self, ContextCreationError> {
379 match RestrictedExpr::record(pairs) {
380 Ok(record) => Self::from_expr(record.as_borrowed(), extensions),
381 Err(ExpressionConstructionError::DuplicateKey(err)) => Err(
382 ExpressionConstructionError::DuplicateKey(err.with_context("in context")).into(),
383 ),
384 }
385 }
386
387 pub fn from_json_str(json: &str) -> Result<Self, ContextJsonDeserializationError> {
394 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
395 .from_json_str(json)
396 }
397
398 pub fn from_json_value(
405 json: serde_json::Value,
406 ) -> Result<Self, ContextJsonDeserializationError> {
407 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
408 .from_json_value(json)
409 }
410
411 pub fn from_json_file(
418 json: impl std::io::Read,
419 ) -> Result<Self, ContextJsonDeserializationError> {
420 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
421 .from_json_file(json)
422 }
423
424 pub fn num_keys(&self) -> usize {
426 match self {
427 Context::Value(record) => record.len(),
428 Context::RestrictedResidual(record) => record.len(),
429 }
430 }
431
432 fn into_pairs(self) -> Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>> {
439 match self {
440 Context::Value(record) => Box::new(
441 Arc::unwrap_or_clone(record)
442 .into_iter()
443 .map(|(k, v)| (k, RestrictedExpr::from(v))),
444 ),
445 Context::RestrictedResidual(record) => Box::new(
446 Arc::unwrap_or_clone(record)
447 .into_iter()
448 .map(|(k, v)| (k, RestrictedExpr::new_unchecked(v))),
451 ),
452 }
453 }
454
455 pub fn substitute(self, mapping: &HashMap<SmolStr, Value>) -> Result<Self, EvaluationError> {
459 match self {
460 Context::RestrictedResidual(residual_context) => {
461 let expr = Expr::record_arc(residual_context).substitute(mapping);
466 let expr = BorrowedRestrictedExpr::new_unchecked(&expr);
467
468 let extns = Extensions::all_available();
469 let eval = RestrictedEvaluator::new(extns);
470 let partial_value = eval.partial_interpret(expr)?;
471
472 #[allow(clippy::expect_used)]
480 Ok(
481 Self::from_restricted_partial_val_unchecked(partial_value).expect(
482 "`from_restricted_partial_val_unchecked` should succeed when called on a record.",
483 ),
484 )
485 }
486 Context::Value(_) => Ok(self),
487 }
488 }
489}
490
491mod iter {
493 use super::*;
494
495 pub struct IntoIter(pub(super) Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>>);
497
498 impl std::fmt::Debug for IntoIter {
499 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
500 write!(f, "IntoIter(<context>)")
501 }
502 }
503
504 impl Iterator for IntoIter {
505 type Item = (SmolStr, RestrictedExpr);
506
507 fn next(&mut self) -> Option<Self::Item> {
508 self.0.next()
509 }
510 }
511}
512
513impl IntoIterator for Context {
514 type Item = (SmolStr, RestrictedExpr);
515 type IntoIter = iter::IntoIter;
516
517 fn into_iter(self) -> Self::IntoIter {
518 iter::IntoIter(self.into_pairs())
519 }
520}
521
522impl From<Context> for RestrictedExpr {
523 fn from(value: Context) -> Self {
524 match value {
525 Context::Value(attrs) => Value::record_arc(attrs, None).into(),
526 Context::RestrictedResidual(attrs) => {
527 RestrictedExpr::new_unchecked(Expr::record_arc(attrs))
531 }
532 }
533 }
534}
535
536impl From<Context> for PartialValue {
537 fn from(ctx: Context) -> PartialValue {
538 match ctx {
539 Context::Value(attrs) => Value::record_arc(attrs, None).into(),
540 Context::RestrictedResidual(attrs) => {
541 PartialValue::Residual(Expr::record_arc(attrs))
546 }
547 }
548 }
549}
550
551impl std::default::Default for Context {
552 fn default() -> Context {
553 Context::empty()
554 }
555}
556
557impl std::fmt::Display for Context {
558 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
559 write!(f, "{}", PartialValue::from(self.clone()))
560 }
561}
562
563impl BoundedDisplay for Context {
564 fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
565 BoundedDisplay::fmt(&PartialValue::from(self.clone()), f, n)
566 }
567}
568
569#[derive(Debug, Diagnostic, Error)]
571pub enum ContextCreationError {
572 #[error(transparent)]
574 #[diagnostic(transparent)]
575 NotARecord(#[from] context_creation_errors::NotARecord),
576 #[error(transparent)]
578 #[diagnostic(transparent)]
579 Evaluation(#[from] EvaluationError),
580 #[error(transparent)]
583 #[diagnostic(transparent)]
584 ExpressionConstruction(#[from] ExpressionConstructionError),
585}
586
587impl ContextCreationError {
588 pub(crate) fn not_a_record(expr: Expr) -> Self {
589 Self::NotARecord(context_creation_errors::NotARecord {
590 expr: Box::new(expr),
591 })
592 }
593}
594
595pub mod context_creation_errors {
597 use super::Expr;
598 use crate::impl_diagnostic_from_method_on_field;
599 use miette::Diagnostic;
600 use thiserror::Error;
601
602 #[derive(Debug, Error)]
608 #[error("expression is not a record: {expr}")]
609 pub struct NotARecord {
610 pub(super) expr: Box<Expr>,
612 }
613
614 impl Diagnostic for NotARecord {
616 impl_diagnostic_from_method_on_field!(expr, source_loc);
617 }
618}
619
620pub trait RequestSchema {
622 type Error: miette::Diagnostic;
624 fn validate_request(
626 &self,
627 request: &Request,
628 extensions: &Extensions<'_>,
629 ) -> Result<(), Self::Error>;
630
631 fn validate_context<'a>(
633 &self,
634 context: &Context,
635 action: &EntityUID,
636 extensions: &Extensions<'a>,
637 ) -> std::result::Result<(), Self::Error>;
638
639 fn validate_scope_variables(
641 &self,
642 principal: Option<&EntityUID>,
643 action: Option<&EntityUID>,
644 resource: Option<&EntityUID>,
645 ) -> std::result::Result<(), Self::Error>;
646}
647
648#[derive(Debug, Clone)]
650pub struct RequestSchemaAllPass;
651impl RequestSchema for RequestSchemaAllPass {
652 type Error = Infallible;
653 fn validate_request(
654 &self,
655 _request: &Request,
656 _extensions: &Extensions<'_>,
657 ) -> Result<(), Self::Error> {
658 Ok(())
659 }
660
661 fn validate_context<'a>(
662 &self,
663 _context: &Context,
664 _action: &EntityUID,
665 _extensions: &Extensions<'a>,
666 ) -> std::result::Result<(), Self::Error> {
667 Ok(())
668 }
669
670 fn validate_scope_variables(
671 &self,
672 _principal: Option<&EntityUID>,
673 _action: Option<&EntityUID>,
674 _resource: Option<&EntityUID>,
675 ) -> std::result::Result<(), Self::Error> {
676 Ok(())
677 }
678}
679
680#[derive(Debug, Diagnostic, Error)]
683#[error(transparent)]
684pub struct Infallible(pub std::convert::Infallible);
685
686#[cfg(test)]
687mod test {
688 use super::*;
689 use cool_asserts::assert_matches;
690
691 #[test]
692 fn test_json_from_str_non_record() {
693 assert_matches!(
694 Context::from_expr(RestrictedExpr::val("1").as_borrowed(), Extensions::none()),
695 Err(ContextCreationError::NotARecord { .. })
696 );
697 assert_matches!(
698 Context::from_json_str("1"),
699 Err(ContextJsonDeserializationError::ContextCreation(
700 ContextCreationError::NotARecord { .. }
701 ))
702 );
703 }
704}