cedar_policy_core/ast/
policy_set.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::{
18    EntityUID, LinkingError, LiteralPolicy, Policy, PolicyID, ReificationError, SlotId,
19    StaticPolicy, Template,
20};
21use itertools::Itertools;
22use linked_hash_map::{Entry, LinkedHashMap};
23use linked_hash_set::LinkedHashSet;
24
25use miette::Diagnostic;
26use smol_str::format_smolstr;
27use std::{borrow::Borrow, collections::HashMap, sync::Arc};
28use thiserror::Error;
29
30/// Represents a set of `Policy`s
31#[derive(Debug, Default, Clone, PartialEq, Eq)]
32pub struct PolicySet {
33    /// `templates` contains all bodies of policies in the `PolicySet`.
34    /// A body is either:
35    /// - A Body of a `Template`, which has slots that need to be filled in
36    /// - A Body of a `StaticPolicy`, which has been converted into a `Template` that has zero slots.
37    ///   The static policy's [`PolicyID`] is the same in both `templates` and `links`.
38    templates: LinkedHashMap<PolicyID, Arc<Template>>,
39    /// `links` contains all of the executable policies in the `PolicySet`
40    /// A `StaticPolicy` must have exactly one `Policy` in `links`
41    ///   (this is managed by `PolicySet::add`)
42    ///   The static policy's PolicyID is the same in both `templates` and `links`
43    /// A `Template` may have zero or many links
44    links: LinkedHashMap<PolicyID, Policy>,
45
46    /// Map from a template `PolicyID` to the set of `PolicyID`s in `links` that are linked to that template.
47    /// There is a key `t` iff `templates` contains the key `t`. The value of `t` will be a (possibly empty)
48    /// set of every `p` in `links` s.t. `p.template().id() == t`.
49    template_to_links_map: LinkedHashMap<PolicyID, LinkedHashSet<PolicyID>>,
50}
51
52/// A Policy Set that contains less rich information than `PolicySet`.
53///
54/// In particular, this form is easier to convert to/from the Protobuf
55/// representation of a `PolicySet`, because policies are represented as
56/// `LiteralPolicy` instead of `Policy`.
57#[derive(Debug)]
58pub struct LiteralPolicySet {
59    /// Like the `templates` field of `PolicySet`
60    templates: LinkedHashMap<PolicyID, Template>,
61    /// Like the `links` field of `PolicySet`, but maps to `LiteralPolicy` only.
62    /// The same invariants apply: e.g., a `StaticPolicy` must have exactly one `Policy` in `links`.
63    links: LinkedHashMap<PolicyID, LiteralPolicy>,
64}
65
66impl LiteralPolicySet {
67    /// Create a new `LiteralPolicySet`. Caller is responsible for ensuring the
68    /// invariants on `LiteralPolicySet`.
69    pub fn new(
70        templates: impl IntoIterator<Item = (PolicyID, Template)>,
71        links: impl IntoIterator<Item = (PolicyID, LiteralPolicy)>,
72    ) -> Self {
73        Self {
74            templates: templates.into_iter().collect(),
75            links: links.into_iter().collect(),
76        }
77    }
78
79    /// Iterate over the `Template`s in the `LiteralPolicySet`. This will
80    /// include both templates and static policies (represented as templates
81    /// with zero slots)
82    pub fn templates(&self) -> impl Iterator<Item = &Template> {
83        self.templates.values()
84    }
85
86    /// Iterate over the `LiteralPolicy`s in the `LiteralPolicySet`. This will
87    /// include both static and template-linked policies.
88    pub fn policies(&self) -> impl Iterator<Item = &LiteralPolicy> {
89        self.links.values()
90    }
91}
92
93/// Converts a LiteralPolicySet into a PolicySet, ensuring the invariants are met
94/// Every `Policy` must point to a `Template` that exists in the set.
95impl TryFrom<LiteralPolicySet> for PolicySet {
96    type Error = ReificationError;
97    fn try_from(pset: LiteralPolicySet) -> Result<Self, Self::Error> {
98        // Allocate the templates into Arc's
99        let templates = pset
100            .templates
101            .into_iter()
102            .map(|(id, template)| (id, Arc::new(template)))
103            .collect::<LinkedHashMap<PolicyID, Arc<Template>>>();
104        let links = pset
105            .links
106            .into_iter()
107            .map(|(id, literal)| literal.reify(&templates).map(|linked| (id, linked)))
108            .collect::<Result<LinkedHashMap<PolicyID, Policy>, ReificationError>>()?;
109
110        let mut template_to_links_map = LinkedHashMap::new();
111        for template in &templates {
112            template_to_links_map.insert(template.0.clone(), LinkedHashSet::new());
113        }
114        for (link_id, link) in &links {
115            let template = link.template().id();
116            match template_to_links_map.entry(template.clone()) {
117                Entry::Occupied(t) => t.into_mut().insert(link_id.clone()),
118                Entry::Vacant(_) => return Err(ReificationError::NoSuchTemplate(template.clone())),
119            };
120        }
121
122        Ok(Self {
123            templates,
124            links,
125            template_to_links_map,
126        })
127    }
128}
129
130impl From<PolicySet> for LiteralPolicySet {
131    fn from(pset: PolicySet) -> Self {
132        let templates = pset
133            .templates
134            .into_iter()
135            .map(|(id, template)| (id, template.as_ref().clone()))
136            .collect();
137        let links = pset
138            .links
139            .into_iter()
140            .map(|(id, p)| (id, p.into()))
141            .collect();
142        Self { templates, links }
143    }
144}
145
146/// Potential errors when working with `PolicySet`s.
147#[derive(Debug, Diagnostic, Error)]
148pub enum PolicySetError {
149    /// There was a duplicate [`PolicyID`] encountered in either the set of
150    /// templates or the set of policies.
151    #[error("duplicate template or policy id `{id}`")]
152    Occupied {
153        /// [`PolicyID`] that was duplicate
154        id: PolicyID,
155    },
156}
157
158/// Potential errors when working with `PolicySet`s.
159#[derive(Debug, Diagnostic, Error)]
160pub enum PolicySetGetLinksError {
161    /// There was no [`PolicyID`] in the set of templates.
162    #[error("No template `{0}`")]
163    MissingTemplate(PolicyID),
164}
165
166/// Potential errors when unlinking from a `PolicySet`.
167#[derive(Debug, Diagnostic, Error)]
168pub enum PolicySetUnlinkError {
169    /// There was no [`PolicyID`] linked policy to unlink
170    #[error("unable to unlink policy id `{0}` because it does not exist")]
171    UnlinkingError(PolicyID),
172    /// There was a template [`PolicyID`] in the list of templates, so `PolicyID` is a static policy
173    #[error("unable to remove link with policy id `{0}` because it is a static policy")]
174    NotLinkError(PolicyID),
175}
176
177/// Potential errors when removing templates from a `PolicySet`.
178#[derive(Debug, Diagnostic, Error)]
179pub enum PolicySetTemplateRemovalError {
180    /// There was no [`PolicyID`] template in the list of templates.
181    #[error("unable to remove template id `{0}` from template list because it does not exist")]
182    RemovePolicyNoTemplateError(PolicyID),
183    /// There are still active links to template [`PolicyID`].
184    #[error(
185        "unable to remove template id `{0}` from template list because it still has active links"
186    )]
187    RemoveTemplateWithLinksError(PolicyID),
188    /// There was a link [`PolicyID`] in the list of links, so `PolicyID` is a static policy
189    #[error("unable to remove template with policy id `{0}` because it is a static policy")]
190    NotTemplateError(PolicyID),
191}
192
193/// Potential errors when removing policies from a `PolicySet`.
194#[derive(Debug, Diagnostic, Error)]
195pub enum PolicySetPolicyRemovalError {
196    /// There was no link [`PolicyID`] in the list of links.
197    #[error("unable to remove static policy id `{0}` from link list because it does not exist")]
198    RemovePolicyNoLinkError(PolicyID),
199    /// There was no template [`PolicyID`] in the list of templates.
200    #[error(
201        "unable to remove static policy id `{0}` from template list because it does not exist"
202    )]
203    RemovePolicyNoTemplateError(PolicyID),
204}
205
206// The public interface of `PolicySet` is intentionally narrow, to allow us
207// maximum flexibility to change the underlying implementation in the future
208impl PolicySet {
209    /// Create a fresh empty `PolicySet`
210    pub fn new() -> Self {
211        Self {
212            templates: LinkedHashMap::new(),
213            links: LinkedHashMap::new(),
214            template_to_links_map: LinkedHashMap::new(),
215        }
216    }
217
218    /// Add a `Policy` to the `PolicySet`.
219    pub fn add(&mut self, policy: Policy) -> Result<(), PolicySetError> {
220        let t = policy.template_arc();
221
222        // we need to check for all possible errors before making any
223        // modifications to `self`.
224        // So we just collect the `ventry` here, and we only do the insertion
225        // once we know there will be no error
226        let template_ventry = match self.templates.entry(t.id().clone()) {
227            Entry::Vacant(ventry) => Some(ventry),
228            Entry::Occupied(oentry) => {
229                if oentry.get() != &t {
230                    return Err(PolicySetError::Occupied {
231                        id: oentry.key().clone(),
232                    });
233                }
234                None
235            }
236        };
237
238        let link_ventry = match self.links.entry(policy.id().clone()) {
239            Entry::Vacant(ventry) => Some(ventry),
240            Entry::Occupied(oentry) => {
241                return Err(PolicySetError::Occupied {
242                    id: oentry.key().clone(),
243                });
244            }
245        };
246
247        // if we get here, there will be no errors.  So actually do the
248        // insertions.
249        if let Some(ventry) = template_ventry {
250            self.template_to_links_map.insert(
251                t.id().clone(),
252                vec![policy.id().clone()]
253                    .into_iter()
254                    .collect::<LinkedHashSet<PolicyID>>(),
255            );
256            ventry.insert(t);
257        } else {
258            //`template_ventry` is None, so `templates` has `t` and we never use the `HashSet::new()`
259            self.template_to_links_map
260                .entry(t.id().clone())
261                .or_default()
262                .insert(policy.id().clone());
263        }
264        if let Some(ventry) = link_ventry {
265            ventry.insert(policy);
266        }
267
268        Ok(())
269    }
270
271    /// Helper function for `merge_policyset` to check if the `PolicyID` pid
272    /// appears in this `PolicySet`'s links or templates.
273    fn policy_id_is_bound(&self, pid: &PolicyID) -> bool {
274        self.templates.contains_key(pid) || self.links.contains_key(pid)
275    }
276
277    /// Helper function for `merge_policyset` to construct a renaming
278    /// that would resolve any conflicting `PolicyID`s. We use the type parameter `T`
279    /// to allow this code to be applied to both Templates and Policies.
280    fn update_renaming<T>(
281        &self,
282        this_contents: &LinkedHashMap<PolicyID, T>,
283        other: &Self,
284        other_contents: &LinkedHashMap<PolicyID, T>,
285        renaming: &mut LinkedHashMap<PolicyID, PolicyID>,
286        start_ind: &mut u32,
287    ) where
288        T: PartialEq + Clone,
289    {
290        for (pid, ot) in other_contents {
291            if let Some(tt) = this_contents.get(pid) {
292                if tt != ot {
293                    let mut new_pid =
294                        PolicyID::from_smolstr(format_smolstr!("policy{}", start_ind));
295                    *start_ind += 1;
296                    while self.policy_id_is_bound(&new_pid) || other.policy_id_is_bound(&new_pid) {
297                        new_pid = PolicyID::from_smolstr(format_smolstr!("policy{}", start_ind));
298                        *start_ind += 1;
299                    }
300                    renaming.insert(pid.clone(), new_pid);
301                }
302            }
303        }
304    }
305
306    /// Merges this `PolicySet` with another `PolicySet`.
307    /// This `PolicySet` is modified while the other `PolicySet`
308    /// remains unchanged.
309    ///
310    /// The flag `rename_duplicates` controls the expected behavior
311    /// when a `PolicyID` in this and the other `PolicySet` conflict.
312    ///
313    /// When `rename_duplicates` is false, conflicting `PolicyID`s result
314    /// in a occupied `PolicySetError`.
315    ///
316    /// Otherwise, when `rename_duplicates` is true, conflicting `PolicyID`s from
317    /// the other `PolicySet` are automatically renamed to avoid conflict.
318    /// This renaming is returned as a Hashmap from the old `PolicyID` to the
319    /// renamed `PolicyID`.
320    pub fn merge_policyset(
321        &mut self,
322        other: &PolicySet,
323        rename_duplicates: bool,
324    ) -> Result<LinkedHashMap<PolicyID, PolicyID>, PolicySetError> {
325        // Check for conflicting policy ids. If there is a conflict either
326        // throw an error or construct a renaming (if `rename_duplicates` is true)
327        let mut min_id = 0;
328        let mut renaming = LinkedHashMap::new();
329        self.update_renaming(
330            &self.templates,
331            other,
332            &other.templates,
333            &mut renaming,
334            &mut min_id,
335        );
336        self.update_renaming(&self.links, other, &other.links, &mut renaming, &mut min_id);
337        // If `rename_dupilicates` is false, then throw an error if any renaming should happen
338        if !rename_duplicates {
339            if let Some(pid) = renaming.keys().next() {
340                return Err(PolicySetError::Occupied { id: pid.clone() });
341            }
342        }
343        // either there are no conflicting policy ids
344        // or we should rename conflicting policy ids (using renaming) to avoid conflicting policy ids
345        for (pid, other_template) in &other.templates {
346            let pid = renaming.get(pid).unwrap_or(pid);
347            self.templates.insert(pid.clone(), other_template.clone());
348        }
349        for (pid, other_policy) in &other.links {
350            let pid = renaming.get(pid).unwrap_or(pid);
351            self.links.insert(pid.clone(), other_policy.clone());
352        }
353        for (tid, other_template_link_set) in &other.template_to_links_map {
354            let tid = renaming.get(tid).unwrap_or(tid);
355            let mut this_template_link_set =
356                self.template_to_links_map.remove(tid).unwrap_or_default();
357            for pid in other_template_link_set {
358                let pid = renaming.get(pid).unwrap_or(pid);
359                this_template_link_set.insert(pid.clone());
360            }
361            self.template_to_links_map
362                .insert(tid.clone(), this_template_link_set);
363        }
364        Ok(renaming)
365    }
366
367    /// Remove a static `Policy`` from the `PolicySet`.
368    pub fn remove_static(
369        &mut self,
370        policy_id: &PolicyID,
371    ) -> Result<Policy, PolicySetPolicyRemovalError> {
372        // Invariant: if `policy_id` is a key in both `self.links` and `self.templates`,
373        // then self.templates[policy_id] has exactly one link: self.links[policy_id]
374        let policy = match self.links.remove(policy_id) {
375            Some(p) => p,
376            None => {
377                return Err(PolicySetPolicyRemovalError::RemovePolicyNoLinkError(
378                    policy_id.clone(),
379                ))
380            }
381        };
382        //links mapped by `PolicyId`, so `policy` is unique
383        match self.templates.remove(policy_id) {
384            Some(_) => {
385                self.template_to_links_map.remove(policy_id);
386                Ok(policy)
387            }
388            None => {
389                //If we removed the link but failed to remove the template
390                //restore the link and return an error
391                self.links.insert(policy_id.clone(), policy);
392                Err(PolicySetPolicyRemovalError::RemovePolicyNoTemplateError(
393                    policy_id.clone(),
394                ))
395            }
396        }
397    }
398
399    /// Add a `StaticPolicy` to the `PolicySet`.
400    pub fn add_static(&mut self, policy: StaticPolicy) -> Result<(), PolicySetError> {
401        let (t, p) = Template::link_static_policy(policy);
402
403        match (
404            self.templates.entry(t.id().clone()),
405            self.links.entry(t.id().clone()),
406        ) {
407            (Entry::Vacant(templates_entry), Entry::Vacant(links_entry)) => {
408                self.template_to_links_map.insert(
409                    t.id().clone(),
410                    vec![p.id().clone()]
411                        .into_iter()
412                        .collect::<LinkedHashSet<PolicyID>>(),
413                );
414                templates_entry.insert(t);
415                links_entry.insert(p);
416                Ok(())
417            }
418            (Entry::Occupied(oentry), _) => Err(PolicySetError::Occupied {
419                id: oentry.key().clone(),
420            }),
421            (_, Entry::Occupied(oentry)) => Err(PolicySetError::Occupied {
422                id: oentry.key().clone(),
423            }),
424        }
425    }
426
427    /// Add a template to the policy set.
428    /// If a link, static policy or template with the same name already exists, this will error.
429    pub fn add_template(&mut self, t: Template) -> Result<(), PolicySetError> {
430        if self.links.contains_key(t.id()) {
431            return Err(PolicySetError::Occupied { id: t.id().clone() });
432        }
433
434        match self.templates.entry(t.id().clone()) {
435            Entry::Occupied(oentry) => Err(PolicySetError::Occupied {
436                id: oentry.key().clone(),
437            }),
438            Entry::Vacant(ventry) => {
439                self.template_to_links_map
440                    .insert(t.id().clone(), LinkedHashSet::new());
441                ventry.insert(Arc::new(t));
442                Ok(())
443            }
444        }
445    }
446
447    /// Remove a template from the policy set.
448    /// This will error if any policy is linked to the template.
449    /// This will error if `policy_id` is not a template.
450    pub fn remove_template(
451        &mut self,
452        policy_id: &PolicyID,
453    ) -> Result<Template, PolicySetTemplateRemovalError> {
454        //A template occurs in templates but not in links.
455        if self.links.contains_key(policy_id) {
456            return Err(PolicySetTemplateRemovalError::NotTemplateError(
457                policy_id.clone(),
458            ));
459        }
460
461        match self.template_to_links_map.get(policy_id) {
462            Some(map) => {
463                if !map.is_empty() {
464                    return Err(PolicySetTemplateRemovalError::RemoveTemplateWithLinksError(
465                        policy_id.clone(),
466                    ));
467                }
468            }
469            None => {
470                return Err(PolicySetTemplateRemovalError::RemovePolicyNoTemplateError(
471                    policy_id.clone(),
472                ))
473            }
474        };
475
476        // PANIC SAFETY: every linked policy should have a template
477        #[allow(clippy::panic)]
478        match self.templates.remove(policy_id) {
479            Some(t) => {
480                self.template_to_links_map.remove(policy_id);
481                Ok(Arc::unwrap_or_clone(t))
482            }
483            None => panic!("Found in template_to_links_map but not in templates"),
484        }
485    }
486
487    /// Get the list of policies linked to `template_id`.
488    /// Returns all p in `links` s.t. `p.template().id() == template_id`
489    pub fn get_linked_policies(
490        &self,
491        template_id: &PolicyID,
492    ) -> Result<impl Iterator<Item = &PolicyID>, PolicySetGetLinksError> {
493        match self.template_to_links_map.get(template_id) {
494            Some(s) => Ok(s.iter()),
495            None => Err(PolicySetGetLinksError::MissingTemplate(template_id.clone())),
496        }
497    }
498
499    /// Attempt to create a new template linked policy and add it to the policy
500    /// set. Returns a references to the new template linked policy if
501    /// successful.
502    ///
503    /// Errors for two reasons
504    ///   1) The the passed SlotEnv either does not match the slots in the templates
505    ///   2) The passed link Id conflicts with an Id already in the set
506    pub fn link(
507        &mut self,
508        template_id: PolicyID,
509        new_id: PolicyID,
510        values: HashMap<SlotId, EntityUID>,
511    ) -> Result<&Policy, LinkingError> {
512        let t =
513            self.get_template_arc(&template_id)
514                .ok_or_else(|| LinkingError::NoSuchTemplate {
515                    id: template_id.clone(),
516                })?;
517        let r = Template::link(t, new_id.clone(), values)?;
518
519        // Both maps must not contain the `new_id`
520        match (
521            self.links.entry(new_id.clone()),
522            self.templates.entry(new_id.clone()),
523        ) {
524            (Entry::Vacant(links_entry), Entry::Vacant(_)) => {
525                //We will never use the .or_default() because we just found `t` above
526                self.template_to_links_map
527                    .entry(template_id)
528                    .or_default()
529                    .insert(new_id);
530                Ok(links_entry.insert(r))
531            }
532            (Entry::Occupied(oentry), _) => Err(LinkingError::PolicyIdConflict {
533                id: oentry.key().clone(),
534            }),
535            (_, Entry::Occupied(oentry)) => Err(LinkingError::PolicyIdConflict {
536                id: oentry.key().clone(),
537            }),
538        }
539    }
540
541    /// Unlink `policy_id`
542    /// If it is not a link this will error
543    pub fn unlink(&mut self, policy_id: &PolicyID) -> Result<Policy, PolicySetUnlinkError> {
544        //A link occurs in links but not in templates.
545        if self.templates.contains_key(policy_id) {
546            return Err(PolicySetUnlinkError::NotLinkError(policy_id.clone()));
547        }
548        match self.links.remove(policy_id) {
549            Some(p) => {
550                // PANIC SAFETY: every linked policy should have a template
551                #[allow(clippy::panic)]
552                match self.template_to_links_map.entry(p.template().id().clone()) {
553                    Entry::Occupied(t) => t.into_mut().remove(policy_id),
554                    Entry::Vacant(_) => {
555                        panic!("No template found for linked policy")
556                    }
557                };
558                Ok(p)
559            }
560            None => Err(PolicySetUnlinkError::UnlinkingError(policy_id.clone())),
561        }
562    }
563
564    /// Iterate over all policies
565    pub fn policies(&self) -> impl Iterator<Item = &Policy> {
566        self.links.values()
567    }
568
569    /// Consume the `PolicySet`, producing an iterator of all the policies in it
570    pub fn into_policies(self) -> impl Iterator<Item = Policy> {
571        self.links.into_iter().map(|(_, p)| p)
572    }
573
574    /// Iterate over everything stored as template, including static policies.
575    /// Ie: all_templates() should equal templates() ++ static_policies().map(|p| p.template())
576    pub fn all_templates(&self) -> impl Iterator<Item = &Template> {
577        self.templates.values().map(|t| t.borrow())
578    }
579
580    /// Iterate over templates with slots
581    pub fn templates(&self) -> impl Iterator<Item = &Template> {
582        self.all_templates().filter(|t| t.slots().count() != 0)
583    }
584
585    /// Iterate over all of the static policies.
586    pub fn static_policies(&self) -> impl Iterator<Item = &Policy> {
587        self.policies().filter(|p| p.is_static())
588    }
589
590    /// Returns true iff the `PolicySet` is empty
591    pub fn is_empty(&self) -> bool {
592        self.templates.is_empty() && self.links.is_empty()
593    }
594
595    /// Lookup a template by policy id, returns [`Option<Arc<Template>>`]
596    pub fn get_template_arc(&self, id: &PolicyID) -> Option<Arc<Template>> {
597        self.templates.get(id).cloned()
598    }
599
600    /// Lookup a template by policy id, returns [`Option<&Template>`]
601    pub fn get_template(&self, id: &PolicyID) -> Option<&Template> {
602        self.templates.get(id).map(AsRef::as_ref)
603    }
604
605    /// Lookup an policy by policy id
606    pub fn get(&self, id: &PolicyID) -> Option<&Policy> {
607        self.links.get(id)
608    }
609
610    /// Attempt to collect an iterator over policies into a PolicySet
611    pub fn try_from_iter<T: IntoIterator<Item = Policy>>(iter: T) -> Result<Self, PolicySetError> {
612        let mut set = Self::new();
613        for p in iter {
614            set.add(p)?;
615        }
616        Ok(set)
617    }
618}
619
620impl std::fmt::Display for PolicySet {
621    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
622        // we don't show the ID, because the Display impl for Policy itself shows the ID
623        if self.is_empty() {
624            write!(f, "<empty policyset>")
625        } else {
626            write!(
627                f,
628                "Templates:\n{}, Template Linked Policies:\n{}",
629                self.all_templates().join("\n"),
630                self.policies().join("\n")
631            )
632        }
633    }
634}
635
636// PANIC SAFETY tests
637#[allow(clippy::panic)]
638// PANIC SAFETY tests
639#[allow(clippy::indexing_slicing)]
640#[cfg(test)]
641mod test {
642    use super::*;
643    use crate::{
644        ast::{
645            annotation::Annotations, ActionConstraint, Effect, PrincipalConstraint,
646            ResourceConstraint,
647        },
648        parser,
649    };
650
651    use std::collections::HashMap;
652
653    #[test]
654    fn link_conflicts() {
655        let mut pset = PolicySet::new();
656        let p1 = parser::parse_policy(
657            Some(PolicyID::from_string("id")),
658            "permit(principal,action,resource);",
659        )
660        .expect("Failed to parse");
661        pset.add_static(p1).expect("Failed to add!");
662        let template = parser::parse_policy_or_template(
663            Some(PolicyID::from_string("t")),
664            "permit(principal == ?principal, action, resource);",
665        )
666        .expect("Failed to parse");
667        pset.add_template(template).expect("Add failed");
668
669        let env: HashMap<SlotId, EntityUID> = HashMap::from([(
670            SlotId::principal(),
671            r#"Test::"test""#.parse().expect("Failed to parse"),
672        )]);
673
674        let r = pset.link(PolicyID::from_string("t"), PolicyID::from_string("id"), env);
675
676        match r {
677            Ok(_) => panic!("Should have failed due to conflict"),
678            Err(LinkingError::PolicyIdConflict { id }) => {
679                assert_eq!(id, PolicyID::from_string("id"))
680            }
681            Err(e) => panic!("Incorrect error: {e}"),
682        };
683    }
684
685    /// This test focuses on `PolicySet::add()`, while other tests mostly use
686    /// `PolicySet::add_static()` and `PolicySet::link()`.
687    #[test]
688    fn policyset_add() {
689        let mut pset = PolicySet::new();
690        let static_policy = parser::parse_policy(
691            Some(PolicyID::from_string("id")),
692            "permit(principal,action,resource);",
693        )
694        .expect("Failed to parse");
695        let static_policy: Policy = static_policy.into();
696        pset.add(static_policy)
697            .expect("Adding static policy in Policy form should succeed");
698
699        let template = Arc::new(
700            parser::parse_policy_or_template(
701                Some(PolicyID::from_string("t")),
702                "permit(principal == ?principal, action, resource);",
703            )
704            .expect("Failed to parse"),
705        );
706        let env1: HashMap<SlotId, EntityUID> = HashMap::from([(
707            SlotId::principal(),
708            r#"Test::"test1""#.parse().expect("Failed to parse"),
709        )]);
710
711        let p1 = Template::link(Arc::clone(&template), PolicyID::from_string("link"), env1)
712            .expect("Failed to link");
713        pset.add(p1).expect(
714            "Adding link should succeed, even though the template wasn't previously in the pset",
715        );
716        assert!(
717            pset.get_template_arc(&PolicyID::from_string("t")).is_some(),
718            "Adding link should implicitly add the template"
719        );
720
721        let env2: HashMap<SlotId, EntityUID> = HashMap::from([(
722            SlotId::principal(),
723            r#"Test::"test2""#.parse().expect("Failed to parse"),
724        )]);
725
726        let p2 = Template::link(
727            Arc::clone(&template),
728            PolicyID::from_string("link"),
729            env2.clone(),
730        )
731        .expect("Failed to link");
732        match pset.add(p2) {
733            Ok(_) => panic!("Should have failed due to conflict with existing link id"),
734            Err(PolicySetError::Occupied { id }) => assert_eq!(id, PolicyID::from_string("link")),
735        }
736
737        let p3 = Template::link(Arc::clone(&template), PolicyID::from_string("link2"), env2)
738            .expect("Failed to link");
739        pset.add(p3).expect(
740            "Adding link should succeed, even though the template already existed in the pset",
741        );
742
743        let template2 = Arc::new(
744            parser::parse_policy_or_template(
745                Some(PolicyID::from_string("t")),
746                "forbid(principal, action, resource == ?resource);",
747            )
748            .expect("Failed to parse"),
749        );
750        let env3: HashMap<SlotId, EntityUID> = HashMap::from([(
751            SlotId::resource(),
752            r#"Test::"test3""#.parse().expect("Failed to parse"),
753        )]);
754
755        let p4 = Template::link(
756            Arc::clone(&template2),
757            PolicyID::from_string("unique3"),
758            env3,
759        )
760        .expect("Failed to link");
761        match pset.add(p4) {
762            Ok(_) => panic!("Should have failed due to conflict on template id"),
763            Err(PolicySetError::Occupied { id }) => {
764                assert_eq!(id, PolicyID::from_string("t"))
765            }
766        }
767    }
768
769    #[test]
770    fn policy_merge_no_conflicts() {
771        let p1 = parser::parse_policy(
772            Some(PolicyID::from_string("policy0")),
773            "permit(principal,action,resource);",
774        )
775        .expect("Failed to parse");
776        let p2 = parser::parse_policy(
777            Some(PolicyID::from_string("policy1")),
778            "permit(principal,action,resource) when { false };",
779        )
780        .expect("Failed to parse");
781        let p3 = parser::parse_policy(
782            Some(PolicyID::from_string("policy0")),
783            "permit(principal,action,resource);",
784        )
785        .expect("Failed to parse");
786        let p4 = parser::parse_policy(
787            Some(PolicyID::from_string("policy2")),
788            "permit(principal,action,resource) when { true };",
789        )
790        .expect("Failed to parse");
791        let mut pset1 = PolicySet::new();
792        let mut pset2 = PolicySet::new();
793        pset1.add_static(p1).expect("Failed to add!");
794        pset1.add_static(p2).expect("Failed to add!");
795        pset2.add_static(p3).expect("Failed to add!");
796        pset2.add_static(p4).expect("Failed to add!");
797        // should not conflict because p1 == p3
798        match pset1.merge_policyset(&pset2, false) {
799            Ok(_) => (),
800            Err(PolicySetError::Occupied { id }) => {
801                panic!("There should not have been an error! Unexpected conflict for id {id}")
802            }
803        }
804    }
805
806    #[test]
807    fn policy_merge_with_conflicts() {
808        let pid0 = PolicyID::from_string("policy0");
809        let pid1 = PolicyID::from_string("policy1");
810        let pid2 = PolicyID::from_string("policy2");
811        let p1 = parser::parse_policy(Some(pid0.clone()), "permit(principal,action,resource);")
812            .expect("Failed to parse");
813        let p2 = parser::parse_policy(
814            Some(pid1.clone()),
815            "permit(principal,action,resource) when { false };",
816        )
817        .expect("Failed to parse");
818        let p3 = parser::parse_policy(Some(pid1.clone()), "permit(principal,action,resource);")
819            .expect("Failed to parse");
820        let p4 = parser::parse_policy(
821            Some(pid2.clone()),
822            "permit(principal,action,resource) when { true };",
823        )
824        .expect("Failed to parse");
825        let mut pset1 = PolicySet::new();
826        let mut pset2 = PolicySet::new();
827        pset1.add_static(p1.clone()).expect("Failed to add!");
828        pset1.add_static(p2.clone()).expect("Failed to add!");
829        pset2.add_static(p3.clone()).expect("Failed to add!");
830        pset2.add_static(p4.clone()).expect("Failed to add!");
831        // should conclict on pid "policy1"
832        match pset1.merge_policyset(&pset2, false) {
833            Ok(_) => panic!("`pset1` and `pset2` should conflict for PolicyID `policy1`"),
834            Err(PolicySetError::Occupied { id }) => {
835                assert_eq!(id, PolicyID::from_string("policy1"));
836            }
837        }
838        // should not conflict because of auto-renaming of conflicting policies
839        match pset1.merge_policyset(&pset2, true) {
840            Ok(renaming) => {
841                // ensure `policy1` was renamed
842                let new_pid1 = match renaming.get(&pid1) {
843                    Some(new_pid1) => new_pid1,
844                    None => panic!("Error: `policy1` is a conflict and should be renamed"),
845                };
846                // ensure no other policy was renamed
847                assert_eq!(renaming.keys().len(), 1);
848                if let Some(new_p1) = pset1.get(&pid0) {
849                    assert_eq!(Policy::from(p1), new_p1.clone());
850                }
851                if let Some(new_p2) = pset1.get(&pid1) {
852                    assert_eq!(Policy::from(p2), new_p2.clone());
853                }
854                if let Some(new_p3) = pset1.get(new_pid1) {
855                    assert_eq!(Policy::from(p3), new_p3.clone());
856                }
857                if let Some(new_p4) = pset1.get(&pid2) {
858                    assert_eq!(Policy::from(p4), new_p4.clone());
859                }
860            }
861            Err(PolicySetError::Occupied { id }) => {
862                panic!("There should not have been an error! Unexpected conflict for id {id}")
863            }
864        }
865    }
866
867    #[test]
868    fn policy_conflicts() {
869        let mut pset = PolicySet::new();
870        let p1 = parser::parse_policy(
871            Some(PolicyID::from_string("id")),
872            "permit(principal,action,resource);",
873        )
874        .expect("Failed to parse");
875        let p2 = parser::parse_policy(
876            Some(PolicyID::from_string("id")),
877            "permit(principal,action,resource) when { false };",
878        )
879        .expect("Failed to parse");
880        pset.add_static(p1).expect("Failed to add!");
881        match pset.add_static(p2) {
882            Ok(_) => panic!("Should have failed to due name conflict"),
883            Err(PolicySetError::Occupied { id }) => assert_eq!(id, PolicyID::from_string("id")),
884        }
885    }
886
887    #[test]
888    fn template_filtering() {
889        let template = parser::parse_policy_or_template(
890            Some(PolicyID::from_string("template")),
891            "permit(principal == ?principal, action, resource);",
892        )
893        .expect("Template Parse Failure");
894        let static_policy = parser::parse_policy(
895            Some(PolicyID::from_string("static")),
896            "permit(principal, action, resource);",
897        )
898        .expect("Static parse failure");
899        let mut set = PolicySet::new();
900        set.add_template(template).unwrap();
901        set.add_static(static_policy).unwrap();
902
903        assert_eq!(set.all_templates().count(), 2);
904        assert_eq!(set.templates().count(), 1);
905        assert_eq!(set.static_policies().count(), 1);
906        assert_eq!(set.policies().count(), 1);
907        set.link(
908            PolicyID::from_string("template"),
909            PolicyID::from_string("id"),
910            HashMap::from([(SlotId::principal(), EntityUID::with_eid("eid"))]),
911        )
912        .expect("Linking failed!");
913        assert_eq!(set.static_policies().count(), 1);
914        assert_eq!(set.policies().count(), 2);
915    }
916
917    #[test]
918    fn linking_missing_template() {
919        let tid = PolicyID::from_string("template");
920        let lid = PolicyID::from_string("link");
921        let t = Template::new(
922            tid.clone(),
923            None,
924            Annotations::new(),
925            Effect::Permit,
926            PrincipalConstraint::any(),
927            ActionConstraint::any(),
928            ResourceConstraint::any(),
929            None,
930        );
931
932        let mut s = PolicySet::new();
933        let e = s
934            .link(tid.clone(), lid.clone(), HashMap::new())
935            .expect_err("Should fail");
936
937        match e {
938            LinkingError::NoSuchTemplate { id } => assert_eq!(tid, id),
939            e => panic!("Wrong error {e}"),
940        };
941
942        s.add_template(t).unwrap();
943        s.link(tid, lid, HashMap::new()).expect("Should succeed");
944    }
945
946    #[test]
947    fn linkinv_valid_link() {
948        let tid = PolicyID::from_string("template");
949        let lid = PolicyID::from_string("link");
950        let t = Template::new(
951            tid.clone(),
952            None,
953            Annotations::new(),
954            Effect::Permit,
955            PrincipalConstraint::is_eq_slot(),
956            ActionConstraint::any(),
957            ResourceConstraint::is_in_slot(),
958            None,
959        );
960
961        let mut s = PolicySet::new();
962        s.add_template(t).unwrap();
963
964        let mut vals = HashMap::new();
965        vals.insert(SlotId::principal(), EntityUID::with_eid("p"));
966        vals.insert(SlotId::resource(), EntityUID::with_eid("a"));
967
968        s.link(tid.clone(), lid.clone(), vals).expect("Should link");
969
970        let v: Vec<_> = s.policies().collect();
971
972        assert_eq!(v[0].id(), &lid);
973        assert_eq!(v[0].template().id(), &tid);
974    }
975
976    #[test]
977    fn linking_empty_set() {
978        let s = PolicySet::new();
979        assert_eq!(s.policies().count(), 0);
980    }
981
982    #[test]
983    fn linking_raw_policy() {
984        let mut s = PolicySet::new();
985        let id = PolicyID::from_string("id");
986        let p = StaticPolicy::new(
987            id.clone(),
988            None,
989            Annotations::new(),
990            Effect::Forbid,
991            PrincipalConstraint::any(),
992            ActionConstraint::any(),
993            ResourceConstraint::any(),
994            None,
995        )
996        .expect("Policy Creation Failed");
997        s.add_static(p).unwrap();
998
999        let mut iter = s.policies();
1000        match iter.next() {
1001            Some(pol) => {
1002                assert_eq!(pol.id(), &id);
1003                assert_eq!(pol.effect(), Effect::Forbid);
1004                assert!(pol.env().is_empty())
1005            }
1006            None => panic!("Linked Record Not Present"),
1007        };
1008    }
1009
1010    #[test]
1011    fn link_slotmap() {
1012        let mut s = PolicySet::new();
1013        let template_id = PolicyID::from_string("template");
1014        let link_id = PolicyID::from_string("link");
1015        let t = Template::new(
1016            template_id.clone(),
1017            None,
1018            Annotations::new(),
1019            Effect::Forbid,
1020            PrincipalConstraint::is_eq_slot(),
1021            ActionConstraint::any(),
1022            ResourceConstraint::any(),
1023            None,
1024        );
1025        s.add_template(t).unwrap();
1026
1027        let mut v = HashMap::new();
1028        let entity = EntityUID::with_eid("eid");
1029        v.insert(SlotId::principal(), entity.clone());
1030        s.link(template_id.clone(), link_id.clone(), v)
1031            .expect("Linking failed!");
1032
1033        let link = s.get(&link_id).expect("Link should exist");
1034        assert_eq!(&link_id, link.id());
1035        assert_eq!(&template_id, link.template().id());
1036        assert_eq!(
1037            &entity,
1038            link.env()
1039                .get(&SlotId::principal())
1040                .expect("Mapping was incorrect")
1041        );
1042    }
1043
1044    #[test]
1045    fn policy_sets() {
1046        let mut pset = PolicySet::new();
1047        assert!(pset.is_empty());
1048        let id1 = PolicyID::from_string("id1");
1049        let tid1 = PolicyID::from_string("template");
1050        let policy1 = StaticPolicy::new(
1051            id1.clone(),
1052            None,
1053            Annotations::new(),
1054            Effect::Permit,
1055            PrincipalConstraint::any(),
1056            ActionConstraint::any(),
1057            ResourceConstraint::any(),
1058            None,
1059        )
1060        .expect("Policy Creation Failed");
1061        let template1 = Template::new(
1062            tid1.clone(),
1063            None,
1064            Annotations::new(),
1065            Effect::Permit,
1066            PrincipalConstraint::any(),
1067            ActionConstraint::any(),
1068            ResourceConstraint::any(),
1069            None,
1070        );
1071        let added = pset.add_static(policy1.clone()).is_ok();
1072        assert!(added);
1073        let added = pset.add_static(policy1).is_ok();
1074        assert!(!added);
1075        let added = pset.add_template(template1.clone()).is_ok();
1076        assert!(added);
1077        let added = pset.add_template(template1).is_ok();
1078        assert!(!added);
1079        assert!(!pset.is_empty());
1080        let id2 = PolicyID::from_string("id2");
1081        let policy2 = StaticPolicy::new(
1082            id2.clone(),
1083            None,
1084            Annotations::new(),
1085            Effect::Forbid,
1086            PrincipalConstraint::is_eq(Arc::new(EntityUID::with_eid("jane"))),
1087            ActionConstraint::any(),
1088            ResourceConstraint::any(),
1089            None,
1090        )
1091        .expect("Policy Creation Failed");
1092        let added = pset.add_static(policy2).is_ok();
1093        assert!(added);
1094
1095        let tid2 = PolicyID::from_string("template2");
1096        let template2 = Template::new(
1097            tid2.clone(),
1098            None,
1099            Annotations::new(),
1100            Effect::Permit,
1101            PrincipalConstraint::is_eq_slot(),
1102            ActionConstraint::any(),
1103            ResourceConstraint::any(),
1104            None,
1105        );
1106        let id3 = PolicyID::from_string("link");
1107        let added = pset.add_template(template2).is_ok();
1108        assert!(added);
1109
1110        let r = pset.link(
1111            tid2.clone(),
1112            id3.clone(),
1113            HashMap::from([(SlotId::principal(), EntityUID::with_eid("example"))]),
1114        );
1115        r.expect("Linking failed");
1116
1117        assert_eq!(pset.get(&id1).expect("should find the policy").id(), &id1);
1118        assert_eq!(pset.get(&id2).expect("should find the policy").id(), &id2);
1119        assert_eq!(pset.get(&id3).expect("should find link").id(), &id3);
1120        assert_eq!(
1121            pset.get(&id3).expect("should find link").template().id(),
1122            &tid2
1123        );
1124        assert!(pset.get(&tid2).is_none());
1125        assert!(pset.get_template_arc(&id1).is_some()); // Static policies are also templates
1126        assert!(pset.get_template_arc(&id2).is_some()); // Static policies are also templates
1127        assert!(pset.get_template_arc(&tid2).is_some());
1128        assert_eq!(pset.policies().count(), 3);
1129
1130        assert_eq!(
1131            pset.get_template_arc(&tid1)
1132                .expect("should find the template")
1133                .id(),
1134            &tid1
1135        );
1136        assert!(pset.get(&tid1).is_none());
1137        assert_eq!(pset.all_templates().count(), 4);
1138    }
1139
1140    #[test]
1141    fn policy_set_insertion_order() {
1142        let mut pset = PolicySet::new();
1143        assert!(pset.is_empty());
1144
1145        let src = "permit(principal, action, resource);";
1146        let ids: Vec<PolicyID> = (1..=4)
1147            .map(|i| {
1148                let id = PolicyID::from_string(format!("id{i}"));
1149                let p = parser::parse_policy(Some(id.clone()), src).unwrap();
1150                let added = pset.add(p.into()).is_ok();
1151                assert!(added);
1152                id
1153            })
1154            .collect();
1155
1156        assert_eq!(
1157            pset.into_policies()
1158                .map(|p| p.id().clone())
1159                .collect::<Vec<PolicyID>>(),
1160            ids
1161        );
1162    }
1163}