1use crate::entities::json::{
18 err::JsonSerializationError, ContextJsonDeserializationError, ContextJsonParser,
19 NullContextSchema,
20};
21use crate::entities::CedarValueJson;
22use crate::evaluator::{EvaluationError, RestrictedEvaluator};
23use crate::extensions::Extensions;
24use crate::parser::Loc;
25use miette::Diagnostic;
26use smol_str::{SmolStr, ToSmolStr};
27use std::collections::{BTreeMap, HashMap};
28use std::sync::Arc;
29use thiserror::Error;
30
31use super::{
32 BorrowedRestrictedExpr, BoundedDisplay, EntityType, EntityUID, Expr, ExprKind,
33 ExpressionConstructionError, PartialValue, RestrictedExpr, Unknown, Value, ValueKind, Var,
34};
35
36#[derive(Debug, Clone)]
38pub struct Request {
39 pub(crate) principal: EntityUIDEntry,
41
42 pub(crate) action: EntityUIDEntry,
44
45 pub(crate) resource: EntityUIDEntry,
47
48 pub(crate) context: Option<Context>,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Hash)]
55#[cfg_attr(
56 feature = "entity-manifest",
57 derive(serde::Serialize, serde::Deserialize)
58)]
59pub struct RequestType {
60 pub principal: EntityType,
62 pub action: EntityUID,
64 pub resource: EntityType,
66}
67
68#[derive(Debug, Clone)]
72pub enum EntityUIDEntry {
73 Known {
75 euid: Arc<EntityUID>,
77 loc: Option<Loc>,
79 },
80 Unknown {
82 ty: Option<EntityType>,
84
85 loc: Option<Loc>,
87 },
88}
89
90impl From<EntityUID> for EntityUIDEntry {
91 fn from(euid: EntityUID) -> Self {
92 Self::Known {
93 euid: Arc::new(euid.clone()),
94 loc: match &euid {
95 EntityUID::EntityUID(euid) => euid.loc(),
96 #[cfg(feature = "tolerant-ast")]
97 EntityUID::Error => None,
98 },
99 }
100 }
101}
102
103impl EntityUIDEntry {
104 pub fn evaluate(&self, var: Var) -> PartialValue {
108 match self {
109 EntityUIDEntry::Known { euid, loc } => {
110 Value::new(Arc::unwrap_or_clone(Arc::clone(euid)), loc.clone()).into()
111 }
112 EntityUIDEntry::Unknown { ty: None, loc } => {
113 Expr::unknown(Unknown::new_untyped(var.to_smolstr()))
114 .with_maybe_source_loc(loc.clone())
115 .into()
116 }
117 EntityUIDEntry::Unknown {
118 ty: Some(known_type),
119 loc,
120 } => Expr::unknown(Unknown::new_with_type(
121 var.to_smolstr(),
122 super::Type::Entity {
123 ty: known_type.clone(),
124 },
125 ))
126 .with_maybe_source_loc(loc.clone())
127 .into(),
128 }
129 }
130
131 pub fn known(euid: EntityUID, loc: Option<Loc>) -> Self {
133 Self::Known {
134 euid: Arc::new(euid),
135 loc,
136 }
137 }
138
139 pub fn unknown() -> Self {
141 Self::Unknown {
142 ty: None,
143 loc: None,
144 }
145 }
146
147 pub fn unknown_with_type(ty: EntityType, loc: Option<Loc>) -> Self {
149 Self::Unknown { ty: Some(ty), loc }
150 }
151
152 pub fn uid(&self) -> Option<&EntityUID> {
154 match self {
155 Self::Known { euid, .. } => Some(euid),
156 Self::Unknown { .. } => None,
157 }
158 }
159
160 pub fn get_type(&self) -> Option<&EntityType> {
162 match self {
163 Self::Known { euid, .. } => Some(euid.entity_type()),
164 Self::Unknown { ty, .. } => ty.as_ref(),
165 }
166 }
167}
168
169impl Request {
170 pub fn new<S: RequestSchema>(
175 principal: (EntityUID, Option<Loc>),
176 action: (EntityUID, Option<Loc>),
177 resource: (EntityUID, Option<Loc>),
178 context: Context,
179 schema: Option<&S>,
180 extensions: &Extensions<'_>,
181 ) -> Result<Self, S::Error> {
182 let req = Self {
183 principal: EntityUIDEntry::known(principal.0, principal.1),
184 action: EntityUIDEntry::known(action.0, action.1),
185 resource: EntityUIDEntry::known(resource.0, resource.1),
186 context: Some(context),
187 };
188 if let Some(schema) = schema {
189 schema.validate_request(&req, extensions)?;
190 }
191 Ok(req)
192 }
193
194 pub fn new_with_unknowns<S: RequestSchema>(
200 principal: EntityUIDEntry,
201 action: EntityUIDEntry,
202 resource: EntityUIDEntry,
203 context: Option<Context>,
204 schema: Option<&S>,
205 extensions: &Extensions<'_>,
206 ) -> Result<Self, S::Error> {
207 let req = Self {
208 principal,
209 action,
210 resource,
211 context,
212 };
213 if let Some(schema) = schema {
214 schema.validate_request(&req, extensions)?;
215 }
216 Ok(req)
217 }
218
219 pub fn new_unchecked(
222 principal: EntityUIDEntry,
223 action: EntityUIDEntry,
224 resource: EntityUIDEntry,
225 context: Option<Context>,
226 ) -> Self {
227 Self {
228 principal,
229 action,
230 resource,
231 context,
232 }
233 }
234
235 pub fn principal(&self) -> &EntityUIDEntry {
237 &self.principal
238 }
239
240 pub fn action(&self) -> &EntityUIDEntry {
242 &self.action
243 }
244
245 pub fn resource(&self) -> &EntityUIDEntry {
247 &self.resource
248 }
249
250 pub fn context(&self) -> Option<&Context> {
253 self.context.as_ref()
254 }
255
256 pub fn to_request_type(&self) -> Option<RequestType> {
262 Some(RequestType {
263 principal: self.principal().uid()?.entity_type().clone(),
264 action: self.action().uid()?.clone(),
265 resource: self.resource().uid()?.entity_type().clone(),
266 })
267 }
268}
269
270impl std::fmt::Display for Request {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 let display_euid = |maybe_euid: &EntityUIDEntry| match maybe_euid {
273 EntityUIDEntry::Known { euid, .. } => format!("{euid}"),
274 EntityUIDEntry::Unknown { ty: None, .. } => "unknown".to_string(),
275 EntityUIDEntry::Unknown {
276 ty: Some(known_type),
277 ..
278 } => format!("unknown of type {known_type}"),
279 };
280 write!(
281 f,
282 "request with principal {}, action {}, resource {}, and context {}",
283 display_euid(&self.principal),
284 display_euid(&self.action),
285 display_euid(&self.resource),
286 match &self.context {
287 Some(x) => format!("{x}"),
288 None => "unknown".to_string(),
289 }
290 )
291 }
292}
293
294#[derive(Debug, Clone, PartialEq, Eq)]
296pub enum Context {
297 Value(Arc<BTreeMap<SmolStr, Value>>),
299 RestrictedResidual(Arc<BTreeMap<SmolStr, Expr>>),
304}
305
306impl Context {
307 pub fn empty() -> Self {
309 Self::Value(Arc::new(BTreeMap::new()))
310 }
311
312 fn from_restricted_partial_val_unchecked(
318 value: PartialValue,
319 ) -> Result<Self, ContextCreationError> {
320 match value {
321 PartialValue::Value(v) => {
322 if let ValueKind::Record(attrs) = v.value {
323 Ok(Context::Value(attrs))
324 } else {
325 Err(ContextCreationError::not_a_record(v.into()))
326 }
327 }
328 PartialValue::Residual(e) => {
329 if let ExprKind::Record(attrs) = e.expr_kind() {
330 Ok(Context::RestrictedResidual(attrs.clone()))
337 } else {
338 Err(ContextCreationError::not_a_record(e))
339 }
340 }
341 }
342 }
343
344 pub fn from_expr(
349 expr: BorrowedRestrictedExpr<'_>,
350 extensions: &Extensions<'_>,
351 ) -> Result<Self, ContextCreationError> {
352 match expr.expr_kind() {
353 ExprKind::Record { .. } => {
354 let evaluator = RestrictedEvaluator::new(extensions);
355 let pval = evaluator.partial_interpret(expr)?;
356 #[expect(clippy::expect_used, reason = "See above")]
363 Ok(Self::from_restricted_partial_val_unchecked(pval).expect(
364 "`from_restricted_partial_val_unchecked` should succeed when called on a record.",
365 ))
366 }
367 _ => Err(ContextCreationError::not_a_record(expr.to_owned().into())),
368 }
369 }
370
371 pub fn from_pairs(
377 pairs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
378 extensions: &Extensions<'_>,
379 ) -> Result<Self, ContextCreationError> {
380 match RestrictedExpr::record(pairs) {
381 Ok(record) => Self::from_expr(record.as_borrowed(), extensions),
382 Err(ExpressionConstructionError::DuplicateKey(err)) => Err(
383 ExpressionConstructionError::DuplicateKey(err.with_context("in context")).into(),
384 ),
385 }
386 }
387
388 pub fn from_json_str(json: &str) -> Result<Self, ContextJsonDeserializationError> {
395 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
396 .from_json_str(json)
397 }
398
399 pub fn from_json_value(
406 json: serde_json::Value,
407 ) -> Result<Self, ContextJsonDeserializationError> {
408 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
409 .from_json_value(json)
410 }
411
412 pub fn from_json_file(
419 json: impl std::io::Read,
420 ) -> Result<Self, ContextJsonDeserializationError> {
421 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
422 .from_json_file(json)
423 }
424
425 pub fn to_json_value(&self) -> Result<serde_json::Value, JsonSerializationError> {
427 match self {
428 Self::Value(record) => record
429 .iter()
430 .map(|(k, v)| {
431 let cjson = CedarValueJson::from_value(v.clone())?;
432 Ok((k.to_string(), serde_json::to_value(cjson)?))
433 })
434 .collect(),
435 Self::RestrictedResidual(record) => record
436 .iter()
437 .map(|(k, v)| {
438 let cjson =
440 CedarValueJson::from_expr(BorrowedRestrictedExpr::new_unchecked(v))?;
441 Ok((k.to_string(), serde_json::to_value(cjson)?))
442 })
443 .collect(),
444 }
445 }
446
447 pub fn num_keys(&self) -> usize {
449 match self {
450 Context::Value(record) => record.len(),
451 Context::RestrictedResidual(record) => record.len(),
452 }
453 }
454
455 fn into_pairs(self) -> Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>> {
462 match self {
463 Context::Value(record) => Box::new(
464 Arc::unwrap_or_clone(record)
465 .into_iter()
466 .map(|(k, v)| (k, RestrictedExpr::from(v))),
467 ),
468 Context::RestrictedResidual(record) => Box::new(
469 Arc::unwrap_or_clone(record)
470 .into_iter()
471 .map(|(k, v)| (k, RestrictedExpr::new_unchecked(v))),
474 ),
475 }
476 }
477
478 pub fn substitute(self, mapping: &HashMap<SmolStr, Value>) -> Result<Self, EvaluationError> {
482 match self {
483 Context::RestrictedResidual(residual_context) => {
484 let expr = Expr::record_arc(residual_context).substitute(mapping);
489 let expr = BorrowedRestrictedExpr::new_unchecked(&expr);
490
491 let extns = Extensions::all_available();
492 let eval = RestrictedEvaluator::new(extns);
493 let partial_value = eval.partial_interpret(expr)?;
494
495 #[expect(clippy::expect_used, reason = "See above")]
502 Ok(
503 Self::from_restricted_partial_val_unchecked(partial_value).expect(
504 "`from_restricted_partial_val_unchecked` should succeed when called on a record.",
505 ),
506 )
507 }
508 Context::Value(_) => Ok(self),
509 }
510 }
511}
512
513mod iter {
515 use super::*;
516
517 pub struct IntoIter(pub(super) Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>>);
519
520 impl std::fmt::Debug for IntoIter {
521 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
522 write!(f, "IntoIter(<context>)")
523 }
524 }
525
526 impl Iterator for IntoIter {
527 type Item = (SmolStr, RestrictedExpr);
528
529 fn next(&mut self) -> Option<Self::Item> {
530 self.0.next()
531 }
532 }
533}
534
535impl IntoIterator for Context {
536 type Item = (SmolStr, RestrictedExpr);
537 type IntoIter = iter::IntoIter;
538
539 fn into_iter(self) -> Self::IntoIter {
540 iter::IntoIter(self.into_pairs())
541 }
542}
543
544impl From<Context> for RestrictedExpr {
545 fn from(value: Context) -> Self {
546 match value {
547 Context::Value(attrs) => Value::record_arc(attrs, None).into(),
548 Context::RestrictedResidual(attrs) => {
549 RestrictedExpr::new_unchecked(Expr::record_arc(attrs))
553 }
554 }
555 }
556}
557
558impl From<Context> for PartialValue {
559 fn from(ctx: Context) -> PartialValue {
560 match ctx {
561 Context::Value(attrs) => Value::record_arc(attrs, None).into(),
562 Context::RestrictedResidual(attrs) => {
563 PartialValue::Residual(Expr::record_arc(attrs))
568 }
569 }
570 }
571}
572
573impl std::default::Default for Context {
574 fn default() -> Context {
575 Context::empty()
576 }
577}
578
579impl std::fmt::Display for Context {
580 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581 write!(f, "{}", PartialValue::from(self.clone()))
582 }
583}
584
585impl BoundedDisplay for Context {
586 fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
587 BoundedDisplay::fmt(&PartialValue::from(self.clone()), f, n)
588 }
589}
590
591#[derive(Debug, Diagnostic, Error)]
593pub enum ContextCreationError {
594 #[error(transparent)]
596 #[diagnostic(transparent)]
597 NotARecord(#[from] context_creation_errors::NotARecord),
598 #[error(transparent)]
600 #[diagnostic(transparent)]
601 Evaluation(#[from] EvaluationError),
602 #[error(transparent)]
605 #[diagnostic(transparent)]
606 ExpressionConstruction(#[from] ExpressionConstructionError),
607}
608
609impl ContextCreationError {
610 pub(crate) fn not_a_record(expr: Expr) -> Self {
611 Self::NotARecord(context_creation_errors::NotARecord {
612 expr: Box::new(expr),
613 })
614 }
615}
616
617pub mod context_creation_errors {
619 use super::Expr;
620 use crate::impl_diagnostic_from_method_on_field;
621 use miette::Diagnostic;
622 use thiserror::Error;
623
624 #[derive(Debug, Error)]
630 #[error("expression is not a record: {expr}")]
631 pub struct NotARecord {
632 pub(super) expr: Box<Expr>,
634 }
635
636 impl Diagnostic for NotARecord {
638 impl_diagnostic_from_method_on_field!(expr, source_loc);
639 }
640}
641
642pub trait RequestSchema {
644 type Error: miette::Diagnostic;
646 fn validate_request(
648 &self,
649 request: &Request,
650 extensions: &Extensions<'_>,
651 ) -> Result<(), Self::Error>;
652
653 fn validate_context<'a>(
655 &self,
656 context: &Context,
657 action: &EntityUID,
658 extensions: &Extensions<'a>,
659 ) -> std::result::Result<(), Self::Error>;
660
661 fn validate_scope_variables(
663 &self,
664 principal: Option<&EntityUID>,
665 action: Option<&EntityUID>,
666 resource: Option<&EntityUID>,
667 ) -> std::result::Result<(), Self::Error>;
668}
669
670#[derive(Debug, Clone)]
672pub struct RequestSchemaAllPass;
673impl RequestSchema for RequestSchemaAllPass {
674 type Error = Infallible;
675 fn validate_request(
676 &self,
677 _request: &Request,
678 _extensions: &Extensions<'_>,
679 ) -> Result<(), Self::Error> {
680 Ok(())
681 }
682
683 fn validate_context<'a>(
684 &self,
685 _context: &Context,
686 _action: &EntityUID,
687 _extensions: &Extensions<'a>,
688 ) -> std::result::Result<(), Self::Error> {
689 Ok(())
690 }
691
692 fn validate_scope_variables(
693 &self,
694 _principal: Option<&EntityUID>,
695 _action: Option<&EntityUID>,
696 _resource: Option<&EntityUID>,
697 ) -> std::result::Result<(), Self::Error> {
698 Ok(())
699 }
700}
701
702#[derive(Debug, Diagnostic, Error)]
705#[error(transparent)]
706pub struct Infallible(pub std::convert::Infallible);
707
708#[cfg(test)]
709mod test {
710 use super::super::Name;
711 use super::*;
712 use cool_asserts::assert_matches;
713 use std::str::FromStr;
714
715 #[track_caller]
716 fn roundtrip_json(context: &Context) -> Context {
717 Context::from_json_value(context.to_json_value().unwrap()).unwrap()
718 }
719
720 #[test]
721 fn test_json_from_str_non_record() {
722 assert_matches!(
723 Context::from_expr(RestrictedExpr::val("1").as_borrowed(), Extensions::none()),
724 Err(ContextCreationError::NotARecord { .. })
725 );
726 assert_matches!(
727 Context::from_json_str("1"),
728 Err(ContextJsonDeserializationError::ContextCreation(
729 ContextCreationError::NotARecord { .. }
730 ))
731 );
732 }
733
734 #[test]
735 fn test_roundtrip_empty() {
736 let context = Context::empty();
737 assert_eq!(context, roundtrip_json(&context));
738 }
739
740 #[test]
741 fn test_roundtrip_complex() {
742 let context = Context::from_pairs(
743 [
744 ("b".into(), RestrictedExpr::val(false)),
745 ("i".into(), RestrictedExpr::val(32)),
746 (
747 "s".into(),
748 RestrictedExpr::val("hi I have spaces and \" special ch@ract&rs: !{} \""),
749 ),
750 (
751 "uid".into(),
752 RestrictedExpr::val(EntityUID::from_str("Group::\"admins\"").unwrap()),
753 ),
754 (
755 "multi".into(),
756 RestrictedExpr::set([
757 RestrictedExpr::val(0),
758 RestrictedExpr::val(22),
759 RestrictedExpr::val(-310),
760 ]),
761 ),
762 (
763 "record".into(),
764 RestrictedExpr::record([
765 ("inner".into(), RestrictedExpr::val(-210)),
766 (
767 "inner_uid".into(),
768 RestrictedExpr::val(EntityUID::from_str("Group::\"interns\"").unwrap()),
769 ),
770 (
771 "inner_set".into(),
772 RestrictedExpr::set([
773 RestrictedExpr::val("my name is"),
774 RestrictedExpr::val("inigo montoya"),
775 ]),
776 ),
777 ])
778 .unwrap(),
779 ),
780 (
781 "dec".into(),
782 RestrictedExpr::call_extension_fn(
783 Name::parse_unqualified_name("decimal").unwrap(),
784 [RestrictedExpr::val("-1.111")],
785 ),
786 ),
787 (
788 "ipv6".into(),
789 RestrictedExpr::call_extension_fn(
790 Name::parse_unqualified_name("ip").unwrap(),
791 [RestrictedExpr::val("ffff::1/16")],
792 ),
793 ),
794 (
795 "dt".into(),
796 RestrictedExpr::call_extension_fn(
797 Name::parse_unqualified_name("datetime").unwrap(),
798 [RestrictedExpr::val("2026-01-01T03:04:05Z")],
799 ),
800 ),
801 ],
802 &Extensions::all_available(),
803 )
804 .unwrap();
805 assert_eq!(context, roundtrip_json(&context));
806 }
807}