1use std::{collections::HashMap, fmt};
6
7use crate::{
8 datalog::{self, SymbolTable},
9 error,
10 token::default_symbol_table,
11};
12
13use super::{Convert, Term};
14
15#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
17pub enum Unary {
18 Negate,
19 Parens,
20 Length,
21 TypeOf,
22 Ffi(String),
23}
24
25#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
27pub enum Binary {
28 LessThan,
29 GreaterThan,
30 LessOrEqual,
31 GreaterOrEqual,
32 Equal,
33 Contains,
34 Prefix,
35 Suffix,
36 Regex,
37 Add,
38 Sub,
39 Mul,
40 Div,
41 And,
42 Or,
43 Intersection,
44 Union,
45 BitwiseAnd,
46 BitwiseOr,
47 BitwiseXor,
48 NotEqual,
49 HeterogeneousEqual,
50 HeterogeneousNotEqual,
51 LazyAnd,
52 LazyOr,
53 All,
54 Any,
55 Get,
56 Ffi(String),
57 TryOr,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct Expression {
63 pub ops: Vec<Op>,
64}
65impl Convert<datalog::Expression> for Expression {
68 fn convert(&self, symbols: &mut SymbolTable) -> datalog::Expression {
69 datalog::Expression {
70 ops: self.ops.iter().map(|op| op.convert(symbols)).collect(),
71 }
72 }
73
74 fn convert_from(e: &datalog::Expression, symbols: &SymbolTable) -> Result<Self, error::Format> {
75 Ok(Expression {
76 ops: e
77 .ops
78 .iter()
79 .map(|op| Op::convert_from(op, symbols))
80 .collect::<Result<Vec<_>, error::Format>>()?,
81 })
82 }
83}
84
85impl AsRef<Expression> for Expression {
86 fn as_ref(&self) -> &Expression {
87 self
88 }
89}
90
91impl fmt::Display for Expression {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 let mut syms = default_symbol_table();
94 let expr = self.convert(&mut syms);
95 let s = expr.print(&syms).unwrap();
96 write!(f, "{}", s)
97 }
98}
99
100impl From<biscuit_parser::builder::Expression> for Expression {
101 fn from(e: biscuit_parser::builder::Expression) -> Self {
102 Expression {
103 ops: e.ops.into_iter().map(|op| op.into()).collect(),
104 }
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
110pub enum Op {
111 Value(Term),
112 Unary(Unary),
113 Binary(Binary),
114 Closure(Vec<String>, Vec<Op>),
115}
116
117impl Op {
118 pub(super) fn collect_parameters(&self, parameters: &mut HashMap<String, Option<Term>>) {
119 match self {
120 Op::Value(Term::Parameter(ref name)) => {
121 parameters.insert(name.to_owned(), None);
122 }
123 Op::Closure(_, ops) => {
124 for op in ops {
125 op.collect_parameters(parameters);
126 }
127 }
128 _ => {}
129 }
130 }
131
132 pub(super) fn apply_parameters(self, parameters: &HashMap<String, Option<Term>>) -> Self {
133 match self {
134 Op::Value(Term::Parameter(ref name)) => {
135 if let Some(Some(t)) = parameters.get(name) {
136 Op::Value(t.clone())
137 } else {
138 self
139 }
140 }
141 Op::Value(_) => self,
142 Op::Unary(_) => self,
143 Op::Binary(_) => self,
144 Op::Closure(args, mut ops) => Op::Closure(
145 args,
146 ops.drain(..)
147 .map(|op| op.apply_parameters(parameters))
148 .collect(),
149 ),
150 }
151 }
152}
153
154impl Convert<datalog::Op> for Op {
155 fn convert(&self, symbols: &mut SymbolTable) -> datalog::Op {
156 match self {
157 Op::Value(t) => datalog::Op::Value(t.convert(symbols)),
158 Op::Unary(u) => datalog::Op::Unary(u.convert(symbols)),
159 Op::Binary(b) => datalog::Op::Binary(b.convert(symbols)),
160 Op::Closure(ps, os) => datalog::Op::Closure(
161 ps.iter().map(|p| symbols.insert(p) as u32).collect(),
162 os.iter().map(|o| o.convert(symbols)).collect(),
163 ),
164 }
165 }
166
167 fn convert_from(op: &datalog::Op, symbols: &SymbolTable) -> Result<Self, error::Format> {
168 Ok(match op {
169 datalog::Op::Value(t) => Op::Value(Term::convert_from(t, symbols)?),
170 datalog::Op::Unary(u) => Op::Unary(Unary::convert_from(u, symbols)?),
171 datalog::Op::Binary(b) => Op::Binary(Binary::convert_from(b, symbols)?),
172 datalog::Op::Closure(ps, os) => Op::Closure(
173 ps.iter()
174 .map(|p| symbols.print_symbol(*p as u64))
175 .collect::<Result<_, _>>()?,
176 os.iter()
177 .map(|o| Op::convert_from(o, symbols))
178 .collect::<Result<_, _>>()?,
179 ),
180 })
181 }
182}
183
184impl From<biscuit_parser::builder::Op> for Op {
185 fn from(op: biscuit_parser::builder::Op) -> Self {
186 match op {
187 biscuit_parser::builder::Op::Value(t) => Op::Value(t.into()),
188 biscuit_parser::builder::Op::Unary(u) => Op::Unary(u.into()),
189 biscuit_parser::builder::Op::Binary(b) => Op::Binary(b.into()),
190 biscuit_parser::builder::Op::Closure(ps, os) => {
191 Op::Closure(ps, os.into_iter().map(|o| o.into()).collect())
192 }
193 }
194 }
195}
196
197impl Convert<datalog::Unary> for Unary {
198 fn convert(&self, symbols: &mut SymbolTable) -> datalog::Unary {
199 match self {
200 Unary::Negate => datalog::Unary::Negate,
201 Unary::Parens => datalog::Unary::Parens,
202 Unary::Length => datalog::Unary::Length,
203 Unary::TypeOf => datalog::Unary::TypeOf,
204 Unary::Ffi(n) => datalog::Unary::Ffi(symbols.insert(n)),
205 }
206 }
207
208 fn convert_from(f: &datalog::Unary, symbols: &SymbolTable) -> Result<Self, error::Format> {
209 match f {
210 datalog::Unary::Negate => Ok(Unary::Negate),
211 datalog::Unary::Parens => Ok(Unary::Parens),
212 datalog::Unary::Length => Ok(Unary::Length),
213 datalog::Unary::TypeOf => Ok(Unary::TypeOf),
214 datalog::Unary::Ffi(i) => Ok(Unary::Ffi(symbols.print_symbol(*i)?)),
215 }
216 }
217}
218
219impl From<biscuit_parser::builder::Unary> for Unary {
220 fn from(unary: biscuit_parser::builder::Unary) -> Self {
221 match unary {
222 biscuit_parser::builder::Unary::Negate => Unary::Negate,
223 biscuit_parser::builder::Unary::Parens => Unary::Parens,
224 biscuit_parser::builder::Unary::Length => Unary::Length,
225 biscuit_parser::builder::Unary::TypeOf => Unary::TypeOf,
226 biscuit_parser::builder::Unary::Ffi(name) => Unary::Ffi(name),
227 }
228 }
229}
230
231impl Convert<datalog::Binary> for Binary {
232 fn convert(&self, symbols: &mut SymbolTable) -> datalog::Binary {
233 match self {
234 Binary::LessThan => datalog::Binary::LessThan,
235 Binary::GreaterThan => datalog::Binary::GreaterThan,
236 Binary::LessOrEqual => datalog::Binary::LessOrEqual,
237 Binary::GreaterOrEqual => datalog::Binary::GreaterOrEqual,
238 Binary::Equal => datalog::Binary::Equal,
239 Binary::Contains => datalog::Binary::Contains,
240 Binary::Prefix => datalog::Binary::Prefix,
241 Binary::Suffix => datalog::Binary::Suffix,
242 Binary::Regex => datalog::Binary::Regex,
243 Binary::Add => datalog::Binary::Add,
244 Binary::Sub => datalog::Binary::Sub,
245 Binary::Mul => datalog::Binary::Mul,
246 Binary::Div => datalog::Binary::Div,
247 Binary::And => datalog::Binary::And,
248 Binary::Or => datalog::Binary::Or,
249 Binary::Intersection => datalog::Binary::Intersection,
250 Binary::Union => datalog::Binary::Union,
251 Binary::BitwiseAnd => datalog::Binary::BitwiseAnd,
252 Binary::BitwiseOr => datalog::Binary::BitwiseOr,
253 Binary::BitwiseXor => datalog::Binary::BitwiseXor,
254 Binary::NotEqual => datalog::Binary::NotEqual,
255 Binary::HeterogeneousEqual => datalog::Binary::HeterogeneousEqual,
256 Binary::HeterogeneousNotEqual => datalog::Binary::HeterogeneousNotEqual,
257 Binary::LazyAnd => datalog::Binary::LazyAnd,
258 Binary::LazyOr => datalog::Binary::LazyOr,
259 Binary::All => datalog::Binary::All,
260 Binary::Any => datalog::Binary::Any,
261 Binary::Get => datalog::Binary::Get,
262 Binary::Ffi(n) => datalog::Binary::Ffi(symbols.insert(n)),
263 Binary::TryOr => datalog::Binary::TryOr,
264 }
265 }
266
267 fn convert_from(f: &datalog::Binary, symbols: &SymbolTable) -> Result<Self, error::Format> {
268 match f {
269 datalog::Binary::LessThan => Ok(Binary::LessThan),
270 datalog::Binary::GreaterThan => Ok(Binary::GreaterThan),
271 datalog::Binary::LessOrEqual => Ok(Binary::LessOrEqual),
272 datalog::Binary::GreaterOrEqual => Ok(Binary::GreaterOrEqual),
273 datalog::Binary::Equal => Ok(Binary::Equal),
274 datalog::Binary::Contains => Ok(Binary::Contains),
275 datalog::Binary::Prefix => Ok(Binary::Prefix),
276 datalog::Binary::Suffix => Ok(Binary::Suffix),
277 datalog::Binary::Regex => Ok(Binary::Regex),
278 datalog::Binary::Add => Ok(Binary::Add),
279 datalog::Binary::Sub => Ok(Binary::Sub),
280 datalog::Binary::Mul => Ok(Binary::Mul),
281 datalog::Binary::Div => Ok(Binary::Div),
282 datalog::Binary::And => Ok(Binary::And),
283 datalog::Binary::Or => Ok(Binary::Or),
284 datalog::Binary::Intersection => Ok(Binary::Intersection),
285 datalog::Binary::Union => Ok(Binary::Union),
286 datalog::Binary::BitwiseAnd => Ok(Binary::BitwiseAnd),
287 datalog::Binary::BitwiseOr => Ok(Binary::BitwiseOr),
288 datalog::Binary::BitwiseXor => Ok(Binary::BitwiseXor),
289 datalog::Binary::NotEqual => Ok(Binary::NotEqual),
290 datalog::Binary::HeterogeneousEqual => Ok(Binary::HeterogeneousEqual),
291 datalog::Binary::HeterogeneousNotEqual => Ok(Binary::HeterogeneousNotEqual),
292 datalog::Binary::LazyAnd => Ok(Binary::LazyAnd),
293 datalog::Binary::LazyOr => Ok(Binary::LazyOr),
294 datalog::Binary::All => Ok(Binary::All),
295 datalog::Binary::Any => Ok(Binary::Any),
296 datalog::Binary::Get => Ok(Binary::Get),
297 datalog::Binary::Ffi(i) => Ok(Binary::Ffi(symbols.print_symbol(*i)?)),
298 datalog::Binary::TryOr => Ok(Binary::TryOr),
299 }
300 }
301}
302
303impl From<biscuit_parser::builder::Binary> for Binary {
304 fn from(binary: biscuit_parser::builder::Binary) -> Self {
305 match binary {
306 biscuit_parser::builder::Binary::LessThan => Binary::LessThan,
307 biscuit_parser::builder::Binary::GreaterThan => Binary::GreaterThan,
308 biscuit_parser::builder::Binary::LessOrEqual => Binary::LessOrEqual,
309 biscuit_parser::builder::Binary::GreaterOrEqual => Binary::GreaterOrEqual,
310 biscuit_parser::builder::Binary::Equal => Binary::Equal,
311 biscuit_parser::builder::Binary::Contains => Binary::Contains,
312 biscuit_parser::builder::Binary::Prefix => Binary::Prefix,
313 biscuit_parser::builder::Binary::Suffix => Binary::Suffix,
314 biscuit_parser::builder::Binary::Regex => Binary::Regex,
315 biscuit_parser::builder::Binary::Add => Binary::Add,
316 biscuit_parser::builder::Binary::Sub => Binary::Sub,
317 biscuit_parser::builder::Binary::Mul => Binary::Mul,
318 biscuit_parser::builder::Binary::Div => Binary::Div,
319 biscuit_parser::builder::Binary::And => Binary::And,
320 biscuit_parser::builder::Binary::Or => Binary::Or,
321 biscuit_parser::builder::Binary::Intersection => Binary::Intersection,
322 biscuit_parser::builder::Binary::Union => Binary::Union,
323 biscuit_parser::builder::Binary::BitwiseAnd => Binary::BitwiseAnd,
324 biscuit_parser::builder::Binary::BitwiseOr => Binary::BitwiseOr,
325 biscuit_parser::builder::Binary::BitwiseXor => Binary::BitwiseXor,
326 biscuit_parser::builder::Binary::NotEqual => Binary::NotEqual,
327 biscuit_parser::builder::Binary::HeterogeneousEqual => Binary::HeterogeneousEqual,
328 biscuit_parser::builder::Binary::HeterogeneousNotEqual => Binary::HeterogeneousNotEqual,
329 biscuit_parser::builder::Binary::LazyAnd => Binary::LazyAnd,
330 biscuit_parser::builder::Binary::LazyOr => Binary::LazyOr,
331 biscuit_parser::builder::Binary::All => Binary::All,
332 biscuit_parser::builder::Binary::Any => Binary::Any,
333 biscuit_parser::builder::Binary::Get => Binary::Get,
334 biscuit_parser::builder::Binary::Ffi(name) => Binary::Ffi(name),
335 biscuit_parser::builder::Binary::TryOr => Binary::TryOr,
336 }
337 }
338}