1use std::collections::BTreeSet;
4
5use byteorder::{BigEndian, ReadBytesExt};
6use fnv::{FnvHashMap, FnvHashSet};
7use tracing::error;
8
9use crate::id::{kind::Kind, AttrId, EntityId, PolicyId, PropId};
10
11use super::code::{Bytecode, PolicyValue};
12
13#[derive(Clone, Copy, PartialEq, Eq, Debug)]
15pub enum EvalError {
16 Program,
18
19 Type,
21}
22
23#[derive(Default, Debug)]
30pub struct AccessControlParams {
31 pub subject_eids: FnvHashMap<PropId, EntityId>,
33
34 pub subject_attrs: FnvHashSet<AttrId>,
36
37 pub resource_eids: FnvHashMap<PropId, EntityId>,
39
40 pub resource_attrs: FnvHashSet<AttrId>,
42}
43
44#[derive(Default, Debug)]
48pub struct PolicyEngine {
49 policies: FnvHashMap<PolicyId, Policy>,
50
51 trigger_groups: FnvHashMap<AttrId, Vec<PolicyTrigger>>,
54}
55
56#[derive(Debug)]
58struct PolicyTrigger {
59 pub attr_matcher: BTreeSet<AttrId>,
61
62 pub policy_ids: BTreeSet<PolicyId>,
64}
65
66#[allow(unused)]
68pub trait PolicyTracer {
69 fn report_applicable(&mut self, class: PolicyValue, policies: impl Iterator<Item = PolicyId>) {}
71
72 fn report_policy_eval_start(&mut self, policy_id: PolicyId) {}
74
75 fn report_policy_eval_end(&mut self, value: bool) {}
77}
78
79pub struct NoOpPolicyTracer;
81
82impl PolicyTracer for NoOpPolicyTracer {}
83
84#[derive(Debug)]
85struct Policy {
86 class: PolicyValue,
87 bytecode: Vec<u8>,
88}
89
90#[derive(PartialEq, Eq, Debug)]
91enum StackItem<'a> {
92 Uint(u64),
93 AttrIdSet(&'a FnvHashSet<AttrId>),
94 EntityId(EntityId),
95 AttrId(AttrId),
96}
97
98#[derive(Debug)]
99struct EvalCtx<'e> {
100 applicable_allow: FnvHashMap<PolicyId, &'e Policy>,
101 applicable_deny: FnvHashMap<PolicyId, &'e Policy>,
102}
103
104impl PolicyEngine {
105 pub fn add_policy(&mut self, id: PolicyId, class: PolicyValue, bytecode: Vec<u8>) {
107 self.policies.insert(id, Policy { class, bytecode });
108 }
109
110 pub fn add_trigger(
112 &mut self,
113 attr_matcher: impl Into<BTreeSet<AttrId>>,
114 policy_ids: impl Into<BTreeSet<PolicyId>>,
115 ) {
116 let attr_matcher = attr_matcher.into();
117 let policy_ids = policy_ids.into();
118
119 if let Some(first_attr) = attr_matcher.iter().next() {
120 self.trigger_groups
121 .entry(*first_attr)
122 .or_default()
123 .push(PolicyTrigger {
124 attr_matcher,
125 policy_ids,
126 });
127 }
128 }
129
130 pub fn get_policy_count(&self) -> usize {
132 self.policies.len()
133 }
134
135 pub fn get_trigger_count(&self) -> usize {
137 self.trigger_groups.values().map(Vec::len).sum()
138 }
139
140 pub fn eval(
142 &self,
143 params: &AccessControlParams,
144 tracer: &mut impl PolicyTracer,
145 ) -> Result<PolicyValue, EvalError> {
146 let mut eval_ctx = EvalCtx {
147 applicable_allow: Default::default(),
148 applicable_deny: Default::default(),
149 };
150
151 for attr in ¶ms.subject_attrs {
152 self.collect_applicable(*attr, params, &mut eval_ctx)?;
153 }
154
155 for attr in ¶ms.resource_attrs {
156 self.collect_applicable(*attr, params, &mut eval_ctx)?;
157 }
158
159 {
160 tracer.report_applicable(PolicyValue::Deny, eval_ctx.applicable_deny.keys().copied());
161 tracer.report_applicable(
162 PolicyValue::Allow,
163 eval_ctx.applicable_allow.keys().copied(),
164 );
165 }
166
167 let has_allow = !eval_ctx.applicable_allow.is_empty();
168 let has_deny = !eval_ctx.applicable_deny.is_empty();
169
170 match (has_allow, has_deny) {
171 (false, false) => {
172 for subj_attr in ¶ms.subject_attrs {
174 if params.resource_attrs.contains(subj_attr) {
175 return Ok(PolicyValue::Allow);
176 }
177 }
178
179 Ok(PolicyValue::Deny)
180 }
181 (true, false) => {
182 let is_allow =
184 eval_policies_disjunctive(eval_ctx.applicable_allow, params, tracer)?;
185 Ok(PolicyValue::from(is_allow))
186 }
187 (false, true) => {
188 let is_deny = eval_policies_disjunctive(eval_ctx.applicable_deny, params, tracer)?;
190 Ok(PolicyValue::from(!is_deny))
191 }
192 (true, true) => {
193 let is_allow =
195 eval_policies_disjunctive(eval_ctx.applicable_allow, params, tracer)?;
196 if !is_allow {
197 return Ok(PolicyValue::Deny);
198 }
199
200 let is_deny = eval_policies_disjunctive(eval_ctx.applicable_deny, params, tracer)?;
202 Ok(PolicyValue::from(!is_deny))
203 }
204 }
205 }
206
207 fn collect_applicable<'e>(
208 &'e self,
209 attr: AttrId,
210 params: &AccessControlParams,
211 eval_ctx: &mut EvalCtx<'e>,
212 ) -> Result<(), EvalError> {
213 let Some(policy_triggers) = self.trigger_groups.get(&attr) else {
215 return Ok(());
216 };
217
218 for policy_trigger in policy_triggers {
219 if policy_trigger.attr_matcher.len() > 1 {
220 let mut matches: BTreeSet<AttrId> = Default::default();
223
224 for attrs in [¶ms.subject_attrs, ¶ms.resource_attrs] {
225 for attr in attrs {
226 if policy_trigger.attr_matcher.contains(attr) {
227 matches.insert(*attr);
228 }
229 }
230 }
231
232 if matches != policy_trigger.attr_matcher {
233 continue;
235 }
236 }
237
238 for policy_id in policy_trigger.policy_ids.iter().copied() {
240 let Some(policy) = self.policies.get(&policy_id) else {
241 error!(?policy_id, "policy is missing");
242
243 continue;
245 };
246
247 match policy.class {
248 PolicyValue::Deny => {
249 eval_ctx.applicable_deny.insert(policy_id, policy);
250 }
251 PolicyValue::Allow => {
252 eval_ctx.applicable_allow.insert(policy_id, policy);
253 }
254 }
255 }
256 }
257
258 Ok(())
259 }
260}
261
262fn eval_policies_disjunctive(
264 map: FnvHashMap<PolicyId, &Policy>,
265 params: &AccessControlParams,
266 tracer: &mut impl PolicyTracer,
267) -> Result<bool, EvalError> {
268 for (policy_id, policy) in &map {
269 tracer.report_policy_eval_start(*policy_id);
270
271 let value = eval_policy(&policy.bytecode, params)?;
272
273 tracer.report_policy_eval_end(value);
274
275 if value {
276 return Ok(true);
277 }
278 }
279
280 Ok(false)
281}
282
283fn eval_policy(mut pc: &[u8], params: &AccessControlParams) -> Result<bool, EvalError> {
285 let mut stack: Vec<StackItem> = Vec::with_capacity(16);
286
287 while let Some(code) = pc.first() {
288 pc = &pc[1..];
289
290 let Ok(code) = Bytecode::try_from(*code) else {
291 return Err(EvalError::Program);
292 };
293
294 match code {
295 Bytecode::LoadSubjectId => {
296 let prop_id = PropId::from_uint(pc.read_u128::<BigEndian>()?);
297 let Some(id) = params.subject_eids.get(&prop_id) else {
298 return Err(EvalError::Type);
299 };
300 stack.push(StackItem::EntityId(*id));
301 }
302 Bytecode::LoadSubjectAttrs => {
303 stack.push(StackItem::AttrIdSet(¶ms.subject_attrs));
304 }
305 Bytecode::LoadResourceId => {
306 let prop_id = PropId::from_uint(pc.read_u128::<BigEndian>()?);
307 let Some(id) = params.resource_eids.get(&prop_id) else {
308 return Err(EvalError::Type);
309 };
310 stack.push(StackItem::EntityId(*id));
311 }
312 Bytecode::LoadResourceAttrs => {
313 stack.push(StackItem::AttrIdSet(¶ms.resource_attrs));
314 }
315 Bytecode::LoadConstEntityId => {
316 let Ok(kind) = Kind::try_from(pc.read_u8()?) else {
317 return Err(EvalError::Type);
318 };
319 let uint = pc.read_u128::<BigEndian>()?;
320 stack.push(StackItem::EntityId(EntityId::new(kind, uint.to_be_bytes())));
321 }
322 Bytecode::LoadConstAttrId => {
323 let attr_id = AttrId::from_uint(pc.read_u128::<BigEndian>()?);
324 stack.push(StackItem::AttrId(attr_id));
325 }
326 Bytecode::IsEq => {
327 let Some(a) = stack.pop() else {
328 return Err(EvalError::Type);
329 };
330 let Some(b) = stack.pop() else {
331 return Err(EvalError::Type);
332 };
333 let is_eq = match (a, b) {
334 (StackItem::AttrId(a), StackItem::AttrId(b)) => a == b,
335 (StackItem::EntityId(a), StackItem::EntityId(b)) => a == b,
336 (StackItem::AttrIdSet(set), StackItem::AttrId(id)) => set.contains(&id),
337 (StackItem::AttrId(id), StackItem::AttrIdSet(set)) => set.contains(&id),
338 _ => false,
339 };
340 stack.push(StackItem::Uint(if is_eq { 1 } else { 0 }));
341 }
342 Bytecode::SupersetOf => {
343 let Some(StackItem::AttrIdSet(a)) = stack.pop() else {
344 return Err(EvalError::Type);
345 };
346 let Some(StackItem::AttrIdSet(b)) = stack.pop() else {
347 return Err(EvalError::Type);
348 };
349 stack.push(StackItem::Uint(if a.is_superset(b) { 1 } else { 0 }));
350 }
351 Bytecode::IdSetContains => {
352 let Some(a) = stack.pop() else {
353 return Err(EvalError::Type);
354 };
355 let Some(b) = stack.pop() else {
356 return Err(EvalError::Type);
357 };
358
359 match (a, b) {
360 (StackItem::AttrIdSet(a), StackItem::AttrId(b)) => {
361 stack.push(StackItem::Uint(if a.contains(&b) { 1 } else { 0 }));
363 }
364 _ => {
365 return Err(EvalError::Type);
366 }
367 }
368 }
369 Bytecode::And => {
370 let Some(StackItem::Uint(rhs)) = stack.pop() else {
371 return Err(EvalError::Type);
372 };
373 let Some(StackItem::Uint(lhs)) = stack.pop() else {
374 return Err(EvalError::Type);
375 };
376 stack.push(StackItem::Uint(if rhs > 0 && lhs > 0 { 1 } else { 0 }));
377 }
378 Bytecode::Or => {
379 let Some(StackItem::Uint(rhs)) = stack.pop() else {
380 return Err(EvalError::Type);
381 };
382 let Some(StackItem::Uint(lhs)) = stack.pop() else {
383 return Err(EvalError::Type);
384 };
385 stack.push(StackItem::Uint(if rhs > 0 || lhs > 0 { 1 } else { 0 }));
386 }
387 Bytecode::Not => {
388 let Some(StackItem::Uint(val)) = stack.pop() else {
389 return Err(EvalError::Type);
390 };
391 stack.push(StackItem::Uint(if val > 0 { 0 } else { 1 }));
392 }
393 Bytecode::Return => {
394 let Some(StackItem::Uint(u)) = stack.pop() else {
395 return Err(EvalError::Type);
396 };
397 return Ok(u > 0);
398 }
399 }
400 }
401
402 Err(EvalError::Program)
403}
404
405impl From<std::io::Error> for EvalError {
406 fn from(_value: std::io::Error) -> Self {
407 EvalError::Program
408 }
409}