1use crate::builder::{CheckKind, Convert};
7use crate::error::Execution;
8use crate::time::Instant;
9use crate::token::{Scope, DATALOG_3_1, DATALOG_3_3, MIN_SCHEMA_VERSION};
10use crate::{builder, error};
11use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
12use std::convert::AsRef;
13use std::fmt;
14use std::time::{Duration, SystemTime, UNIX_EPOCH};
15
16mod expression;
17mod origin;
18mod symbol;
19pub use expression::*;
20pub use origin::*;
21pub use symbol::*;
22
23#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
24pub enum Term {
25 Variable(u32),
26 Integer(i64),
27 Str(SymbolIndex),
28 Date(u64),
29 Bytes(Vec<u8>),
30 Bool(bool),
31 Set(BTreeSet<Term>),
32 Null,
33 Array(Vec<Term>),
34 Map(BTreeMap<MapKey, Term>),
35}
36
37#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
38pub enum MapKey {
39 Integer(i64),
40 Str(SymbolIndex),
41}
42
43impl From<&Term> for Term {
44 fn from(i: &Term) -> Self {
45 match i {
46 Term::Variable(ref v) => Term::Variable(*v),
47 Term::Integer(ref i) => Term::Integer(*i),
48 Term::Str(ref s) => Term::Str(*s),
49 Term::Date(ref d) => Term::Date(*d),
50 Term::Bytes(ref b) => Term::Bytes(b.clone()),
51 Term::Bool(ref b) => Term::Bool(*b),
52 Term::Set(ref s) => Term::Set(s.clone()),
53 Term::Null => Term::Null,
54 Term::Array(ref a) => Term::Array(a.clone()),
55 Term::Map(m) => Term::Map(m.clone()),
56 }
57 }
58}
59
60impl AsRef<Term> for Term {
61 fn as_ref(&self) -> &Term {
62 self
63 }
64}
65
66#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)]
67pub struct Predicate {
68 pub name: SymbolIndex,
69 pub terms: Vec<Term>,
70}
71
72impl Predicate {
73 pub fn new(name: SymbolIndex, terms: &[Term]) -> Predicate {
74 Predicate {
75 name,
76 terms: terms.to_vec(),
77 }
78 }
79}
80
81impl AsRef<Predicate> for Predicate {
82 fn as_ref(&self) -> &Predicate {
83 self
84 }
85}
86
87#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
88pub struct Fact {
89 pub predicate: Predicate,
90}
91
92impl Fact {
93 pub fn new(name: SymbolIndex, terms: &[Term]) -> Fact {
94 Fact {
95 predicate: Predicate::new(name, terms),
96 }
97 }
98}
99
100#[derive(Debug, Clone, Hash, PartialEq, Eq)]
101pub struct Rule {
102 pub head: Predicate,
103 pub body: Vec<Predicate>,
104 pub expressions: Vec<Expression>,
105 pub scopes: Vec<Scope>,
106}
107
108impl AsRef<Expression> for Expression {
109 fn as_ref(&self) -> &Expression {
110 self
111 }
112}
113
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub struct Check {
116 pub queries: Vec<Rule>,
117 pub kind: CheckKind,
118}
119
120impl fmt::Display for Fact {
121 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
122 write!(f, "{}({:?})", self.predicate.name, self.predicate.terms)
123 }
124}
125
126impl Rule {
127 fn variables_set(&self) -> HashSet<u32> {
129 self.body
130 .iter()
131 .flat_map(|pred| {
132 pred.terms.iter().filter_map(|id| match id {
133 Term::Variable(i) => Some(*i),
134 _ => None,
135 })
136 })
137 .collect::<HashSet<_>>()
138 }
139
140 pub fn apply<'a, IT>(
141 &'a self,
142 facts: IT,
143 rule_origin: usize,
144 symbols: &'a SymbolTable,
145 extern_funcs: &'a HashMap<String, ExternFunc>,
146 ) -> impl Iterator<Item = Result<(Origin, Fact), error::Expression>> + 'a
147 where
148 IT: Iterator<Item = (&'a Origin, &'a Fact)> + Clone + 'a,
149 {
150 let head = self.head.clone();
151 let variables = MatchedVariables::new(self.variables_set());
152
153 CombineIt::new(variables, &self.body, facts, symbols)
154 .map(move |(origin, variables)| {
155 let mut temporary_symbols = TemporarySymbolTable::new(symbols);
156 for e in self.expressions.iter() {
157 match e.evaluate(&variables, &mut temporary_symbols, extern_funcs) {
158 Ok(Term::Bool(true)) => {}
159 Ok(Term::Bool(false)) => return Ok((origin, variables, false)),
160 Ok(_) => return Err(error::Expression::InvalidType),
161 Err(e) => {
162 return Err(e);
164 }
165 }
166 }
167 Ok((origin, variables, true))
168 }).filter_map(move |res| {
169 match res {
170 Ok((mut origin,h , expression_res)) => {
171 if expression_res {
172 let mut p = head.clone();
173 for index in 0..p.terms.len() {
174 match &p.terms[index] {
175 Term::Variable(i) => match h.get(i) {
176 Some(val) => p.terms[index] = val.clone(),
177 None => {
178 println!("error: variables that appear in the head should appear in the body and constraints as well");
179 return None;
180 }
181 },
182 _ => continue,
183 };
184 }
185
186 origin.insert(rule_origin);
187 Some(Ok((origin, Fact { predicate: p })))
188 } else {None}
189 },
190 Err(e) => Some(Err(e))
191 }
192
193 })
194 }
195
196 pub fn find_match(
197 &self,
198 facts: &FactSet,
199 origin: usize,
200 scope: &TrustedOrigins,
201 symbols: &SymbolTable,
202 extern_funcs: &HashMap<String, ExternFunc>,
203 ) -> Result<bool, Execution> {
204 let fact_it = facts.iterator(scope);
205 let mut it = self.apply(fact_it, origin, symbols, extern_funcs);
206
207 let next = it.next();
208 match next {
209 None => Ok(false),
210 Some(Ok(_)) => Ok(true),
211 Some(Err(e)) => Err(Execution::Expression(e)),
212 }
213 }
214
215 pub fn check_match_all(
216 &self,
217 facts: &FactSet,
218 scope: &TrustedOrigins,
219 symbols: &SymbolTable,
220 extern_funcs: &HashMap<String, ExternFunc>,
221 ) -> Result<bool, Execution> {
222 let fact_it = facts.iterator(scope);
223 let variables = MatchedVariables::new(self.variables_set());
224 let mut found = false;
225
226 for (_, variables) in CombineIt::new(variables, &self.body, fact_it, symbols) {
227 found = true;
228
229 let mut temporary_symbols = TemporarySymbolTable::new(symbols);
230 for e in self.expressions.iter() {
231 match e.evaluate(&variables, &mut temporary_symbols, extern_funcs) {
232 Ok(Term::Bool(true)) => {}
233 Ok(Term::Bool(false)) => {
234 return Ok(false);
236 }
237 Ok(_) => {
238 return Err(error::Execution::Expression(error::Expression::InvalidType))
239 }
240 Err(e) => {
241 return Err(error::Execution::Expression(e));
242 }
243 }
244 }
245 }
246
247 Ok(found)
248 }
249
250 pub fn translate(
255 &self,
256 origin_symbols: &SymbolTable,
257 target_symbols: &mut SymbolTable,
258 ) -> Result<Self, error::Format> {
259 Ok(Rule {
260 head: builder::Predicate::convert_from(&self.head, origin_symbols)?
261 .convert(target_symbols),
262 body: self
263 .body
264 .iter()
265 .map(|p| {
266 builder::Predicate::convert_from(p, origin_symbols)
267 .map(|p| p.convert(target_symbols))
268 })
269 .collect::<Result<Vec<_>, _>>()?,
270 expressions: self
271 .expressions
272 .iter()
273 .map(|c| {
274 builder::Expression::convert_from(c, origin_symbols)
275 .map(|e| e.convert(target_symbols))
276 })
277 .collect::<Result<Vec<_>, _>>()?,
278 scopes: self
279 .scopes
280 .iter()
281 .map(|s| {
282 builder::Scope::convert_from(s, origin_symbols)
283 .map(|s| s.convert(target_symbols))
284 })
285 .collect::<Result<Vec<_>, _>>()?,
286 })
287 }
288
289 pub fn validate_variables(&self, symbols: &SymbolTable) -> Result<(), String> {
290 let mut head_variables: std::collections::HashSet<u32> = self
291 .head
292 .terms
293 .iter()
294 .filter_map(|term| match term {
295 Term::Variable(s) => Some(*s),
296 _ => None,
297 })
298 .collect();
299
300 for predicate in self.body.iter() {
301 for term in predicate.terms.iter() {
302 if let Term::Variable(v) = term {
303 head_variables.remove(v);
304 if head_variables.is_empty() {
305 return Ok(());
306 }
307 }
308 }
309 }
310
311 if head_variables.is_empty() {
312 Ok(())
313 } else {
314 Err(format!(
315 "rule head contains variables that are not used in predicates of the rule's body: {}",
316 head_variables
317 .iter()
318 .map(|s| format!("${}", symbols.print_symbol_default(*s as u64)))
319 .collect::<Vec<_>>()
320 .join(", ")
321 ))
322 }
323 }
324}
325
326pub struct CombineIt<'a, IT> {
328 variables: MatchedVariables,
329 predicates: &'a [Predicate],
330 all_facts: IT,
331 symbols: &'a SymbolTable,
332 current_facts: Box<dyn Iterator<Item = (&'a Origin, &'a Fact)> + 'a>,
333 current_it: Option<Box<dyn Iterator<Item = (Origin, HashMap<u32, Term>)> + 'a>>,
334}
335
336impl<'a, IT> CombineIt<'a, IT>
337where
338 IT: Iterator<Item = (&'a Origin, &'a Fact)> + Clone + 'a,
339{
340 pub fn new(
341 variables: MatchedVariables,
342 predicates: &'a [Predicate],
343 facts: IT,
344 symbols: &'a SymbolTable,
345 ) -> Self {
346 let current_facts: Box<dyn Iterator<Item = (&'a Origin, &'a Fact)> + 'a> =
347 if predicates.is_empty() {
348 Box::new(facts.clone())
349 } else {
350 let p = predicates[0].clone();
351 Box::new(
352 facts
353 .clone()
354 .filter(move |fact| match_preds(&p, &fact.1.predicate)),
355 )
356 };
357
358 CombineIt {
359 variables,
360 predicates,
361 all_facts: facts,
362 symbols,
363 current_facts,
364 current_it: None,
365 }
366 }
367}
368
369impl<'a, IT> Iterator for CombineIt<'a, IT>
370where
371 IT: Iterator<Item = (&'a Origin, &'a Fact)> + Clone + 'a,
372 Self: 'a,
373{
374 type Item = (Origin, HashMap<u32, Term>);
375
376 fn next(&mut self) -> Option<(Origin, HashMap<u32, Term>)> {
377 if self.predicates.is_empty() {
379 match self.variables.complete() {
380 None => return None,
381 Some(variables) => {
383 self.variables = MatchedVariables::new([0].into());
388 return Some((Origin::default(), variables));
389 }
390 }
391 }
392
393 loop {
394 if self.current_it.is_none() {
395 let pred = &self.predicates[0];
397
398 loop {
399 if let Some((current_origin, current_fact)) = self.current_facts.next() {
400 let mut vars = self.variables.clone();
403 let mut match_terms = true;
404 for (key, id) in pred.terms.iter().zip(¤t_fact.predicate.terms) {
405 if let (Term::Variable(k), id) = (key, id) {
406 if !vars.insert(*k, id) {
407 match_terms = false;
408 }
409
410 if !match_terms {
411 break;
412 }
413 }
414 }
415
416 if !match_terms {
417 continue;
418 }
419
420 if self.predicates.len() == 1 {
421 match vars.complete() {
422 None => {
423 continue;
425 }
426 Some(variables) => {
428 return Some((current_origin.clone(), variables));
429 }
430 }
431 } else {
432 self.current_it = Some(Box::new(
435 CombineIt::new(
436 vars,
437 &self.predicates[1..],
438 self.all_facts.clone(),
439 self.symbols,
440 )
441 .map(move |(origin, variables)| {
442 (origin.union(current_origin), variables)
443 }),
444 ));
445 }
446 break;
447 } else {
448 return None;
449 }
450 }
451 }
452
453 if self.current_it.is_none() {
454 break None;
455 }
456
457 if let Some((origin, variables)) = self.current_it.as_mut().and_then(|it| it.next()) {
458 break Some((origin, variables));
459 } else {
460 self.current_it = None;
461 }
462 }
463 }
464}
465
466#[derive(Debug, Clone, PartialEq, Eq)]
467pub struct MatchedVariables {
468 pub variables: HashMap<u32, Option<Term>>,
469}
470
471impl MatchedVariables {
472 pub fn new(import: HashSet<u32>) -> Self {
473 MatchedVariables {
474 variables: import.iter().map(|key| (*key, None)).collect(),
475 }
476 }
477
478 pub fn insert(&mut self, key: u32, value: &Term) -> bool {
479 match self.variables.get(&key) {
480 Some(None) => {
481 self.variables.insert(key, Some(value.clone()));
482 true
483 }
484 Some(Some(v)) => value == v,
485 None => false,
486 }
487 }
488
489 pub fn is_complete(&self) -> bool {
490 self.variables.values().all(|v| v.is_some())
491 }
492
493 pub fn complete(&self) -> Option<HashMap<u32, Term>> {
494 let mut result = HashMap::new();
495 for (k, v) in self.variables.iter() {
496 match v {
497 Some(value) => result.insert(*k, value.clone()),
498 None => return None,
499 };
500 }
501 Some(result)
502 }
503}
504
505pub fn fact<I: AsRef<Term>>(name: SymbolIndex, terms: &[I]) -> Fact {
506 Fact {
507 predicate: Predicate {
508 name,
509 terms: terms.iter().map(|id| id.as_ref().clone()).collect(),
510 },
511 }
512}
513
514pub fn pred<I: AsRef<Term>>(name: SymbolIndex, terms: &[I]) -> Predicate {
515 Predicate {
516 name,
517 terms: terms.iter().map(|id| id.as_ref().clone()).collect(),
518 }
519}
520
521pub fn rule<I: AsRef<Term>, P: AsRef<Predicate>>(
522 head_name: SymbolIndex,
523 head_terms: &[I],
524 predicates: &[P],
525) -> Rule {
526 Rule {
527 head: pred(head_name, head_terms),
528 body: predicates.iter().map(|p| p.as_ref().clone()).collect(),
529 expressions: Vec::new(),
530 scopes: vec![],
531 }
532}
533
534pub fn expressed_rule<I: AsRef<Term>, P: AsRef<Predicate>, C: AsRef<Expression>>(
535 head_name: SymbolIndex,
536 head_terms: &[I],
537 predicates: &[P],
538 expressions: &[C],
539) -> Rule {
540 Rule {
541 head: pred(head_name, head_terms),
542 body: predicates.iter().map(|p| p.as_ref().clone()).collect(),
543 expressions: expressions.iter().map(|c| c.as_ref().clone()).collect(),
544 scopes: vec![],
545 }
546}
547
548pub fn int(i: i64) -> Term {
549 Term::Integer(i)
550}
551
552pub fn date(t: &SystemTime) -> Term {
557 let dur = t.duration_since(UNIX_EPOCH).unwrap();
558 Term::Date(dur.as_secs())
559}
560
561pub fn var(syms: &mut SymbolTable, name: &str) -> Term {
562 let id = syms.insert(name);
563 Term::Variable(id as u32)
564}
565
566pub fn match_preds(rule_pred: &Predicate, fact_pred: &Predicate) -> bool {
567 rule_pred.name == fact_pred.name
568 && rule_pred.terms.len() == fact_pred.terms.len()
569 && rule_pred
570 .terms
571 .iter()
572 .zip(&fact_pred.terms)
573 .all(|(fid, pid)| match (fid, pid) {
574 (_, Term::Variable(_)) => false,
576 (Term::Variable(_), _) => true,
577 (Term::Integer(i), Term::Integer(j)) => i == j,
578 (Term::Str(i), Term::Str(j)) => i == j,
579 (Term::Date(i), Term::Date(j)) => i == j,
580 (Term::Bytes(i), Term::Bytes(j)) => i == j,
581 (Term::Bool(i), Term::Bool(j)) => i == j,
582 (Term::Null, Term::Null) => true,
583 (Term::Set(i), Term::Set(j)) => i == j,
584 (Term::Array(i), Term::Array(j)) => i == j,
585 (Term::Map(i), Term::Map(j)) => i == j,
586 _ => false,
587 })
588}
589
590#[derive(Debug, Clone, Default)]
591pub struct World {
592 pub facts: FactSet,
593 pub rules: RuleSet,
594 pub iterations: u64,
595 pub extern_funcs: HashMap<String, ExternFunc>,
596}
597
598impl World {
599 pub fn new() -> Self {
600 World::default()
601 }
602
603 pub fn add_fact(&mut self, origin: &Origin, fact: Fact) {
604 self.facts.insert(origin, fact);
605 }
606
607 pub fn add_rule(&mut self, origin: usize, scope: &TrustedOrigins, rule: Rule) {
608 self.rules.insert(origin, scope, rule);
609 }
610
611 pub fn run(&mut self, symbols: &SymbolTable) -> Result<(), crate::error::Execution> {
612 self.run_with_limits(symbols, RunLimits::default())
613 }
614
615 pub fn run_with_limits(
616 &mut self,
617 symbols: &SymbolTable,
618 limits: RunLimits,
619 ) -> Result<(), crate::error::Execution> {
620 let start = Instant::now();
621 let time_limit = start + limits.max_time;
622 let mut index = 0;
623
624 let res = loop {
625 let mut new_facts = FactSet::default();
626
627 for (scope, rules) in self.rules.inner.iter() {
628 let it = self.facts.iterator(scope);
629 for (origin, rule) in rules {
630 for res in rule.apply(it.clone(), *origin, symbols, &self.extern_funcs) {
631 match res {
632 Ok((origin, fact)) => {
633 new_facts.insert(&origin, fact);
634 }
635 Err(e) => {
636 return Err(Execution::Expression(e));
637 }
638 }
639 }
640 }
642 }
643
644 let len = self.facts.len();
645 self.facts.merge(new_facts);
646 if self.facts.len() == len {
647 break Ok(());
648 }
649
650 index += 1;
651 if index == limits.max_iterations {
652 break Err(Execution::RunLimit(
653 crate::error::RunLimit::TooManyIterations,
654 ));
655 }
656
657 if self.facts.len() >= limits.max_facts as usize {
658 break Err(Execution::RunLimit(crate::error::RunLimit::TooManyFacts));
659 }
660
661 let now = Instant::now();
662 if now >= time_limit {
663 break Err(Execution::RunLimit(crate::error::RunLimit::Timeout));
664 }
665 };
666
667 self.iterations += index;
668
669 res
670 }
671
672 pub fn query_rule(
696 &self,
697 rule: Rule,
698 origin: usize,
699 scope: &TrustedOrigins,
700 symbols: &SymbolTable,
701 ) -> Result<FactSet, Execution> {
702 let mut new_facts = FactSet::default();
703 let it = self.facts.iterator(scope);
704 for res in rule.apply(it.clone(), origin, symbols, &self.extern_funcs) {
706 match res {
707 Ok((origin, fact)) => {
708 new_facts.insert(&origin, fact);
709 }
710 Err(e) => {
711 return Err(Execution::Expression(e));
712 }
713 }
714 }
715
716 Ok(new_facts)
717 }
718
719 pub fn query_match(
720 &self,
721 rule: Rule,
722 origin: usize,
723 scope: &TrustedOrigins,
724 symbols: &SymbolTable,
725 ) -> Result<bool, Execution> {
726 rule.find_match(&self.facts, origin, scope, symbols, &self.extern_funcs)
727 }
728
729 pub fn query_match_all(
730 &self,
731 rule: Rule,
732 scope: &TrustedOrigins,
733 symbols: &SymbolTable,
734 ) -> Result<bool, Execution> {
735 rule.check_match_all(&self.facts, scope, symbols, &self.extern_funcs)
736 }
737}
738
739#[derive(Debug, Clone, PartialEq, Eq)]
741pub struct RunLimits {
742 pub max_facts: u64,
744 pub max_iterations: u64,
746 pub max_time: Duration,
748}
749
750impl std::default::Default for RunLimits {
751 fn default() -> Self {
752 RunLimits {
753 max_facts: 1000,
754 max_iterations: 100,
755 max_time: Duration::from_millis(1),
756 }
757 }
758}
759
760#[derive(Clone, Debug, Default)]
761pub struct FactSet {
762 pub(crate) inner: HashMap<Origin, HashSet<Fact>>,
763}
764
765impl FactSet {
766 pub fn insert(&mut self, origin: &Origin, fact: Fact) {
767 match self.inner.get_mut(origin) {
768 None => {
769 let mut set = HashSet::new();
770 set.insert(fact);
771 self.inner.insert(origin.clone(), set);
772 }
773 Some(set) => {
774 set.insert(fact);
775 }
776 }
777 }
778
779 pub fn len(&self) -> usize {
780 self.inner.values().fold(0, |acc, set| acc + set.len())
781 }
782
783 pub fn is_empty(&self) -> bool {
784 self.inner.values().all(|set| set.is_empty())
785 }
786
787 pub fn iterator<'a>(
788 &'a self,
789 block_ids: &'a TrustedOrigins,
790 ) -> impl Iterator<Item = (&'a Origin, &'a Fact)> + Clone {
791 self.inner
792 .iter()
793 .filter_map(move |(ids, facts)| {
794 if block_ids.contains(ids) {
795 Some(facts.iter().map(move |fact| (ids, fact)))
796 } else {
797 None
798 }
799 })
800 .flatten()
801 }
802
803 pub fn iter_all(&self) -> impl Iterator<Item = (&Origin, &Fact)> + Clone {
804 self.inner
805 .iter()
806 .flat_map(move |(ids, facts)| facts.iter().map(move |fact| (ids, fact)))
807 }
808
809 pub fn merge(&mut self, other: FactSet) {
810 for (origin, facts) in other.inner {
811 let entry = self.inner.entry(origin).or_default();
812 entry.extend(facts.into_iter());
813 }
814 }
815}
816
817impl Extend<(Origin, Fact)> for FactSet {
818 fn extend<T: IntoIterator<Item = (Origin, Fact)>>(&mut self, iter: T) {
819 for (origin, fact) in iter {
820 let entry = self.inner.entry(origin).or_default();
821 entry.insert(fact);
822 }
823 }
824}
825
826impl IntoIterator for FactSet {
827 type Item = (Origin, Fact);
828
829 type IntoIter = Box<dyn Iterator<Item = (Origin, Fact)>>;
830
831 fn into_iter(self) -> Self::IntoIter {
832 Box::new(
833 self.inner.into_iter().flat_map(move |(ids, facts)| {
834 facts.into_iter().map(move |fact| (ids.clone(), fact))
835 }),
836 )
837 }
838}
839
840#[derive(Clone, Debug, Default)]
841pub struct RuleSet {
842 pub inner: HashMap<TrustedOrigins, Vec<(usize, Rule)>>,
843}
844
845impl RuleSet {
846 pub fn insert(&mut self, origin: usize, scope: &TrustedOrigins, rule: Rule) {
847 match self.inner.get_mut(scope) {
848 None => {
849 self.inner.insert(scope.clone(), vec![(origin, rule)]);
850 }
851 Some(set) => {
852 set.push((origin, rule));
853 }
854 }
855 }
856
857 pub fn iter_all(&self) -> impl Iterator<Item = (&TrustedOrigins, &Rule)> + Clone {
858 self.inner
859 .iter()
860 .flat_map(move |(ids, rules)| rules.iter().map(move |(_, rule)| (ids, rule)))
861 }
862}
863
864pub struct SchemaVersion {
865 contains_scopes: bool,
866 contains_v3_1: bool,
867 contains_check_all: bool,
868 contains_v3_3: bool,
869}
870
871impl SchemaVersion {
872 pub fn version(&self) -> u32 {
873 if self.contains_v3_3 {
874 DATALOG_3_3
875 } else if self.contains_scopes || self.contains_v3_1 || self.contains_check_all {
876 DATALOG_3_1
877 } else {
878 MIN_SCHEMA_VERSION
879 }
880 }
881
882 pub fn check_compatibility(&self, version: u32) -> Result<(), error::Format> {
883 if version < DATALOG_3_1 {
884 if self.contains_scopes {
885 Err(error::Format::DeserializationError(
886 "scopes are only supported in datalog v3.1+".to_string(),
887 ))
888 } else if self.contains_v3_1 {
889 Err(error::Format::DeserializationError(
890 "bitwise operators and != are only supported in datalog v3.1+".to_string(),
891 ))
892 } else if self.contains_check_all {
893 Err(error::Format::DeserializationError(
894 "check all is only supported in datalog v3.1+".to_string(),
895 ))
896 } else {
897 Ok(())
898 }
899 } else if version < DATALOG_3_3 && self.contains_v3_3 {
900 Err(error::Format::DeserializationError(
901 "maps, arrays, null, closures are only supported in datalog v3.3+".to_string(),
902 ))
903 } else {
904 Ok(())
905 }
906 }
907}
908
909pub fn get_schema_version(
911 facts: &[Fact],
912 rules: &[Rule],
913 checks: &[Check],
914 scopes: &[Scope],
915) -> SchemaVersion {
916 let contains_scopes = !scopes.is_empty()
917 || rules.iter().any(|r: &Rule| !r.scopes.is_empty())
918 || checks
919 .iter()
920 .any(|c: &Check| c.queries.iter().any(|q| !q.scopes.is_empty()));
921
922 let mut contains_check_all = false;
923 let mut contains_v3_3 = false;
924
925 for c in checks.iter() {
926 if c.kind == CheckKind::All {
927 contains_check_all = true;
928 } else if c.kind == CheckKind::Reject {
929 contains_v3_3 = true;
930 }
931 }
932
933 let contains_v3_1 = rules.iter().any(|rule| contains_v3_1_op(&rule.expressions))
934 || checks.iter().any(|check| {
935 check
936 .queries
937 .iter()
938 .any(|query| contains_v3_1_op(&query.expressions))
939 });
940
941 if !contains_v3_3 {
943 contains_v3_3 = rules.iter().any(|rule| {
944 contains_v3_3_predicate(&rule.head)
945 || rule.body.iter().any(contains_v3_3_predicate)
946 || contains_v3_3_op(&rule.expressions)
947 }) || checks.iter().any(|check| {
948 check.queries.iter().any(|query| {
949 query.body.iter().any(contains_v3_3_predicate)
950 || contains_v3_3_op(&query.expressions)
951 })
952 });
953 }
954 if !contains_v3_3 {
955 contains_v3_3 = facts
956 .iter()
957 .any(|fact| contains_v3_3_predicate(&fact.predicate))
958 }
959
960 SchemaVersion {
961 contains_scopes,
962 contains_v3_1,
963 contains_check_all,
964 contains_v3_3,
965 }
966}
967
968pub fn contains_v3_1_op(expressions: &[Expression]) -> bool {
971 expressions.iter().any(|expression| {
972 expression.ops.iter().any(|op| {
973 if let Op::Binary(binary) = op {
974 match binary {
975 Binary::BitwiseAnd
976 | Binary::BitwiseOr
977 | Binary::BitwiseXor
978 | Binary::NotEqual => return true,
979 _ => return false,
980 }
981 }
982 false
983 })
984 })
985}
986
987fn contains_v3_3_op(expressions: &[Expression]) -> bool {
988 expressions.iter().any(|expression| {
989 expression.ops.iter().any(|op| match op {
990 Op::Value(term) => contains_v3_3_term(term),
991 Op::Closure(_, _) => true,
992 Op::Unary(unary) => matches!(unary, Unary::TypeOf | Unary::Ffi(_)),
993 Op::Binary(binary) => matches!(
994 binary,
995 Binary::HeterogeneousEqual
996 | Binary::HeterogeneousNotEqual
997 | Binary::LazyAnd
998 | Binary::LazyOr
999 | Binary::All
1000 | Binary::Any
1001 | Binary::Ffi(_)
1002 ),
1003 })
1004 })
1005}
1006
1007fn contains_v3_3_predicate(predicate: &Predicate) -> bool {
1008 predicate.terms.iter().any(contains_v3_3_term)
1009}
1010
1011fn contains_v3_3_term(term: &Term) -> bool {
1012 match term {
1013 Term::Null => true,
1014 Term::Set(s) => s.contains(&Term::Null),
1015 _ => false,
1016 }
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021 use super::*;
1022 use std::time::Duration;
1023
1024 #[test]
1025 fn family() {
1026 let mut w = World::new();
1027 let mut syms = SymbolTable::new();
1028
1029 let a = syms.add("A");
1030 let b = syms.add("B");
1031 let c = syms.add("C");
1032 let d = syms.add("D");
1033 let e = syms.add("e");
1034 let parent = syms.insert("parent");
1035 let grandparent = syms.insert("grandparent");
1036
1037 w.add_fact(&[0].iter().collect(), fact(parent, &[&a, &b]));
1038 w.add_fact(&[0].iter().collect(), fact(parent, &[&b, &c]));
1039 w.add_fact(&[0].iter().collect(), fact(parent, &[&c, &d]));
1040
1041 let r1 = rule(
1042 grandparent,
1043 &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
1044 &[
1045 pred(
1046 parent,
1047 &[var(&mut syms, "grandparent"), var(&mut syms, "parent")],
1048 ),
1049 pred(
1050 parent,
1051 &[var(&mut syms, "parent"), var(&mut syms, "grandchild")],
1052 ),
1053 ],
1054 );
1055
1056 println!("symbols: {:?}", syms);
1057 println!("testing r1: {}", syms.print_rule(&r1));
1058 let query_rule_result = w.query_rule(r1, 0, &[0].iter().collect(), &syms);
1059 println!("grandparents query_rules: {:?}", query_rule_result);
1060 println!("current facts: {:?}", w.facts);
1061
1062 let r2 = rule(
1063 grandparent,
1064 &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
1065 &[
1066 pred(
1067 parent,
1068 &[var(&mut syms, "grandparent"), var(&mut syms, "parent")],
1069 ),
1070 pred(
1071 parent,
1072 &[var(&mut syms, "parent"), var(&mut syms, "grandchild")],
1073 ),
1074 ],
1075 );
1076
1077 println!("adding r2: {}", syms.print_rule(&r2));
1078 w.add_rule(0, &[0].iter().collect(), r2);
1079
1080 w.run_with_limits(
1081 &syms,
1082 RunLimits {
1083 max_time: Duration::from_secs(10),
1084 ..Default::default()
1085 },
1086 )
1087 .unwrap();
1088
1089 println!("parents:");
1090 let res = w
1091 .query_rule(
1092 rule::<Term, Predicate>(
1093 parent,
1094 &[var(&mut syms, "parent"), var(&mut syms, "child")],
1095 &[pred(
1096 parent,
1097 &[var(&mut syms, "parent"), var(&mut syms, "child")],
1098 )],
1099 ),
1100 0,
1101 &[0].iter().collect(),
1102 &syms,
1103 )
1104 .unwrap();
1105
1106 for (origin, fact) in res.iterator(&[0].iter().collect()) {
1107 println!("\t{:?}\t{}", origin, syms.print_fact(fact));
1108 }
1109
1110 println!(
1111 "parents of B: {:?}",
1112 w.query_rule(
1113 rule::<&Term, Predicate>(
1114 parent,
1115 &[&var(&mut syms, "parent"), &b],
1116 &[pred(parent, &[&var(&mut syms, "parent"), &b])]
1117 ),
1118 0,
1119 &[0].iter().collect(),
1120 &syms,
1121 )
1122 );
1123 println!(
1124 "grandparents: {:?}",
1125 w.query_rule(
1126 rule::<Term, Predicate>(
1127 grandparent,
1128 &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
1129 &[pred(
1130 grandparent,
1131 &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")]
1132 )]
1133 ),
1134 0,
1135 &[0].iter().collect(),
1136 &syms,
1137 )
1138 );
1139 w.add_fact(&[0].iter().collect(), fact(parent, &[&c, &e]));
1140 w.run(&syms).unwrap();
1141 let res = w
1142 .query_rule(
1143 rule::<Term, Predicate>(
1144 grandparent,
1145 &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
1146 &[pred(
1147 grandparent,
1148 &[var(&mut syms, "grandparent"), var(&mut syms, "grandchild")],
1149 )],
1150 ),
1151 0,
1152 &[0].iter().collect(),
1153 &syms,
1154 )
1155 .unwrap();
1156 println!("grandparents after inserting parent(C, E): {:?}", res);
1157
1158 let res = res
1159 .iter_all()
1160 .map(|(_origin, fact)| fact)
1161 .cloned()
1162 .collect::<HashSet<_>>();
1163 let compared = (vec![
1164 fact(grandparent, &[&a, &c]),
1165 fact(grandparent, &[&b, &d]),
1166 fact(grandparent, &[&b, &e]),
1167 ])
1168 .drain(..)
1169 .collect::<HashSet<_>>();
1170 assert_eq!(res, compared);
1171
1172 }
1181
1182 #[test]
1183 fn numbers() {
1184 let mut w = World::new();
1185 let mut syms = SymbolTable::new();
1186
1187 let abc = syms.add("abc");
1188 let def = syms.add("def");
1189 let ghi = syms.add("ghi");
1190 let jkl = syms.add("jkl");
1191 let mno = syms.add("mno");
1192 let aaa = syms.add("AAA");
1193 let bbb = syms.add("BBB");
1194 let ccc = syms.add("CCC");
1195 let t1 = syms.insert("t1");
1196 let t2 = syms.insert("t2");
1197 let join = syms.insert("join");
1198
1199 w.add_fact(&[0].iter().collect(), fact(t1, &[&int(0), &abc]));
1200 w.add_fact(&[0].iter().collect(), fact(t1, &[&int(1), &def]));
1201 w.add_fact(&[0].iter().collect(), fact(t1, &[&int(2), &ghi]));
1202 w.add_fact(&[0].iter().collect(), fact(t1, &[&int(3), &jkl]));
1203 w.add_fact(&[0].iter().collect(), fact(t1, &[&int(4), &mno]));
1204
1205 w.add_fact(&[0].iter().collect(), fact(t2, &[&int(0), &aaa, &int(0)]));
1206 w.add_fact(&[0].iter().collect(), fact(t2, &[&int(1), &bbb, &int(0)]));
1207 w.add_fact(&[0].iter().collect(), fact(t2, &[&int(2), &ccc, &int(1)]));
1208
1209 let res = w
1210 .query_rule(
1211 rule(
1212 join,
1213 &[var(&mut syms, "left"), var(&mut syms, "right")],
1214 &[
1215 pred(t1, &[var(&mut syms, "id"), var(&mut syms, "left")]),
1216 pred(
1217 t2,
1218 &[
1219 var(&mut syms, "t2_id"),
1220 var(&mut syms, "right"),
1221 var(&mut syms, "id"),
1222 ],
1223 ),
1224 ],
1225 ),
1226 0,
1227 &[0].iter().collect(),
1228 &syms,
1229 )
1230 .unwrap();
1231
1232 for (_, fact) in res.iter_all() {
1233 println!("\t{}", syms.print_fact(fact));
1234 }
1235
1236 let res2 = res
1237 .iter_all()
1238 .map(|(_origin, fact)| fact)
1239 .cloned()
1240 .collect::<HashSet<_>>();
1241 let compared = (vec![
1242 fact(join, &[&abc, &aaa]),
1243 fact(join, &[&abc, &bbb]),
1244 fact(join, &[&def, &ccc]),
1245 ])
1246 .drain(..)
1247 .collect::<HashSet<_>>();
1248 assert_eq!(res2, compared);
1249
1250 let res = w
1252 .query_rule(
1253 expressed_rule(
1254 join,
1255 &[var(&mut syms, "left"), var(&mut syms, "right")],
1256 &[
1257 pred(t1, &[var(&mut syms, "id"), var(&mut syms, "left")]),
1258 pred(
1259 t2,
1260 &[
1261 var(&mut syms, "t2_id"),
1262 var(&mut syms, "right"),
1263 var(&mut syms, "id"),
1264 ],
1265 ),
1266 ],
1267 &[Expression {
1268 ops: vec![
1269 Op::Value(var(&mut syms, "id")),
1270 Op::Value(Term::Integer(1)),
1271 Op::Binary(Binary::LessThan),
1272 ],
1273 }],
1274 ),
1275 0,
1276 &[0].iter().collect(),
1277 &syms,
1278 )
1279 .unwrap();
1280
1281 for (_, fact) in res.iter_all() {
1282 println!("\t{}", syms.print_fact(fact));
1283 }
1284
1285 let res2 = res
1286 .iter_all()
1287 .map(|(_origin, fact)| fact)
1288 .cloned()
1289 .collect::<HashSet<_>>();
1290 let compared = (vec![fact(join, &[&abc, &aaa]), fact(join, &[&abc, &bbb])])
1291 .drain(..)
1292 .collect::<HashSet<_>>();
1293 assert_eq!(res2, compared);
1294 }
1295
1296 #[test]
1297 fn str() {
1298 let mut w = World::new();
1299 let mut syms = SymbolTable::new();
1300
1301 let app_0 = syms.add("app_0");
1302 let app_1 = syms.add("app_1");
1303 let app_2 = syms.add("app_2");
1304 let route = syms.insert("route");
1305 let suff = syms.insert("route suffix");
1306 let example = syms.add("example.com");
1307 let test_com = syms.add("test.com");
1308 let test_fr = syms.add("test.fr");
1309 let www_example = syms.add("www.example.com");
1310 let mx_example = syms.add("mx.example.com");
1311
1312 w.add_fact(
1313 &[0].iter().collect(),
1314 fact(route, &[&int(0), &app_0, &example]),
1315 );
1316 w.add_fact(
1317 &[0].iter().collect(),
1318 fact(route, &[&int(1), &app_1, &test_com]),
1319 );
1320 w.add_fact(
1321 &[0].iter().collect(),
1322 fact(route, &[&int(2), &app_2, &test_fr]),
1323 );
1324 w.add_fact(
1325 &[0].iter().collect(),
1326 fact(route, &[&int(3), &app_0, &www_example]),
1327 );
1328 w.add_fact(
1329 &[0].iter().collect(),
1330 fact(route, &[&int(4), &app_1, &mx_example]),
1331 );
1332
1333 fn test_suffix(
1334 w: &World,
1335 syms: &mut SymbolTable,
1336 suff: SymbolIndex,
1337 route: SymbolIndex,
1338 suffix: &str,
1339 ) -> Vec<Fact> {
1340 let id_suff = syms.add(suffix);
1341 w.query_rule(
1342 expressed_rule(
1343 suff,
1344 &[var(syms, "app_id"), var(syms, "domain_name")],
1345 &[pred(
1346 route,
1347 &[
1348 var(syms, "route_id"),
1349 var(syms, "app_id"),
1350 var(syms, "domain_name"),
1351 ],
1352 )],
1353 &[Expression {
1354 ops: vec![
1355 Op::Value(var(syms, "domain_name")),
1356 Op::Value(id_suff),
1357 Op::Binary(Binary::Suffix),
1358 ],
1359 }],
1360 ),
1361 0,
1362 &[0].iter().collect(),
1363 &syms,
1364 )
1365 .unwrap()
1366 .iter_all()
1367 .map(|(_, fact)| fact.clone())
1368 .collect()
1369 }
1370
1371 let res = test_suffix(&w, &mut syms, suff, route, ".fr");
1372 for fact in &res {
1373 println!("\t{}", syms.print_fact(fact));
1374 }
1375
1376 let res2 = res.iter().cloned().collect::<HashSet<_>>();
1377 let compared = (vec![fact(suff, &[&app_2, &test_fr])])
1378 .drain(..)
1379 .collect::<HashSet<_>>();
1380 assert_eq!(res2, compared);
1381
1382 let res = test_suffix(&w, &mut syms, suff, route, "example.com");
1383 for fact in &res {
1384 println!("\t{}", syms.print_fact(fact));
1385 }
1386
1387 let res2 = res.iter().cloned().collect::<HashSet<_>>();
1388 let compared = (vec![
1389 fact(suff, &[&app_0, &example]),
1390 fact(suff, &[&app_0, &www_example]),
1391 fact(suff, &[&app_1, &mx_example]),
1392 ])
1393 .drain(..)
1394 .collect::<HashSet<_>>();
1395 assert_eq!(res2, compared);
1396 }
1397
1398 #[test]
1399 fn date_constraint() {
1400 let mut w = World::new();
1401 let mut syms = SymbolTable::new();
1402
1403 let t1 = SystemTime::now();
1404 println!("t1 = {:?}", t1);
1405 let t2 = t1 + Duration::from_secs(10);
1406 println!("t2 = {:?}", t2);
1407 let t3 = t2 + Duration::from_secs(30);
1408 println!("t3 = {:?}", t3);
1409
1410 let t2_timestamp = t2.duration_since(UNIX_EPOCH).unwrap().as_secs();
1411
1412 let abc = syms.add("abc");
1413 let def = syms.add("def");
1414 let x = syms.insert("x");
1415 let before = syms.insert("before");
1416 let after = syms.insert("after");
1417
1418 w.add_fact(&[0].iter().collect(), fact(x, &[&date(&t1), &abc]));
1419 w.add_fact(&[0].iter().collect(), fact(x, &[&date(&t3), &def]));
1420
1421 let r1 = expressed_rule(
1422 before,
1423 &[var(&mut syms, "date"), var(&mut syms, "val")],
1424 &[pred(x, &[var(&mut syms, "date"), var(&mut syms, "val")])],
1425 &[
1426 Expression {
1427 ops: vec![
1428 Op::Value(var(&mut syms, "date")),
1429 Op::Value(Term::Date(t2_timestamp)),
1430 Op::Binary(Binary::LessOrEqual),
1431 ],
1432 },
1433 Expression {
1434 ops: vec![
1435 Op::Value(var(&mut syms, "date")),
1436 Op::Value(Term::Date(0)),
1437 Op::Binary(Binary::GreaterOrEqual),
1438 ],
1439 },
1440 ],
1441 );
1442
1443 println!("testing r1: {}", syms.print_rule(&r1));
1444 let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap();
1445 for (_, fact) in res.iter_all() {
1446 println!("\t{}", syms.print_fact(fact));
1447 }
1448
1449 let res2 = res
1450 .iter_all()
1451 .map(|(_origin, fact)| fact)
1452 .cloned()
1453 .collect::<HashSet<_>>();
1454 let compared = (vec![fact(before, &[&date(&t1), &abc])])
1455 .drain(..)
1456 .collect::<HashSet<_>>();
1457 assert_eq!(res2, compared);
1458
1459 let r2 = expressed_rule(
1460 after,
1461 &[var(&mut syms, "date"), var(&mut syms, "val")],
1462 &[pred(x, &[var(&mut syms, "date"), var(&mut syms, "val")])],
1463 &[
1464 Expression {
1465 ops: vec![
1466 Op::Value(var(&mut syms, "date")),
1467 Op::Value(Term::Date(t2_timestamp)),
1468 Op::Binary(Binary::GreaterOrEqual),
1469 ],
1470 },
1471 Expression {
1472 ops: vec![
1473 Op::Value(var(&mut syms, "date")),
1474 Op::Value(Term::Date(0)),
1475 Op::Binary(Binary::GreaterOrEqual),
1476 ],
1477 },
1478 ],
1479 );
1480
1481 println!("testing r2: {}", syms.print_rule(&r2));
1482 let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap();
1483 for (_, fact) in res.iter_all() {
1484 println!("\t{}", syms.print_fact(fact));
1485 }
1486
1487 let res2 = res
1488 .iter_all()
1489 .map(|(_, fact)| fact)
1490 .cloned()
1491 .collect::<HashSet<_>>();
1492 let compared = (vec![fact(after, &[&date(&t3), &def])])
1493 .drain(..)
1494 .collect::<HashSet<_>>();
1495 assert_eq!(res2, compared);
1496 }
1497
1498 #[test]
1499 fn set_constraint() {
1500 let mut w = World::new();
1501 let mut syms = SymbolTable::new();
1502
1503 let abc = syms.add("abc");
1504 let def = syms.add("def");
1505 let x = syms.insert("x");
1506 let int_set = syms.insert("int_set");
1507 let symbol_set = syms.insert("symbol_set");
1508 let string_set = syms.insert("string_set");
1509 let test = syms.add("test");
1510 let hello = syms.add("hello");
1511 let aaa = syms.add("zzz");
1512
1513 w.add_fact(&[0].iter().collect(), fact(x, &[&abc, &int(0), &test]));
1514 w.add_fact(&[0].iter().collect(), fact(x, &[&def, &int(2), &hello]));
1515
1516 let res = w
1517 .query_rule(
1518 expressed_rule(
1519 int_set,
1520 &[var(&mut syms, "sym"), var(&mut syms, "str")],
1521 &[pred(
1522 x,
1523 &[
1524 var(&mut syms, "sym"),
1525 var(&mut syms, "int"),
1526 var(&mut syms, "str"),
1527 ],
1528 )],
1529 &[Expression {
1530 ops: vec![
1531 Op::Value(Term::Set(
1532 [Term::Integer(0), Term::Integer(1)]
1533 .iter()
1534 .cloned()
1535 .collect(),
1536 )),
1537 Op::Value(var(&mut syms, "int")),
1538 Op::Binary(Binary::Contains),
1539 ],
1540 }],
1541 ),
1542 0,
1543 &[0].iter().collect(),
1544 &syms,
1545 )
1546 .unwrap();
1547
1548 for (_, fact) in res.iter_all() {
1549 println!("\t{}", syms.print_fact(fact));
1550 }
1551
1552 let res2 = res
1553 .iter_all()
1554 .map(|(_, fact)| fact)
1555 .cloned()
1556 .collect::<HashSet<_>>();
1557 let compared = (vec![fact(int_set, &[&abc, &test])])
1558 .drain(..)
1559 .collect::<HashSet<_>>();
1560 assert_eq!(res2, compared);
1561
1562 let abc_sym_id = syms.add("abc");
1563 let ghi_sym_id = syms.add("ghi");
1564
1565 let res = w
1566 .query_rule(
1567 expressed_rule(
1568 symbol_set,
1569 &[
1570 var(&mut syms, "symbol"),
1571 var(&mut syms, "int"),
1572 var(&mut syms, "str"),
1573 ],
1574 &[pred(
1575 x,
1576 &[
1577 var(&mut syms, "symbol"),
1578 var(&mut syms, "int"),
1579 var(&mut syms, "str"),
1580 ],
1581 )],
1582 &[Expression {
1583 ops: vec![
1584 Op::Value(Term::Set(
1585 [abc_sym_id, ghi_sym_id].iter().cloned().collect(),
1586 )),
1587 Op::Value(var(&mut syms, "symbol")),
1588 Op::Binary(Binary::Contains),
1589 Op::Unary(Unary::Negate),
1590 ],
1591 }],
1592 ),
1593 0,
1594 &[0].iter().collect(),
1595 &syms,
1596 )
1597 .unwrap();
1598
1599 for (_, fact) in res.iter_all() {
1600 println!("\t{}", syms.print_fact(fact));
1601 }
1602
1603 let res2 = res
1604 .iter_all()
1605 .map(|(_, fact)| fact)
1606 .cloned()
1607 .collect::<HashSet<_>>();
1608 let compared = (vec![fact(symbol_set, &[&def, &int(2), &hello])])
1609 .drain(..)
1610 .collect::<HashSet<_>>();
1611 assert_eq!(res2, compared);
1612
1613 let res = w
1614 .query_rule(
1615 expressed_rule(
1616 string_set,
1617 &[
1618 var(&mut syms, "sym"),
1619 var(&mut syms, "int"),
1620 var(&mut syms, "str"),
1621 ],
1622 &[pred(
1623 x,
1624 &[
1625 var(&mut syms, "sym"),
1626 var(&mut syms, "int"),
1627 var(&mut syms, "str"),
1628 ],
1629 )],
1630 &[Expression {
1631 ops: vec![
1632 Op::Value(Term::Set([test.clone(), aaa].iter().cloned().collect())),
1633 Op::Value(var(&mut syms, "str")),
1634 Op::Binary(Binary::Contains),
1635 ],
1636 }],
1637 ),
1638 0,
1639 &[0].iter().collect(),
1640 &syms,
1641 )
1642 .unwrap();
1643
1644 for (_, fact) in res.iter_all() {
1645 println!("\t{}", syms.print_fact(fact));
1646 }
1647
1648 let res2 = res
1649 .iter_all()
1650 .map(|(_, fact)| fact)
1651 .cloned()
1652 .collect::<HashSet<_>>();
1653 let compared = (vec![fact(string_set, &[&abc, &int(0), &test])])
1654 .drain(..)
1655 .collect::<HashSet<_>>();
1656 assert_eq!(res2, compared);
1657 }
1658
1659 #[test]
1660 fn resource() {
1661 let mut w = World::new();
1662 let mut syms = SymbolTable::new();
1663
1664 let resource = syms.insert("resource");
1665 let operation = syms.insert("operation");
1666 let right = syms.insert("right");
1667 let file1 = syms.add("file1");
1668 let file2 = syms.add("file2");
1669 let read = syms.add("read");
1670 let write = syms.add("write");
1671 let check1 = syms.insert("check1");
1672 let check2 = syms.insert("check2");
1673
1674 w.add_fact(&[0].iter().collect(), fact(resource, &[&file2]));
1675 w.add_fact(&[0].iter().collect(), fact(operation, &[&write]));
1676 w.add_fact(&[0].iter().collect(), fact(right, &[&file1, &read]));
1677 w.add_fact(&[0].iter().collect(), fact(right, &[&file2, &read]));
1678 w.add_fact(&[0].iter().collect(), fact(right, &[&file1, &write]));
1679
1680 let res = w
1681 .query_rule(
1682 rule(check1, &[&file1], &[pred(resource, &[&file1])]),
1683 0,
1684 &[0].iter().collect(),
1685 &syms,
1686 )
1687 .unwrap();
1688
1689 for (_, fact) in res.iter_all() {
1690 println!("\t{}", syms.print_fact(fact));
1691 }
1692
1693 assert!(res.is_empty());
1694
1695 let res = w
1696 .query_rule(
1697 rule(
1698 check2,
1699 &[Term::Variable(0)],
1700 &[
1701 pred(resource, &[&Term::Variable(0)]),
1702 pred(operation, &[&read]),
1703 pred(right, &[&Term::Variable(0), &read]),
1704 ],
1705 ),
1706 0,
1707 &[0].iter().collect(),
1708 &syms,
1709 )
1710 .unwrap();
1711
1712 for (_, fact) in res.iter_all() {
1713 println!("\t{}", syms.print_fact(fact));
1714 }
1715
1716 assert!(res.is_empty());
1717 }
1718
1719 #[test]
1720 fn int_expr() {
1721 let mut w = World::new();
1722 let mut syms = SymbolTable::new();
1723
1724 let abc = syms.add("abc");
1725 let def = syms.add("def");
1726 let x = syms.insert("x");
1727 let less_than = syms.insert("less_than");
1728
1729 w.add_fact(&[0].iter().collect(), fact(x, &[&int(-2), &abc]));
1730 w.add_fact(&[0].iter().collect(), fact(x, &[&int(0), &def]));
1731
1732 let r1 = expressed_rule(
1733 less_than,
1734 &[var(&mut syms, "nb"), var(&mut syms, "val")],
1735 &[pred(x, &[var(&mut syms, "nb"), var(&mut syms, "val")])],
1736 &[Expression {
1737 ops: vec![
1738 Op::Value(Term::Integer(5)),
1739 Op::Value(Term::Integer(-4)),
1740 Op::Binary(Binary::Add),
1741 Op::Value(Term::Integer(-1)),
1742 Op::Binary(Binary::Mul),
1743 Op::Value(var(&mut syms, "nb")),
1744 Op::Binary(Binary::LessThan),
1745 ],
1746 }],
1747 );
1748
1749 println!("world:\n{}\n", syms.print_world(&w));
1750 println!("\ntesting r1: {}\n", syms.print_rule(&r1));
1751 let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap();
1752 for (_, fact) in res.iter_all() {
1753 println!("\t{}", syms.print_fact(fact));
1754 }
1755
1756 let res2 = res
1757 .iter_all()
1758 .map(|(_, fact)| fact)
1759 .cloned()
1760 .collect::<HashSet<_>>();
1761 println!("got res: {:?}", res2);
1762 let compared = (vec![fact(less_than, &[&int(0), &def])])
1763 .drain(..)
1764 .collect::<HashSet<_>>();
1765 assert_eq!(res2, compared);
1766 }
1767
1768 #[test]
1769 fn unbound_variables() {
1770 let mut w = World::new();
1771 let mut syms = SymbolTable::new();
1772
1773 let operation = syms.insert("operation");
1774 let check = syms.insert("check");
1775 let read = syms.add("read");
1776 let write = syms.add("write");
1777 let unbound = var(&mut syms, "unbound");
1778 let any1 = var(&mut syms, "any1");
1779 let any2 = var(&mut syms, "any2");
1780
1781 w.add_fact(&[0].iter().collect(), fact(operation, &[&write]));
1782
1783 let r1 = rule(
1784 operation,
1785 &[&unbound, &read],
1786 &[pred(operation, &[&any1, &any2])],
1787 );
1788 println!("world:\n{}\n", syms.print_world(&w));
1789 println!("\ntesting r1: {}\n", syms.print_rule(&r1));
1790 let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap();
1791
1792 println!("generated facts:");
1793 for (_, fact) in res.iter_all() {
1794 println!("\t{}", syms.print_fact(fact));
1795 }
1796
1797 assert!(res.len() == 0);
1798
1799 w.add_fact(&[0].iter().collect(), fact(operation, &[&unbound, &read]));
1803 let r2 = rule(check, &[&read], &[pred(operation, &[&read])]);
1804 println!("world:\n{}\n", syms.print_world(&w));
1805 println!("\ntesting r2: {}\n", syms.print_rule(&r2));
1806 let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap();
1807
1808 println!("generated facts:");
1809 for (_, fact) in res.iter_all() {
1810 println!("\t{}", syms.print_fact(fact));
1811 }
1812 assert!(res.is_empty());
1813 }
1814}