biscuit_auth/token/builder/
expression.rs

1/*
2 * Copyright (c) 2019 Geoffroy Couprie <contact@geoffroycouprie.com> and Contributors to the Eclipse Foundation.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5use std::{collections::HashMap, fmt};
6
7use crate::{
8    datalog::{self, SymbolTable},
9    error,
10    token::default_symbol_table,
11};
12
13use super::{Convert, Term};
14
15/// Builder for a unary operation
16#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
17pub enum Unary {
18    Negate,
19    Parens,
20    Length,
21    TypeOf,
22    Ffi(String),
23}
24
25/// Builder for a binary operation
26#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
27pub enum Binary {
28    LessThan,
29    GreaterThan,
30    LessOrEqual,
31    GreaterOrEqual,
32    Equal,
33    Contains,
34    Prefix,
35    Suffix,
36    Regex,
37    Add,
38    Sub,
39    Mul,
40    Div,
41    And,
42    Or,
43    Intersection,
44    Union,
45    BitwiseAnd,
46    BitwiseOr,
47    BitwiseXor,
48    NotEqual,
49    HeterogeneousEqual,
50    HeterogeneousNotEqual,
51    LazyAnd,
52    LazyOr,
53    All,
54    Any,
55    Get,
56    Ffi(String),
57    TryOr,
58}
59
60/// Builder for a Datalog expression
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct Expression {
63    pub ops: Vec<Op>,
64}
65// todo track parameters
66
67impl Convert<datalog::Expression> for Expression {
68    fn convert(&self, symbols: &mut SymbolTable) -> datalog::Expression {
69        datalog::Expression {
70            ops: self.ops.iter().map(|op| op.convert(symbols)).collect(),
71        }
72    }
73
74    fn convert_from(e: &datalog::Expression, symbols: &SymbolTable) -> Result<Self, error::Format> {
75        Ok(Expression {
76            ops: e
77                .ops
78                .iter()
79                .map(|op| Op::convert_from(op, symbols))
80                .collect::<Result<Vec<_>, error::Format>>()?,
81        })
82    }
83}
84
85impl AsRef<Expression> for Expression {
86    fn as_ref(&self) -> &Expression {
87        self
88    }
89}
90
91impl fmt::Display for Expression {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        let mut syms = default_symbol_table();
94        let expr = self.convert(&mut syms);
95        let s = expr.print(&syms).unwrap();
96        write!(f, "{}", s)
97    }
98}
99
100impl From<biscuit_parser::builder::Expression> for Expression {
101    fn from(e: biscuit_parser::builder::Expression) -> Self {
102        Expression {
103            ops: e.ops.into_iter().map(|op| op.into()).collect(),
104        }
105    }
106}
107
108/// Builder for an expression operation
109#[derive(Debug, Clone, PartialEq, Eq)]
110pub enum Op {
111    Value(Term),
112    Unary(Unary),
113    Binary(Binary),
114    Closure(Vec<String>, Vec<Op>),
115}
116
117impl Op {
118    pub(super) fn collect_parameters(&self, parameters: &mut HashMap<String, Option<Term>>) {
119        match self {
120            Op::Value(Term::Parameter(ref name)) => {
121                parameters.insert(name.to_owned(), None);
122            }
123            Op::Closure(_, ops) => {
124                for op in ops {
125                    op.collect_parameters(parameters);
126                }
127            }
128            _ => {}
129        }
130    }
131
132    pub(super) fn apply_parameters(self, parameters: &HashMap<String, Option<Term>>) -> Self {
133        match self {
134            Op::Value(Term::Parameter(ref name)) => {
135                if let Some(Some(t)) = parameters.get(name) {
136                    Op::Value(t.clone())
137                } else {
138                    self
139                }
140            }
141            Op::Value(_) => self,
142            Op::Unary(_) => self,
143            Op::Binary(_) => self,
144            Op::Closure(args, mut ops) => Op::Closure(
145                args,
146                ops.drain(..)
147                    .map(|op| op.apply_parameters(parameters))
148                    .collect(),
149            ),
150        }
151    }
152}
153
154impl Convert<datalog::Op> for Op {
155    fn convert(&self, symbols: &mut SymbolTable) -> datalog::Op {
156        match self {
157            Op::Value(t) => datalog::Op::Value(t.convert(symbols)),
158            Op::Unary(u) => datalog::Op::Unary(u.convert(symbols)),
159            Op::Binary(b) => datalog::Op::Binary(b.convert(symbols)),
160            Op::Closure(ps, os) => datalog::Op::Closure(
161                ps.iter().map(|p| symbols.insert(p) as u32).collect(),
162                os.iter().map(|o| o.convert(symbols)).collect(),
163            ),
164        }
165    }
166
167    fn convert_from(op: &datalog::Op, symbols: &SymbolTable) -> Result<Self, error::Format> {
168        Ok(match op {
169            datalog::Op::Value(t) => Op::Value(Term::convert_from(t, symbols)?),
170            datalog::Op::Unary(u) => Op::Unary(Unary::convert_from(u, symbols)?),
171            datalog::Op::Binary(b) => Op::Binary(Binary::convert_from(b, symbols)?),
172            datalog::Op::Closure(ps, os) => Op::Closure(
173                ps.iter()
174                    .map(|p| symbols.print_symbol(*p as u64))
175                    .collect::<Result<_, _>>()?,
176                os.iter()
177                    .map(|o| Op::convert_from(o, symbols))
178                    .collect::<Result<_, _>>()?,
179            ),
180        })
181    }
182}
183
184impl From<biscuit_parser::builder::Op> for Op {
185    fn from(op: biscuit_parser::builder::Op) -> Self {
186        match op {
187            biscuit_parser::builder::Op::Value(t) => Op::Value(t.into()),
188            biscuit_parser::builder::Op::Unary(u) => Op::Unary(u.into()),
189            biscuit_parser::builder::Op::Binary(b) => Op::Binary(b.into()),
190            biscuit_parser::builder::Op::Closure(ps, os) => {
191                Op::Closure(ps, os.into_iter().map(|o| o.into()).collect())
192            }
193        }
194    }
195}
196
197impl Convert<datalog::Unary> for Unary {
198    fn convert(&self, symbols: &mut SymbolTable) -> datalog::Unary {
199        match self {
200            Unary::Negate => datalog::Unary::Negate,
201            Unary::Parens => datalog::Unary::Parens,
202            Unary::Length => datalog::Unary::Length,
203            Unary::TypeOf => datalog::Unary::TypeOf,
204            Unary::Ffi(n) => datalog::Unary::Ffi(symbols.insert(n)),
205        }
206    }
207
208    fn convert_from(f: &datalog::Unary, symbols: &SymbolTable) -> Result<Self, error::Format> {
209        match f {
210            datalog::Unary::Negate => Ok(Unary::Negate),
211            datalog::Unary::Parens => Ok(Unary::Parens),
212            datalog::Unary::Length => Ok(Unary::Length),
213            datalog::Unary::TypeOf => Ok(Unary::TypeOf),
214            datalog::Unary::Ffi(i) => Ok(Unary::Ffi(symbols.print_symbol(*i)?)),
215        }
216    }
217}
218
219impl From<biscuit_parser::builder::Unary> for Unary {
220    fn from(unary: biscuit_parser::builder::Unary) -> Self {
221        match unary {
222            biscuit_parser::builder::Unary::Negate => Unary::Negate,
223            biscuit_parser::builder::Unary::Parens => Unary::Parens,
224            biscuit_parser::builder::Unary::Length => Unary::Length,
225            biscuit_parser::builder::Unary::TypeOf => Unary::TypeOf,
226            biscuit_parser::builder::Unary::Ffi(name) => Unary::Ffi(name),
227        }
228    }
229}
230
231impl Convert<datalog::Binary> for Binary {
232    fn convert(&self, symbols: &mut SymbolTable) -> datalog::Binary {
233        match self {
234            Binary::LessThan => datalog::Binary::LessThan,
235            Binary::GreaterThan => datalog::Binary::GreaterThan,
236            Binary::LessOrEqual => datalog::Binary::LessOrEqual,
237            Binary::GreaterOrEqual => datalog::Binary::GreaterOrEqual,
238            Binary::Equal => datalog::Binary::Equal,
239            Binary::Contains => datalog::Binary::Contains,
240            Binary::Prefix => datalog::Binary::Prefix,
241            Binary::Suffix => datalog::Binary::Suffix,
242            Binary::Regex => datalog::Binary::Regex,
243            Binary::Add => datalog::Binary::Add,
244            Binary::Sub => datalog::Binary::Sub,
245            Binary::Mul => datalog::Binary::Mul,
246            Binary::Div => datalog::Binary::Div,
247            Binary::And => datalog::Binary::And,
248            Binary::Or => datalog::Binary::Or,
249            Binary::Intersection => datalog::Binary::Intersection,
250            Binary::Union => datalog::Binary::Union,
251            Binary::BitwiseAnd => datalog::Binary::BitwiseAnd,
252            Binary::BitwiseOr => datalog::Binary::BitwiseOr,
253            Binary::BitwiseXor => datalog::Binary::BitwiseXor,
254            Binary::NotEqual => datalog::Binary::NotEqual,
255            Binary::HeterogeneousEqual => datalog::Binary::HeterogeneousEqual,
256            Binary::HeterogeneousNotEqual => datalog::Binary::HeterogeneousNotEqual,
257            Binary::LazyAnd => datalog::Binary::LazyAnd,
258            Binary::LazyOr => datalog::Binary::LazyOr,
259            Binary::All => datalog::Binary::All,
260            Binary::Any => datalog::Binary::Any,
261            Binary::Get => datalog::Binary::Get,
262            Binary::Ffi(n) => datalog::Binary::Ffi(symbols.insert(n)),
263            Binary::TryOr => datalog::Binary::TryOr,
264        }
265    }
266
267    fn convert_from(f: &datalog::Binary, symbols: &SymbolTable) -> Result<Self, error::Format> {
268        match f {
269            datalog::Binary::LessThan => Ok(Binary::LessThan),
270            datalog::Binary::GreaterThan => Ok(Binary::GreaterThan),
271            datalog::Binary::LessOrEqual => Ok(Binary::LessOrEqual),
272            datalog::Binary::GreaterOrEqual => Ok(Binary::GreaterOrEqual),
273            datalog::Binary::Equal => Ok(Binary::Equal),
274            datalog::Binary::Contains => Ok(Binary::Contains),
275            datalog::Binary::Prefix => Ok(Binary::Prefix),
276            datalog::Binary::Suffix => Ok(Binary::Suffix),
277            datalog::Binary::Regex => Ok(Binary::Regex),
278            datalog::Binary::Add => Ok(Binary::Add),
279            datalog::Binary::Sub => Ok(Binary::Sub),
280            datalog::Binary::Mul => Ok(Binary::Mul),
281            datalog::Binary::Div => Ok(Binary::Div),
282            datalog::Binary::And => Ok(Binary::And),
283            datalog::Binary::Or => Ok(Binary::Or),
284            datalog::Binary::Intersection => Ok(Binary::Intersection),
285            datalog::Binary::Union => Ok(Binary::Union),
286            datalog::Binary::BitwiseAnd => Ok(Binary::BitwiseAnd),
287            datalog::Binary::BitwiseOr => Ok(Binary::BitwiseOr),
288            datalog::Binary::BitwiseXor => Ok(Binary::BitwiseXor),
289            datalog::Binary::NotEqual => Ok(Binary::NotEqual),
290            datalog::Binary::HeterogeneousEqual => Ok(Binary::HeterogeneousEqual),
291            datalog::Binary::HeterogeneousNotEqual => Ok(Binary::HeterogeneousNotEqual),
292            datalog::Binary::LazyAnd => Ok(Binary::LazyAnd),
293            datalog::Binary::LazyOr => Ok(Binary::LazyOr),
294            datalog::Binary::All => Ok(Binary::All),
295            datalog::Binary::Any => Ok(Binary::Any),
296            datalog::Binary::Get => Ok(Binary::Get),
297            datalog::Binary::Ffi(i) => Ok(Binary::Ffi(symbols.print_symbol(*i)?)),
298            datalog::Binary::TryOr => Ok(Binary::TryOr),
299        }
300    }
301}
302
303impl From<biscuit_parser::builder::Binary> for Binary {
304    fn from(binary: biscuit_parser::builder::Binary) -> Self {
305        match binary {
306            biscuit_parser::builder::Binary::LessThan => Binary::LessThan,
307            biscuit_parser::builder::Binary::GreaterThan => Binary::GreaterThan,
308            biscuit_parser::builder::Binary::LessOrEqual => Binary::LessOrEqual,
309            biscuit_parser::builder::Binary::GreaterOrEqual => Binary::GreaterOrEqual,
310            biscuit_parser::builder::Binary::Equal => Binary::Equal,
311            biscuit_parser::builder::Binary::Contains => Binary::Contains,
312            biscuit_parser::builder::Binary::Prefix => Binary::Prefix,
313            biscuit_parser::builder::Binary::Suffix => Binary::Suffix,
314            biscuit_parser::builder::Binary::Regex => Binary::Regex,
315            biscuit_parser::builder::Binary::Add => Binary::Add,
316            biscuit_parser::builder::Binary::Sub => Binary::Sub,
317            biscuit_parser::builder::Binary::Mul => Binary::Mul,
318            biscuit_parser::builder::Binary::Div => Binary::Div,
319            biscuit_parser::builder::Binary::And => Binary::And,
320            biscuit_parser::builder::Binary::Or => Binary::Or,
321            biscuit_parser::builder::Binary::Intersection => Binary::Intersection,
322            biscuit_parser::builder::Binary::Union => Binary::Union,
323            biscuit_parser::builder::Binary::BitwiseAnd => Binary::BitwiseAnd,
324            biscuit_parser::builder::Binary::BitwiseOr => Binary::BitwiseOr,
325            biscuit_parser::builder::Binary::BitwiseXor => Binary::BitwiseXor,
326            biscuit_parser::builder::Binary::NotEqual => Binary::NotEqual,
327            biscuit_parser::builder::Binary::HeterogeneousEqual => Binary::HeterogeneousEqual,
328            biscuit_parser::builder::Binary::HeterogeneousNotEqual => Binary::HeterogeneousNotEqual,
329            biscuit_parser::builder::Binary::LazyAnd => Binary::LazyAnd,
330            biscuit_parser::builder::Binary::LazyOr => Binary::LazyOr,
331            biscuit_parser::builder::Binary::All => Binary::All,
332            biscuit_parser::builder::Binary::Any => Binary::Any,
333            biscuit_parser::builder::Binary::Get => Binary::Get,
334            biscuit_parser::builder::Binary::Ffi(name) => Binary::Ffi(name),
335            biscuit_parser::builder::Binary::TryOr => Binary::TryOr,
336        }
337    }
338}