biscuit_auth/datalog/
origin.rs1use std::collections::BTreeSet;
6use std::collections::HashMap;
7use std::fmt::Display;
8use std::hash::Hash;
9use std::iter::FromIterator;
10
11use crate::token::Scope;
12
13#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, PartialOrd, Ord)]
14pub struct Origin {
15 pub(crate) inner: BTreeSet<usize>,
16}
17
18impl Origin {
19 pub fn insert(&mut self, i: usize) {
20 self.inner.insert(i);
21 }
22
23 pub fn union(&self, other: &Self) -> Self {
24 Origin {
25 inner: self.inner.union(&other.inner).cloned().collect(),
26 }
27 }
28
29 pub fn is_superset(&self, other: &Self) -> bool {
30 self.inner.is_superset(&other.inner)
31 }
32}
33
34impl<'a> Extend<&'a usize> for Origin {
35 fn extend<T: IntoIterator<Item = &'a usize>>(&mut self, iter: T) {
36 self.inner.extend(iter)
37 }
38}
39
40impl Extend<usize> for Origin {
41 fn extend<T: IntoIterator<Item = usize>>(&mut self, iter: T) {
42 self.inner.extend(iter)
43 }
44}
45
46impl<'a> FromIterator<&'a usize> for Origin {
47 fn from_iter<T: IntoIterator<Item = &'a usize>>(iter: T) -> Self {
48 Self {
49 inner: iter.into_iter().cloned().collect(),
50 }
51 }
52}
53
54impl FromIterator<usize> for Origin {
55 fn from_iter<T: IntoIterator<Item = usize>>(iter: T) -> Self {
56 Self {
57 inner: iter.into_iter().collect(),
58 }
59 }
60}
61
62impl Display for Origin {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 let mut it = self.inner.iter();
65
66 if let Some(i) = it.next() {
67 if *i == usize::MAX {
68 write!(f, "authorizer")?;
69 } else {
70 write!(f, "{i}")?;
71 }
72 }
73
74 for i in it {
75 if *i == usize::MAX {
76 write!(f, ", authorizer")?;
77 } else {
78 write!(f, ", {i}")?;
79 }
80 }
81 Ok(())
82 }
83}
84
85#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, PartialOrd, Ord)]
87pub struct TrustedOrigins(Origin);
88
89impl TrustedOrigins {
90 pub fn default() -> TrustedOrigins {
91 let mut origins = Origin::default();
92 origins.insert(usize::MAX);
93 origins.insert(0);
94 TrustedOrigins(origins)
95 }
96
97 pub fn from_scopes(
98 rule_scopes: &[Scope],
99 default_origins: &TrustedOrigins,
100 current_block: usize,
101 public_key_to_block_id: &HashMap<usize, Vec<usize>>,
102 ) -> TrustedOrigins {
103 if rule_scopes.is_empty() {
104 let mut origins = default_origins.clone();
105 origins.0.insert(current_block);
106 origins.0.insert(usize::MAX);
107 return origins;
108 }
109
110 let mut origins = Origin::default();
111 origins.insert(usize::MAX);
112 origins.insert(current_block);
113
114 for scope in rule_scopes {
115 match scope {
116 Scope::Authority => {
117 origins.insert(0);
118 }
119 Scope::Previous => {
120 if current_block != usize::MAX {
121 origins.extend(0..current_block + 1)
122 }
123 }
124 Scope::PublicKey(key_id) => {
125 if let Some(block_ids) = public_key_to_block_id.get(&(*key_id as usize)) {
126 origins.extend(block_ids.iter())
127 }
128 }
129 }
130 }
131
132 TrustedOrigins(origins)
133 }
134
135 pub fn contains(&self, fact_origin: &Origin) -> bool {
136 self.0.is_superset(fact_origin)
137 }
138}
139
140impl FromIterator<usize> for TrustedOrigins {
141 fn from_iter<T: IntoIterator<Item = usize>>(iter: T) -> Self {
142 Self(iter.into_iter().collect())
143 }
144}
145
146impl<'a> FromIterator<&'a usize> for TrustedOrigins {
147 fn from_iter<T: IntoIterator<Item = &'a usize>>(iter: T) -> Self {
148 Self(iter.into_iter().cloned().collect())
149 }
150}