biscuit_auth/datalog/
mod.rs

1/*
2 * Copyright (c) 2019 Geoffroy Couprie <contact@geoffroycouprie.com> and Contributors to the Eclipse Foundation.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5//! Logic language implementation for checks
6use crate::builder::{CheckKind, Convert};
7use crate::error::Execution;
8use crate::time::Instant;
9use crate::token::{Scope, DATALOG_3_1, DATALOG_3_3, MIN_SCHEMA_VERSION};
10use crate::{builder, error};
11use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
12use std::convert::AsRef;
13use std::fmt;
14use std::time::{Duration, SystemTime, UNIX_EPOCH};
15
16mod expression;
17mod origin;
18mod symbol;
19pub use expression::*;
20pub use origin::*;
21pub use symbol::*;
22
23#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
24pub enum Term {
25    Variable(u32),
26    Integer(i64),
27    Str(SymbolIndex),
28    Date(u64),
29    Bytes(Vec<u8>),
30    Bool(bool),
31    Set(BTreeSet<Term>),
32    Null,
33    Array(Vec<Term>),
34    Map(BTreeMap<MapKey, Term>),
35}
36
37#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
38pub enum MapKey {
39    Integer(i64),
40    Str(SymbolIndex),
41}
42
43impl From<&Term> for Term {
44    fn from(i: &Term) -> Self {
45        match i {
46            Term::Variable(ref v) => Term::Variable(*v),
47            Term::Integer(ref i) => Term::Integer(*i),
48            Term::Str(ref s) => Term::Str(*s),
49            Term::Date(ref d) => Term::Date(*d),
50            Term::Bytes(ref b) => Term::Bytes(b.clone()),
51            Term::Bool(ref b) => Term::Bool(*b),
52            Term::Set(ref s) => Term::Set(s.clone()),
53            Term::Null => Term::Null,
54            Term::Array(ref a) => Term::Array(a.clone()),
55            Term::Map(m) => Term::Map(m.clone()),
56        }
57    }
58}
59
60impl AsRef<Term> for Term {
61    fn as_ref(&self) -> &Term {
62        self
63    }
64}
65
66#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
67pub struct Predicate {
68    pub name: SymbolIndex,
69    pub terms: Vec<Term>,
70}
71
72impl Predicate {
73    pub fn new(name: SymbolIndex, terms: &[Term]) -> Predicate {
74        Predicate {
75            name,
76            terms: terms.to_vec(),
77        }
78    }
79}
80
81impl AsRef<Predicate> for Predicate {
82    fn as_ref(&self) -> &Predicate {
83        self
84    }
85}
86
87#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
88pub struct Fact {
89    pub predicate: Predicate,
90}
91
92impl Fact {
93    pub fn new(name: SymbolIndex, terms: &[Term]) -> Fact {
94        Fact {
95            predicate: Predicate::new(name, terms),
96        }
97    }
98}
99
100#[derive(Debug, Clone, Hash, PartialEq, Eq)]
101pub struct Rule {
102    pub head: Predicate,
103    pub body: Vec<Predicate>,
104    pub expressions: Vec<Expression>,
105    pub scopes: Vec<Scope>,
106}
107
108impl AsRef<Expression> for Expression {
109    fn as_ref(&self) -> &Expression {
110        self
111    }
112}
113
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub struct Check {
116    pub queries: Vec<Rule>,
117    pub kind: CheckKind,
118}
119
120impl fmt::Display for Fact {
121    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
122        write!(f, "{}({:?})", self.predicate.name, self.predicate.terms)
123    }
124}
125
126impl Rule {
127    /// gather all of the variables used in that rule
128    fn variables_set(&self) -> HashSet<u32> {
129        self.body
130            .iter()
131            .flat_map(|pred| {
132                pred.terms.iter().filter_map(|id| match id {
133                    Term::Variable(i) => Some(*i),
134                    _ => None,
135                })
136            })
137            .collect::<HashSet<_>>()
138    }
139
140    pub fn apply<'a, IT>(
141        &'a self,
142        facts: IT,
143        rule_origin: usize,
144        symbols: &'a SymbolTable,
145        extern_funcs: &'a HashMap<String, ExternFunc>,
146    ) -> impl Iterator<Item = Result<(Origin, Fact), error::Expression>> + 'a
147    where
148        IT: Iterator<Item = (&'a Origin, &'a Fact)> + Clone + 'a,
149    {
150        let head = self.head.clone();
151        let variables = MatchedVariables::new(self.variables_set());
152
153        CombineIt::new(variables, &self.body, facts, symbols)
154        .map(move |(origin, variables)| {
155                    let mut temporary_symbols = TemporarySymbolTable::new(symbols);
156                    for e in self.expressions.iter() {
157                        match e.evaluate(&variables, &mut temporary_symbols, extern_funcs) {
158                            Ok(Term::Bool(true)) => {}
159                            Ok(Term::Bool(false)) => return Ok((origin, variables, false)),
160                            Ok(_) => return Err(error::Expression::InvalidType),
161                            Err(e) => {
162                                //println!("expr returned {:?}", res);
163                                return Err(e);
164                            }
165                        }
166                    }
167            Ok((origin, variables, true))
168        }).filter_map(move |res/*(mut origin,h, expression_res)*/| {
169            match res {
170                Ok((mut origin,h , expression_res)) => {
171                    if expression_res {
172                    let mut p = head.clone();
173                    for index in 0..p.terms.len() {
174                        match &p.terms[index] {
175                            Term::Variable(i) => match h.get(i) {
176                              Some(val) => p.terms[index] = val.clone(),
177                              None => {
178                                println!("error: variables that appear in the head should appear in the body and constraints as well");
179                                return None;
180                              }
181                            },
182                            _ => continue,
183                        };
184                    }
185
186                    origin.insert(rule_origin);
187                    Some(Ok((origin, Fact { predicate: p })))
188                } else {None}
189                },
190                Err(e) => Some(Err(e))
191            }
192
193        })
194    }
195
196    pub fn find_match(
197        &self,
198        facts: &FactSet,
199        origin: usize,
200        scope: &TrustedOrigins,
201        symbols: &SymbolTable,
202        extern_funcs: &HashMap<String, ExternFunc>,
203    ) -> Result<bool, Execution> {
204        let fact_it = facts.iterator(scope);
205        let mut it = self.apply(fact_it, origin, symbols, extern_funcs);
206
207        let next = it.next();
208        match next {
209            None => Ok(false),
210            Some(Ok(_)) => Ok(true),
211            Some(Err(e)) => Err(Execution::Expression(e)),
212        }
213    }
214
215    pub fn check_match_all(
216        &self,
217        facts: &FactSet,
218        scope: &TrustedOrigins,
219        symbols: &SymbolTable,
220        extern_funcs: &HashMap<String, ExternFunc>,
221    ) -> Result<bool, Execution> {
222        let fact_it = facts.iterator(scope);
223        let variables = MatchedVariables::new(self.variables_set());
224        let mut found = false;
225
226        for (_, variables) in CombineIt::new(variables, &self.body, fact_it, symbols) {
227            found = true;
228
229            let mut temporary_symbols = TemporarySymbolTable::new(symbols);
230            for e in self.expressions.iter() {
231                match e.evaluate(&variables, &mut temporary_symbols, extern_funcs) {
232                    Ok(Term::Bool(true)) => {}
233                    Ok(Term::Bool(false)) => {
234                        //println!("expr returned {:?}", res);
235                        return Ok(false);
236                    }
237                    Ok(_) => {
238                        return Err(error::Execution::Expression(error::Expression::InvalidType))
239                    }
240                    Err(e) => {
241                        return Err(error::Execution::Expression(e));
242                    }
243                }
244            }
245        }
246
247        Ok(found)
248    }
249
250    // use this to translate rules and checks from token to authorizer world without translating
251    // to a builder Rule first, because the builder Rule can contain a public key, so we would
252    // need to loo up then retranslate that key, while the datalog rule does not need to know about
253    // the key (the scope is driven by the authorizer's side)
254    pub fn translate(
255        &self,
256        origin_symbols: &SymbolTable,
257        target_symbols: &mut SymbolTable,
258    ) -> Result<Self, error::Format> {
259        Ok(Rule {
260            head: builder::Predicate::convert_from(&self.head, origin_symbols)?
261                .convert(target_symbols),
262            body: self
263                .body
264                .iter()
265                .map(|p| {
266                    builder::Predicate::convert_from(p, origin_symbols)
267                        .map(|p| p.convert(target_symbols))
268                })
269                .collect::<Result<Vec<_>, _>>()?,
270            expressions: self
271                .expressions
272                .iter()
273                .map(|c| {
274                    builder::Expression::convert_from(c, origin_symbols)
275                        .map(|e| e.convert(target_symbols))
276                })
277                .collect::<Result<Vec<_>, _>>()?,
278            scopes: self
279                .scopes
280                .iter()
281                .map(|s| {
282                    builder::Scope::convert_from(s, origin_symbols)
283                        .map(|s| s.convert(target_symbols))
284                })
285                .collect::<Result<Vec<_>, _>>()?,
286        })
287    }
288
289    pub fn validate_variables(&self, symbols: &SymbolTable) -> Result<(), String> {
290        let mut head_variables: std::collections::HashSet<u32> = self
291            .head
292            .terms
293            .iter()
294            .filter_map(|term| match term {
295                Term::Variable(s) => Some(*s),
296                _ => None,
297            })
298            .collect();
299
300        for predicate in self.body.iter() {
301            for term in predicate.terms.iter() {
302                if let Term::Variable(v) = term {
303                    head_variables.remove(v);
304                    if head_variables.is_empty() {
305                        return Ok(());
306                    }
307                }
308            }
309        }
310
311        if head_variables.is_empty() {
312            Ok(())
313        } else {
314            Err(format!(
315                    "rule head contains variables that are not used in predicates of the rule's body: {}",
316                    head_variables
317                    .iter()
318                    .map(|s| format!("${}", symbols.print_symbol_default(*s as u64)))
319                    .collect::<Vec<_>>()
320                    .join(", ")
321                    ))
322        }
323    }
324}
325
326/// recursive iterator for rule application
327pub struct CombineIt<'a, IT> {
328    variables: MatchedVariables,
329    predicates: &'a [Predicate],
330    all_facts: IT,
331    symbols: &'a SymbolTable,
332    current_facts: Box<dyn Iterator<Item = (&'a Origin, &'a Fact)> + 'a>,
333    current_it: Option<Box<dyn Iterator<Item = (Origin, HashMap<u32, Term>)> + 'a>>,
334}
335
336impl<'a, IT> CombineIt<'a, IT>
337where
338    IT: Iterator<Item = (&'a Origin, &'a Fact)> + Clone + 'a,
339{
340    pub fn new(
341        variables: MatchedVariables,
342        predicates: &'a [Predicate],
343        facts: IT,
344        symbols: &'a SymbolTable,
345    ) -> Self {
346        let current_facts: Box<dyn Iterator<Item = (&'a Origin, &'a Fact)> + 'a> =
347            if predicates.is_empty() {
348                Box::new(facts.clone())
349            } else {
350                let p = predicates[0].clone();
351                Box::new(
352                    facts
353                        .clone()
354                        .filter(move |fact| match_preds(&p, &fact.1.predicate)),
355                )
356            };
357
358        CombineIt {
359            variables,
360            predicates,
361            all_facts: facts,
362            symbols,
363            current_facts,
364            current_it: None,
365        }
366    }
367}
368
369impl<'a, IT> Iterator for CombineIt<'a, IT>
370where
371    IT: Iterator<Item = (&'a Origin, &'a Fact)> + Clone + 'a,
372    Self: 'a,
373{
374    type Item = (Origin, HashMap<u32, Term>);
375
376    fn next(&mut self) -> Option<(Origin, HashMap<u32, Term>)> {
377        // if we're the last iterator in the recursive chain, stop here
378        if self.predicates.is_empty() {
379            match self.variables.complete() {
380                None => return None,
381                // we got a complete set of variables, let's test the expressions
382                Some(variables) => {
383                    // if there were no predicates and expressions evaluated to true,
384                    // we should return a value, but only once. To prevent further
385                    // successful calls, we create a set of variables that cannot
386                    // possibly be completed, so the next call will fail
387                    self.variables = MatchedVariables::new([0].into());
388                    return Some((Origin::default(), variables));
389                }
390            }
391        }
392
393        loop {
394            if self.current_it.is_none() {
395                //fix the first predicate
396                let pred = &self.predicates[0];
397
398                loop {
399                    if let Some((current_origin, current_fact)) = self.current_facts.next() {
400                        // create a new MatchedVariables in which we fix variables we could unify
401                        // from our first predicate and the current fact
402                        let mut vars = self.variables.clone();
403                        let mut match_terms = true;
404                        for (key, id) in pred.terms.iter().zip(&current_fact.predicate.terms) {
405                            if let (Term::Variable(k), id) = (key, id) {
406                                if !vars.insert(*k, id) {
407                                    match_terms = false;
408                                }
409
410                                if !match_terms {
411                                    break;
412                                }
413                            }
414                        }
415
416                        if !match_terms {
417                            continue;
418                        }
419
420                        if self.predicates.len() == 1 {
421                            match vars.complete() {
422                                None => {
423                                    //println!("variables not complete, continue");
424                                    continue;
425                                }
426                                // we got a complete set of variables, let's test the expressions
427                                Some(variables) => {
428                                    return Some((current_origin.clone(), variables));
429                                }
430                            }
431                        } else {
432                            // create a new iterator with the matched variables, the rest of the predicates,
433                            // and all of the facts
434                            self.current_it = Some(Box::new(
435                                CombineIt::new(
436                                    vars,
437                                    &self.predicates[1..],
438                                    self.all_facts.clone(),
439                                    self.symbols,
440                                )
441                                .map(move |(origin, variables)| {
442                                    (origin.union(current_origin), variables)
443                                }),
444                            ));
445                        }
446                        break;
447                    } else {
448                        return None;
449                    }
450                }
451            }
452
453            if self.current_it.is_none() {
454                break None;
455            }
456
457            if let Some((origin, variables)) = self.current_it.as_mut().and_then(|it| it.next()) {
458                break Some((origin, variables));
459            } else {
460                self.current_it = None;
461            }
462        }
463    }
464}
465
466#[derive(Debug, Clone, PartialEq, Eq)]
467pub struct MatchedVariables {
468    pub variables: HashMap<u32, Option<Term>>,
469}
470
471impl MatchedVariables {
472    pub fn new(import: HashSet<u32>) -> Self {
473        MatchedVariables {
474            variables: import.iter().map(|key| (*key, None)).collect(),
475        }
476    }
477
478    pub fn insert(&mut self, key: u32, value: &Term) -> bool {
479        match self.variables.get(&key) {
480            Some(None) => {
481                self.variables.insert(key, Some(value.clone()));
482                true
483            }
484            Some(Some(v)) => value == v,
485            None => false,
486        }
487    }
488
489    pub fn is_complete(&self) -> bool {
490        self.variables.values().all(|v| v.is_some())
491    }
492
493    pub fn complete(&self) -> Option<HashMap<u32, Term>> {
494        let mut result = HashMap::new();
495        for (k, v) in self.variables.iter() {
496            match v {
497                Some(value) => result.insert(*k, value.clone()),
498                None => return None,
499            };
500        }
501        Some(result)
502    }
503}
504
505pub fn fact<I: AsRef<Term>>(name: SymbolIndex, terms: &[I]) -> Fact {
506    Fact {
507        predicate: Predicate {
508            name,
509            terms: terms.iter().map(|id| id.as_ref().clone()).collect(),
510        },
511    }
512}
513
514pub fn pred<I: AsRef<Term>>(name: SymbolIndex, terms: &[I]) -> Predicate {
515    Predicate {
516        name,
517        terms: terms.iter().map(|id| id.as_ref().clone()).collect(),
518    }
519}
520
521pub fn rule<I: AsRef<Term>, P: AsRef<Predicate>>(
522    head_name: SymbolIndex,
523    head_terms: &[I],
524    predicates: &[P],
525) -> Rule {
526    Rule {
527        head: pred(head_name, head_terms),
528        body: predicates.iter().map(|p| p.as_ref().clone()).collect(),
529        expressions: Vec::new(),
530        scopes: vec![],
531    }
532}
533
534pub fn expressed_rule<I: AsRef<Term>, P: AsRef<Predicate>, C: AsRef<Expression>>(
535    head_name: SymbolIndex,
536    head_terms: &[I],
537    predicates: &[P],
538    expressions: &[C],
539) -> Rule {
540    Rule {
541        head: pred(head_name, head_terms),
542        body: predicates.iter().map(|p| p.as_ref().clone()).collect(),
543        expressions: expressions.iter().map(|c| c.as_ref().clone()).collect(),
544        scopes: vec![],
545    }
546}
547
548pub fn int(i: i64) -> Term {
549    Term::Integer(i)
550}
551
552/*pub fn string(s: &str) -> Term {
553    Term::Str(s.to_string())
554}*/
555
556pub fn date(t: &SystemTime) -> Term {
557    let dur = t.duration_since(UNIX_EPOCH).unwrap();
558    Term::Date(dur.as_secs())
559}
560
561pub fn var(syms: &mut SymbolTable, name: &str) -> Term {
562    let id = syms.insert(name);
563    Term::Variable(id as u32)
564}
565
566pub fn match_preds(rule_pred: &Predicate, fact_pred: &Predicate) -> bool {
567    rule_pred.name == fact_pred.name
568        && rule_pred.terms.len() == fact_pred.terms.len()
569        && rule_pred
570            .terms
571            .iter()
572            .zip(&fact_pred.terms)
573            .all(|(fid, pid)| match (fid, pid) {
574                // the fact should not contain variables
575                (_, Term::Variable(_)) => false,
576                (Term::Variable(_), _) => true,
577                (Term::Integer(i), Term::Integer(j)) => i == j,
578                (Term::Str(i), Term::Str(j)) => i == j,
579                (Term::Date(i), Term::Date(j)) => i == j,
580                (Term::Bytes(i), Term::Bytes(j)) => i == j,
581                (Term::Bool(i), Term::Bool(j)) => i == j,
582                (Term::Null, Term::Null) => true,
583                (Term::Set(i), Term::Set(j)) => i == j,
584                (Term::Array(i), Term::Array(j)) => i == j,
585                (Term::Map(i), Term::Map(j)) => i == j,
586                _ => false,
587            })
588}
589
590#[derive(Debug, Clone, Default)]
591pub struct World {
592    pub facts: FactSet,
593    pub rules: RuleSet,
594    pub iterations: u64,
595    pub extern_funcs: HashMap<String, ExternFunc>,
596}
597
598impl World {
599    pub fn new() -> Self {
600        World::default()
601    }
602
603    pub fn add_fact(&mut self, origin: &Origin, fact: Fact) {
604        self.facts.insert(origin, fact);
605    }
606
607    pub fn add_rule(&mut self, origin: usize, scope: &TrustedOrigins, rule: Rule) {
608        self.rules.insert(origin, scope, rule);
609    }
610
611    pub fn run(&mut self, symbols: &SymbolTable) -> Result<(), crate::error::Execution> {
612        self.run_with_limits(symbols, RunLimits::default())
613    }
614
615    pub fn run_with_limits(
616        &mut self,
617        symbols: &SymbolTable,
618        limits: RunLimits,
619    ) -> Result<(), crate::error::Execution> {
620        let start = Instant::now();
621        let time_limit = start + limits.max_time;
622        let mut index = 0;
623
624        let res = loop {
625            let mut new_facts = FactSet::default();
626
627            for (scope, rules) in self.rules.inner.iter() {
628                let it = self.facts.iterator(scope);
629                for (origin, rule) in rules {
630                    for res in rule.apply(it.clone(), *origin, symbols, &self.extern_funcs) {
631                        match res {
632                            Ok((origin, fact)) => {
633                                new_facts.insert(&origin, fact);
634                            }
635                            Err(e) => {
636                                return Err(Execution::Expression(e));
637                            }
638                        }
639                    }
640                    //println!("new_facts after applying {:?}:\n{:#?}", rule, new_facts);
641                }
642            }
643
644            let len = self.facts.len();
645            self.facts.merge(new_facts);
646            if self.facts.len() == len {
647                break Ok(());
648            }
649
650            index += 1;
651            if index == limits.max_iterations {
652                break Err(Execution::RunLimit(
653                    crate::error::RunLimit::TooManyIterations,
654                ));
655            }
656
657            if self.facts.len() >= limits.max_facts as usize {
658                break Err(Execution::RunLimit(crate::error::RunLimit::TooManyFacts));
659            }
660
661            let now = Instant::now();
662            if now >= time_limit {
663                break Err(Execution::RunLimit(crate::error::RunLimit::Timeout));
664            }
665        };
666
667        self.iterations += index;
668
669        res
670    }
671
672    /*pub fn query(&self, pred: Predicate) -> Vec<&Fact> {
673        self.facts
674            .iter()
675            .filter(|f| {
676                f.predicate.name == pred.name
677                    && f.predicate.terms.iter().zip(&pred.terms).all(|(fid, pid)| {
678                        match (fid, pid) {
679                            //(Term::Symbol(_), Term::Variable(_)) => true,
680                            //(Term::Symbol(i), Term::Symbol(ref j)) => i == j,
681                            (_, Term::Variable(_)) => true,
682                            (Term::Integer(i), Term::Integer(ref j)) => i == j,
683                            (Term::Str(i), Term::Str(ref j)) => i == j,
684                            (Term::Date(i), Term::Date(ref j)) => i == j,
685                            (Term::Bytes(i), Term::Bytes(ref j)) => i == j,
686                            (Term::Bool(i), Term::Bool(ref j)) => i == j,
687                            (Term::Set(i), Term::Set(ref j)) => i == j,
688                            _ => false,
689                        }
690                    })
691            })
692            .collect::<Vec<_>>()
693    }*/
694
695    pub fn query_rule(
696        &self,
697        rule: Rule,
698        origin: usize,
699        scope: &TrustedOrigins,
700        symbols: &SymbolTable,
701    ) -> Result<FactSet, Execution> {
702        let mut new_facts = FactSet::default();
703        let it = self.facts.iterator(scope);
704        //new_facts.extend(rule.apply(it, origin, symbols));
705        for res in rule.apply(it.clone(), origin, symbols, &self.extern_funcs) {
706            match res {
707                Ok((origin, fact)) => {
708                    new_facts.insert(&origin, fact);
709                }
710                Err(e) => {
711                    return Err(Execution::Expression(e));
712                }
713            }
714        }
715
716        Ok(new_facts)
717    }
718
719    pub fn query_match(
720        &self,
721        rule: Rule,
722        origin: usize,
723        scope: &TrustedOrigins,
724        symbols: &SymbolTable,
725    ) -> Result<bool, Execution> {
726        rule.find_match(&self.facts, origin, scope, symbols, &self.extern_funcs)
727    }
728
729    pub fn query_match_all(
730        &self,
731        rule: Rule,
732        scope: &TrustedOrigins,
733        symbols: &SymbolTable,
734    ) -> Result<bool, Execution> {
735        rule.check_match_all(&self.facts, scope, symbols, &self.extern_funcs)
736    }
737}
738
739/// runtime limits for the Datalog engine
740#[derive(Debug, Clone, PartialEq, Eq)]
741pub struct RunLimits {
742    /// maximum number of Datalog facts (memory usage)
743    pub max_facts: u64,
744    /// maximum number of iterations of the rules applications (prevents degenerate rules)
745    pub max_iterations: u64,
746    /// maximum execution time
747    pub max_time: Duration,
748}
749
750impl std::default::Default for RunLimits {
751    fn default() -> Self {
752        RunLimits {
753            max_facts: 1000,
754            max_iterations: 100,
755            max_time: Duration::from_millis(1),
756        }
757    }
758}
759
760#[derive(Clone, Debug, Default)]
761pub struct FactSet {
762    pub(crate) inner: HashMap<Origin, HashSet<Fact>>,
763}
764
765impl FactSet {
766    pub fn insert(&mut self, origin: &Origin, fact: Fact) {
767        match self.inner.get_mut(origin) {
768            None => {
769                let mut set = HashSet::new();
770                set.insert(fact);
771                self.inner.insert(origin.clone(), set);
772            }
773            Some(set) => {
774                set.insert(fact);
775            }
776        }
777    }
778
779    pub fn len(&self) -> usize {
780        self.inner.values().fold(0, |acc, set| acc + set.len())
781    }
782
783    pub fn is_empty(&self) -> bool {
784        self.inner.values().all(|set| set.is_empty())
785    }
786
787    pub fn iterator<'a>(
788        &'a self,
789        block_ids: &'a TrustedOrigins,
790    ) -> impl Iterator<Item = (&'a Origin, &'a Fact)> + Clone {
791        self.inner
792            .iter()
793            .filter_map(move |(ids, facts)| {
794                if block_ids.contains(ids) {
795                    Some(facts.iter().map(move |fact| (ids, fact)))
796                } else {
797                    None
798                }
799            })
800            .flatten()
801    }
802
803    pub fn iter_all(&self) -> impl Iterator<Item = (&Origin, &Fact)> + Clone {
804        self.inner
805            .iter()
806            .flat_map(move |(ids, facts)| facts.iter().map(move |fact| (ids, fact)))
807    }
808
809    pub fn merge(&mut self, other: FactSet) {
810        for (origin, facts) in other.inner {
811            let entry = self.inner.entry(origin).or_default();
812            entry.extend(facts.into_iter());
813        }
814    }
815}
816
817impl Extend<(Origin, Fact)> for FactSet {
818    fn extend<T: IntoIterator<Item = (Origin, Fact)>>(&mut self, iter: T) {
819        for (origin, fact) in iter {
820            let entry = self.inner.entry(origin).or_default();
821            entry.insert(fact);
822        }
823    }
824}
825
826impl IntoIterator for FactSet {
827    type Item = (Origin, Fact);
828
829    type IntoIter = Box<dyn Iterator<Item = (Origin, Fact)>>;
830
831    fn into_iter(self) -> Self::IntoIter {
832        Box::new(
833            self.inner.into_iter().flat_map(move |(ids, facts)| {
834                facts.into_iter().map(move |fact| (ids.clone(), fact))
835            }),
836        )
837    }
838}
839
840#[derive(Clone, Debug, Default)]
841pub struct RuleSet {
842    pub inner: HashMap<TrustedOrigins, Vec<(usize, Rule)>>,
843}
844
845impl RuleSet {
846    pub fn insert(&mut self, origin: usize, scope: &TrustedOrigins, rule: Rule) {
847        match self.inner.get_mut(scope) {
848            None => {
849                self.inner.insert(scope.clone(), vec![(origin, rule)]);
850            }
851            Some(set) => {
852                set.push((origin, rule));
853            }
854        }
855    }
856
857    pub fn iter_all(&self) -> impl Iterator<Item = (&TrustedOrigins, &Rule)> + Clone {
858        self.inner
859            .iter()
860            .flat_map(move |(ids, rules)| rules.iter().map(move |(_, rule)| (ids, rule)))
861    }
862}
863
864pub struct SchemaVersion {
865    contains_scopes: bool,
866    contains_v3_1: bool,
867    contains_check_all: bool,
868    contains_v3_3: bool,
869}
870
871impl SchemaVersion {
872    pub fn version(&self) -> u32 {
873        if self.contains_v3_3 {
874            DATALOG_3_3
875        } else if self.contains_scopes || self.contains_v3_1 || self.contains_check_all {
876            DATALOG_3_1
877        } else {
878            MIN_SCHEMA_VERSION
879        }
880    }
881
882    pub fn check_compatibility(&self, version: u32) -> Result<(), error::Format> {
883        if version < DATALOG_3_1 {
884            if self.contains_scopes {
885                Err(error::Format::DeserializationError(
886                    "scopes are only supported in datalog v3.1+".to_string(),
887                ))
888            } else if self.contains_v3_1 {
889                Err(error::Format::DeserializationError(
890                    "bitwise operators and != are only supported in datalog v3.1+".to_string(),
891                ))
892            } else if self.contains_check_all {
893                Err(error::Format::DeserializationError(
894                    "check all is only supported in datalog v3.1+".to_string(),
895                ))
896            } else {
897                Ok(())
898            }
899        } else if version < DATALOG_3_3 && self.contains_v3_3 {
900            Err(error::Format::DeserializationError(
901                "maps, arrays, null, closures are only supported in datalog v3.3+".to_string(),
902            ))
903        } else {
904            Ok(())
905        }
906    }
907}
908
909/// Determine the schema version given the elements of a block.
910pub fn get_schema_version(
911    facts: &[Fact],
912    rules: &[Rule],
913    checks: &[Check],
914    scopes: &[Scope],
915) -> SchemaVersion {
916    let contains_scopes = !scopes.is_empty()
917        || rules.iter().any(|r: &Rule| !r.scopes.is_empty())
918        || checks
919            .iter()
920            .any(|c: &Check| c.queries.iter().any(|q| !q.scopes.is_empty()));
921
922    let mut contains_check_all = false;
923    let mut contains_v3_3 = false;
924
925    for c in checks.iter() {
926        if c.kind == CheckKind::All {
927            contains_check_all = true;
928        } else if c.kind == CheckKind::Reject {
929            contains_v3_3 = true;
930        }
931    }
932
933    let contains_v3_1 = rules.iter().any(|rule| contains_v3_1_op(&rule.expressions))
934        || checks.iter().any(|check| {
935            check
936                .queries
937                .iter()
938                .any(|query| contains_v3_1_op(&query.expressions))
939        });
940
941    // null, heterogeneous equals, closures
942    if !contains_v3_3 {
943        contains_v3_3 = rules.iter().any(|rule| {
944            contains_v3_3_predicate(&rule.head)
945                || rule.body.iter().any(contains_v3_3_predicate)
946                || contains_v3_3_op(&rule.expressions)
947        }) || checks.iter().any(|check| {
948            check.queries.iter().any(|query| {
949                query.body.iter().any(contains_v3_3_predicate)
950                    || contains_v3_3_op(&query.expressions)
951            })
952        });
953    }
954    if !contains_v3_3 {
955        contains_v3_3 = facts
956            .iter()
957            .any(|fact| contains_v3_3_predicate(&fact.predicate))
958    }
959
960    SchemaVersion {
961        contains_scopes,
962        contains_v3_1,
963        contains_check_all,
964        contains_v3_3,
965    }
966}
967
968/// Determine whether any of the expression contain a v3.1 operator.
969/// Bitwise operators and != are only supported in biscuits v3.1+
970pub fn contains_v3_1_op(expressions: &[Expression]) -> bool {
971    expressions.iter().any(|expression| {
972        expression.ops.iter().any(|op| {
973            if let Op::Binary(binary) = op {
974                match binary {
975                    Binary::BitwiseAnd
976                    | Binary::BitwiseOr
977                    | Binary::BitwiseXor
978                    | Binary::NotEqual => return true,
979                    _ => return false,
980                }
981            }
982            false
983        })
984    })
985}
986
987fn contains_v3_3_op(expressions: &[Expression]) -> bool {
988    expressions.iter().any(|expression| {
989        expression.ops.iter().any(|op| match op {
990            Op::Value(term) => contains_v3_3_term(term),
991            Op::Closure(_, _) => true,
992            Op::Unary(unary) => matches!(unary, Unary::TypeOf | Unary::Ffi(_)),
993            Op::Binary(binary) => matches!(
994                binary,
995                Binary::HeterogeneousEqual
996                    | Binary::HeterogeneousNotEqual
997                    | Binary::LazyAnd
998                    | Binary::LazyOr
999                    | Binary::All
1000                    | Binary::Any
1001                    | Binary::Ffi(_)
1002            ),
1003        })
1004    })
1005}
1006
1007fn contains_v3_3_predicate(predicate: &Predicate) -> bool {
1008    predicate.terms.iter().any(contains_v3_3_term)
1009}
1010
1011fn contains_v3_3_term(term: &Term) -> bool {
1012    match term {
1013        Term::Null => true,
1014        Term::Set(s) => s.contains(&Term::Null),
1015        _ => false,
1016    }
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021    use super::*;
1022    use std::time::Duration;
1023
1024    #[test]
1025    fn family() {
1026        let mut w = World::new();
1027        let mut syms = SymbolTable::new();
1028
1029        let a = syms.add("A");
1030        let b = syms.add("B");
1031        let c = syms.add("C");
1032        let d = syms.add("D");
1033        let e = syms.add("e");
1034        let parent = syms.insert("parent");
1035        let grandparent = syms.insert("grandparent");
1036
1037        w.add_fact(&[0].iter().collect(), fact(parent, &[&a, &b]));
1038        w.add_fact(&[0].iter().collect(), fact(parent, &[&b, &c]));
1039        w.add_fact(&[0].iter().collect(), fact(parent, &[&c, &d]));
1040
1041        let r1 = rule(
1042            grandparent,
1043            &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
1044            &[
1045                pred(
1046                    parent,
1047                    &[var(&mut syms, "grandparent"), var(&mut syms, "parent")],
1048                ),
1049                pred(
1050                    parent,
1051                    &[var(&mut syms, "parent"), var(&mut syms, "grandchild")],
1052                ),
1053            ],
1054        );
1055
1056        println!("symbols: {:?}", syms);
1057        println!("testing r1: {}", syms.print_rule(&r1));
1058        let query_rule_result = w.query_rule(r1, 0, &[0].iter().collect(), &syms);
1059        println!("grandparents query_rules: {:?}", query_rule_result);
1060        println!("current facts: {:?}", w.facts);
1061
1062        let r2 = rule(
1063            grandparent,
1064            &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
1065            &[
1066                pred(
1067                    parent,
1068                    &[var(&mut syms, "grandparent"), var(&mut syms, "parent")],
1069                ),
1070                pred(
1071                    parent,
1072                    &[var(&mut syms, "parent"), var(&mut syms, "grandchild")],
1073                ),
1074            ],
1075        );
1076
1077        println!("adding r2: {}", syms.print_rule(&r2));
1078        w.add_rule(0, &[0].iter().collect(), r2);
1079
1080        w.run_with_limits(
1081            &syms,
1082            RunLimits {
1083                max_time: Duration::from_secs(10),
1084                ..Default::default()
1085            },
1086        )
1087        .unwrap();
1088
1089        println!("parents:");
1090        let res = w
1091            .query_rule(
1092                rule::<Term, Predicate>(
1093                    parent,
1094                    &[var(&mut syms, "parent"), var(&mut syms, "child")],
1095                    &[pred(
1096                        parent,
1097                        &[var(&mut syms, "parent"), var(&mut syms, "child")],
1098                    )],
1099                ),
1100                0,
1101                &[0].iter().collect(),
1102                &syms,
1103            )
1104            .unwrap();
1105
1106        for (origin, fact) in res.iterator(&[0].iter().collect()) {
1107            println!("\t{:?}\t{}", origin, syms.print_fact(fact));
1108        }
1109
1110        println!(
1111            "parents of B: {:?}",
1112            w.query_rule(
1113                rule::<&Term, Predicate>(
1114                    parent,
1115                    &[&var(&mut syms, "parent"), &b],
1116                    &[pred(parent, &[&var(&mut syms, "parent"), &b])]
1117                ),
1118                0,
1119                &[0].iter().collect(),
1120                &syms,
1121            )
1122        );
1123        println!(
1124            "grandparents: {:?}",
1125            w.query_rule(
1126                rule::<Term, Predicate>(
1127                    grandparent,
1128                    &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
1129                    &[pred(
1130                        grandparent,
1131                        &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")]
1132                    )]
1133                ),
1134                0,
1135                &[0].iter().collect(),
1136                &syms,
1137            )
1138        );
1139        w.add_fact(&[0].iter().collect(), fact(parent, &[&c, &e]));
1140        w.run(&syms).unwrap();
1141        let res = w
1142            .query_rule(
1143                rule::<Term, Predicate>(
1144                    grandparent,
1145                    &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
1146                    &[pred(
1147                        grandparent,
1148                        &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
1149                    )],
1150                ),
1151                0,
1152                &[0].iter().collect(),
1153                &syms,
1154            )
1155            .unwrap();
1156        println!("grandparents after inserting parent(C, E): {:?}", res);
1157
1158        let res = res
1159            .iter_all()
1160            .map(|(_origin, fact)| fact)
1161            .cloned()
1162            .collect::<HashSet<_>>();
1163        let compared = (vec![
1164            fact(grandparent, &[&a, &c]),
1165            fact(grandparent, &[&b, &d]),
1166            fact(grandparent, &[&b, &e]),
1167        ])
1168        .drain(..)
1169        .collect::<HashSet<_>>();
1170        assert_eq!(res, compared);
1171
1172        /*w.add_rule(rule("siblings", &[var("A"), var("B")], &[
1173          pred(parent, &[var(parent), var("A")]),
1174          pred(parent, &[var(parent), var("B")])
1175        ]));
1176
1177        w.run();
1178        println!("siblings: {:#?}", w.query(pred("siblings", &[var("A"), var("B")])));
1179        */
1180    }
1181
1182    #[test]
1183    fn numbers() {
1184        let mut w = World::new();
1185        let mut syms = SymbolTable::new();
1186
1187        let abc = syms.add("abc");
1188        let def = syms.add("def");
1189        let ghi = syms.add("ghi");
1190        let jkl = syms.add("jkl");
1191        let mno = syms.add("mno");
1192        let aaa = syms.add("AAA");
1193        let bbb = syms.add("BBB");
1194        let ccc = syms.add("CCC");
1195        let t1 = syms.insert("t1");
1196        let t2 = syms.insert("t2");
1197        let join = syms.insert("join");
1198
1199        w.add_fact(&[0].iter().collect(), fact(t1, &[&int(0), &abc]));
1200        w.add_fact(&[0].iter().collect(), fact(t1, &[&int(1), &def]));
1201        w.add_fact(&[0].iter().collect(), fact(t1, &[&int(2), &ghi]));
1202        w.add_fact(&[0].iter().collect(), fact(t1, &[&int(3), &jkl]));
1203        w.add_fact(&[0].iter().collect(), fact(t1, &[&int(4), &mno]));
1204
1205        w.add_fact(&[0].iter().collect(), fact(t2, &[&int(0), &aaa, &int(0)]));
1206        w.add_fact(&[0].iter().collect(), fact(t2, &[&int(1), &bbb, &int(0)]));
1207        w.add_fact(&[0].iter().collect(), fact(t2, &[&int(2), &ccc, &int(1)]));
1208
1209        let res = w
1210            .query_rule(
1211                rule(
1212                    join,
1213                    &[var(&mut syms, "left"), var(&mut syms, "right")],
1214                    &[
1215                        pred(t1, &[var(&mut syms, "id"), var(&mut syms, "left")]),
1216                        pred(
1217                            t2,
1218                            &[
1219                                var(&mut syms, "t2_id"),
1220                                var(&mut syms, "right"),
1221                                var(&mut syms, "id"),
1222                            ],
1223                        ),
1224                    ],
1225                ),
1226                0,
1227                &[0].iter().collect(),
1228                &syms,
1229            )
1230            .unwrap();
1231
1232        for (_, fact) in res.iter_all() {
1233            println!("\t{}", syms.print_fact(fact));
1234        }
1235
1236        let res2 = res
1237            .iter_all()
1238            .map(|(_origin, fact)| fact)
1239            .cloned()
1240            .collect::<HashSet<_>>();
1241        let compared = (vec![
1242            fact(join, &[&abc, &aaa]),
1243            fact(join, &[&abc, &bbb]),
1244            fact(join, &[&def, &ccc]),
1245        ])
1246        .drain(..)
1247        .collect::<HashSet<_>>();
1248        assert_eq!(res2, compared);
1249
1250        // test constraints
1251        let res = w
1252            .query_rule(
1253                expressed_rule(
1254                    join,
1255                    &[var(&mut syms, "left"), var(&mut syms, "right")],
1256                    &[
1257                        pred(t1, &[var(&mut syms, "id"), var(&mut syms, "left")]),
1258                        pred(
1259                            t2,
1260                            &[
1261                                var(&mut syms, "t2_id"),
1262                                var(&mut syms, "right"),
1263                                var(&mut syms, "id"),
1264                            ],
1265                        ),
1266                    ],
1267                    &[Expression {
1268                        ops: vec![
1269                            Op::Value(var(&mut syms, "id")),
1270                            Op::Value(Term::Integer(1)),
1271                            Op::Binary(Binary::LessThan),
1272                        ],
1273                    }],
1274                ),
1275                0,
1276                &[0].iter().collect(),
1277                &syms,
1278            )
1279            .unwrap();
1280
1281        for (_, fact) in res.iter_all() {
1282            println!("\t{}", syms.print_fact(fact));
1283        }
1284
1285        let res2 = res
1286            .iter_all()
1287            .map(|(_origin, fact)| fact)
1288            .cloned()
1289            .collect::<HashSet<_>>();
1290        let compared = (vec![fact(join, &[&abc, &aaa]), fact(join, &[&abc, &bbb])])
1291            .drain(..)
1292            .collect::<HashSet<_>>();
1293        assert_eq!(res2, compared);
1294    }
1295
1296    #[test]
1297    fn str() {
1298        let mut w = World::new();
1299        let mut syms = SymbolTable::new();
1300
1301        let app_0 = syms.add("app_0");
1302        let app_1 = syms.add("app_1");
1303        let app_2 = syms.add("app_2");
1304        let route = syms.insert("route");
1305        let suff = syms.insert("route suffix");
1306        let example = syms.add("example.com");
1307        let test_com = syms.add("test.com");
1308        let test_fr = syms.add("test.fr");
1309        let www_example = syms.add("www.example.com");
1310        let mx_example = syms.add("mx.example.com");
1311
1312        w.add_fact(
1313            &[0].iter().collect(),
1314            fact(route, &[&int(0), &app_0, &example]),
1315        );
1316        w.add_fact(
1317            &[0].iter().collect(),
1318            fact(route, &[&int(1), &app_1, &test_com]),
1319        );
1320        w.add_fact(
1321            &[0].iter().collect(),
1322            fact(route, &[&int(2), &app_2, &test_fr]),
1323        );
1324        w.add_fact(
1325            &[0].iter().collect(),
1326            fact(route, &[&int(3), &app_0, &www_example]),
1327        );
1328        w.add_fact(
1329            &[0].iter().collect(),
1330            fact(route, &[&int(4), &app_1, &mx_example]),
1331        );
1332
1333        fn test_suffix(
1334            w: &World,
1335            syms: &mut SymbolTable,
1336            suff: SymbolIndex,
1337            route: SymbolIndex,
1338            suffix: &str,
1339        ) -> Vec<Fact> {
1340            let id_suff = syms.add(suffix);
1341            w.query_rule(
1342                expressed_rule(
1343                    suff,
1344                    &[var(syms, "app_id"), var(syms, "domain_name")],
1345                    &[pred(
1346                        route,
1347                        &[
1348                            var(syms, "route_id"),
1349                            var(syms, "app_id"),
1350                            var(syms, "domain_name"),
1351                        ],
1352                    )],
1353                    &[Expression {
1354                        ops: vec![
1355                            Op::Value(var(syms, "domain_name")),
1356                            Op::Value(id_suff),
1357                            Op::Binary(Binary::Suffix),
1358                        ],
1359                    }],
1360                ),
1361                0,
1362                &[0].iter().collect(),
1363                &syms,
1364            )
1365            .unwrap()
1366            .iter_all()
1367            .map(|(_, fact)| fact.clone())
1368            .collect()
1369        }
1370
1371        let res = test_suffix(&w, &mut syms, suff, route, ".fr");
1372        for fact in &res {
1373            println!("\t{}", syms.print_fact(fact));
1374        }
1375
1376        let res2 = res.iter().cloned().collect::<HashSet<_>>();
1377        let compared = (vec![fact(suff, &[&app_2, &test_fr])])
1378            .drain(..)
1379            .collect::<HashSet<_>>();
1380        assert_eq!(res2, compared);
1381
1382        let res = test_suffix(&w, &mut syms, suff, route, "example.com");
1383        for fact in &res {
1384            println!("\t{}", syms.print_fact(fact));
1385        }
1386
1387        let res2 = res.iter().cloned().collect::<HashSet<_>>();
1388        let compared = (vec![
1389            fact(suff, &[&app_0, &example]),
1390            fact(suff, &[&app_0, &www_example]),
1391            fact(suff, &[&app_1, &mx_example]),
1392        ])
1393        .drain(..)
1394        .collect::<HashSet<_>>();
1395        assert_eq!(res2, compared);
1396    }
1397
1398    #[test]
1399    fn date_constraint() {
1400        let mut w = World::new();
1401        let mut syms = SymbolTable::new();
1402
1403        let t1 = SystemTime::now();
1404        println!("t1 = {:?}", t1);
1405        let t2 = t1 + Duration::from_secs(10);
1406        println!("t2 = {:?}", t2);
1407        let t3 = t2 + Duration::from_secs(30);
1408        println!("t3 = {:?}", t3);
1409
1410        let t2_timestamp = t2.duration_since(UNIX_EPOCH).unwrap().as_secs();
1411
1412        let abc = syms.add("abc");
1413        let def = syms.add("def");
1414        let x = syms.insert("x");
1415        let before = syms.insert("before");
1416        let after = syms.insert("after");
1417
1418        w.add_fact(&[0].iter().collect(), fact(x, &[&date(&t1), &abc]));
1419        w.add_fact(&[0].iter().collect(), fact(x, &[&date(&t3), &def]));
1420
1421        let r1 = expressed_rule(
1422            before,
1423            &[var(&mut syms, "date"), var(&mut syms, "val")],
1424            &[pred(x, &[var(&mut syms, "date"), var(&mut syms, "val")])],
1425            &[
1426                Expression {
1427                    ops: vec![
1428                        Op::Value(var(&mut syms, "date")),
1429                        Op::Value(Term::Date(t2_timestamp)),
1430                        Op::Binary(Binary::LessOrEqual),
1431                    ],
1432                },
1433                Expression {
1434                    ops: vec![
1435                        Op::Value(var(&mut syms, "date")),
1436                        Op::Value(Term::Date(0)),
1437                        Op::Binary(Binary::GreaterOrEqual),
1438                    ],
1439                },
1440            ],
1441        );
1442
1443        println!("testing r1: {}", syms.print_rule(&r1));
1444        let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap();
1445        for (_, fact) in res.iter_all() {
1446            println!("\t{}", syms.print_fact(fact));
1447        }
1448
1449        let res2 = res
1450            .iter_all()
1451            .map(|(_origin, fact)| fact)
1452            .cloned()
1453            .collect::<HashSet<_>>();
1454        let compared = (vec![fact(before, &[&date(&t1), &abc])])
1455            .drain(..)
1456            .collect::<HashSet<_>>();
1457        assert_eq!(res2, compared);
1458
1459        let r2 = expressed_rule(
1460            after,
1461            &[var(&mut syms, "date"), var(&mut syms, "val")],
1462            &[pred(x, &[var(&mut syms, "date"), var(&mut syms, "val")])],
1463            &[
1464                Expression {
1465                    ops: vec![
1466                        Op::Value(var(&mut syms, "date")),
1467                        Op::Value(Term::Date(t2_timestamp)),
1468                        Op::Binary(Binary::GreaterOrEqual),
1469                    ],
1470                },
1471                Expression {
1472                    ops: vec![
1473                        Op::Value(var(&mut syms, "date")),
1474                        Op::Value(Term::Date(0)),
1475                        Op::Binary(Binary::GreaterOrEqual),
1476                    ],
1477                },
1478            ],
1479        );
1480
1481        println!("testing r2: {}", syms.print_rule(&r2));
1482        let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap();
1483        for (_, fact) in res.iter_all() {
1484            println!("\t{}", syms.print_fact(fact));
1485        }
1486
1487        let res2 = res
1488            .iter_all()
1489            .map(|(_, fact)| fact)
1490            .cloned()
1491            .collect::<HashSet<_>>();
1492        let compared = (vec![fact(after, &[&date(&t3), &def])])
1493            .drain(..)
1494            .collect::<HashSet<_>>();
1495        assert_eq!(res2, compared);
1496    }
1497
1498    #[test]
1499    fn set_constraint() {
1500        let mut w = World::new();
1501        let mut syms = SymbolTable::new();
1502
1503        let abc = syms.add("abc");
1504        let def = syms.add("def");
1505        let x = syms.insert("x");
1506        let int_set = syms.insert("int_set");
1507        let symbol_set = syms.insert("symbol_set");
1508        let string_set = syms.insert("string_set");
1509        let test = syms.add("test");
1510        let hello = syms.add("hello");
1511        let aaa = syms.add("zzz");
1512
1513        w.add_fact(&[0].iter().collect(), fact(x, &[&abc, &int(0), &test]));
1514        w.add_fact(&[0].iter().collect(), fact(x, &[&def, &int(2), &hello]));
1515
1516        let res = w
1517            .query_rule(
1518                expressed_rule(
1519                    int_set,
1520                    &[var(&mut syms, "sym"), var(&mut syms, "str")],
1521                    &[pred(
1522                        x,
1523                        &[
1524                            var(&mut syms, "sym"),
1525                            var(&mut syms, "int"),
1526                            var(&mut syms, "str"),
1527                        ],
1528                    )],
1529                    &[Expression {
1530                        ops: vec![
1531                            Op::Value(Term::Set(
1532                                [Term::Integer(0), Term::Integer(1)]
1533                                    .iter()
1534                                    .cloned()
1535                                    .collect(),
1536                            )),
1537                            Op::Value(var(&mut syms, "int")),
1538                            Op::Binary(Binary::Contains),
1539                        ],
1540                    }],
1541                ),
1542                0,
1543                &[0].iter().collect(),
1544                &syms,
1545            )
1546            .unwrap();
1547
1548        for (_, fact) in res.iter_all() {
1549            println!("\t{}", syms.print_fact(fact));
1550        }
1551
1552        let res2 = res
1553            .iter_all()
1554            .map(|(_, fact)| fact)
1555            .cloned()
1556            .collect::<HashSet<_>>();
1557        let compared = (vec![fact(int_set, &[&abc, &test])])
1558            .drain(..)
1559            .collect::<HashSet<_>>();
1560        assert_eq!(res2, compared);
1561
1562        let abc_sym_id = syms.add("abc");
1563        let ghi_sym_id = syms.add("ghi");
1564
1565        let res = w
1566            .query_rule(
1567                expressed_rule(
1568                    symbol_set,
1569                    &[
1570                        var(&mut syms, "symbol"),
1571                        var(&mut syms, "int"),
1572                        var(&mut syms, "str"),
1573                    ],
1574                    &[pred(
1575                        x,
1576                        &[
1577                            var(&mut syms, "symbol"),
1578                            var(&mut syms, "int"),
1579                            var(&mut syms, "str"),
1580                        ],
1581                    )],
1582                    &[Expression {
1583                        ops: vec![
1584                            Op::Value(Term::Set(
1585                                [abc_sym_id, ghi_sym_id].iter().cloned().collect(),
1586                            )),
1587                            Op::Value(var(&mut syms, "symbol")),
1588                            Op::Binary(Binary::Contains),
1589                            Op::Unary(Unary::Negate),
1590                        ],
1591                    }],
1592                ),
1593                0,
1594                &[0].iter().collect(),
1595                &syms,
1596            )
1597            .unwrap();
1598
1599        for (_, fact) in res.iter_all() {
1600            println!("\t{}", syms.print_fact(fact));
1601        }
1602
1603        let res2 = res
1604            .iter_all()
1605            .map(|(_, fact)| fact)
1606            .cloned()
1607            .collect::<HashSet<_>>();
1608        let compared = (vec![fact(symbol_set, &[&def, &int(2), &hello])])
1609            .drain(..)
1610            .collect::<HashSet<_>>();
1611        assert_eq!(res2, compared);
1612
1613        let res = w
1614            .query_rule(
1615                expressed_rule(
1616                    string_set,
1617                    &[
1618                        var(&mut syms, "sym"),
1619                        var(&mut syms, "int"),
1620                        var(&mut syms, "str"),
1621                    ],
1622                    &[pred(
1623                        x,
1624                        &[
1625                            var(&mut syms, "sym"),
1626                            var(&mut syms, "int"),
1627                            var(&mut syms, "str"),
1628                        ],
1629                    )],
1630                    &[Expression {
1631                        ops: vec![
1632                            Op::Value(Term::Set([test.clone(), aaa].iter().cloned().collect())),
1633                            Op::Value(var(&mut syms, "str")),
1634                            Op::Binary(Binary::Contains),
1635                        ],
1636                    }],
1637                ),
1638                0,
1639                &[0].iter().collect(),
1640                &syms,
1641            )
1642            .unwrap();
1643
1644        for (_, fact) in res.iter_all() {
1645            println!("\t{}", syms.print_fact(fact));
1646        }
1647
1648        let res2 = res
1649            .iter_all()
1650            .map(|(_, fact)| fact)
1651            .cloned()
1652            .collect::<HashSet<_>>();
1653        let compared = (vec![fact(string_set, &[&abc, &int(0), &test])])
1654            .drain(..)
1655            .collect::<HashSet<_>>();
1656        assert_eq!(res2, compared);
1657    }
1658
1659    #[test]
1660    fn resource() {
1661        let mut w = World::new();
1662        let mut syms = SymbolTable::new();
1663
1664        let resource = syms.insert("resource");
1665        let operation = syms.insert("operation");
1666        let right = syms.insert("right");
1667        let file1 = syms.add("file1");
1668        let file2 = syms.add("file2");
1669        let read = syms.add("read");
1670        let write = syms.add("write");
1671        let check1 = syms.insert("check1");
1672        let check2 = syms.insert("check2");
1673
1674        w.add_fact(&[0].iter().collect(), fact(resource, &[&file2]));
1675        w.add_fact(&[0].iter().collect(), fact(operation, &[&write]));
1676        w.add_fact(&[0].iter().collect(), fact(right, &[&file1, &read]));
1677        w.add_fact(&[0].iter().collect(), fact(right, &[&file2, &read]));
1678        w.add_fact(&[0].iter().collect(), fact(right, &[&file1, &write]));
1679
1680        let res = w
1681            .query_rule(
1682                rule(check1, &[&file1], &[pred(resource, &[&file1])]),
1683                0,
1684                &[0].iter().collect(),
1685                &syms,
1686            )
1687            .unwrap();
1688
1689        for (_, fact) in res.iter_all() {
1690            println!("\t{}", syms.print_fact(fact));
1691        }
1692
1693        assert!(res.is_empty());
1694
1695        let res = w
1696            .query_rule(
1697                rule(
1698                    check2,
1699                    &[Term::Variable(0)],
1700                    &[
1701                        pred(resource, &[&Term::Variable(0)]),
1702                        pred(operation, &[&read]),
1703                        pred(right, &[&Term::Variable(0), &read]),
1704                    ],
1705                ),
1706                0,
1707                &[0].iter().collect(),
1708                &syms,
1709            )
1710            .unwrap();
1711
1712        for (_, fact) in res.iter_all() {
1713            println!("\t{}", syms.print_fact(fact));
1714        }
1715
1716        assert!(res.is_empty());
1717    }
1718
1719    #[test]
1720    fn int_expr() {
1721        let mut w = World::new();
1722        let mut syms = SymbolTable::new();
1723
1724        let abc = syms.add("abc");
1725        let def = syms.add("def");
1726        let x = syms.insert("x");
1727        let less_than = syms.insert("less_than");
1728
1729        w.add_fact(&[0].iter().collect(), fact(x, &[&int(-2), &abc]));
1730        w.add_fact(&[0].iter().collect(), fact(x, &[&int(0), &def]));
1731
1732        let r1 = expressed_rule(
1733            less_than,
1734            &[var(&mut syms, "nb"), var(&mut syms, "val")],
1735            &[pred(x, &[var(&mut syms, "nb"), var(&mut syms, "val")])],
1736            &[Expression {
1737                ops: vec![
1738                    Op::Value(Term::Integer(5)),
1739                    Op::Value(Term::Integer(-4)),
1740                    Op::Binary(Binary::Add),
1741                    Op::Value(Term::Integer(-1)),
1742                    Op::Binary(Binary::Mul),
1743                    Op::Value(var(&mut syms, "nb")),
1744                    Op::Binary(Binary::LessThan),
1745                ],
1746            }],
1747        );
1748
1749        println!("world:\n{}\n", syms.print_world(&w));
1750        println!("\ntesting r1: {}\n", syms.print_rule(&r1));
1751        let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap();
1752        for (_, fact) in res.iter_all() {
1753            println!("\t{}", syms.print_fact(fact));
1754        }
1755
1756        let res2 = res
1757            .iter_all()
1758            .map(|(_, fact)| fact)
1759            .cloned()
1760            .collect::<HashSet<_>>();
1761        println!("got res: {:?}", res2);
1762        let compared = (vec![fact(less_than, &[&int(0), &def])])
1763            .drain(..)
1764            .collect::<HashSet<_>>();
1765        assert_eq!(res2, compared);
1766    }
1767
1768    #[test]
1769    fn unbound_variables() {
1770        let mut w = World::new();
1771        let mut syms = SymbolTable::new();
1772
1773        let operation = syms.insert("operation");
1774        let check = syms.insert("check");
1775        let read = syms.add("read");
1776        let write = syms.add("write");
1777        let unbound = var(&mut syms, "unbound");
1778        let any1 = var(&mut syms, "any1");
1779        let any2 = var(&mut syms, "any2");
1780
1781        w.add_fact(&[0].iter().collect(), fact(operation, &[&write]));
1782
1783        let r1 = rule(
1784            operation,
1785            &[&unbound, &read],
1786            &[pred(operation, &[&any1, &any2])],
1787        );
1788        println!("world:\n{}\n", syms.print_world(&w));
1789        println!("\ntesting r1: {}\n", syms.print_rule(&r1));
1790        let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap();
1791
1792        println!("generated facts:");
1793        for (_, fact) in res.iter_all() {
1794            println!("\t{}", syms.print_fact(fact));
1795        }
1796
1797        assert!(res.len() == 0);
1798
1799        // operation($unbound, "read") should not have been generated
1800        // in case it is generated though, verify that rule application
1801        // will not match it
1802        w.add_fact(&[0].iter().collect(), fact(operation, &[&unbound, &read]));
1803        let r2 = rule(check, &[&read], &[pred(operation, &[&read])]);
1804        println!("world:\n{}\n", syms.print_world(&w));
1805        println!("\ntesting r2: {}\n", syms.print_rule(&r2));
1806        let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap();
1807
1808        println!("generated facts:");
1809        for (_, fact) in res.iter_all() {
1810            println!("\t{}", syms.print_fact(fact));
1811        }
1812        assert!(res.is_empty());
1813    }
1814}