biscuit_auth/token/builder/
term.rs

1/*
2 * Copyright (c) 2019 Geoffroy Couprie <contact@geoffroycouprie.com> and Contributors to the Eclipse Foundation.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5use std::{
6    collections::{BTreeMap, BTreeSet, HashMap},
7    convert::{TryFrom, TryInto},
8    fmt,
9    time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12use crate::{
13    datalog::{self, SymbolTable, TemporarySymbolTable},
14    error,
15};
16
17#[cfg(feature = "datalog-macro")]
18use super::AnyParam;
19use super::{set, Convert, Fact, ToAnyParam};
20
21/// Builder for a Datalog value
22#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
23pub enum Term {
24    Variable(String),
25    Integer(i64),
26    Str(String),
27    Date(u64),
28    Bytes(Vec<u8>),
29    Bool(bool),
30    Set(BTreeSet<Term>),
31    Parameter(String),
32    Null,
33    Array(Vec<Term>),
34    Map(BTreeMap<MapKey, Term>),
35}
36
37impl Term {
38    pub(super) fn extract_parameters(&self, parameters: &mut HashMap<String, Option<Term>>) {
39        match self {
40            Term::Parameter(name) => {
41                parameters.insert(name.to_string(), None);
42            }
43            Term::Set(s) => {
44                for term in s {
45                    term.extract_parameters(parameters);
46                }
47            }
48            Term::Array(a) => {
49                for term in a {
50                    term.extract_parameters(parameters);
51                }
52            }
53            Term::Map(m) => {
54                for (key, term) in m {
55                    if let MapKey::Parameter(name) = key {
56                        parameters.insert(name.to_string(), None);
57                    }
58                    term.extract_parameters(parameters);
59                }
60            }
61            _ => {}
62        }
63    }
64
65    pub(super) fn apply_parameters(self, parameters: &HashMap<String, Option<Term>>) -> Term {
66        match self {
67            Term::Parameter(name) => {
68                if let Some(Some(term)) = parameters.get(&name) {
69                    term.clone()
70                } else {
71                    Term::Parameter(name)
72                }
73            }
74            Term::Map(m) => Term::Map(
75                m.into_iter()
76                    .map(|(key, term)| {
77                        (
78                            match key {
79                                MapKey::Parameter(name) => {
80                                    if let Some(Some(key_term)) = parameters.get(&name) {
81                                        match key_term {
82                                            Term::Integer(i) => MapKey::Integer(*i),
83                                            Term::Str(s) => MapKey::Str(s.clone()),
84                                            //FIXME: we should return an error
85                                            _ => MapKey::Parameter(name),
86                                        }
87                                    } else {
88                                        MapKey::Parameter(name)
89                                    }
90                                }
91                                _ => key,
92                            },
93                            term.apply_parameters(parameters),
94                        )
95                    })
96                    .collect(),
97            ),
98            Term::Array(array) => Term::Array(
99                array
100                    .into_iter()
101                    .map(|term| term.apply_parameters(parameters))
102                    .collect(),
103            ),
104            Term::Set(set) => Term::Set(
105                set.into_iter()
106                    .map(|term| term.apply_parameters(parameters))
107                    .collect(),
108            ),
109            _ => self,
110        }
111    }
112}
113
114#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
115pub enum MapKey {
116    Integer(i64),
117    Str(String),
118    Parameter(String),
119}
120
121impl Term {
122    pub fn to_datalog(self, symbols: &mut TemporarySymbolTable) -> datalog::Term {
123        match self {
124            Term::Variable(s) => datalog::Term::Variable(symbols.insert(&s) as u32),
125            Term::Integer(i) => datalog::Term::Integer(i),
126            Term::Str(s) => datalog::Term::Str(symbols.insert(&s)),
127            Term::Date(d) => datalog::Term::Date(d),
128            Term::Bytes(s) => datalog::Term::Bytes(s),
129            Term::Bool(b) => datalog::Term::Bool(b),
130            Term::Set(s) => {
131                datalog::Term::Set(s.into_iter().map(|i| i.to_datalog(symbols)).collect())
132            }
133            Term::Null => datalog::Term::Null,
134            Term::Array(a) => {
135                datalog::Term::Array(a.into_iter().map(|i| i.to_datalog(symbols)).collect())
136            }
137            Term::Map(m) => datalog::Term::Map(
138                m.into_iter()
139                    .map(|(k, i)| {
140                        (
141                            match k {
142                                MapKey::Integer(i) => datalog::MapKey::Integer(i),
143                                MapKey::Str(s) => datalog::MapKey::Str(symbols.insert(&s)),
144                                // The error is caught in the `add_xxx` functions, so this should
145                                // not happen™
146                                MapKey::Parameter(s) => panic!("Remaining parameter {}", &s),
147                            },
148                            i.to_datalog(symbols),
149                        )
150                    })
151                    .collect(),
152            ),
153            // The error is caught in the `add_xxx` functions, so this should
154            // not happen™
155            Term::Parameter(s) => panic!("Remaining parameter {}", &s),
156        }
157    }
158
159    pub fn from_datalog(
160        term: datalog::Term,
161        symbols: &TemporarySymbolTable,
162    ) -> Result<Self, error::Expression> {
163        Ok(match term {
164            datalog::Term::Variable(s) => Term::Variable(
165                symbols
166                    .get_symbol(s as u64)
167                    .ok_or(error::Expression::UnknownVariable(s))?
168                    .to_string(),
169            ),
170            datalog::Term::Integer(i) => Term::Integer(i),
171            datalog::Term::Str(s) => Term::Str(
172                symbols
173                    .get_symbol(s)
174                    .ok_or(error::Expression::UnknownSymbol(s))?
175                    .to_string(),
176            ),
177            datalog::Term::Date(d) => Term::Date(d),
178            datalog::Term::Bytes(s) => Term::Bytes(s),
179            datalog::Term::Bool(b) => Term::Bool(b),
180            datalog::Term::Set(s) => Term::Set(
181                s.into_iter()
182                    .map(|i| Self::from_datalog(i, symbols))
183                    .collect::<Result<_, _>>()?,
184            ),
185            datalog::Term::Null => Term::Null,
186            datalog::Term::Array(a) => Term::Array(
187                a.into_iter()
188                    .map(|i| Self::from_datalog(i, symbols))
189                    .collect::<Result<_, _>>()?,
190            ),
191            datalog::Term::Map(m) => Term::Map(
192                m.into_iter()
193                    .map(|(k, i)| {
194                        Ok((
195                            match k {
196                                datalog::MapKey::Integer(i) => MapKey::Integer(i),
197                                datalog::MapKey::Str(s) => MapKey::Str(
198                                    symbols
199                                        .get_symbol(s)
200                                        .ok_or(error::Expression::UnknownSymbol(s))?
201                                        .to_string(),
202                                ),
203                            },
204                            Self::from_datalog(i, symbols)?,
205                        ))
206                    })
207                    .collect::<Result<_, _>>()?,
208            ),
209        })
210    }
211}
212
213impl Convert<datalog::Term> for Term {
214    fn convert(&self, symbols: &mut SymbolTable) -> datalog::Term {
215        match self {
216            Term::Variable(s) => datalog::Term::Variable(symbols.insert(s) as u32),
217            Term::Integer(i) => datalog::Term::Integer(*i),
218            Term::Str(s) => datalog::Term::Str(symbols.insert(s)),
219            Term::Date(d) => datalog::Term::Date(*d),
220            Term::Bytes(s) => datalog::Term::Bytes(s.clone()),
221            Term::Bool(b) => datalog::Term::Bool(*b),
222            Term::Set(s) => datalog::Term::Set(s.iter().map(|i| i.convert(symbols)).collect()),
223            Term::Null => datalog::Term::Null,
224            // The error is caught in the `add_xxx` functions, so this should
225            // not happen™
226            Term::Parameter(s) => panic!("Remaining parameter {}", &s),
227            Term::Array(a) => datalog::Term::Array(a.iter().map(|i| i.convert(symbols)).collect()),
228            Term::Map(m) => datalog::Term::Map(
229                m.iter()
230                    .map(|(key, term)| {
231                        let key = match key {
232                            MapKey::Integer(i) => datalog::MapKey::Integer(*i),
233                            MapKey::Str(s) => datalog::MapKey::Str(symbols.insert(s)),
234                            MapKey::Parameter(s) => panic!("Remaining parameter {}", &s),
235                        };
236
237                        (key, term.convert(symbols))
238                    })
239                    .collect(),
240            ),
241        }
242    }
243
244    fn convert_from(f: &datalog::Term, symbols: &SymbolTable) -> Result<Self, error::Format> {
245        Ok(match f {
246            datalog::Term::Variable(s) => Term::Variable(symbols.print_symbol(*s as u64)?),
247            datalog::Term::Integer(i) => Term::Integer(*i),
248            datalog::Term::Str(s) => Term::Str(symbols.print_symbol(*s)?),
249            datalog::Term::Date(d) => Term::Date(*d),
250            datalog::Term::Bytes(s) => Term::Bytes(s.clone()),
251            datalog::Term::Bool(b) => Term::Bool(*b),
252            datalog::Term::Set(s) => Term::Set(
253                s.iter()
254                    .map(|i| Term::convert_from(i, symbols))
255                    .collect::<Result<BTreeSet<_>, error::Format>>()?,
256            ),
257            datalog::Term::Null => Term::Null,
258            datalog::Term::Array(a) => Term::Array(
259                a.iter()
260                    .map(|i| Term::convert_from(i, symbols))
261                    .collect::<Result<Vec<_>, error::Format>>()?,
262            ),
263            datalog::Term::Map(m) => Term::Map(
264                m.iter()
265                    .map(|(key, term)| {
266                        let key = match key {
267                            datalog::MapKey::Integer(i) => Ok(MapKey::Integer(*i)),
268                            datalog::MapKey::Str(s) => symbols.print_symbol(*s).map(MapKey::Str),
269                        };
270
271                        key.and_then(|k| Term::convert_from(term, symbols).map(|term| (k, term)))
272                    })
273                    .collect::<Result<BTreeMap<_, _>, error::Format>>()?,
274            ),
275        })
276    }
277}
278
279impl From<&Term> for Term {
280    fn from(i: &Term) -> Self {
281        match i {
282            Term::Variable(ref v) => Term::Variable(v.clone()),
283            Term::Integer(ref i) => Term::Integer(*i),
284            Term::Str(ref s) => Term::Str(s.clone()),
285            Term::Date(ref d) => Term::Date(*d),
286            Term::Bytes(ref s) => Term::Bytes(s.clone()),
287            Term::Bool(b) => Term::Bool(*b),
288            Term::Set(ref s) => Term::Set(s.clone()),
289            Term::Parameter(ref p) => Term::Parameter(p.clone()),
290            Term::Null => Term::Null,
291            Term::Array(ref a) => Term::Array(a.clone()),
292            Term::Map(m) => Term::Map(m.clone()),
293        }
294    }
295}
296
297impl From<biscuit_parser::builder::Term> for Term {
298    fn from(t: biscuit_parser::builder::Term) -> Self {
299        match t {
300            biscuit_parser::builder::Term::Variable(v) => Term::Variable(v),
301            biscuit_parser::builder::Term::Integer(i) => Term::Integer(i),
302            biscuit_parser::builder::Term::Str(s) => Term::Str(s),
303            biscuit_parser::builder::Term::Date(d) => Term::Date(d),
304            biscuit_parser::builder::Term::Bytes(s) => Term::Bytes(s),
305            biscuit_parser::builder::Term::Bool(b) => Term::Bool(b),
306            biscuit_parser::builder::Term::Set(s) => {
307                Term::Set(s.into_iter().map(|t| t.into()).collect())
308            }
309            biscuit_parser::builder::Term::Null => Term::Null,
310            biscuit_parser::builder::Term::Parameter(ref p) => Term::Parameter(p.clone()),
311            biscuit_parser::builder::Term::Array(a) => {
312                Term::Array(a.into_iter().map(|t| t.into()).collect())
313            }
314            biscuit_parser::builder::Term::Map(a) => Term::Map(
315                a.into_iter()
316                    .map(|(key, term)| {
317                        (
318                            match key {
319                                biscuit_parser::builder::MapKey::Parameter(s) => {
320                                    MapKey::Parameter(s)
321                                }
322                                biscuit_parser::builder::MapKey::Integer(i) => MapKey::Integer(i),
323                                biscuit_parser::builder::MapKey::Str(s) => MapKey::Str(s),
324                            },
325                            term.into(),
326                        )
327                    })
328                    .collect(),
329            ),
330        }
331    }
332}
333
334impl AsRef<Term> for Term {
335    fn as_ref(&self) -> &Term {
336        self
337    }
338}
339
340impl fmt::Display for Term {
341    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342        match self {
343            Term::Variable(i) => write!(f, "${}", i),
344            Term::Integer(i) => write!(f, "{}", i),
345            Term::Str(s) => write!(f, "\"{}\"", s),
346            Term::Date(d) => {
347                let date = time::OffsetDateTime::from_unix_timestamp(*d as i64)
348                    .ok()
349                    .and_then(|t| {
350                        t.format(&time::format_description::well_known::Rfc3339)
351                            .ok()
352                    })
353                    .unwrap_or_else(|| "<invalid date>".to_string());
354
355                write!(f, "{}", date)
356            }
357            Term::Bytes(s) => write!(f, "hex:{}", hex::encode(s)),
358            Term::Bool(b) => {
359                if *b {
360                    write!(f, "true")
361                } else {
362                    write!(f, "false")
363                }
364            }
365            Term::Set(s) => {
366                if s.is_empty() {
367                    write!(f, "{{,}}")
368                } else {
369                    let terms = s.iter().map(|term| term.to_string()).collect::<Vec<_>>();
370                    write!(f, "{{{}}}", terms.join(", "))
371                }
372            }
373            Term::Parameter(s) => {
374                write!(f, "{{{}}}", s)
375            }
376            Term::Null => write!(f, "null"),
377            Term::Array(a) => {
378                let terms = a.iter().map(|term| term.to_string()).collect::<Vec<_>>();
379                write!(f, "[{}]", terms.join(", "))
380            }
381            Term::Map(m) => {
382                let terms = m
383                    .iter()
384                    .map(|(key, term)| match key {
385                        MapKey::Integer(i) => format!("{i}: {}", term),
386                        MapKey::Str(s) => format!("\"{s}\": {}", term),
387                        MapKey::Parameter(s) => format!("{{{s}}}: {}", term),
388                    })
389                    .collect::<Vec<_>>();
390                write!(f, "{{{}}}", terms.join(", "))
391            }
392        }
393    }
394}
395
396#[cfg(feature = "datalog-macro")]
397impl ToAnyParam for Term {
398    fn to_any_param(&self) -> AnyParam {
399        AnyParam::Term(self.clone())
400    }
401}
402
403impl From<i64> for Term {
404    fn from(i: i64) -> Self {
405        Term::Integer(i)
406    }
407}
408
409#[cfg(feature = "datalog-macro")]
410impl ToAnyParam for i64 {
411    fn to_any_param(&self) -> AnyParam {
412        AnyParam::Term((*self).into())
413    }
414}
415
416impl TryFrom<Term> for i64 {
417    type Error = error::Token;
418    fn try_from(value: Term) -> Result<Self, Self::Error> {
419        match value {
420            Term::Integer(i) => Ok(i),
421            _ => Err(error::Token::ConversionError(format!(
422                "expected integer, got {:?}",
423                value
424            ))),
425        }
426    }
427}
428
429impl From<bool> for Term {
430    fn from(b: bool) -> Self {
431        Term::Bool(b)
432    }
433}
434
435#[cfg(feature = "datalog-macro")]
436impl ToAnyParam for bool {
437    fn to_any_param(&self) -> AnyParam {
438        AnyParam::Term((*self).into())
439    }
440}
441
442impl TryFrom<Term> for bool {
443    type Error = error::Token;
444    fn try_from(value: Term) -> Result<Self, Self::Error> {
445        match value {
446            Term::Bool(b) => Ok(b),
447            _ => Err(error::Token::ConversionError(format!(
448                "expected boolean, got {:?}",
449                value
450            ))),
451        }
452    }
453}
454
455impl From<String> for Term {
456    fn from(s: String) -> Self {
457        Term::Str(s)
458    }
459}
460
461#[cfg(feature = "datalog-macro")]
462impl ToAnyParam for String {
463    fn to_any_param(&self) -> AnyParam {
464        AnyParam::Term((self.clone()).into())
465    }
466}
467
468impl From<&str> for Term {
469    fn from(s: &str) -> Self {
470        Term::Str(s.into())
471    }
472}
473
474#[cfg(feature = "datalog-macro")]
475impl ToAnyParam for &str {
476    fn to_any_param(&self) -> AnyParam {
477        AnyParam::Term(self.to_string().into())
478    }
479}
480
481impl TryFrom<Term> for String {
482    type Error = error::Token;
483    fn try_from(value: Term) -> Result<Self, Self::Error> {
484        match value {
485            Term::Str(s) => Ok(s),
486            _ => Err(error::Token::ConversionError(format!(
487                "expected string or symbol, got {:?}",
488                value
489            ))),
490        }
491    }
492}
493
494impl From<Vec<u8>> for Term {
495    fn from(v: Vec<u8>) -> Self {
496        Term::Bytes(v)
497    }
498}
499
500#[cfg(feature = "datalog-macro")]
501impl ToAnyParam for Vec<u8> {
502    fn to_any_param(&self) -> AnyParam {
503        AnyParam::Term((self.clone()).into())
504    }
505}
506
507impl TryFrom<Term> for Vec<u8> {
508    type Error = error::Token;
509    fn try_from(value: Term) -> Result<Self, Self::Error> {
510        match value {
511            Term::Bytes(b) => Ok(b),
512            _ => Err(error::Token::ConversionError(format!(
513                "expected byte array, got {:?}",
514                value
515            ))),
516        }
517    }
518}
519
520impl From<&[u8]> for Term {
521    fn from(v: &[u8]) -> Self {
522        Term::Bytes(v.into())
523    }
524}
525
526#[cfg(feature = "datalog-macro")]
527impl ToAnyParam for [u8] {
528    fn to_any_param(&self) -> AnyParam {
529        AnyParam::Term(self.into())
530    }
531}
532
533#[cfg(feature = "uuid")]
534impl ToAnyParam for uuid::Uuid {
535    fn to_any_param(&self) -> AnyParam {
536        AnyParam::Term(Term::Bytes(self.as_bytes().to_vec()))
537    }
538}
539
540impl From<SystemTime> for Term {
541    fn from(t: SystemTime) -> Self {
542        let dur = t.duration_since(UNIX_EPOCH).unwrap();
543        Term::Date(dur.as_secs())
544    }
545}
546
547#[cfg(feature = "datalog-macro")]
548impl ToAnyParam for SystemTime {
549    fn to_any_param(&self) -> AnyParam {
550        AnyParam::Term((*self).into())
551    }
552}
553
554impl TryFrom<Term> for SystemTime {
555    type Error = error::Token;
556    fn try_from(value: Term) -> Result<Self, Self::Error> {
557        match value {
558            Term::Date(d) => Ok(UNIX_EPOCH + Duration::from_secs(d)),
559            _ => Err(error::Token::ConversionError(format!(
560                "expected date, got {:?}",
561                value
562            ))),
563        }
564    }
565}
566
567impl From<BTreeSet<Term>> for Term {
568    fn from(value: BTreeSet<Term>) -> Term {
569        set(value)
570    }
571}
572
573#[cfg(feature = "datalog-macro")]
574impl ToAnyParam for BTreeSet<Term> {
575    fn to_any_param(&self) -> AnyParam {
576        AnyParam::Term((self.clone()).into())
577    }
578}
579
580impl<T: Ord + TryFrom<Term, Error = error::Token>> TryFrom<Term> for BTreeSet<T> {
581    type Error = error::Token;
582    fn try_from(value: Term) -> Result<Self, Self::Error> {
583        match value {
584            Term::Set(d) => d.iter().cloned().map(TryFrom::try_from).collect(),
585            _ => Err(error::Token::ConversionError(format!(
586                "expected set, got {:?}",
587                value
588            ))),
589        }
590    }
591}
592
593// TODO: From and ToAnyParam for arrays and maps
594impl TryFrom<serde_json::Value> for Term {
595    type Error = &'static str;
596
597    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
598        match value {
599            serde_json::Value::Null => Ok(Term::Null),
600            serde_json::Value::Bool(b) => Ok(Term::Bool(b)),
601            serde_json::Value::Number(i) => match i.as_i64() {
602                Some(i) => Ok(Term::Integer(i)),
603                None => Err("Biscuit values do not support floating point numbers"),
604            },
605            serde_json::Value::String(s) => Ok(Term::Str(s)),
606            serde_json::Value::Array(array) => Ok(Term::Array(
607                array
608                    .into_iter()
609                    .map(|v| v.try_into())
610                    .collect::<Result<_, _>>()?,
611            )),
612            serde_json::Value::Object(o) => Ok(Term::Map(
613                o.into_iter()
614                    .map(|(key, value)| {
615                        let value: Term = value.try_into()?;
616                        Ok::<_, &'static str>((MapKey::Str(key), value))
617                    })
618                    .collect::<Result<_, _>>()?,
619            )),
620        }
621    }
622}
623
624macro_rules! tuple_try_from(
625    ($ty1:ident, $ty2:ident, $($ty:ident),*) => (
626        tuple_try_from!(__impl $ty1, $ty2; $($ty),*);
627        );
628    (__impl $($ty: ident),+; $ty1:ident, $($ty2:ident),*) => (
629        tuple_try_from_impl!($($ty),+);
630        tuple_try_from!(__impl $($ty),+ , $ty1; $($ty2),*);
631        );
632    (__impl $($ty: ident),+; $ty1:ident) => (
633        tuple_try_from_impl!($($ty),+);
634        tuple_try_from_impl!($($ty),+, $ty1);
635        );
636    );
637
638impl<A: TryFrom<Term, Error = error::Token>> TryFrom<Fact> for (A,) {
639    type Error = error::Token;
640    fn try_from(fact: Fact) -> Result<Self, Self::Error> {
641        let mut terms = fact.predicate.terms;
642        let mut it = terms.drain(..);
643
644        Ok((it
645            .next()
646            .ok_or_else(|| error::Token::ConversionError("not enough terms in fact".to_string()))
647            .and_then(A::try_from)?,))
648    }
649}
650
651macro_rules! tuple_try_from_impl(
652    ($($ty: ident),+) => (
653        impl<$($ty: TryFrom<Term, Error = error::Token>),+> TryFrom<Fact> for ($($ty),+) {
654            type Error = error::Token;
655            fn try_from(fact: Fact) -> Result<Self, Self::Error> {
656                let mut terms = fact.predicate.terms;
657                let mut it = terms.drain(..);
658
659                Ok((
660                        $(
661                            it.next().ok_or(error::Token::ConversionError("not enough terms in fact".to_string())).and_then($ty::try_from)?
662                         ),+
663                   ))
664
665            }
666        }
667        );
668    );
669
670tuple_try_from!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U);