biscuit_auth/datalog/
symbol.rs

1/*
2 * Copyright (c) 2019 Geoffroy Couprie <contact@geoffroycouprie.com> and Contributors to the Eclipse Foundation.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5//! Symbol table implementation
6use std::collections::HashSet;
7use time::{format_description::well_known::Rfc3339, OffsetDateTime};
8
9pub type SymbolIndex = u64;
10use crate::crypto::PublicKey;
11use crate::token::default_symbol_table;
12use crate::{error, token::public_keys::PublicKeys};
13
14use super::{Check, Fact, Predicate, Rule, Term, World};
15
16#[derive(Clone, Debug, PartialEq, Eq)]
17pub struct SymbolTable {
18    symbols: Vec<String>,
19    pub(crate) public_keys: PublicKeys,
20}
21
22const DEFAULT_SYMBOLS: [&str; 28] = [
23    "read",
24    "write",
25    "resource",
26    "operation",
27    "right",
28    "time",
29    "role",
30    "owner",
31    "tenant",
32    "namespace",
33    "user",
34    "team",
35    "service",
36    "admin",
37    "email",
38    "group",
39    "member",
40    "ip_address",
41    "client",
42    "client_ip",
43    "domain",
44    "path",
45    "version",
46    "cluster",
47    "node",
48    "hostname",
49    "nonce",
50    "query",
51];
52
53const OFFSET: usize = 1024;
54
55impl SymbolTable {
56    pub fn new() -> Self {
57        SymbolTable {
58            symbols: vec![],
59            public_keys: PublicKeys::new(),
60        }
61    }
62
63    pub fn from(symbols: Vec<String>) -> Result<Self, error::Format> {
64        let h1 = DEFAULT_SYMBOLS.iter().copied().collect::<HashSet<_>>();
65        let h2 = symbols.iter().map(|s| s.as_str()).collect::<HashSet<_>>();
66
67        if !h1.is_disjoint(&h2) {
68            return Err(error::Format::SymbolTableOverlap);
69        }
70
71        Ok(SymbolTable {
72            symbols,
73            public_keys: PublicKeys::new(),
74        })
75    }
76
77    pub fn from_symbols_and_public_keys(
78        symbols: Vec<String>,
79        public_keys: Vec<PublicKey>,
80    ) -> Result<Self, error::Format> {
81        let mut table = Self::from(symbols)?;
82        table.public_keys = PublicKeys::from(public_keys);
83        Ok(table)
84    }
85
86    pub fn extend(&mut self, other: &SymbolTable) -> Result<(), error::Format> {
87        if !self.is_disjoint(other) {
88            return Err(error::Format::SymbolTableOverlap);
89        }
90        self.symbols.extend(other.symbols.iter().cloned());
91        self.public_keys.extend(&other.public_keys)?;
92        Ok(())
93    }
94
95    pub fn insert(&mut self, s: &str) -> SymbolIndex {
96        if let Some(index) = DEFAULT_SYMBOLS.iter().position(|sym| *sym == s) {
97            return index as u64;
98        }
99
100        match self.symbols.iter().position(|sym| sym.as_str() == s) {
101            Some(index) => (OFFSET + index) as u64,
102            None => {
103                self.symbols.push(s.to_string());
104                (OFFSET + (self.symbols.len() - 1)) as u64
105            }
106        }
107    }
108
109    pub fn add(&mut self, s: &str) -> Term {
110        let term = self.insert(s);
111        Term::Str(term)
112    }
113
114    pub fn get(&self, s: &str) -> Option<SymbolIndex> {
115        if let Some(index) = DEFAULT_SYMBOLS.iter().position(|sym| *sym == s) {
116            return Some(index as u64);
117        }
118
119        self.symbols
120            .iter()
121            .position(|sym| sym.as_str() == s)
122            .map(|i| (OFFSET + i) as SymbolIndex)
123    }
124
125    pub fn strings(&self) -> Vec<String> {
126        self.symbols.clone()
127    }
128
129    pub fn current_offset(&self) -> usize {
130        self.symbols.len()
131    }
132
133    pub fn split_at(&mut self, offset: usize) -> SymbolTable {
134        let mut table = SymbolTable::new();
135        table.symbols = self.symbols.split_off(offset);
136        table
137    }
138
139    pub fn is_disjoint(&self, other: &SymbolTable) -> bool {
140        let h1 = self.symbols.iter().collect::<HashSet<_>>();
141        let h2 = other.symbols.iter().collect::<HashSet<_>>();
142
143        h1.is_disjoint(&h2)
144    }
145
146    pub fn get_symbol(&self, i: SymbolIndex) -> Option<&str> {
147        if i >= OFFSET as u64 {
148            self.symbols
149                .get((i - OFFSET as u64) as usize)
150                .map(|s| s.as_str())
151        } else {
152            DEFAULT_SYMBOLS.get(i as usize).copied()
153        }
154    }
155
156    pub fn print_symbol(&self, i: SymbolIndex) -> Result<String, error::Format> {
157        self.get_symbol(i)
158            .map(|s| s.to_string())
159            .ok_or(error::Format::UnknownSymbol(i))
160    }
161
162    // infallible symbol printing method
163    pub fn print_symbol_default(&self, i: SymbolIndex) -> String {
164        self.get_symbol(i)
165            .map(|s| s.to_string())
166            .unwrap_or_else(|| format!("<{}?>", i))
167    }
168
169    pub fn print_world(&self, w: &World) -> String {
170        let facts = w
171            .facts
172            .inner
173            .iter()
174            .flat_map(|facts| facts.1.iter())
175            .map(|f| self.print_fact(f))
176            .collect::<Vec<_>>();
177        let rules = w
178            .rules
179            .inner
180            .iter()
181            .flat_map(|rules| rules.1.iter())
182            .map(|(_, r)| self.print_rule(r))
183            .collect::<Vec<_>>();
184        format!("World {{\n  facts: {:#?}\n  rules: {:#?}\n}}", facts, rules)
185    }
186
187    pub fn print_term(&self, term: &Term) -> String {
188        match term {
189            Term::Variable(i) => format!("${}", self.print_symbol_default(*i as u64)),
190            Term::Integer(i) => i.to_string(),
191            Term::Str(index) => format!("\"{}\"", self.print_symbol_default(*index)),
192            Term::Date(d) => OffsetDateTime::from_unix_timestamp(*d as i64)
193                .ok()
194                .and_then(|t| t.format(&Rfc3339).ok())
195                .unwrap_or_else(|| "<invalid date>".to_string()),
196            Term::Bytes(s) => format!("hex:{}", hex::encode(s)),
197            Term::Bool(b) => {
198                if *b {
199                    "true".to_string()
200                } else {
201                    "false".to_string()
202                }
203            }
204            Term::Set(s) => {
205                if s.is_empty() {
206                    "{,}".to_string()
207                } else {
208                    let terms = s
209                        .iter()
210                        .map(|term| self.print_term(term))
211                        .collect::<Vec<_>>();
212                    format!("{{{}}}", terms.join(", "))
213                }
214            }
215            Term::Null => "null".to_string(),
216            Term::Array(a) => {
217                let terms = a
218                    .iter()
219                    .map(|term| self.print_term(term))
220                    .collect::<Vec<_>>();
221                format!("[{}]", terms.join(", "))
222            }
223            Term::Map(m) => {
224                let terms = m
225                    .iter()
226                    .map(|(key, term)| match key {
227                        crate::datalog::MapKey::Integer(i) => {
228                            format!("{}: {}", i, self.print_term(term))
229                        }
230                        crate::datalog::MapKey::Str(s) => {
231                            format!(
232                                "\"{}\": {}",
233                                self.print_symbol_default(*s as u64),
234                                self.print_term(term)
235                            )
236                        }
237                    })
238                    .collect::<Vec<_>>();
239                format!("{{{}}}", terms.join(", "))
240            }
241        }
242    }
243
244    pub fn print_fact(&self, f: &Fact) -> String {
245        self.print_predicate(&f.predicate)
246    }
247
248    pub fn print_predicate(&self, p: &Predicate) -> String {
249        let strings = p
250            .terms
251            .iter()
252            .map(|term| self.print_term(term))
253            .collect::<Vec<_>>();
254        format!(
255            "{}({})",
256            self.get_symbol(p.name).unwrap_or("<?>"),
257            strings.join(", ")
258        )
259    }
260
261    pub fn print_expression(&self, e: &super::expression::Expression) -> String {
262        e.print(self)
263            .unwrap_or_else(|| format!("<invalid expression: {:?}>", e.ops))
264    }
265
266    pub fn print_rule_body(&self, r: &Rule) -> String {
267        let preds: Vec<_> = r.body.iter().map(|p| self.print_predicate(p)).collect();
268
269        let expressions: Vec<_> = r
270            .expressions
271            .iter()
272            .map(|c| self.print_expression(c))
273            .collect();
274
275        let e = if expressions.is_empty() {
276            String::new()
277        } else if preds.is_empty() {
278            expressions.join(", ")
279        } else {
280            format!(", {}", expressions.join(", "))
281        };
282
283        let scopes = if r.scopes.is_empty() {
284            String::new()
285        } else {
286            let s: Vec<_> = r
287                .scopes
288                .iter()
289                .map(|scope| match scope {
290                    crate::token::Scope::Authority => "authority".to_string(),
291                    crate::token::Scope::Previous => "previous".to_string(),
292                    crate::token::Scope::PublicKey(key_id) => {
293                        match self.public_keys.get_key(*key_id) {
294                            Some(key) => key.print(),
295                            None => "<unknown public key id>".to_string(),
296                        }
297                    }
298                })
299                .collect();
300            format!(" trusting {}", s.join(", "))
301        };
302
303        format!("{}{}{}", preds.join(", "), e, scopes)
304    }
305
306    pub fn print_rule(&self, r: &Rule) -> String {
307        let res = self.print_predicate(&r.head);
308
309        format!("{} <- {}", res, self.print_rule_body(r))
310    }
311
312    pub fn print_check(&self, c: &Check) -> String {
313        let queries = c
314            .queries
315            .iter()
316            .map(|r| self.print_rule_body(r))
317            .collect::<Vec<_>>();
318
319        format!(
320            "{} {}",
321            match c.kind {
322                crate::builder::CheckKind::One => "check if",
323                crate::builder::CheckKind::All => "check all",
324                crate::builder::CheckKind::Reject => "reject if",
325            },
326            queries.join(" or ")
327        )
328    }
329}
330
331impl Default for SymbolTable {
332    fn default() -> Self {
333        default_symbol_table()
334    }
335}
336
337#[derive(Clone, Debug, PartialEq, Eq)]
338pub struct TemporarySymbolTable<'a> {
339    base: &'a SymbolTable,
340    offset: usize,
341    symbols: Vec<String>,
342}
343
344impl<'a> TemporarySymbolTable<'a> {
345    pub fn new(base: &'a SymbolTable) -> Self {
346        let offset = OFFSET + base.current_offset();
347
348        TemporarySymbolTable {
349            base,
350            offset,
351            symbols: vec![],
352        }
353    }
354
355    pub fn get_symbol(&self, i: SymbolIndex) -> Option<&str> {
356        if i as usize >= self.offset {
357            self.symbols
358                .get(i as usize - self.offset)
359                .map(|s| s.as_str())
360        } else {
361            self.base.get_symbol(i)
362        }
363    }
364
365    pub fn insert(&mut self, s: &str) -> SymbolIndex {
366        if let Some(index) = self.base.get(s) {
367            return index;
368        }
369
370        match self.symbols.iter().position(|sym| sym.as_str() == s) {
371            Some(index) => (self.offset + index) as u64,
372            None => {
373                self.symbols.push(s.to_string());
374                (self.offset + (self.symbols.len() - 1)) as u64
375            }
376        }
377    }
378}