1use super::constraints::{ActionConstraint, PrincipalConstraint, ResourceConstraint};
23use super::expr::{EntityUID, Expr, SlotId};
24use crate::ast;
25use crate::pst::err::error_body::{ContainsSlotError, InvalidExpressionError, LinkingError};
26use crate::pst::PstConstructionError;
27use smol_str::SmolStr;
28use std::collections::{BTreeMap, HashMap, HashSet};
29use std::fmt::Display;
30use std::sync::Arc;
31
32#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)]
34pub struct PolicyID(pub SmolStr);
35
36impl From<PolicyID> for ast::PolicyID {
37 fn from(id: PolicyID) -> Self {
38 ast::PolicyID::from_smolstr(id.0)
39 }
40}
41
42impl From<ast::PolicyID> for PolicyID {
43 fn from(id: ast::PolicyID) -> Self {
44 Self(id.into_smolstr())
45 }
46}
47
48impl From<&str> for PolicyID {
49 fn from(s: &str) -> Self {
50 Self(s.into())
51 }
52}
53
54impl Display for PolicyID {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
67pub enum Effect {
68 Permit,
70 Forbid,
72}
73
74impl std::fmt::Display for Effect {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 match self {
77 Effect::Permit => write!(f, "permit"),
78 Effect::Forbid => write!(f, "forbid"),
79 }
80 }
81}
82
83#[derive(Debug, Clone, PartialEq, Eq)]
91pub enum Clause {
92 When(Arc<Expr>),
94 Unless(Arc<Expr>),
96}
97
98impl std::fmt::Display for Clause {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 match self {
101 Clause::When(expr) => write!(f, "when {{ {} }}", expr),
102 Clause::Unless(expr) => write!(f, "unless {{ {} }}", expr),
103 }
104 }
105}
106
107#[derive(Debug, Clone, PartialEq, Eq)]
169pub struct Template {
170 pub id: PolicyID,
172 pub effect: Effect,
174 pub principal: PrincipalConstraint,
176 pub action: ActionConstraint,
178 pub resource: ResourceConstraint,
180 pub(crate) clauses: Vec<Clause>,
182 pub annotations: BTreeMap<String, SmolStr>,
184 _private: (),
185}
186
187fn validate_clause(clause: Clause) -> Result<Clause, PstConstructionError> {
189 match &clause {
190 Clause::When(e) | Clause::Unless(e) => {
191 if e.has_slots() {
192 return Err(ContainsSlotError { slots: e.slots() }.into());
193 }
194 if e.has_unknowns() {
195 return Err(InvalidExpressionError::new(
196 "clause contains an `Unknown`".to_string(),
197 )
198 .into());
199 }
200 Ok(clause)
201 }
202 }
203}
204
205impl Template {
206 pub fn new(
209 id: impl Into<PolicyID>,
210 effect: Effect,
211 principal: PrincipalConstraint,
212 action: ActionConstraint,
213 resource: ResourceConstraint,
214 ) -> Self {
215 Self {
216 id: id.into(),
217 effect,
218 principal,
219 action,
220 resource,
221 clauses: vec![],
222 annotations: BTreeMap::new(),
223 _private: (),
224 }
225 }
226
227 pub fn clauses(&self) -> &Vec<Clause> {
229 &self.clauses
230 }
231
232 pub fn into_clauses(self) -> Vec<Clause> {
234 self.clauses
235 }
236
237 pub fn try_with_clauses(
239 self,
240 clauses: impl IntoIterator<Item = Clause>,
241 ) -> Result<Self, PstConstructionError> {
242 let clauses: Vec<Clause> = clauses
243 .into_iter()
244 .map(validate_clause)
245 .collect::<Result<_, PstConstructionError>>()?;
246 Ok(Self { clauses, ..self })
247 }
248
249 pub fn try_add_clause(&mut self, clause: Clause) -> Result<(), PstConstructionError> {
251 self.clauses.push(validate_clause(clause)?);
252 Ok(())
253 }
254
255 pub fn with_annotations(self, annotations: BTreeMap<String, SmolStr>) -> Self {
257 Self {
258 annotations,
259 ..self
260 }
261 }
262
263 pub fn with_id(self, id: PolicyID) -> Self {
265 Self { id, ..self }
266 }
267
268 pub fn link(
272 self,
273 vals: &HashMap<SlotId, EntityUID>,
274 ) -> Result<StaticPolicy, PstConstructionError> {
275 Ok(StaticPolicy::try_from(Template {
276 id: self.id,
277 effect: self.effect,
278 principal: self.principal.link(vals)?,
279 action: self.action.link(vals)?,
280 resource: self.resource.link(vals)?,
281 clauses: self.clauses,
282 annotations: self.annotations,
283 _private: (),
284 })?)
285 }
286
287 pub fn slots(&self) -> HashSet<SlotId> {
289 let mut slots = HashSet::new();
290 slots.extend(self.principal.slot());
291 slots.extend(self.action.slot());
292 slots.extend(self.resource.slot());
293 slots
294 }
295
296 pub fn is_static(&self) -> bool {
298 !(self.principal.has_slot() || self.resource.has_slot() || self.action.has_slot())
300 }
301}
302
303impl std::fmt::Display for Template {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 let est_res: Result<crate::est::Policy, PstConstructionError> = self.clone().try_into();
306 match est_res {
307 Ok(est) => write!(f, "{est}"),
308 Err(e) => write!(f, "<invalid policy: {e}>"),
309 }
310 }
311}
312
313#[derive(Debug, Clone, PartialEq, Eq)]
318#[non_exhaustive]
319pub struct StaticPolicy {
320 pub(crate) body: Template,
322}
323
324impl StaticPolicy {
325 pub fn id(&self) -> &PolicyID {
327 &self.body.id
328 }
329
330 pub fn body(&self) -> &Template {
332 &self.body
333 }
334}
335
336impl TryFrom<Template> for StaticPolicy {
337 type Error = ContainsSlotError;
338 fn try_from(body: Template) -> Result<Self, Self::Error> {
339 if body.principal.has_slot() || body.resource.has_slot() || body.action.has_slot() {
342 Err(ContainsSlotError {
343 slots: body.slots(),
344 })
345 } else {
346 Ok(StaticPolicy { body })
347 }
348 }
349}
350
351#[derive(Debug, Clone, PartialEq, Eq)]
356#[non_exhaustive]
357pub struct LinkedPolicy {
358 pub(crate) body: Arc<Template>,
360 pub(crate) values: HashMap<SlotId, EntityUID>,
362 pub(crate) instance_id: PolicyID,
364}
365
366impl LinkedPolicy {
367 pub fn new(
370 template: Arc<Template>,
371 values: HashMap<SlotId, EntityUID>,
372 instance_id: PolicyID,
373 ) -> Result<Self, LinkingError> {
374 for slot in template.slots() {
375 if !values.contains_key(&slot) {
376 return Err(LinkingError::MissedSlot { slot });
377 }
378 }
379 Ok(Self {
380 body: template,
381 values,
382 instance_id,
383 })
384 }
385
386 pub fn into_static_policy(&self) -> Result<StaticPolicy, PstConstructionError> {
391 let mut policy = self.body.as_ref().clone().link(&self.values)?;
392 policy.body.id = self.instance_id.clone();
393 Ok(policy)
394 }
395
396 pub fn id(&self) -> &PolicyID {
398 &self.instance_id
399 }
400
401 pub fn body(&self) -> &Template {
403 &self.body
404 }
405
406 pub fn values(&self) -> &HashMap<SlotId, EntityUID> {
408 &self.values
409 }
410}
411
412impl From<StaticPolicy> for Policy {
413 fn from(p: StaticPolicy) -> Self {
414 Policy::Static(p)
415 }
416}
417
418impl From<LinkedPolicy> for Policy {
419 fn from(p: LinkedPolicy) -> Self {
420 Policy::Linked(p)
421 }
422}
423
424#[derive(Debug, Clone, PartialEq, Eq)]
428pub enum Policy {
429 Static(StaticPolicy),
431 Linked(LinkedPolicy),
433}
434
435impl Policy {
436 pub fn body(&self) -> &Template {
438 match self {
439 Policy::Static(p) => p.body(),
440 Policy::Linked(p) => p.body(),
441 }
442 }
443
444 pub fn new_id(&self, id: PolicyID) -> Self {
446 match self {
447 Policy::Static(sp) => {
448 let mut body = sp.body.clone();
449 body.id = id;
450 Policy::Static(StaticPolicy { body })
451 }
452 Policy::Linked(lp) => Policy::Linked(LinkedPolicy {
453 body: lp.body.clone(),
454 values: lp.values.clone(),
455 instance_id: id,
456 }),
457 }
458 }
459}
460
461impl std::fmt::Display for StaticPolicy {
462 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
463 write!(f, "{}", self.body)
464 }
465}
466
467impl std::fmt::Display for LinkedPolicy {
468 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469 write!(f, "{}", self.body)
470 }
471}
472
473impl std::fmt::Display for Policy {
474 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
475 write!(f, "{}", self.body())
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use cool_asserts::assert_matches;
482 use smol_str::ToSmolStr;
483
484 use super::*;
485 use crate::pst::expr::Literal;
486
487 #[test]
488 fn test_with_annotations() {
489 let mut annotations = BTreeMap::new();
490 annotations.insert("author".to_string(), "alice".to_smolstr());
491 annotations.insert("version".to_string(), "1.0".to_smolstr());
492 let template = Template::new(
493 "p",
494 Effect::Permit,
495 PrincipalConstraint::Any,
496 ActionConstraint::Any,
497 ResourceConstraint::Any,
498 )
499 .with_annotations(annotations.clone());
500 assert_eq!(template.annotations, annotations);
501 }
502
503 #[test]
504 fn test_policy_id_conversion() {
505 let pst_id = PolicyID(SmolStr::from("test_policy"));
506 let ast_id: ast::PolicyID = pst_id.into();
507 assert_eq!(ast_id.to_string(), "test_policy");
508 }
509
510 fn make_uid(ty: &str, id: &str) -> EntityUID {
511 EntityUID {
512 ty: crate::pst::EntityType::from_name(crate::pst::Name::unqualified(ty).unwrap()),
513 eid: SmolStr::from(id),
514 }
515 }
516
517 #[test]
518 fn test_policy_link_replaces_all_slots() {
519 use crate::pst::constraints::*;
520 use crate::pst::expr::SlotId;
521
522 let mut template = Template::new(
524 "t1",
525 Effect::Permit,
526 PrincipalConstraint::Eq(EntityOrSlot::Slot(SlotId::Principal)),
527 ActionConstraint::Eq(make_uid("Action", "view")),
528 ResourceConstraint::In(EntityOrSlot::Slot(SlotId::Resource)),
529 );
530 assert!(matches!(
531 template
532 .clone()
533 .try_add_clause(Clause::When(Arc::new(Expr::Slot(SlotId::Principal)))),
534 Err(PstConstructionError::ContainsSlots(..))
535 ));
536
537 template
539 .try_add_clause(Clause::When(Arc::new(Expr::Literal(Literal::Bool(true)))))
540 .unwrap();
541
542 let mut vals = HashMap::new();
543 vals.insert(SlotId::Principal, make_uid("User", "alice"));
544 vals.insert(SlotId::Resource, make_uid("Album", "vacation"));
545
546 let linked = template.link(&vals).unwrap();
547
548 assert_eq!(
549 linked.body.principal,
550 PrincipalConstraint::Eq(EntityOrSlot::Entity(make_uid("User", "alice")))
551 );
552 assert_eq!(
553 linked.body.resource,
554 ResourceConstraint::In(EntityOrSlot::Entity(make_uid("Album", "vacation")))
555 );
556 assert_eq!(
557 linked.body.clauses,
558 vec![Clause::When(Arc::new(Expr::Literal(Literal::Bool(true))))]
559 );
560 }
561
562 #[test]
563 fn test_policy_link_or_new_linked_policy_missing_slot_errors() {
564 use crate::pst::constraints::*;
565 use crate::pst::expr::SlotId;
566
567 let template = Template::new(
568 "t2",
569 Effect::Forbid,
570 PrincipalConstraint::Eq(EntityOrSlot::Slot(SlotId::Principal)),
571 ActionConstraint::Any,
572 ResourceConstraint::Any,
573 );
574
575 let result = template.clone().link(&HashMap::new());
576 assert!(matches!(
577 result,
578 Err(PstConstructionError::LinkingFailed(..))
579 ));
580 let new_policy = LinkedPolicy::new(Arc::new(template), HashMap::new(), "test0".into());
581 assert!(matches!(
582 new_policy,
583 Err(LinkingError::MissedSlot {
584 slot: SlotId::Principal
585 })
586 ));
587 }
588
589 #[test]
590 fn test_static_policy() {
591 let mut template = Template::new(
592 "my_policy",
593 Effect::Permit,
594 PrincipalConstraint::Any,
595 ActionConstraint::Any,
596 ResourceConstraint::Any,
597 );
598 template
599 .try_add_clause(Clause::When(Arc::new(Expr::Literal(Literal::Bool(true)))))
600 .unwrap();
601 let original = template.clone();
602 let static_policy = StaticPolicy::try_from(template).unwrap();
603 assert_eq!(static_policy.id().0.as_str(), "my_policy");
604 assert_eq!(static_policy.body, original);
605 let _ = static_policy.to_string();
606 }
607
608 #[test]
609 fn test_effect_and_clause_display() {
610 assert_eq!(Effect::Permit.to_string(), "permit");
611 assert_eq!(Effect::Forbid.to_string(), "forbid");
612 assert_eq!(
613 Clause::When(Arc::new(Expr::Literal(Literal::Bool(true)))).to_string(),
614 "when { true }"
615 );
616 assert_eq!(
617 Clause::Unless(Arc::new(Expr::Literal(Literal::Bool(false)))).to_string(),
618 "unless { false }"
619 );
620 }
621
622 #[test]
623 fn test_template_methods() {
624 use crate::pst::constraints::*;
625 use crate::pst::expr::SlotId;
626
627 let clause = Clause::When(Arc::new(Expr::Literal(Literal::Bool(true))));
628 let mut template = Template::new(
629 "p",
630 Effect::Permit,
631 PrincipalConstraint::Any,
632 ActionConstraint::Any,
633 ResourceConstraint::Any,
634 );
635 template.try_add_clause(clause.clone()).unwrap();
636
637 assert_eq!(template.clauses(), &vec![clause.clone()]);
638 assert!(template.is_static());
639 assert!(template.slots().is_empty());
640 let s = template.to_string();
641 assert!(s.contains("permit") && s.contains("when"));
642 assert_eq!(template.into_clauses(), vec![clause]);
643
644 let slotted = Template::new(
645 "t",
646 Effect::Permit,
647 PrincipalConstraint::Eq(EntityOrSlot::Slot(SlotId::Principal)),
648 ActionConstraint::Any,
649 ResourceConstraint::Any,
650 );
651 assert!(!slotted.is_static());
652 assert!(slotted.slots().contains(&SlotId::Principal));
653 }
654
655 #[test]
656 fn test_slot_error_paths() {
657 use crate::pst::constraints::*;
658 use crate::pst::expr::SlotId;
659
660 let template = Template::new(
661 "t",
662 Effect::Permit,
663 PrincipalConstraint::Eq(EntityOrSlot::Slot(SlotId::Principal)),
664 ActionConstraint::Any,
665 ResourceConstraint::Any,
666 );
667 assert!(matches!(
668 template
669 .clone()
670 .try_with_clauses(vec![Clause::When(Arc::new(Expr::Slot(SlotId::Principal)))]),
671 Err(PstConstructionError::ContainsSlots(..))
672 ));
673 assert!(StaticPolicy::try_from(template).is_err());
674 }
675
676 #[test]
677 fn test_unknown_rejected_in_clauses() {
678 let unknown = Arc::new(Expr::Unknown {
679 name: SmolStr::from("x"),
680 });
681
682 let template = Template::new(
683 "p",
684 Effect::Permit,
685 PrincipalConstraint::Any,
686 ActionConstraint::Any,
687 ResourceConstraint::Any,
688 );
689
690 let err = template
692 .clone()
693 .try_with_clauses(vec![Clause::When(unknown.clone())])
694 .unwrap_err();
695 assert!(
696 matches!(err,
697 PstConstructionError::InvalidExpression(ref e)
698 if e.to_string().contains("clause contains an `Unknown`")),
699 "expected InvalidExpression mentioning unknown, got: {err}"
700 );
701
702 let mut t2 = template.clone();
704 let err = t2
705 .try_add_clause(Clause::When(unknown.clone()))
706 .unwrap_err();
707 assert!(
708 matches!(err,
709 PstConstructionError::InvalidExpression(ref e)
710 if e.to_string().contains("clause contains an `Unknown`")),
711 "expected InvalidExpression mentioning unknown, got: {err}"
712 );
713
714 let nested = Arc::new(Expr::BinaryOp {
716 op: crate::pst::BinaryOp::And,
717 left: Arc::new(Expr::Literal(Literal::Bool(true))),
718 right: unknown,
719 });
720 let err = template
721 .clone()
722 .try_with_clauses(vec![Clause::Unless(nested.clone())])
723 .unwrap_err();
724 assert!(
725 matches!(err, PstConstructionError::InvalidExpression(ref e)
726 if e.to_string().contains("clause contains an `Unknown`")),
727 "expected nested unknown to be caught, got: {err}"
728 );
729
730 let ok_clause = Clause::When(Arc::new(Expr::Literal(Literal::Bool(true))));
732 assert!(template.try_with_clauses(vec![ok_clause]).is_ok());
733 }
734
735 #[test]
736 fn test_linked_policy() {
737 use crate::pst::constraints::*;
738 use crate::pst::expr::SlotId;
739
740 let mut vals = HashMap::new();
741 vals.insert(SlotId::Principal, make_uid("User", "alice"));
742 let linked = LinkedPolicy {
743 body: Arc::new(Template::new(
744 "tmpl",
745 Effect::Permit,
746 PrincipalConstraint::Eq(EntityOrSlot::Slot(SlotId::Principal)),
747 ActionConstraint::Any,
748 ResourceConstraint::Any,
749 )),
750 values: vals,
751 instance_id: PolicyID("link1".into()),
752 };
753 assert_eq!(linked.id().0.as_str(), "link1");
754 let _ = linked.to_string();
755 let static_policy = linked.into_static_policy().unwrap();
756 assert_eq!(static_policy.id().0.as_str(), "link1");
757
758 let static_p = Policy::Static(
760 StaticPolicy::try_from(Template::new(
761 "p",
762 Effect::Permit,
763 PrincipalConstraint::Any,
764 ActionConstraint::Any,
765 ResourceConstraint::Any,
766 ))
767 .unwrap(),
768 );
769 assert_matches!(
770 static_p.body(),
771 Template {
772 effect: Effect::Permit,
773 action: ActionConstraint::Any,
774 clauses: v,
775 ..
776 } if v.is_empty()
777 );
778 let _ = static_p.to_string();
779
780 let linked_p = Policy::Linked(LinkedPolicy {
781 body: Arc::new(Template::new(
782 "tmpl2",
783 Effect::Forbid,
784 PrincipalConstraint::Eq(EntityOrSlot::Slot(SlotId::Principal)),
785 ActionConstraint::Any,
786 ResourceConstraint::Any,
787 )),
788 values: {
789 let mut m = HashMap::new();
790 m.insert(SlotId::Principal, make_uid("User", "bob"));
791 m
792 },
793 instance_id: PolicyID("link2".into()),
794 });
795 assert_matches!(
796 linked_p.body(),
797 Template {
798 effect: Effect::Forbid,
799 action: ActionConstraint::Any,
800 clauses: v,
801 ..
802 } if v.is_empty()
803 );
804 match &linked_p {
805 Policy::Linked(lp) => assert_eq!(lp.values().len(), 1),
806 _ => (),
807 };
808 let _ = linked_p.to_string();
809 }
810
811 #[test]
812 fn test_new_id_static() {
813 let policy = Policy::Static(
814 StaticPolicy::try_from(Template::new(
815 "old",
816 Effect::Permit,
817 PrincipalConstraint::Any,
818 ActionConstraint::Any,
819 ResourceConstraint::Any,
820 ))
821 .unwrap(),
822 );
823 let renamed = policy.new_id("new".into());
824 match &renamed {
825 Policy::Static(sp) => assert_eq!(sp.id().0.as_str(), "new"),
826 Policy::Linked(_) => panic!("expected Static"),
827 }
828 }
829
830 #[test]
831 fn test_new_id_linked() {
832 use crate::pst::constraints::*;
833 use crate::pst::expr::SlotId;
834
835 let template = Arc::new(Template::new(
836 "tmpl",
837 Effect::Permit,
838 PrincipalConstraint::Eq(EntityOrSlot::Slot(SlotId::Principal)),
839 ActionConstraint::Any,
840 ResourceConstraint::Any,
841 ));
842 let policy = Policy::Linked(
843 LinkedPolicy::new(
844 template.clone(),
845 HashMap::from([(SlotId::Principal, make_uid("User", "alice"))]),
846 "old_link".into(),
847 )
848 .unwrap(),
849 );
850 let renamed = policy.new_id("new_link".into());
851 match &renamed {
852 Policy::Linked(lp) => {
853 assert_eq!(lp.id().0.as_str(), "new_link");
854 assert_eq!(lp.body.id.0.as_str(), "tmpl");
856 }
857 Policy::Static(_) => panic!("expected Linked"),
858 }
859 }
860}