1use std::collections::HashSet;
7use time::{format_description::well_known::Rfc3339, OffsetDateTime};
8
9pub type SymbolIndex = u64;
10use crate::crypto::PublicKey;
11use crate::token::default_symbol_table;
12use crate::{error, token::public_keys::PublicKeys};
13
14use super::{Check, Fact, Predicate, Rule, Term, World};
15
16#[derive(Clone, Debug, PartialEq, Eq)]
17pub struct SymbolTable {
18 symbols: Vec<String>,
19 pub(crate) public_keys: PublicKeys,
20}
21
22const DEFAULT_SYMBOLS: [&str; 28] = [
23 "read",
24 "write",
25 "resource",
26 "operation",
27 "right",
28 "time",
29 "role",
30 "owner",
31 "tenant",
32 "namespace",
33 "user",
34 "team",
35 "service",
36 "admin",
37 "email",
38 "group",
39 "member",
40 "ip_address",
41 "client",
42 "client_ip",
43 "domain",
44 "path",
45 "version",
46 "cluster",
47 "node",
48 "hostname",
49 "nonce",
50 "query",
51];
52
53const OFFSET: usize = 1024;
54
55impl SymbolTable {
56 pub fn new() -> Self {
57 SymbolTable {
58 symbols: vec![],
59 public_keys: PublicKeys::new(),
60 }
61 }
62
63 pub fn from(symbols: Vec<String>) -> Result<Self, error::Format> {
64 let h1 = DEFAULT_SYMBOLS.iter().copied().collect::<HashSet<_>>();
65 let h2 = symbols.iter().map(|s| s.as_str()).collect::<HashSet<_>>();
66
67 if !h1.is_disjoint(&h2) {
68 return Err(error::Format::SymbolTableOverlap);
69 }
70
71 Ok(SymbolTable {
72 symbols,
73 public_keys: PublicKeys::new(),
74 })
75 }
76
77 pub fn from_symbols_and_public_keys(
78 symbols: Vec<String>,
79 public_keys: Vec<PublicKey>,
80 ) -> Result<Self, error::Format> {
81 let mut table = Self::from(symbols)?;
82 table.public_keys = PublicKeys::from(public_keys);
83 Ok(table)
84 }
85
86 pub fn extend(&mut self, other: &SymbolTable) -> Result<(), error::Format> {
87 if !self.is_disjoint(other) {
88 return Err(error::Format::SymbolTableOverlap);
89 }
90 self.symbols.extend(other.symbols.iter().cloned());
91 self.public_keys.extend(&other.public_keys)?;
92 Ok(())
93 }
94
95 pub fn insert(&mut self, s: &str) -> SymbolIndex {
96 if let Some(index) = DEFAULT_SYMBOLS.iter().position(|sym| *sym == s) {
97 return index as u64;
98 }
99
100 match self.symbols.iter().position(|sym| sym.as_str() == s) {
101 Some(index) => (OFFSET + index) as u64,
102 None => {
103 self.symbols.push(s.to_string());
104 (OFFSET + (self.symbols.len() - 1)) as u64
105 }
106 }
107 }
108
109 pub fn add(&mut self, s: &str) -> Term {
110 let term = self.insert(s);
111 Term::Str(term)
112 }
113
114 pub fn get(&self, s: &str) -> Option<SymbolIndex> {
115 if let Some(index) = DEFAULT_SYMBOLS.iter().position(|sym| *sym == s) {
116 return Some(index as u64);
117 }
118
119 self.symbols
120 .iter()
121 .position(|sym| sym.as_str() == s)
122 .map(|i| (OFFSET + i) as SymbolIndex)
123 }
124
125 pub fn strings(&self) -> Vec<String> {
126 self.symbols.clone()
127 }
128
129 pub fn current_offset(&self) -> usize {
130 self.symbols.len()
131 }
132
133 pub fn split_at(&mut self, offset: usize) -> SymbolTable {
134 let mut table = SymbolTable::new();
135 table.symbols = self.symbols.split_off(offset);
136 table
137 }
138
139 pub fn is_disjoint(&self, other: &SymbolTable) -> bool {
140 let h1 = self.symbols.iter().collect::<HashSet<_>>();
141 let h2 = other.symbols.iter().collect::<HashSet<_>>();
142
143 h1.is_disjoint(&h2)
144 }
145
146 pub fn get_symbol(&self, i: SymbolIndex) -> Option<&str> {
147 if i >= OFFSET as u64 {
148 self.symbols
149 .get((i - OFFSET as u64) as usize)
150 .map(|s| s.as_str())
151 } else {
152 DEFAULT_SYMBOLS.get(i as usize).copied()
153 }
154 }
155
156 pub fn print_symbol(&self, i: SymbolIndex) -> Result<String, error::Format> {
157 self.get_symbol(i)
158 .map(|s| s.to_string())
159 .ok_or(error::Format::UnknownSymbol(i))
160 }
161
162 pub fn print_symbol_default(&self, i: SymbolIndex) -> String {
164 self.get_symbol(i)
165 .map(|s| s.to_string())
166 .unwrap_or_else(|| format!("<{}?>", i))
167 }
168
169 pub fn print_world(&self, w: &World) -> String {
170 let facts = w
171 .facts
172 .inner
173 .iter()
174 .flat_map(|facts| facts.1.iter())
175 .map(|f| self.print_fact(f))
176 .collect::<Vec<_>>();
177 let rules = w
178 .rules
179 .inner
180 .iter()
181 .flat_map(|rules| rules.1.iter())
182 .map(|(_, r)| self.print_rule(r))
183 .collect::<Vec<_>>();
184 format!("World {{\n facts: {:#?}\n rules: {:#?}\n}}", facts, rules)
185 }
186
187 pub fn print_term(&self, term: &Term) -> String {
188 match term {
189 Term::Variable(i) => format!("${}", self.print_symbol_default(*i as u64)),
190 Term::Integer(i) => i.to_string(),
191 Term::Str(index) => format!("\"{}\"", self.print_symbol_default(*index)),
192 Term::Date(d) => OffsetDateTime::from_unix_timestamp(*d as i64)
193 .ok()
194 .and_then(|t| t.format(&Rfc3339).ok())
195 .unwrap_or_else(|| "<invalid date>".to_string()),
196 Term::Bytes(s) => format!("hex:{}", hex::encode(s)),
197 Term::Bool(b) => {
198 if *b {
199 "true".to_string()
200 } else {
201 "false".to_string()
202 }
203 }
204 Term::Set(s) => {
205 if s.is_empty() {
206 "{,}".to_string()
207 } else {
208 let terms = s
209 .iter()
210 .map(|term| self.print_term(term))
211 .collect::<Vec<_>>();
212 format!("{{{}}}", terms.join(", "))
213 }
214 }
215 Term::Null => "null".to_string(),
216 Term::Array(a) => {
217 let terms = a
218 .iter()
219 .map(|term| self.print_term(term))
220 .collect::<Vec<_>>();
221 format!("[{}]", terms.join(", "))
222 }
223 Term::Map(m) => {
224 let terms = m
225 .iter()
226 .map(|(key, term)| match key {
227 crate::datalog::MapKey::Integer(i) => {
228 format!("{}: {}", i, self.print_term(term))
229 }
230 crate::datalog::MapKey::Str(s) => {
231 format!(
232 "\"{}\": {}",
233 self.print_symbol_default(*s as u64),
234 self.print_term(term)
235 )
236 }
237 })
238 .collect::<Vec<_>>();
239 format!("{{{}}}", terms.join(", "))
240 }
241 }
242 }
243
244 pub fn print_fact(&self, f: &Fact) -> String {
245 self.print_predicate(&f.predicate)
246 }
247
248 pub fn print_predicate(&self, p: &Predicate) -> String {
249 let strings = p
250 .terms
251 .iter()
252 .map(|term| self.print_term(term))
253 .collect::<Vec<_>>();
254 format!(
255 "{}({})",
256 self.get_symbol(p.name).unwrap_or("<?>"),
257 strings.join(", ")
258 )
259 }
260
261 pub fn print_expression(&self, e: &super::expression::Expression) -> String {
262 e.print(self)
263 .unwrap_or_else(|| format!("<invalid expression: {:?}>", e.ops))
264 }
265
266 pub fn print_rule_body(&self, r: &Rule) -> String {
267 let preds: Vec<_> = r.body.iter().map(|p| self.print_predicate(p)).collect();
268
269 let expressions: Vec<_> = r
270 .expressions
271 .iter()
272 .map(|c| self.print_expression(c))
273 .collect();
274
275 let e = if expressions.is_empty() {
276 String::new()
277 } else if preds.is_empty() {
278 expressions.join(", ")
279 } else {
280 format!(", {}", expressions.join(", "))
281 };
282
283 let scopes = if r.scopes.is_empty() {
284 String::new()
285 } else {
286 let s: Vec<_> = r
287 .scopes
288 .iter()
289 .map(|scope| match scope {
290 crate::token::Scope::Authority => "authority".to_string(),
291 crate::token::Scope::Previous => "previous".to_string(),
292 crate::token::Scope::PublicKey(key_id) => {
293 match self.public_keys.get_key(*key_id) {
294 Some(key) => key.print(),
295 None => "<unknown public key id>".to_string(),
296 }
297 }
298 })
299 .collect();
300 format!(" trusting {}", s.join(", "))
301 };
302
303 format!("{}{}{}", preds.join(", "), e, scopes)
304 }
305
306 pub fn print_rule(&self, r: &Rule) -> String {
307 let res = self.print_predicate(&r.head);
308
309 format!("{} <- {}", res, self.print_rule_body(r))
310 }
311
312 pub fn print_check(&self, c: &Check) -> String {
313 let queries = c
314 .queries
315 .iter()
316 .map(|r| self.print_rule_body(r))
317 .collect::<Vec<_>>();
318
319 format!(
320 "{} {}",
321 match c.kind {
322 crate::builder::CheckKind::One => "check if",
323 crate::builder::CheckKind::All => "check all",
324 crate::builder::CheckKind::Reject => "reject if",
325 },
326 queries.join(" or ")
327 )
328 }
329}
330
331impl Default for SymbolTable {
332 fn default() -> Self {
333 default_symbol_table()
334 }
335}
336
337#[derive(Clone, Debug, PartialEq, Eq)]
338pub struct TemporarySymbolTable<'a> {
339 base: &'a SymbolTable,
340 offset: usize,
341 symbols: Vec<String>,
342}
343
344impl<'a> TemporarySymbolTable<'a> {
345 pub fn new(base: &'a SymbolTable) -> Self {
346 let offset = OFFSET + base.current_offset();
347
348 TemporarySymbolTable {
349 base,
350 offset,
351 symbols: vec![],
352 }
353 }
354
355 pub fn get_symbol(&self, i: SymbolIndex) -> Option<&str> {
356 if i as usize >= self.offset {
357 self.symbols
358 .get(i as usize - self.offset)
359 .map(|s| s.as_str())
360 } else {
361 self.base.get_symbol(i)
362 }
363 }
364
365 pub fn insert(&mut self, s: &str) -> SymbolIndex {
366 if let Some(index) = self.base.get(s) {
367 return index;
368 }
369
370 match self.symbols.iter().position(|sym| sym.as_str() == s) {
371 Some(index) => (self.offset + index) as u64,
372 None => {
373 self.symbols.push(s.to_string());
374 (self.offset + (self.symbols.len() - 1)) as u64
375 }
376 }
377 }
378}