biscuit_auth/datalog/
mod.rs

1//! Logic language implementation for checks
2use crate::builder::{CheckKind, Convert};
3use crate::error::Execution;
4use crate::time::Instant;
5use crate::token::{Scope, MIN_SCHEMA_VERSION};
6use crate::{builder, error};
7use std::collections::{BTreeSet, HashMap, HashSet};
8use std::convert::AsRef;
9use std::fmt;
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11
12mod expression;
13mod origin;
14mod symbol;
15pub use expression::*;
16pub use origin::*;
17pub use symbol::*;
18
19#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
20pub enum Term {
21    Variable(u32),
22    Integer(i64),
23    Str(SymbolIndex),
24    Date(u64),
25    Bytes(Vec<u8>),
26    Bool(bool),
27    Set(BTreeSet<Term>),
28}
29
30impl From<&Term> for Term {
31    fn from(i: &Term) -> Self {
32        match i {
33            Term::Variable(ref v) => Term::Variable(*v),
34            Term::Integer(ref i) => Term::Integer(*i),
35            Term::Str(ref s) => Term::Str(*s),
36            Term::Date(ref d) => Term::Date(*d),
37            Term::Bytes(ref b) => Term::Bytes(b.clone()),
38            Term::Bool(ref b) => Term::Bool(*b),
39            Term::Set(ref s) => Term::Set(s.clone()),
40        }
41    }
42}
43
44impl AsRef<Term> for Term {
45    fn as_ref(&self) -> &Term {
46        self
47    }
48}
49
50#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
51pub struct Predicate {
52    pub name: SymbolIndex,
53    pub terms: Vec<Term>,
54}
55
56impl Predicate {
57    pub fn new(name: SymbolIndex, terms: &[Term]) -> Predicate {
58        Predicate {
59            name,
60            terms: terms.to_vec(),
61        }
62    }
63}
64
65impl AsRef<Predicate> for Predicate {
66    fn as_ref(&self) -> &Predicate {
67        self
68    }
69}
70
71#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
72pub struct Fact {
73    pub predicate: Predicate,
74}
75
76impl Fact {
77    pub fn new(name: SymbolIndex, terms: &[Term]) -> Fact {
78        Fact {
79            predicate: Predicate::new(name, terms),
80        }
81    }
82}
83
84#[derive(Debug, Clone, Hash, PartialEq, Eq)]
85pub struct Rule {
86    pub head: Predicate,
87    pub body: Vec<Predicate>,
88    pub expressions: Vec<Expression>,
89    pub scopes: Vec<Scope>,
90}
91
92impl AsRef<Expression> for Expression {
93    fn as_ref(&self) -> &Expression {
94        self
95    }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct Check {
100    pub queries: Vec<Rule>,
101    pub kind: CheckKind,
102}
103
104impl fmt::Display for Fact {
105    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
106        write!(f, "{}({:?})", self.predicate.name, self.predicate.terms)
107    }
108}
109
110impl Rule {
111    /// gather all of the variables used in that rule
112    fn variables_set(&self) -> HashSet<u32> {
113        self.body
114            .iter()
115            .flat_map(|pred| {
116                pred.terms.iter().filter_map(|id| match id {
117                    Term::Variable(i) => Some(*i),
118                    _ => None,
119                })
120            })
121            .collect::<HashSet<_>>()
122    }
123
124    pub fn apply<'a, IT>(
125        &'a self,
126        facts: IT,
127        rule_origin: usize,
128        symbols: &'a SymbolTable,
129    ) -> impl Iterator<Item = Result<(Origin, Fact), error::Expression>> + 'a
130    where
131        IT: Iterator<Item = (&'a Origin, &'a Fact)> + Clone + 'a,
132    {
133        let head = self.head.clone();
134        let variables = MatchedVariables::new(self.variables_set());
135
136        CombineIt::new(variables, &self.body, facts, symbols)
137        .map(move |(origin, variables)| {
138                    let mut temporary_symbols = TemporarySymbolTable::new(symbols);
139                    for e in self.expressions.iter() {
140                        match e.evaluate(&variables, &mut temporary_symbols) {
141                            Ok(Term::Bool(true)) => {}
142                            Ok(Term::Bool(false)) => return Ok((origin, variables, false)),
143                            Ok(_) => return Err(error::Expression::InvalidType),
144                            Err(e) => {
145                                //println!("expr returned {:?}", res);
146                                return Err(e);
147                            }
148                        }
149                    }
150            Ok((origin, variables, true))
151        }).filter_map(move |res/*(mut origin,h, expression_res)*/| {
152            match res {
153                Ok((mut origin,h , expression_res)) => {
154                    if expression_res {
155                    let mut p = head.clone();
156                    for index in 0..p.terms.len() {
157                        match &p.terms[index] {
158                            Term::Variable(i) => match h.get(i) {
159                              Some(val) => p.terms[index] = val.clone(),
160                              None => {
161                                println!("error: variables that appear in the head should appear in the body and constraints as well");
162                                return None;
163                              }
164                            },
165                            _ => continue,
166                        };
167                    }
168        
169                    origin.insert(rule_origin);
170                    Some(Ok((origin, Fact { predicate: p })))
171                } else {None}
172                },
173                Err(e) => Some(Err(e))
174            }
175          
176        })
177    }
178
179    pub fn find_match(
180        &self,
181        facts: &FactSet,
182        origin: usize,
183        scope: &TrustedOrigins,
184        symbols: &SymbolTable,
185    ) -> Result<bool, Execution> {
186        let fact_it = facts.iterator(scope);
187        let mut it = self.apply(fact_it, origin, symbols);
188
189        let next = it.next();
190        match next {
191            None => Ok(false),
192            Some(Ok(_)) => Ok(true),
193            Some(Err(e)) => Err(Execution::Expression(e))
194        }
195    }
196
197    pub fn check_match_all(
198        &self,
199        facts: &FactSet,
200        scope: &TrustedOrigins,
201        symbols: &SymbolTable,
202    ) -> Result<bool, Execution> {
203        let fact_it = facts.iterator(scope);
204        let variables = MatchedVariables::new(self.variables_set());
205        let mut found = false;
206
207        for (_, variables) in CombineIt::new(variables, &self.body, fact_it, symbols) {
208            found = true;
209
210            let mut temporary_symbols = TemporarySymbolTable::new(symbols);
211            for e in self.expressions.iter() {
212                match e.evaluate(&variables, &mut temporary_symbols) {
213                    Ok(Term::Bool(true)) => {}
214                    Ok(Term::Bool(false)) => {
215                        //println!("expr returned {:?}", res);
216                        return Ok(false);
217                    },
218                    Ok(_) => return Err(error::Execution::Expression(error::Expression::InvalidType)),
219                    Err(e) => {
220                        return Err(error::Execution::Expression(e));
221                    }
222                }
223            }
224        }
225
226        Ok(found)
227    }
228
229    // use this to translate rules and checks from token to authorizer world without translating
230    // to a builder Rule first, because the builder Rule can contain a public key, so we would
231    // need to loo up then retranslate that key, while the datalog rule does not need to know about
232    // the key (the scope is driven by the authorizer's side)
233    pub fn translate(
234        &self,
235        origin_symbols: &SymbolTable,
236        target_symbols: &mut SymbolTable,
237    ) -> Result<Self, error::Format> {
238        Ok(Rule {
239            head: builder::Predicate::convert_from(&self.head, origin_symbols)?
240                .convert(target_symbols),
241            body: self
242                .body
243                .iter()
244                .map(|p| {
245                    builder::Predicate::convert_from(p, origin_symbols)
246                        .map(|p| p.convert(target_symbols))
247                })
248                .collect::<Result<Vec<_>, _>>()?,
249            expressions: self
250                .expressions
251                .iter()
252                .map(|c| {
253                    builder::Expression::convert_from(c, origin_symbols)
254                        .map(|e| e.convert(target_symbols))
255                })
256                .collect::<Result<Vec<_>, _>>()?,
257            scopes: self
258                .scopes
259                .iter()
260                .map(|s| {
261                    builder::Scope::convert_from(s, origin_symbols)
262                        .map(|s| s.convert(target_symbols))
263                })
264                .collect::<Result<Vec<_>, _>>()?,
265        })
266    }
267
268    pub fn validate_variables(&self, symbols: &SymbolTable) -> Result<(), String> {
269        let mut head_variables: std::collections::HashSet<u32> = self
270            .head
271            .terms
272            .iter()
273            .filter_map(|term| match term {
274                Term::Variable(s) => Some(*s),
275                _ => None,
276            })
277            .collect();
278
279        for predicate in self.body.iter() {
280            for term in predicate.terms.iter() {
281                if let Term::Variable(v) = term {
282                    head_variables.remove(v);
283                    if head_variables.is_empty() {
284                        return Ok(());
285                    }
286                }
287            }
288        }
289
290        if head_variables.is_empty() {
291            Ok(())
292        } else {
293            Err(format!(
294                    "rule head contains variables that are not used in predicates of the rule's body: {}",
295                    head_variables
296                    .iter()
297                    .map(|s| format!("${}", symbols.print_symbol_default(*s as u64)))
298                    .collect::<Vec<_>>()
299                    .join(", ")
300                    ))
301        }
302    }
303}
304
305/// recursive iterator for rule application
306pub struct CombineIt<'a, IT> {
307    variables: MatchedVariables,
308    predicates: &'a [Predicate],
309    all_facts: IT,
310    symbols: &'a SymbolTable,
311    current_facts: Box<dyn Iterator<Item = (&'a Origin, &'a Fact)> + 'a>,
312    current_it: Option<Box<dyn Iterator<Item = (Origin, HashMap<u32, Term>)> + 'a>>,
313}
314
315impl<'a, IT> CombineIt<'a, IT>
316where
317    IT: Iterator<Item = (&'a Origin, &'a Fact)> + Clone + 'a,
318{
319    pub fn new(
320        variables: MatchedVariables,
321        predicates: &'a [Predicate],
322        facts: IT,
323        symbols: &'a SymbolTable,
324    ) -> Self {
325        let current_facts: Box<dyn Iterator<Item = (&'a Origin, &'a Fact)> + 'a> =
326            if predicates.is_empty() {
327                Box::new(facts.clone())
328            } else {
329                let p = predicates[0].clone();
330                Box::new(
331                    facts
332                        .clone()
333                        .filter(move |fact| match_preds(&p, &fact.1.predicate)),
334                )
335            };
336
337        CombineIt {
338            variables,
339            predicates,
340            all_facts: facts,
341            symbols,
342            current_facts,
343            current_it: None,
344        }
345    }
346}
347
348impl<'a, IT> Iterator for CombineIt<'a, IT>
349where
350    IT: Iterator<Item = (&'a Origin, &'a Fact)> + Clone + 'a,
351    Self: 'a,
352{
353    type Item = (Origin, HashMap<u32, Term>);
354
355    fn next(&mut self) -> Option<(Origin, HashMap<u32, Term>)> {
356        // if we're the last iterator in the recursive chain, stop here
357        if self.predicates.is_empty() {
358            match self.variables.complete() {
359                None => return None,
360                // we got a complete set of variables, let's test the expressions
361                Some(variables) => {
362                    // if there were no predicates and expressions evaluated to true,
363                    // we should return a value, but only once. To prevent further
364                    // successful calls, we create a set of variables that cannot
365                    // possibly be completed, so the next call will fail
366                    self.variables = MatchedVariables::new([0].into());
367                    return Some((Origin::default(), variables));
368                }
369            }
370        }
371
372        loop {
373            if self.current_it.is_none() {
374                //fix the first predicate
375                let pred = &self.predicates[0];
376
377                loop {
378                    if let Some((current_origin, current_fact)) = self.current_facts.next() {
379                        // create a new MatchedVariables in which we fix variables we could unify
380                        // from our first predicate and the current fact
381                        let mut vars = self.variables.clone();
382                        let mut match_terms = true;
383                        for (key, id) in pred.terms.iter().zip(&current_fact.predicate.terms) {
384                            if let (Term::Variable(k), id) = (key, id) {
385                                if !vars.insert(*k, id) {
386                                    match_terms = false;
387                                }
388
389                                if !match_terms {
390                                    break;
391                                }
392                            }
393                        }
394
395                        if !match_terms {
396                            continue;
397                        }
398
399                        if self.predicates.len() == 1 {
400                            match vars.complete() {
401                                None => {
402                                    //println!("variables not complete, continue");
403                                    continue;
404                                }
405                                // we got a complete set of variables, let's test the expressions
406                                Some(variables) => {
407                                    return Some((current_origin.clone(), variables));
408                                }
409                            }
410                        } else {
411                            // create a new iterator with the matched variables, the rest of the predicates,
412                            // and all of the facts
413                            self.current_it = Some(Box::new(
414                                CombineIt::new(
415                                    vars,
416                                    &self.predicates[1..],
417                                    self.all_facts.clone(),
418                                    self.symbols,
419                                )
420                                .map(move |(origin, variables)| {
421                                    (origin.union(current_origin), variables)
422                                }),
423                            ));
424                        }
425                        break;
426                    } else {
427                        return None;
428                    }
429                }
430            }
431
432            if self.current_it.is_none() {
433                break None;
434            }
435
436            if let Some((origin, variables)) = self.current_it.as_mut().and_then(|it| it.next()) {
437                break Some((origin, variables));
438            } else {
439                self.current_it = None;
440            }
441        }
442    }
443}
444
445#[derive(Debug, Clone, PartialEq, Eq)]
446pub struct MatchedVariables {
447    pub variables: HashMap<u32, Option<Term>>,
448}
449
450impl MatchedVariables {
451    pub fn new(import: HashSet<u32>) -> Self {
452        MatchedVariables {
453            variables: import.iter().map(|key| (*key, None)).collect(),
454        }
455    }
456
457    pub fn insert(&mut self, key: u32, value: &Term) -> bool {
458        match self.variables.get(&key) {
459            Some(None) => {
460                self.variables.insert(key, Some(value.clone()));
461                true
462            }
463            Some(Some(v)) => value == v,
464            None => false,
465        }
466    }
467
468    pub fn is_complete(&self) -> bool {
469        self.variables.values().all(|v| v.is_some())
470    }
471
472    pub fn complete(&self) -> Option<HashMap<u32, Term>> {
473        let mut result = HashMap::new();
474        for (k, v) in self.variables.iter() {
475            match v {
476                Some(value) => result.insert(*k, value.clone()),
477                None => return None,
478            };
479        }
480        Some(result)
481    }
482}
483
484pub fn fact<I: AsRef<Term>>(name: SymbolIndex, terms: &[I]) -> Fact {
485    Fact {
486        predicate: Predicate {
487            name,
488            terms: terms.iter().map(|id| id.as_ref().clone()).collect(),
489        },
490    }
491}
492
493pub fn pred<I: AsRef<Term>>(name: SymbolIndex, terms: &[I]) -> Predicate {
494    Predicate {
495        name,
496        terms: terms.iter().map(|id| id.as_ref().clone()).collect(),
497    }
498}
499
500pub fn rule<I: AsRef<Term>, P: AsRef<Predicate>>(
501    head_name: SymbolIndex,
502    head_terms: &[I],
503    predicates: &[P],
504) -> Rule {
505    Rule {
506        head: pred(head_name, head_terms),
507        body: predicates.iter().map(|p| p.as_ref().clone()).collect(),
508        expressions: Vec::new(),
509        scopes: vec![],
510    }
511}
512
513pub fn expressed_rule<I: AsRef<Term>, P: AsRef<Predicate>, C: AsRef<Expression>>(
514    head_name: SymbolIndex,
515    head_terms: &[I],
516    predicates: &[P],
517    expressions: &[C],
518) -> Rule {
519    Rule {
520        head: pred(head_name, head_terms),
521        body: predicates.iter().map(|p| p.as_ref().clone()).collect(),
522        expressions: expressions.iter().map(|c| c.as_ref().clone()).collect(),
523        scopes: vec![],
524    }
525}
526
527pub fn int(i: i64) -> Term {
528    Term::Integer(i)
529}
530
531/*pub fn string(s: &str) -> Term {
532    Term::Str(s.to_string())
533}*/
534
535pub fn date(t: &SystemTime) -> Term {
536    let dur = t.duration_since(UNIX_EPOCH).unwrap();
537    Term::Date(dur.as_secs())
538}
539
540pub fn var(syms: &mut SymbolTable, name: &str) -> Term {
541    let id = syms.insert(name);
542    Term::Variable(id as u32)
543}
544
545pub fn match_preds(rule_pred: &Predicate, fact_pred: &Predicate) -> bool {
546    rule_pred.name == fact_pred.name
547        && rule_pred.terms.len() == fact_pred.terms.len()
548        && rule_pred
549            .terms
550            .iter()
551            .zip(&fact_pred.terms)
552            .all(|(fid, pid)| match (fid, pid) {
553                // the fact should not contain variables
554                (_, Term::Variable(_)) => false,
555                (Term::Variable(_), _) => true,
556                (Term::Integer(i), Term::Integer(j)) => i == j,
557                (Term::Str(i), Term::Str(j)) => i == j,
558                (Term::Date(i), Term::Date(j)) => i == j,
559                (Term::Bytes(i), Term::Bytes(j)) => i == j,
560                (Term::Bool(i), Term::Bool(j)) => i == j,
561                (Term::Set(i), Term::Set(j)) => i == j,
562                _ => false,
563            })
564}
565
566#[derive(Debug, Clone, Default)]
567pub struct World {
568    pub facts: FactSet,
569    pub rules: RuleSet,
570    pub iterations: u64,
571}
572
573impl World {
574    pub fn new() -> Self {
575        World::default()
576    }
577
578    pub fn add_fact(&mut self, origin: &Origin, fact: Fact) {
579        self.facts.insert(origin, fact);
580    }
581
582    pub fn add_rule(&mut self, origin: usize, scope: &TrustedOrigins, rule: Rule) {
583        self.rules.insert(origin, scope, rule);
584    }
585
586    pub fn run(&mut self, symbols: &SymbolTable) -> Result<(), crate::error::Execution> {
587        self.run_with_limits(symbols, RunLimits::default())
588    }
589
590    pub fn run_with_limits(
591        &mut self,
592        symbols: &SymbolTable,
593        limits: RunLimits,
594    ) -> Result<(), crate::error::Execution> {
595        let start = Instant::now();
596        let time_limit = start + limits.max_time;
597        let mut index = 0;
598
599        let res = loop {
600            let mut new_facts = FactSet::default();
601
602            for (scope, rules) in self.rules.inner.iter() {
603                let it = self.facts.iterator(scope);
604                for (origin, rule) in rules {
605                    for res in rule.apply(it.clone(), *origin, symbols) {
606                        match res {
607                            Ok((origin,fact)) => {
608                                new_facts.insert(&origin, fact);
609
610                            },
611                            Err(e)  => {
612                                return Err(Execution::Expression(e));
613                            }
614                        }
615                    }
616                    //println!("new_facts after applying {:?}:\n{:#?}", rule, new_facts);
617                }
618            }
619
620            let len = self.facts.len();
621            self.facts.merge(new_facts);
622            if self.facts.len() == len {
623                break Ok(());
624            }
625
626            index += 1;
627            if index == limits.max_iterations {
628                break Err(Execution::RunLimit( crate::error::RunLimit::TooManyIterations));
629            }
630
631            if self.facts.len() >= limits.max_facts as usize {
632                break Err(Execution::RunLimit(crate::error::RunLimit::TooManyFacts));
633            }
634
635            let now = Instant::now();
636            if now >= time_limit {
637                break Err(Execution::RunLimit(crate::error::RunLimit::Timeout));
638            }
639        };
640
641        self.iterations += index;
642
643        res
644    }
645
646    /*pub fn query(&self, pred: Predicate) -> Vec<&Fact> {
647        self.facts
648            .iter()
649            .filter(|f| {
650                f.predicate.name == pred.name
651                    && f.predicate.terms.iter().zip(&pred.terms).all(|(fid, pid)| {
652                        match (fid, pid) {
653                            //(Term::Symbol(_), Term::Variable(_)) => true,
654                            //(Term::Symbol(i), Term::Symbol(ref j)) => i == j,
655                            (_, Term::Variable(_)) => true,
656                            (Term::Integer(i), Term::Integer(ref j)) => i == j,
657                            (Term::Str(i), Term::Str(ref j)) => i == j,
658                            (Term::Date(i), Term::Date(ref j)) => i == j,
659                            (Term::Bytes(i), Term::Bytes(ref j)) => i == j,
660                            (Term::Bool(i), Term::Bool(ref j)) => i == j,
661                            (Term::Set(i), Term::Set(ref j)) => i == j,
662                            _ => false,
663                        }
664                    })
665            })
666            .collect::<Vec<_>>()
667    }*/
668
669    pub fn query_rule(
670        &self,
671        rule: Rule,
672        origin: usize,
673        scope: &TrustedOrigins,
674        symbols: &SymbolTable,
675    ) -> Result<FactSet, Execution> {
676        let mut new_facts = FactSet::default();
677        let it = self.facts.iterator(scope);
678        //new_facts.extend(rule.apply(it, origin, symbols));
679        for res in rule.apply(it.clone(), origin, symbols) {
680            match res {
681                Ok((origin,fact)) => {
682                    new_facts.insert(&origin, fact);
683
684                },
685                Err(e)  => {
686                    return Err(Execution::Expression(e));
687                }
688            }
689        }
690
691        Ok(new_facts)
692    }
693
694    pub fn query_match(
695        &self,
696        rule: Rule,
697        origin: usize,
698        scope: &TrustedOrigins,
699        symbols: &SymbolTable,
700    ) -> Result<bool, Execution> {
701        rule.find_match(&self.facts, origin, scope, symbols)
702    }
703
704    pub fn query_match_all(
705        &self,
706        rule: Rule,
707        scope: &TrustedOrigins,
708        symbols: &SymbolTable,
709    ) -> Result<bool, Execution> {
710        rule.check_match_all(&self.facts, scope, symbols)
711    }
712}
713
714/// runtime limits for the Datalog engine
715#[derive(Debug, Clone)]
716pub struct RunLimits {
717    /// maximum number of Datalog facts (memory usage)
718    pub max_facts: u64,
719    /// maximum number of iterations of the rules applications (prevents degenerate rules)
720    pub max_iterations: u64,
721    /// maximum execution time
722    pub max_time: Duration,
723}
724
725impl std::default::Default for RunLimits {
726    fn default() -> Self {
727        RunLimits {
728            max_facts: 1000,
729            max_iterations: 100,
730            max_time: Duration::from_millis(1),
731        }
732    }
733}
734
735#[derive(Clone, Debug, Default)]
736pub struct FactSet {
737    pub(crate) inner: HashMap<Origin, HashSet<Fact>>,
738}
739
740impl FactSet {
741    pub fn insert(&mut self, origin: &Origin, fact: Fact) {
742        match self.inner.get_mut(origin) {
743            None => {
744                let mut set = HashSet::new();
745                set.insert(fact);
746                self.inner.insert(origin.clone(), set);
747            }
748            Some(set) => {
749                set.insert(fact);
750            }
751        }
752    }
753
754    pub fn len(&self) -> usize {
755        self.inner.values().fold(0, |acc, set| acc + set.len())
756    }
757
758    pub fn is_empty(&self) -> bool {
759        self.inner.values().all(|set| set.is_empty())
760    }
761
762    pub fn iterator<'a>(
763        &'a self,
764        block_ids: &'a TrustedOrigins,
765    ) -> impl Iterator<Item = (&Origin, &Fact)> + Clone {
766        self.inner
767            .iter()
768            .filter_map(move |(ids, facts)| {
769                if block_ids.contains(ids) {
770                    Some(facts.iter().map(move |fact| (ids, fact)))
771                } else {
772                    None
773                }
774            })
775            .flatten()
776    }
777
778    pub fn iter_all(&self) -> impl Iterator<Item = (&Origin, &Fact)> + Clone {
779        self.inner
780            .iter()
781            .flat_map(move |(ids, facts)| facts.iter().map(move |fact| (ids, fact)))
782    }
783
784    pub fn merge(&mut self, other: FactSet) {
785        for (origin, facts) in other.inner {
786            let entry = self.inner.entry(origin).or_default();
787            entry.extend(facts.into_iter());
788        }
789    }
790}
791
792impl Extend<(Origin, Fact)> for FactSet {
793    fn extend<T: IntoIterator<Item = (Origin, Fact)>>(&mut self, iter: T) {
794        for (origin, fact) in iter {
795            let entry = self.inner.entry(origin).or_default();
796            entry.insert(fact);
797        }
798    }
799}
800
801impl IntoIterator for FactSet {
802    type Item = (Origin, Fact);
803
804    type IntoIter = Box<dyn Iterator<Item = (Origin, Fact)>>;
805
806    fn into_iter(self) -> Self::IntoIter {
807        Box::new(
808            self.inner.into_iter().flat_map(move |(ids, facts)| {
809                facts.into_iter().map(move |fact| (ids.clone(), fact))
810            }),
811        )
812    }
813}
814
815#[derive(Clone, Debug, Default)]
816pub struct RuleSet {
817    pub inner: HashMap<TrustedOrigins, Vec<(usize, Rule)>>,
818}
819
820impl RuleSet {
821    pub fn insert(&mut self, origin: usize, scope: &TrustedOrigins, rule: Rule) {
822        match self.inner.get_mut(scope) {
823            None => {
824                self.inner.insert(scope.clone(), vec![(origin, rule)]);
825            }
826            Some(set) => {
827                set.push((origin, rule));
828            }
829        }
830    }
831
832    pub fn iter_all(&self) -> impl Iterator<Item = (&TrustedOrigins, &Rule)> + Clone {
833        self.inner
834            .iter()
835            .flat_map(move |(ids, rules)| rules.iter().map(move |(_, rule)| (ids, rule)))
836    }
837}
838
839pub struct SchemaVersion {
840    contains_scopes: bool,
841    contains_v4: bool,
842    contains_check_all: bool,
843}
844
845impl SchemaVersion {
846    pub fn version(&self) -> u32 {
847        if self.contains_scopes || self.contains_v4 || self.contains_check_all {
848            4
849        } else {
850            MIN_SCHEMA_VERSION
851        }
852    }
853
854    pub fn check_compatibility(&self, version: u32) -> Result<(), error::Format> {
855        if version < 4 {
856            if self.contains_scopes {
857                Err(error::Format::DeserializationError(
858                    "v3 blocks must not have scopes".to_string(),
859                ))
860            } else if self.contains_v4 {
861                Err(error::Format::DeserializationError(
862                    "v3 blocks must not have v4 operators (bitwise operators or !=)".to_string(),
863                ))
864            } else if self.contains_check_all {
865                Err(error::Format::DeserializationError(
866                    "v3 blocks must not have use all".to_string(),
867                ))
868            } else {
869                Ok(())
870            }
871        } else {
872            Ok(())
873        }
874    }
875}
876
877/// Determine the schema version given the elements of a block.
878pub fn get_schema_version(
879    _facts: &[Fact],
880    rules: &[Rule],
881    checks: &[Check],
882    scopes: &[Scope],
883) -> SchemaVersion {
884    let contains_scopes = !scopes.is_empty()
885        || rules.iter().any(|r: &Rule| !r.scopes.is_empty())
886        || checks
887            .iter()
888            .any(|c: &Check| c.queries.iter().any(|q| !q.scopes.is_empty()));
889
890    let contains_check_all = checks.iter().any(|c: &Check| c.kind == CheckKind::All);
891
892    let contains_v4 = rules.iter().any(|rule| contains_v4_op(&rule.expressions))
893        || checks.iter().any(|check| {
894            check
895                .queries
896                .iter()
897                .any(|query| contains_v4_op(&query.expressions))
898        });
899
900    SchemaVersion {
901        contains_scopes,
902        contains_v4,
903        contains_check_all,
904    }
905}
906
907/// Determine whether any of the expression contain a v4 operator.
908/// Bitwise operators and != are only supported in biscuits v4+
909pub fn contains_v4_op(expressions: &[Expression]) -> bool {
910    expressions.iter().any(|expression| {
911        expression.ops.iter().any(|op| {
912            if let Op::Binary(binary) = op {
913                match binary {
914                    Binary::BitwiseAnd
915                    | Binary::BitwiseOr
916                    | Binary::BitwiseXor
917                    | Binary::NotEqual => return true,
918                    _ => return false,
919                }
920            }
921            false
922        })
923    })
924}
925
926#[cfg(test)]
927mod tests {
928    use super::*;
929    use std::time::Duration;
930
931    #[test]
932    fn family() {
933        let mut w = World::new();
934        let mut syms = SymbolTable::new();
935
936        let a = syms.add("A");
937        let b = syms.add("B");
938        let c = syms.add("C");
939        let d = syms.add("D");
940        let e = syms.add("e");
941        let parent = syms.insert("parent");
942        let grandparent = syms.insert("grandparent");
943
944        w.add_fact(&[0].iter().collect(), fact(parent, &[&a, &b]));
945        w.add_fact(&[0].iter().collect(), fact(parent, &[&b, &c]));
946        w.add_fact(&[0].iter().collect(), fact(parent, &[&c, &d]));
947
948        let r1 = rule(
949            grandparent,
950            &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
951            &[
952                pred(
953                    parent,
954                    &[var(&mut syms, "grandparent"), var(&mut syms, "parent")],
955                ),
956                pred(
957                    parent,
958                    &[var(&mut syms, "parent"), var(&mut syms, "grandchild")],
959                ),
960            ],
961        );
962
963        println!("symbols: {:?}", syms);
964        println!("testing r1: {}", syms.print_rule(&r1));
965        let query_rule_result = w.query_rule(r1, 0, &[0].iter().collect(), &syms);
966        println!("grandparents query_rules: {:?}", query_rule_result);
967        println!("current facts: {:?}", w.facts);
968
969        let r2 = rule(
970            grandparent,
971            &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
972            &[
973                pred(
974                    parent,
975                    &[var(&mut syms, "grandparent"), var(&mut syms, "parent")],
976                ),
977                pred(
978                    parent,
979                    &[var(&mut syms, "parent"), var(&mut syms, "grandchild")],
980                ),
981            ],
982        );
983
984        println!("adding r2: {}", syms.print_rule(&r2));
985        w.add_rule(0, &[0].iter().collect(), r2);
986
987        w.run_with_limits(&syms, RunLimits {
988             max_time: Duration::from_secs(10),
989            ..Default::default()
990        }).unwrap();
991
992        println!("parents:");
993        let res = w.query_rule(
994            rule::<Term, Predicate>(
995                parent,
996                &[var(&mut syms, "parent"), var(&mut syms, "child")],
997                &[pred(
998                    parent,
999                    &[var(&mut syms, "parent"), var(&mut syms, "child")],
1000                )],
1001            ),
1002            0,
1003            &[0].iter().collect(),
1004            &syms,
1005        ).unwrap();
1006
1007        for (origin, fact) in res.iterator(&[0].iter().collect()) {
1008            println!("\t{:?}\t{}", origin, syms.print_fact(fact));
1009        }
1010
1011        println!(
1012            "parents of B: {:?}",
1013            w.query_rule(
1014                rule::<&Term, Predicate>(
1015                    parent,
1016                    &[&var(&mut syms, "parent"), &b],
1017                    &[pred(parent, &[&var(&mut syms, "parent"), &b])]
1018                ),
1019                0,
1020                &[0].iter().collect(),
1021                &syms
1022            )
1023        );
1024        println!(
1025            "grandparents: {:?}",
1026            w.query_rule(
1027                rule::<Term, Predicate>(
1028                    grandparent,
1029                    &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
1030                    &[pred(
1031                        grandparent,
1032                        &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")]
1033                    )]
1034                ),
1035                0,
1036                &[0].iter().collect(),
1037                &syms
1038            )
1039        );
1040        w.add_fact(&[0].iter().collect(), fact(parent, &[&c, &e]));
1041        w.run(&syms).unwrap();
1042        let res = w.query_rule(
1043            rule::<Term, Predicate>(
1044                grandparent,
1045                &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
1046                &[pred(
1047                    grandparent,
1048                    &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
1049                )],
1050            ),
1051            0,
1052            &[0].iter().collect(),
1053            &syms,
1054        ).unwrap();
1055        println!("grandparents after inserting parent(C, E): {:?}", res);
1056
1057        let res = res
1058            .iter_all()
1059            .map(|(_origin, fact)| fact)
1060            .cloned()
1061            .collect::<HashSet<_>>();
1062        let compared = (vec![
1063            fact(grandparent, &[&a, &c]),
1064            fact(grandparent, &[&b, &d]),
1065            fact(grandparent, &[&b, &e]),
1066        ])
1067        .drain(..)
1068        .collect::<HashSet<_>>();
1069        assert_eq!(res, compared);
1070
1071        /*w.add_rule(rule("siblings", &[var("A"), var("B")], &[
1072          pred(parent, &[var(parent), var("A")]),
1073          pred(parent, &[var(parent), var("B")])
1074        ]));
1075
1076        w.run();
1077        println!("siblings: {:#?}", w.query(pred("siblings", &[var("A"), var("B")])));
1078        */
1079    }
1080
1081    #[test]
1082    fn numbers() {
1083        let mut w = World::new();
1084        let mut syms = SymbolTable::new();
1085
1086        let abc = syms.add("abc");
1087        let def = syms.add("def");
1088        let ghi = syms.add("ghi");
1089        let jkl = syms.add("jkl");
1090        let mno = syms.add("mno");
1091        let aaa = syms.add("AAA");
1092        let bbb = syms.add("BBB");
1093        let ccc = syms.add("CCC");
1094        let t1 = syms.insert("t1");
1095        let t2 = syms.insert("t2");
1096        let join = syms.insert("join");
1097
1098        w.add_fact(&[0].iter().collect(), fact(t1, &[&int(0), &abc]));
1099        w.add_fact(&[0].iter().collect(), fact(t1, &[&int(1), &def]));
1100        w.add_fact(&[0].iter().collect(), fact(t1, &[&int(2), &ghi]));
1101        w.add_fact(&[0].iter().collect(), fact(t1, &[&int(3), &jkl]));
1102        w.add_fact(&[0].iter().collect(), fact(t1, &[&int(4), &mno]));
1103
1104        w.add_fact(&[0].iter().collect(), fact(t2, &[&int(0), &aaa, &int(0)]));
1105        w.add_fact(&[0].iter().collect(), fact(t2, &[&int(1), &bbb, &int(0)]));
1106        w.add_fact(&[0].iter().collect(), fact(t2, &[&int(2), &ccc, &int(1)]));
1107
1108        let res = w.query_rule(
1109            rule(
1110                join,
1111                &[var(&mut syms, "left"), var(&mut syms, "right")],
1112                &[
1113                    pred(t1, &[var(&mut syms, "id"), var(&mut syms, "left")]),
1114                    pred(
1115                        t2,
1116                        &[
1117                            var(&mut syms, "t2_id"),
1118                            var(&mut syms, "right"),
1119                            var(&mut syms, "id"),
1120                        ],
1121                    ),
1122                ],
1123            ),
1124            0,
1125            &[0].iter().collect(),
1126            &syms,
1127        ).unwrap();
1128
1129        for (_, fact) in res.iter_all() {
1130            println!("\t{}", syms.print_fact(fact));
1131        }
1132
1133        let res2 = res
1134            .iter_all()
1135            .map(|(_origin, fact)| fact)
1136            .cloned()
1137            .collect::<HashSet<_>>();
1138        let compared = (vec![
1139            fact(join, &[&abc, &aaa]),
1140            fact(join, &[&abc, &bbb]),
1141            fact(join, &[&def, &ccc]),
1142        ])
1143        .drain(..)
1144        .collect::<HashSet<_>>();
1145        assert_eq!(res2, compared);
1146
1147        // test constraints
1148        let res = w.query_rule(
1149            expressed_rule(
1150                join,
1151                &[var(&mut syms, "left"), var(&mut syms, "right")],
1152                &[
1153                    pred(t1, &[var(&mut syms, "id"), var(&mut syms, "left")]),
1154                    pred(
1155                        t2,
1156                        &[
1157                            var(&mut syms, "t2_id"),
1158                            var(&mut syms, "right"),
1159                            var(&mut syms, "id"),
1160                        ],
1161                    ),
1162                ],
1163                &[Expression {
1164                    ops: vec![
1165                        Op::Value(var(&mut syms, "id")),
1166                        Op::Value(Term::Integer(1)),
1167                        Op::Binary(Binary::LessThan),
1168                    ],
1169                }],
1170            ),
1171            0,
1172            &[0].iter().collect(),
1173            &syms,
1174        ).unwrap();
1175
1176        for (_, fact) in res.iter_all() {
1177            println!("\t{}", syms.print_fact(fact));
1178        }
1179
1180        let res2 = res
1181            .iter_all()
1182            .map(|(_origin, fact)| fact)
1183            .cloned()
1184            .collect::<HashSet<_>>();
1185        let compared = (vec![fact(join, &[&abc, &aaa]), fact(join, &[&abc, &bbb])])
1186            .drain(..)
1187            .collect::<HashSet<_>>();
1188        assert_eq!(res2, compared);
1189    }
1190
1191    #[test]
1192    fn str() {
1193        let mut w = World::new();
1194        let mut syms = SymbolTable::new();
1195
1196        let app_0 = syms.add("app_0");
1197        let app_1 = syms.add("app_1");
1198        let app_2 = syms.add("app_2");
1199        let route = syms.insert("route");
1200        let suff = syms.insert("route suffix");
1201        let example = syms.add("example.com");
1202        let test_com = syms.add("test.com");
1203        let test_fr = syms.add("test.fr");
1204        let www_example = syms.add("www.example.com");
1205        let mx_example = syms.add("mx.example.com");
1206
1207        w.add_fact(
1208            &[0].iter().collect(),
1209            fact(route, &[&int(0), &app_0, &example]),
1210        );
1211        w.add_fact(
1212            &[0].iter().collect(),
1213            fact(route, &[&int(1), &app_1, &test_com]),
1214        );
1215        w.add_fact(
1216            &[0].iter().collect(),
1217            fact(route, &[&int(2), &app_2, &test_fr]),
1218        );
1219        w.add_fact(
1220            &[0].iter().collect(),
1221            fact(route, &[&int(3), &app_0, &www_example]),
1222        );
1223        w.add_fact(
1224            &[0].iter().collect(),
1225            fact(route, &[&int(4), &app_1, &mx_example]),
1226        );
1227
1228        fn test_suffix(
1229            w: &World,
1230            syms: &mut SymbolTable,
1231            suff: SymbolIndex,
1232            route: SymbolIndex,
1233            suffix: &str,
1234        ) -> Vec<Fact> {
1235            let id_suff = syms.add(suffix);
1236            w.query_rule(
1237                expressed_rule(
1238                    suff,
1239                    &[var(syms, "app_id"), var(syms, "domain_name")],
1240                    &[pred(
1241                        route,
1242                        &[
1243                            var(syms, "route_id"),
1244                            var(syms, "app_id"),
1245                            var(syms, "domain_name"),
1246                        ],
1247                    )],
1248                    &[Expression {
1249                        ops: vec![
1250                            Op::Value(var(syms, "domain_name")),
1251                            Op::Value(id_suff),
1252                            Op::Binary(Binary::Suffix),
1253                        ],
1254                    }],
1255                ),
1256                0,
1257                &[0].iter().collect(),
1258                syms,
1259            ).unwrap()
1260            .iter_all()
1261            .map(|(_, fact)| fact.clone())
1262            .collect()
1263        }
1264
1265        let res = test_suffix(&w, &mut syms, suff, route, ".fr");
1266        for fact in &res {
1267            println!("\t{}", syms.print_fact(fact));
1268        }
1269
1270        let res2 = res.iter().cloned().collect::<HashSet<_>>();
1271        let compared = (vec![fact(suff, &[&app_2, &test_fr])])
1272            .drain(..)
1273            .collect::<HashSet<_>>();
1274        assert_eq!(res2, compared);
1275
1276        let res = test_suffix(&w, &mut syms, suff, route, "example.com");
1277        for fact in &res {
1278            println!("\t{}", syms.print_fact(fact));
1279        }
1280
1281        let res2 = res.iter().cloned().collect::<HashSet<_>>();
1282        let compared = (vec![
1283            fact(suff, &[&app_0, &example]),
1284            fact(suff, &[&app_0, &www_example]),
1285            fact(suff, &[&app_1, &mx_example]),
1286        ])
1287        .drain(..)
1288        .collect::<HashSet<_>>();
1289        assert_eq!(res2, compared);
1290    }
1291
1292    #[test]
1293    fn date_constraint() {
1294        let mut w = World::new();
1295        let mut syms = SymbolTable::new();
1296
1297        let t1 = SystemTime::now();
1298        println!("t1 = {:?}", t1);
1299        let t2 = t1 + Duration::from_secs(10);
1300        println!("t2 = {:?}", t2);
1301        let t3 = t2 + Duration::from_secs(30);
1302        println!("t3 = {:?}", t3);
1303
1304        let t2_timestamp = t2.duration_since(UNIX_EPOCH).unwrap().as_secs();
1305
1306        let abc = syms.add("abc");
1307        let def = syms.add("def");
1308        let x = syms.insert("x");
1309        let before = syms.insert("before");
1310        let after = syms.insert("after");
1311
1312        w.add_fact(&[0].iter().collect(), fact(x, &[&date(&t1), &abc]));
1313        w.add_fact(&[0].iter().collect(), fact(x, &[&date(&t3), &def]));
1314
1315        let r1 = expressed_rule(
1316            before,
1317            &[var(&mut syms, "date"), var(&mut syms, "val")],
1318            &[pred(x, &[var(&mut syms, "date"), var(&mut syms, "val")])],
1319            &[
1320                Expression {
1321                    ops: vec![
1322                        Op::Value(var(&mut syms, "date")),
1323                        Op::Value(Term::Date(t2_timestamp)),
1324                        Op::Binary(Binary::LessOrEqual),
1325                    ],
1326                },
1327                Expression {
1328                    ops: vec![
1329                        Op::Value(var(&mut syms, "date")),
1330                        Op::Value(Term::Date(0)),
1331                        Op::Binary(Binary::GreaterOrEqual),
1332                    ],
1333                },
1334            ],
1335        );
1336
1337        println!("testing r1: {}", syms.print_rule(&r1));
1338        let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap();
1339        for (_, fact) in res.iter_all() {
1340            println!("\t{}", syms.print_fact(fact));
1341        }
1342
1343        let res2 = res
1344            .iter_all()
1345            .map(|(_origin, fact)| fact)
1346            .cloned()
1347            .collect::<HashSet<_>>();
1348        let compared = (vec![fact(before, &[&date(&t1), &abc])])
1349            .drain(..)
1350            .collect::<HashSet<_>>();
1351        assert_eq!(res2, compared);
1352
1353        let r2 = expressed_rule(
1354            after,
1355            &[var(&mut syms, "date"), var(&mut syms, "val")],
1356            &[pred(x, &[var(&mut syms, "date"), var(&mut syms, "val")])],
1357            &[
1358                Expression {
1359                    ops: vec![
1360                        Op::Value(var(&mut syms, "date")),
1361                        Op::Value(Term::Date(t2_timestamp)),
1362                        Op::Binary(Binary::GreaterOrEqual),
1363                    ],
1364                },
1365                Expression {
1366                    ops: vec![
1367                        Op::Value(var(&mut syms, "date")),
1368                        Op::Value(Term::Date(0)),
1369                        Op::Binary(Binary::GreaterOrEqual),
1370                    ],
1371                },
1372            ],
1373        );
1374
1375        println!("testing r2: {}", syms.print_rule(&r2));
1376        let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap();
1377        for (_, fact) in res.iter_all() {
1378            println!("\t{}", syms.print_fact(fact));
1379        }
1380
1381        let res2 = res
1382            .iter_all()
1383            .map(|(_, fact)| fact)
1384            .cloned()
1385            .collect::<HashSet<_>>();
1386        let compared = (vec![fact(after, &[&date(&t3), &def])])
1387            .drain(..)
1388            .collect::<HashSet<_>>();
1389        assert_eq!(res2, compared);
1390    }
1391
1392    #[test]
1393    fn set_constraint() {
1394        let mut w = World::new();
1395        let mut syms = SymbolTable::new();
1396
1397        let abc = syms.add("abc");
1398        let def = syms.add("def");
1399        let x = syms.insert("x");
1400        let int_set = syms.insert("int_set");
1401        let symbol_set = syms.insert("symbol_set");
1402        let string_set = syms.insert("string_set");
1403        let test = syms.add("test");
1404        let hello = syms.add("hello");
1405        let aaa = syms.add("zzz");
1406
1407        w.add_fact(&[0].iter().collect(), fact(x, &[&abc, &int(0), &test]));
1408        w.add_fact(&[0].iter().collect(), fact(x, &[&def, &int(2), &hello]));
1409
1410        let res = w.query_rule(
1411            expressed_rule(
1412                int_set,
1413                &[var(&mut syms, "sym"), var(&mut syms, "str")],
1414                &[pred(
1415                    x,
1416                    &[
1417                        var(&mut syms, "sym"),
1418                        var(&mut syms, "int"),
1419                        var(&mut syms, "str"),
1420                    ],
1421                )],
1422                &[Expression {
1423                    ops: vec![
1424                        Op::Value(Term::Set(
1425                            [Term::Integer(0), Term::Integer(1)]
1426                                .iter()
1427                                .cloned()
1428                                .collect(),
1429                        )),
1430                        Op::Value(var(&mut syms, "int")),
1431                        Op::Binary(Binary::Contains),
1432                    ],
1433                }],
1434            ),
1435            0,
1436            &[0].iter().collect(),
1437            &syms,
1438        ).unwrap();
1439
1440        for (_, fact) in res.iter_all() {
1441            println!("\t{}", syms.print_fact(fact));
1442        }
1443
1444        let res2 = res
1445            .iter_all()
1446            .map(|(_, fact)| fact)
1447            .cloned()
1448            .collect::<HashSet<_>>();
1449        let compared = (vec![fact(int_set, &[&abc, &test])])
1450            .drain(..)
1451            .collect::<HashSet<_>>();
1452        assert_eq!(res2, compared);
1453
1454        let abc_sym_id = syms.add("abc");
1455        let ghi_sym_id = syms.add("ghi");
1456
1457        let res = w.query_rule(
1458            expressed_rule(
1459                symbol_set,
1460                &[
1461                    var(&mut syms, "symbol"),
1462                    var(&mut syms, "int"),
1463                    var(&mut syms, "str"),
1464                ],
1465                &[pred(
1466                    x,
1467                    &[
1468                        var(&mut syms, "symbol"),
1469                        var(&mut syms, "int"),
1470                        var(&mut syms, "str"),
1471                    ],
1472                )],
1473                &[Expression {
1474                    ops: vec![
1475                        Op::Value(Term::Set(
1476                            [abc_sym_id, ghi_sym_id].iter().cloned().collect(),
1477                        )),
1478                        Op::Value(var(&mut syms, "symbol")),
1479                        Op::Binary(Binary::Contains),
1480                        Op::Unary(Unary::Negate),
1481                    ],
1482                }],
1483            ),
1484            0,
1485            &[0].iter().collect(),
1486            &syms,
1487        ).unwrap();
1488
1489        for (_, fact) in res.iter_all() {
1490            println!("\t{}", syms.print_fact(fact));
1491        }
1492
1493        let res2 = res
1494            .iter_all()
1495            .map(|(_, fact)| fact)
1496            .cloned()
1497            .collect::<HashSet<_>>();
1498        let compared = (vec![fact(symbol_set, &[&def, &int(2), &hello])])
1499            .drain(..)
1500            .collect::<HashSet<_>>();
1501        assert_eq!(res2, compared);
1502
1503        let res = w.query_rule(
1504            expressed_rule(
1505                string_set,
1506                &[
1507                    var(&mut syms, "sym"),
1508                    var(&mut syms, "int"),
1509                    var(&mut syms, "str"),
1510                ],
1511                &[pred(
1512                    x,
1513                    &[
1514                        var(&mut syms, "sym"),
1515                        var(&mut syms, "int"),
1516                        var(&mut syms, "str"),
1517                    ],
1518                )],
1519                &[Expression {
1520                    ops: vec![
1521                        Op::Value(Term::Set([test.clone(), aaa].iter().cloned().collect())),
1522                        Op::Value(var(&mut syms, "str")),
1523                        Op::Binary(Binary::Contains),
1524                    ],
1525                }],
1526            ),
1527            0,
1528            &[0].iter().collect(),
1529            &syms,
1530        ).unwrap();
1531
1532        for (_, fact) in res.iter_all() {
1533            println!("\t{}", syms.print_fact(fact));
1534        }
1535
1536        let res2 = res
1537            .iter_all()
1538            .map(|(_, fact)| fact)
1539            .cloned()
1540            .collect::<HashSet<_>>();
1541        let compared = (vec![fact(string_set, &[&abc, &int(0), &test])])
1542            .drain(..)
1543            .collect::<HashSet<_>>();
1544        assert_eq!(res2, compared);
1545    }
1546
1547    #[test]
1548    fn resource() {
1549        let mut w = World::new();
1550        let mut syms = SymbolTable::new();
1551
1552        let resource = syms.insert("resource");
1553        let operation = syms.insert("operation");
1554        let right = syms.insert("right");
1555        let file1 = syms.add("file1");
1556        let file2 = syms.add("file2");
1557        let read = syms.add("read");
1558        let write = syms.add("write");
1559        let check1 = syms.insert("check1");
1560        let check2 = syms.insert("check2");
1561
1562        w.add_fact(&[0].iter().collect(), fact(resource, &[&file2]));
1563        w.add_fact(&[0].iter().collect(), fact(operation, &[&write]));
1564        w.add_fact(&[0].iter().collect(), fact(right, &[&file1, &read]));
1565        w.add_fact(&[0].iter().collect(), fact(right, &[&file2, &read]));
1566        w.add_fact(&[0].iter().collect(), fact(right, &[&file1, &write]));
1567
1568        let res = w.query_rule(
1569            rule(check1, &[&file1], &[pred(resource, &[&file1])]),
1570            0,
1571            &[0].iter().collect(),
1572            &syms,
1573        ).unwrap();
1574
1575        for (_, fact) in res.iter_all() {
1576            println!("\t{}", syms.print_fact(fact));
1577        }
1578
1579        assert!(res.is_empty());
1580
1581        let res = w.query_rule(
1582            rule(
1583                check2,
1584                &[Term::Variable(0)],
1585                &[
1586                    pred(resource, &[&Term::Variable(0)]),
1587                    pred(operation, &[&read]),
1588                    pred(right, &[&Term::Variable(0), &read]),
1589                ],
1590            ),
1591            0,
1592            &[0].iter().collect(),
1593            &syms,
1594        ).unwrap();
1595
1596        for (_, fact) in res.iter_all() {
1597            println!("\t{}", syms.print_fact(fact));
1598        }
1599
1600        assert!(res.is_empty());
1601    }
1602
1603    #[test]
1604    fn int_expr() {
1605        let mut w = World::new();
1606        let mut syms = SymbolTable::new();
1607
1608        let abc = syms.add("abc");
1609        let def = syms.add("def");
1610        let x = syms.insert("x");
1611        let less_than = syms.insert("less_than");
1612
1613        w.add_fact(&[0].iter().collect(), fact(x, &[&int(-2), &abc]));
1614        w.add_fact(&[0].iter().collect(), fact(x, &[&int(0), &def]));
1615
1616        let r1 = expressed_rule(
1617            less_than,
1618            &[var(&mut syms, "nb"), var(&mut syms, "val")],
1619            &[pred(x, &[var(&mut syms, "nb"), var(&mut syms, "val")])],
1620            &[Expression {
1621                ops: vec![
1622                    Op::Value(Term::Integer(5)),
1623                    Op::Value(Term::Integer(-4)),
1624                    Op::Binary(Binary::Add),
1625                    Op::Value(Term::Integer(-1)),
1626                    Op::Binary(Binary::Mul),
1627                    Op::Value(var(&mut syms, "nb")),
1628                    Op::Binary(Binary::LessThan),
1629                ],
1630            }],
1631        );
1632
1633        println!("world:\n{}\n", syms.print_world(&w));
1634        println!("\ntesting r1: {}\n", syms.print_rule(&r1));
1635        let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap();
1636        for (_, fact) in res.iter_all() {
1637            println!("\t{}", syms.print_fact(fact));
1638        }
1639
1640        let res2 = res
1641            .iter_all()
1642            .map(|(_, fact)| fact)
1643            .cloned()
1644            .collect::<HashSet<_>>();
1645        println!("got res: {:?}", res2);
1646        let compared = (vec![fact(less_than, &[&int(0), &def])])
1647            .drain(..)
1648            .collect::<HashSet<_>>();
1649        assert_eq!(res2, compared);
1650    }
1651
1652    #[test]
1653    fn unbound_variables() {
1654        let mut w = World::new();
1655        let mut syms = SymbolTable::new();
1656
1657        let operation = syms.insert("operation");
1658        let check = syms.insert("check");
1659        let read = syms.add("read");
1660        let write = syms.add("write");
1661        let unbound = var(&mut syms, "unbound");
1662        let any1 = var(&mut syms, "any1");
1663        let any2 = var(&mut syms, "any2");
1664
1665        w.add_fact(&[0].iter().collect(), fact(operation, &[&write]));
1666
1667        let r1 = rule(
1668            operation,
1669            &[&unbound, &read],
1670            &[pred(operation, &[&any1, &any2])],
1671        );
1672        println!("world:\n{}\n", syms.print_world(&w));
1673        println!("\ntesting r1: {}\n", syms.print_rule(&r1));
1674        let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap();
1675
1676        println!("generated facts:");
1677        for (_, fact) in res.iter_all() {
1678            println!("\t{}", syms.print_fact(fact));
1679        }
1680
1681        assert!(res.len() == 0);
1682
1683        // operation($unbound, "read") should not have been generated
1684        // in case it is generated though, verify that rule application
1685        // will not match it
1686        w.add_fact(&[0].iter().collect(), fact(operation, &[&unbound, &read]));
1687        let r2 = rule(check, &[&read], &[pred(operation, &[&read])]);
1688        println!("world:\n{}\n", syms.print_world(&w));
1689        println!("\ntesting r2: {}\n", syms.print_rule(&r2));
1690        let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap();
1691
1692        println!("generated facts:");
1693        for (_, fact) in res.iter_all() {
1694            println!("\t{}", syms.print_fact(fact));
1695        }
1696        assert!(res.is_empty());
1697    }
1698}