authly_common/policy/
engine.rs

1//! Policy evaluation engine that implements a Policy Decision Point (PDP).
2
3use std::collections::BTreeSet;
4
5use byteorder::{BigEndian, ReadBytesExt};
6use fnv::{FnvHashMap, FnvHashSet};
7use tracing::error;
8
9use crate::id::{kind::Kind, AttrId, EntityId, PolicyId, PropId};
10
11use super::code::{Bytecode, PolicyValue};
12
13/// Evaluation error.
14#[derive(Clone, Copy, PartialEq, Eq, Debug)]
15pub enum EvalError {
16    /// Error in the program encoding
17    Program,
18
19    /// Type error
20    Type,
21}
22
23/// The parameters to an policy-based access control evaluation.
24///
25/// The access control paramaters generall consists of attributes related to a `subject` and a `resource`.
26///
27/// The `subject` represents the entity or entities requesting access.
28/// The `resource` is a representation of the abstract object being requested.
29#[derive(Default, Debug)]
30pub struct AccessControlParams {
31    /// Entity IDs related to the `subject`.
32    pub subject_eids: FnvHashMap<PropId, EntityId>,
33
34    /// Attributes related to the `subject`.
35    pub subject_attrs: FnvHashSet<AttrId>,
36
37    /// Entity IDs related to the `resource`.
38    pub resource_eids: FnvHashMap<PropId, EntityId>,
39
40    /// Attributes related to the `resource`.
41    pub resource_attrs: FnvHashSet<AttrId>,
42}
43
44/// The state of the policy engine.
45///
46/// Contains compiled policies and their triggers.
47#[derive(Default, Debug)]
48pub struct PolicyEngine {
49    policies: FnvHashMap<PolicyId, Policy>,
50
51    /// The triggers in this map are keyed by the one of the
52    /// attributes that has to match the trigger.
53    trigger_groups: FnvHashMap<AttrId, Vec<PolicyTrigger>>,
54}
55
56/// The policy trigger maps a set of attributes to a set of policies.
57#[derive(Debug)]
58struct PolicyTrigger {
59    /// The set of attributes that has to match for this policy to trigger
60    pub attr_matcher: BTreeSet<AttrId>,
61
62    /// The policy which gets triggered by this attribute matcher
63    pub policy_ids: BTreeSet<PolicyId>,
64}
65
66/// A tracer used to collect debugging information from the policy engine
67#[allow(unused)]
68pub trait PolicyTracer {
69    /// Reports applicable policies of a specific class
70    fn report_applicable(&mut self, class: PolicyValue, policies: impl Iterator<Item = PolicyId>) {}
71
72    /// Report start of a policy evaluation
73    fn report_policy_eval_start(&mut self, policy_id: PolicyId) {}
74
75    /// Reports the value of policy after it has been evaluated
76    fn report_policy_eval_end(&mut self, value: bool) {}
77}
78
79/// A [PolicyTracer] that does nothing.
80pub struct NoOpPolicyTracer;
81
82impl PolicyTracer for NoOpPolicyTracer {}
83
84#[derive(Debug)]
85struct Policy {
86    class: PolicyValue,
87    bytecode: Vec<u8>,
88}
89
90#[derive(PartialEq, Eq, Debug)]
91enum StackItem<'a> {
92    Uint(u64),
93    AttrIdSet(&'a FnvHashSet<AttrId>),
94    EntityId(EntityId),
95    AttrId(AttrId),
96}
97
98#[derive(Debug)]
99struct EvalCtx<'e> {
100    applicable_allow: FnvHashMap<PolicyId, &'e Policy>,
101    applicable_deny: FnvHashMap<PolicyId, &'e Policy>,
102}
103
104impl PolicyEngine {
105    /// Adds a new policy to the engine.
106    pub fn add_policy(&mut self, id: PolicyId, class: PolicyValue, bytecode: Vec<u8>) {
107        self.policies.insert(id, Policy { class, bytecode });
108    }
109
110    /// Adds a new policy trigger to the engine.
111    pub fn add_trigger(
112        &mut self,
113        attr_matcher: impl Into<BTreeSet<AttrId>>,
114        policy_ids: impl Into<BTreeSet<PolicyId>>,
115    ) {
116        let attr_matcher = attr_matcher.into();
117        let policy_ids = policy_ids.into();
118
119        if let Some(first_attr) = attr_matcher.iter().next() {
120            self.trigger_groups
121                .entry(*first_attr)
122                .or_default()
123                .push(PolicyTrigger {
124                    attr_matcher,
125                    policy_ids,
126                });
127        }
128    }
129
130    /// Get the number of policies currently in the engine.
131    pub fn get_policy_count(&self) -> usize {
132        self.policies.len()
133    }
134
135    /// Get the number of policy triggers currently in the engine.
136    pub fn get_trigger_count(&self) -> usize {
137        self.trigger_groups.values().map(Vec::len).sum()
138    }
139
140    /// Perform an access control evalution of the given parameters within this engine.
141    pub fn eval(
142        &self,
143        params: &AccessControlParams,
144        tracer: &mut impl PolicyTracer,
145    ) -> Result<PolicyValue, EvalError> {
146        let mut eval_ctx = EvalCtx {
147            applicable_allow: Default::default(),
148            applicable_deny: Default::default(),
149        };
150
151        for attr in &params.subject_attrs {
152            self.collect_applicable(*attr, params, &mut eval_ctx)?;
153        }
154
155        for attr in &params.resource_attrs {
156            self.collect_applicable(*attr, params, &mut eval_ctx)?;
157        }
158
159        {
160            tracer.report_applicable(PolicyValue::Deny, eval_ctx.applicable_deny.keys().copied());
161            tracer.report_applicable(
162                PolicyValue::Allow,
163                eval_ctx.applicable_allow.keys().copied(),
164            );
165        }
166
167        let has_allow = !eval_ctx.applicable_allow.is_empty();
168        let has_deny = !eval_ctx.applicable_deny.is_empty();
169
170        match (has_allow, has_deny) {
171            (false, false) => {
172                // idea: Fallback mode, no policies matched
173                for subj_attr in &params.subject_attrs {
174                    if params.resource_attrs.contains(subj_attr) {
175                        return Ok(PolicyValue::Allow);
176                    }
177                }
178
179                Ok(PolicyValue::Deny)
180            }
181            (true, false) => {
182                // starts in Deny state, try to prove Allow
183                let is_allow =
184                    eval_policies_disjunctive(eval_ctx.applicable_allow, params, tracer)?;
185                Ok(PolicyValue::from(is_allow))
186            }
187            (false, true) => {
188                // starts in Allow state, try to prove Deny
189                let is_deny = eval_policies_disjunctive(eval_ctx.applicable_deny, params, tracer)?;
190                Ok(PolicyValue::from(!is_deny))
191            }
192            (true, true) => {
193                // starts in Deny state, try to prove Allow
194                let is_allow =
195                    eval_policies_disjunctive(eval_ctx.applicable_allow, params, tracer)?;
196                if !is_allow {
197                    return Ok(PolicyValue::Deny);
198                }
199
200                // moved into in Allow state, try to prove Deny
201                let is_deny = eval_policies_disjunctive(eval_ctx.applicable_deny, params, tracer)?;
202                Ok(PolicyValue::from(!is_deny))
203            }
204        }
205    }
206
207    fn collect_applicable<'e>(
208        &'e self,
209        attr: AttrId,
210        params: &AccessControlParams,
211        eval_ctx: &mut EvalCtx<'e>,
212    ) -> Result<(), EvalError> {
213        // Find all potential triggers to investigate for this attribute
214        let Some(policy_triggers) = self.trigger_groups.get(&attr) else {
215            return Ok(());
216        };
217
218        for policy_trigger in policy_triggers {
219            if policy_trigger.attr_matcher.len() > 1 {
220                // a multi-attribute trigger: needs some post-processing
221                // to figure out if it applies
222                let mut matches: BTreeSet<AttrId> = Default::default();
223
224                for attrs in [&params.subject_attrs, &params.resource_attrs] {
225                    for attr in attrs {
226                        if policy_trigger.attr_matcher.contains(attr) {
227                            matches.insert(*attr);
228                        }
229                    }
230                }
231
232                if matches != policy_trigger.attr_matcher {
233                    // not applicable
234                    continue;
235                }
236            }
237
238            // The trigger applies; register all its policies as applicable
239            for policy_id in policy_trigger.policy_ids.iter().copied() {
240                let Some(policy) = self.policies.get(&policy_id) else {
241                    error!(?policy_id, "policy is missing");
242
243                    // internal error, which is not exposed
244                    continue;
245                };
246
247                match policy.class {
248                    PolicyValue::Deny => {
249                        eval_ctx.applicable_deny.insert(policy_id, policy);
250                    }
251                    PolicyValue::Allow => {
252                        eval_ctx.applicable_allow.insert(policy_id, policy);
253                    }
254                }
255            }
256        }
257
258        Ok(())
259    }
260}
261
262/// Evaluate set of policies, map their outputs to a boolean value and return the OR function applied to those values.
263fn eval_policies_disjunctive(
264    map: FnvHashMap<PolicyId, &Policy>,
265    params: &AccessControlParams,
266    tracer: &mut impl PolicyTracer,
267) -> Result<bool, EvalError> {
268    for (policy_id, policy) in &map {
269        tracer.report_policy_eval_start(*policy_id);
270
271        let value = eval_policy(&policy.bytecode, params)?;
272
273        tracer.report_policy_eval_end(value);
274
275        if value {
276            return Ok(true);
277        }
278    }
279
280    Ok(false)
281}
282
283/// Evaluate one standalone policy on the given access control parameters
284fn eval_policy(mut pc: &[u8], params: &AccessControlParams) -> Result<bool, EvalError> {
285    let mut stack: Vec<StackItem> = Vec::with_capacity(16);
286
287    while let Some(code) = pc.first() {
288        pc = &pc[1..];
289
290        let Ok(code) = Bytecode::try_from(*code) else {
291            return Err(EvalError::Program);
292        };
293
294        match code {
295            Bytecode::LoadSubjectId => {
296                let prop_id = PropId::from_uint(pc.read_u128::<BigEndian>()?);
297                let Some(id) = params.subject_eids.get(&prop_id) else {
298                    return Err(EvalError::Type);
299                };
300                stack.push(StackItem::EntityId(*id));
301            }
302            Bytecode::LoadSubjectAttrs => {
303                stack.push(StackItem::AttrIdSet(&params.subject_attrs));
304            }
305            Bytecode::LoadResourceId => {
306                let prop_id = PropId::from_uint(pc.read_u128::<BigEndian>()?);
307                let Some(id) = params.resource_eids.get(&prop_id) else {
308                    return Err(EvalError::Type);
309                };
310                stack.push(StackItem::EntityId(*id));
311            }
312            Bytecode::LoadResourceAttrs => {
313                stack.push(StackItem::AttrIdSet(&params.resource_attrs));
314            }
315            Bytecode::LoadConstEntityId => {
316                let Ok(kind) = Kind::try_from(pc.read_u8()?) else {
317                    return Err(EvalError::Type);
318                };
319                let uint = pc.read_u128::<BigEndian>()?;
320                stack.push(StackItem::EntityId(EntityId::new(kind, uint.to_be_bytes())));
321            }
322            Bytecode::LoadConstAttrId => {
323                let attr_id = AttrId::from_uint(pc.read_u128::<BigEndian>()?);
324                stack.push(StackItem::AttrId(attr_id));
325            }
326            Bytecode::IsEq => {
327                let Some(a) = stack.pop() else {
328                    return Err(EvalError::Type);
329                };
330                let Some(b) = stack.pop() else {
331                    return Err(EvalError::Type);
332                };
333                let is_eq = match (a, b) {
334                    (StackItem::AttrId(a), StackItem::AttrId(b)) => a == b,
335                    (StackItem::EntityId(a), StackItem::EntityId(b)) => a == b,
336                    (StackItem::AttrIdSet(set), StackItem::AttrId(id)) => set.contains(&id),
337                    (StackItem::AttrId(id), StackItem::AttrIdSet(set)) => set.contains(&id),
338                    _ => false,
339                };
340                stack.push(StackItem::Uint(if is_eq { 1 } else { 0 }));
341            }
342            Bytecode::SupersetOf => {
343                let Some(StackItem::AttrIdSet(a)) = stack.pop() else {
344                    return Err(EvalError::Type);
345                };
346                let Some(StackItem::AttrIdSet(b)) = stack.pop() else {
347                    return Err(EvalError::Type);
348                };
349                stack.push(StackItem::Uint(if a.is_superset(b) { 1 } else { 0 }));
350            }
351            Bytecode::IdSetContains => {
352                let Some(a) = stack.pop() else {
353                    return Err(EvalError::Type);
354                };
355                let Some(b) = stack.pop() else {
356                    return Err(EvalError::Type);
357                };
358
359                match (a, b) {
360                    (StackItem::AttrIdSet(a), StackItem::AttrId(b)) => {
361                        // BUG: Does not support u128?
362                        stack.push(StackItem::Uint(if a.contains(&b) { 1 } else { 0 }));
363                    }
364                    _ => {
365                        return Err(EvalError::Type);
366                    }
367                }
368            }
369            Bytecode::And => {
370                let Some(StackItem::Uint(rhs)) = stack.pop() else {
371                    return Err(EvalError::Type);
372                };
373                let Some(StackItem::Uint(lhs)) = stack.pop() else {
374                    return Err(EvalError::Type);
375                };
376                stack.push(StackItem::Uint(if rhs > 0 && lhs > 0 { 1 } else { 0 }));
377            }
378            Bytecode::Or => {
379                let Some(StackItem::Uint(rhs)) = stack.pop() else {
380                    return Err(EvalError::Type);
381                };
382                let Some(StackItem::Uint(lhs)) = stack.pop() else {
383                    return Err(EvalError::Type);
384                };
385                stack.push(StackItem::Uint(if rhs > 0 || lhs > 0 { 1 } else { 0 }));
386            }
387            Bytecode::Not => {
388                let Some(StackItem::Uint(val)) = stack.pop() else {
389                    return Err(EvalError::Type);
390                };
391                stack.push(StackItem::Uint(if val > 0 { 0 } else { 1 }));
392            }
393            Bytecode::Return => {
394                let Some(StackItem::Uint(u)) = stack.pop() else {
395                    return Err(EvalError::Type);
396                };
397                return Ok(u > 0);
398            }
399        }
400    }
401
402    Err(EvalError::Program)
403}
404
405impl From<std::io::Error> for EvalError {
406    fn from(_value: std::io::Error) -> Self {
407        EvalError::Program
408    }
409}