polar_core/
kb.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt::Write;
3use std::sync::Arc;
4
5pub use super::bindings::Bindings;
6use super::constants::Constants;
7use super::counter::Counter;
8use super::diagnostic::Diagnostic;
9use super::error::{invalid_state, PolarError, PolarResult, RuntimeError, ValidationError};
10use super::resource_block::{ResourceBlocks, ACTOR_UNION_NAME, RESOURCE_UNION_NAME};
11use super::rules::*;
12use super::terms::*;
13use super::validations::check_undefined_rule_calls;
14
15enum RuleParamMatch {
16    True,
17    False(String),
18}
19
20#[cfg(test)]
21impl RuleParamMatch {
22    fn is_true(&self) -> bool {
23        matches!(self, RuleParamMatch::True)
24    }
25}
26
27#[derive(Default)]
28pub struct KnowledgeBase {
29    /// A map of bindings: variable name → value. The VM uses a stack internally,
30    /// but can translate to and from this type.
31    constants: Constants,
32    /// Map of class name -> MRO list where the MRO list is a list of class instance IDs
33    pub mro: HashMap<Symbol, Vec<u64>>,
34
35    /// Map from contents to filename for files loaded into the KB.
36    loaded_content: HashMap<String, String>,
37
38    rules: HashMap<Symbol, GenericRule>,
39    rule_types: RuleTypes,
40    /// For symbols returned from gensym.
41    gensym_counter: Counter,
42    /// For call IDs, instance IDs, symbols, etc.
43    id_counter: Counter,
44    pub inline_queries: Vec<Term>,
45
46    /// Resource block bookkeeping.
47    pub resource_blocks: ResourceBlocks,
48}
49
50impl KnowledgeBase {
51    pub fn new() -> Self {
52        Self::default()
53    }
54
55    /// Return a monotonically increasing integer ID.
56    ///
57    /// Wraps around at 52 bits of precision so that it can be safely
58    /// coerced to an IEEE-754 double-float (f64).
59    pub fn new_id(&self) -> u64 {
60        self.id_counter.next()
61    }
62
63    pub fn id_counter(&self) -> Counter {
64        self.id_counter.clone()
65    }
66
67    /// Generate a temporary variable prefix from a variable name.
68    pub fn temp_prefix(name: &str) -> String {
69        match name {
70            "_" => String::from(name),
71            _ => format!("_{}_", name),
72        }
73    }
74
75    /// Generate a new symbol.
76    pub fn gensym(&self, prefix: &str) -> Symbol {
77        let next = self.gensym_counter.next();
78        Symbol(format!("{}{}", Self::temp_prefix(prefix), next))
79    }
80
81    /// Add a generic rule to the knowledge base.
82    #[cfg(test)]
83    pub fn add_generic_rule(&mut self, rule: GenericRule) {
84        self.rules.insert(rule.name.clone(), rule);
85    }
86
87    pub fn add_rule(&mut self, rule: Rule) {
88        let generic_rule = self
89            .rules
90            .entry(rule.name.clone())
91            .or_insert_with(|| GenericRule::new(rule.name.clone(), vec![]));
92        generic_rule.add_rule(Arc::new(rule));
93    }
94
95    pub fn validate_rules(&self) -> Vec<Diagnostic> {
96        // Prior to #1310 these validations were not order dependent due to the
97        // use of static default rule types.
98        // Now that rule types are dynamically generated based on policy
99        // contents we validate types first to surface missing required rule
100        // implementations which would otherwise raise opaque "call to undefined rule"
101        // errors
102        let mut diagnostics = vec![];
103
104        if let Err(e) = self.validate_rule_types() {
105            diagnostics.push(e.into());
106        }
107
108        diagnostics.append(&mut check_undefined_rule_calls(self));
109
110        diagnostics
111    }
112
113    /// Validate that all rules loaded into the knowledge base are valid based on rule types.
114    fn validate_rule_types(&self) -> PolarResult<()> {
115        // For every rule, if there *is* a rule type, check that the rule matches the rule type.
116        for (rule_name, generic_rule) in &self.rules {
117            if let Some(types) = self.rule_types.get(rule_name) {
118                // If a type with the same name exists, then the parameters must match for each rule
119                for rule in generic_rule.rules.values() {
120                    let mut msg = "Must match one of the following rule types:\n".to_owned();
121
122                    let results = types
123                        .iter()
124                        .map(|rule_type| {
125                            self.rule_params_match(rule.as_ref(), rule_type)
126                                .map(|result| (result, rule_type))
127                        })
128                        .collect::<PolarResult<Vec<_>>>()?;
129                    let found_match = results.iter().any(|(result, rule_type)| match result {
130                        RuleParamMatch::True => true,
131                        RuleParamMatch::False(message) => {
132                            write!(
133                                msg,
134                                "\n{}\n\tFailed to match because: {}\n",
135                                rule_type, message
136                            )
137                            .unwrap();
138                            false
139                        }
140                    });
141                    if !found_match {
142                        let rule = Rule::clone(rule);
143                        return Err(ValidationError::InvalidRule { rule, msg }.into());
144                    }
145                }
146            }
147        }
148
149        // For every rule type that is *required*, see that there is at least one corresponding
150        // implementation.
151        for rule_type in self.rule_types.required_rule_types() {
152            if let Some(GenericRule { rules, .. }) = self.rules.get(&rule_type.name) {
153                let mut found_match = false;
154                for rule in rules.values() {
155                    found_match = self
156                        .rule_params_match(rule.as_ref(), rule_type)
157                        .map(|r| matches!(r, RuleParamMatch::True))?;
158                    if found_match {
159                        break;
160                    }
161                }
162                if !found_match {
163                    let rule_type = rule_type.clone();
164                    return Err(ValidationError::MissingRequiredRule { rule_type }.into());
165                }
166            } else {
167                let rule_type = rule_type.clone();
168                return Err(ValidationError::MissingRequiredRule { rule_type }.into());
169            }
170        }
171
172        Ok(())
173    }
174
175    /// Determine whether the fields of a rule parameter specializer match the fields of a type parameter specializer.
176    /// Rule fields match if they are a superset of type fields and all field values are equal.
177    // TODO: once field-level specializers are working this should be updated so
178    // that it recursively checks all fields match, rather than checking for
179    // equality
180    fn param_fields_match(&self, type_fields: &Dictionary, rule_fields: &Dictionary) -> bool {
181        return type_fields
182            .fields
183            .iter()
184            .map(|(k, type_value)| {
185                rule_fields
186                    .fields
187                    .get(k)
188                    .map(|rule_value| rule_value == type_value)
189                    .unwrap_or_else(|| false)
190            })
191            .all(|v| v);
192    }
193
194    /// Use MRO lists passed in from host library to determine if one `InstanceLiteral` pattern is
195    /// a subclass of another `InstanceLiteral` pattern. This function is used for Rule Type
196    /// validation.
197    fn check_rule_instance_is_subclass_of_rule_type_instance(
198        &self,
199        rule_instance: &InstanceLiteral,
200        rule_type_instance: &InstanceLiteral,
201        index: usize,
202    ) -> PolarResult<RuleParamMatch> {
203        // Get the unique ID of the prototype instance pattern class.
204        // TODO(gj): make actual term available here instead of constructing a fake test one.
205        let term = self.get_registered_class(&term!(rule_type_instance.tag.clone()))?;
206        if let Value::ExternalInstance(ExternalInstance { instance_id, .. }) = term.value() {
207            if let Some(rule_mro) = self.mro.get(&rule_instance.tag) {
208                if !rule_mro.contains(instance_id) {
209                    Ok(RuleParamMatch::False(format!(
210                        "Rule specializer {} on parameter {} must match rule type specializer {}",
211                        rule_instance.tag, index, rule_type_instance.tag
212                    )))
213                } else if !self
214                    .param_fields_match(&rule_type_instance.fields, &rule_instance.fields)
215                {
216                    Ok(RuleParamMatch::False(format!("Rule specializer {} on parameter {} did not match rule type specializer {} because the specializer fields did not match.", rule_instance, index, rule_type_instance)))
217                } else {
218                    Ok(RuleParamMatch::True)
219                }
220            } else {
221                // If `rule_instance.tag` were registered as a class, it would have an MRO.
222                Ok(RuleParamMatch::False(format!(
223                    "Rule specializer {} on parameter {} is not registered as a class.",
224                    rule_instance.tag, index
225                )))
226            }
227        } else {
228            Ok(RuleParamMatch::False(format!(
229                "Rule type specializer {} on parameter {} should be a registered class, but instead it's registered as a constant with value: {}",
230                rule_type_instance.tag, index, term
231            )))
232        }
233    }
234
235    /// Check that a rule parameter that has a pattern specializer matches a rule type parameter that has a pattern specializer.
236    fn check_pattern_param(
237        &self,
238        index: usize,
239        rule_pattern: &Pattern,
240        rule_type_pattern: &Pattern,
241    ) -> PolarResult<RuleParamMatch> {
242        Ok(match (rule_type_pattern, rule_pattern) {
243            (Pattern::Instance(rule_type_instance), Pattern::Instance(rule_instance)) => {
244                // if tags match, all rule type fields must match those in rule fields, otherwise false
245                if rule_type_instance.tag == rule_instance.tag {
246                    if self.param_fields_match(
247                        &rule_type_instance.fields,
248                        &rule_instance.fields,
249                    ) {
250                        RuleParamMatch::True
251                    } else {
252                        RuleParamMatch::False(format!("Rule specializer {} on parameter {} did not match rule type specializer {} because the specializer fields did not match.", rule_instance, index, rule_type_instance))
253                    }
254                } else if self.is_union(&term!(sym!(&rule_type_instance.tag.0))) {
255                    if self.is_union(&term!(sym!(&rule_instance.tag.0))) {
256                        // If both specializers are the same union, check fields.
257                        if rule_instance.tag == rule_type_instance.tag {
258                            if self.param_fields_match(
259                                &rule_type_instance.fields,
260                                &rule_instance.fields,
261                            ) {
262                                return Ok(RuleParamMatch::True);
263                            } else {
264                                return Ok(RuleParamMatch::False(format!("Rule specializer {} on parameter {} did not match rule type specializer {} because the specializer fields did not match.", rule_instance, index, rule_type_instance)));
265                            }
266                        } else {
267                            // TODO(gj): revisit when we have unions beyond Actor & Resource. Union
268                            // A matches union B if union A is a member of union B.
269                            return Ok(RuleParamMatch::False(format!("Rule specializer {} on parameter {} does not match rule type specializer {}", rule_instance.tag, index, rule_type_instance.tag)));
270                        }
271                    }
272
273                    let members = self.get_union_members(&term!(sym!(&rule_type_instance.tag.0)));
274                    // If the rule specializer is not a direct member of the union, we still need
275                    // to check if it's a subclass of any member of the union.
276                    if !members.contains(&term!(sym!(&rule_instance.tag.0))) {
277                        let mut success = false;
278                        for member in members {
279                            // Turn `member` into an `InstanceLiteral` by copying fields from
280                            // `rule_type_instance`.
281                            let rule_type_instance = InstanceLiteral {
282                                tag: member.as_symbol()?.clone(),
283                                fields: rule_type_instance.fields.clone()
284                            };
285                            match self.check_rule_instance_is_subclass_of_rule_type_instance(rule_instance, &rule_type_instance, index) {
286                                Ok(RuleParamMatch::True) if !success => success = true,
287                                Err(e) => return Err(e),
288                                _ => (),
289                            }
290                        }
291                        if !success {
292                            let mut err = format!("Rule specializer {} on parameter {} must be a member of rule type specializer {}", rule_instance.tag,index, rule_type_instance.tag);
293                            if rule_type_instance.tag.0 == ACTOR_UNION_NAME {
294                                write!(err, "
295
296\tPerhaps you meant to add an actor block to the top of your policy, like this:
297
298\t  actor {} {{}}", rule_instance.tag).unwrap();
299                            } else if rule_type_instance.tag.0 == RESOURCE_UNION_NAME {
300                                write!(err, "
301
302\tPerhaps you meant to add a resource block to your policy, like this:
303
304\t  resource {} {{ .. }}", rule_instance.tag).unwrap();
305
306                            }
307
308                            return Ok(RuleParamMatch::False(err));
309                        }
310                    }
311                    if !self.param_fields_match(&rule_type_instance.fields, &rule_instance.fields) {
312                        RuleParamMatch::False(format!("Rule specializer {} on parameter {} did not match rule type specializer {} because the specializer fields did not match.", rule_instance, index, rule_type_instance))
313                    } else {
314                        RuleParamMatch::True
315                    }
316                // If tags don't match, then rule specializer must be a subclass of rule type specializer
317                } else {
318                    self.check_rule_instance_is_subclass_of_rule_type_instance(rule_instance, rule_type_instance, index)?
319                }
320            }
321            (Pattern::Dictionary(rule_type_fields), Pattern::Dictionary(rule_fields))
322            | (Pattern::Dictionary(rule_type_fields), Pattern::Instance(InstanceLiteral { fields: rule_fields, .. })) => {
323                if self.param_fields_match(rule_type_fields, rule_fields) {
324                    RuleParamMatch::True
325                } else {
326                    RuleParamMatch::False(format!("Specializer mismatch on parameter {}. Rule specializer fields {:#?} do not match rule type specializer fields {:#?}.", index, rule_fields, rule_type_fields))
327                }
328            }
329            (
330                Pattern::Instance(InstanceLiteral {
331                    tag,
332                    fields: rule_type_fields,
333                }),
334                Pattern::Dictionary(rule_fields),
335            ) if tag == &sym!("Dictionary") => {
336                if self.param_fields_match(rule_type_fields, rule_fields) {
337                    RuleParamMatch::True
338                } else {
339                    RuleParamMatch::False(format!("Specializer mismatch on parameter {}. Rule specializer fields {:#?} do not match rule type specializer fields {:#?}.", index, rule_fields, rule_type_fields))
340                }
341            }
342            (_, _) => {
343                RuleParamMatch::False(format!("Mismatch on parameter {}. Rule parameter {:#?} does not match rule type parameter {:#?}.", index, rule_type_pattern, rule_pattern))
344            }
345        })
346    }
347
348    /// Check that a rule parameter that is a value matches a rule type parameter that is a value
349    fn check_value_param(
350        &self,
351        index: usize,
352        rule_value: &Value,
353        rule_type_value: &Value,
354        rule_type: &Rule,
355    ) -> PolarResult<RuleParamMatch> {
356        Ok(match (rule_type_value, rule_value) {
357            // List in rule head must be equal to or more specific than the list in the rule type head in order to match
358            (Value::List(rule_type_list), Value::List(rule_list)) => {
359                if has_rest_var(rule_type_list) {
360                    return Err(ValidationError::InvalidRuleType {
361                        rule_type: rule_type.clone(),
362                        msg: "Rule types cannot contain *rest variables.".to_string(),
363                    }
364                    .into());
365                }
366                if rule_type_list.iter().all(|t| rule_list.contains(t)) {
367                    RuleParamMatch::True
368                } else {
369                    RuleParamMatch::False(format!(
370                        "Invalid parameter {}. Rule type expected list {:#?}, got list {:#?}.",
371                        index, rule_type_list, rule_list
372                    ))
373                }
374            }
375            (Value::Dictionary(rule_type_fields), Value::Dictionary(rule_fields)) => {
376                if self.param_fields_match(rule_type_fields, rule_fields) {
377                    RuleParamMatch::True
378                } else {
379                    RuleParamMatch::False(format!("Invalid parameter {}. Rule type expected Dictionary with fields {:#?}, got Dictionary with fields {:#?}", index, rule_type_fields, rule_fields
380                        ))
381                }
382            }
383            (_, _) => {
384                if rule_type_value == rule_value {
385                    RuleParamMatch::True
386                } else {
387                    RuleParamMatch::False(format!(
388                        "Invalid parameter {}. Rule value {} != rule type value {}",
389                        index, rule_value, rule_type_value
390                    ))
391                }
392            }
393        })
394    }
395    /// Check a single rule parameter against a rule type parameter.
396    fn check_param(
397        &self,
398        index: usize,
399        rule_param: &Parameter,
400        rule_type_param: &Parameter,
401        rule_type: &Rule,
402    ) -> PolarResult<RuleParamMatch> {
403        Ok(
404            match (
405                rule_type_param.parameter.value(),
406                rule_type_param.specializer.as_ref().map(Term::value),
407                rule_param.parameter.value(),
408                rule_param.specializer.as_ref().map(Term::value),
409            ) {
410                // Rule and rule type both have pattern specializers
411                (
412                    Value::Variable(_),
413                    Some(Value::Pattern(rule_type_spec)),
414                    Value::Variable(_),
415                    Some(Value::Pattern(rule_spec)),
416                ) => self.check_pattern_param(index, rule_spec, rule_type_spec)?,
417                // RuleType has an instance pattern specializer but rule has no specializer
418                (
419                    Value::Variable(_),
420                    Some(Value::Pattern(Pattern::Instance(InstanceLiteral { tag, .. }))),
421                    Value::Variable(parameter),
422                    None,
423                ) => RuleParamMatch::False(format!(
424                    "Parameter `{parameter}` expects a {tag} type constraint.
425
426\t{parameter}: {tag}",
427                    parameter = parameter,
428                    tag = tag
429                )),
430                // RuleType has specializer but rule doesn't
431                (Value::Variable(_), Some(rule_type_spec), Value::Variable(_), None) => {
432                    RuleParamMatch::False(format!(
433                        "Invalid rule parameter {}. Rule type expected {}",
434                        index, rule_type_spec
435                    ))
436                }
437                // Rule has value or value specializer, rule type has pattern specializer
438                (
439                    Value::Variable(_),
440                    Some(Value::Pattern(rule_type_spec)),
441                    Value::Variable(_),
442                    Some(rule_value),
443                )
444                | (Value::Variable(_), Some(Value::Pattern(rule_type_spec)), rule_value, None) => {
445                    match rule_type_spec {
446                        // Rule type specializer is an instance pattern
447                        Pattern::Instance(InstanceLiteral { .. }) => {
448                            let rule_spec = match rule_value {
449                                Value::String(_) => instance!(sym!("String")),
450                                Value::Number(Numeric::Integer(_)) => instance!(sym!("Integer")),
451                                Value::Number(Numeric::Float(_)) => instance!(sym!("Float")),
452                                Value::Boolean(_) => instance!(sym!("Boolean")),
453                                Value::List(_) => instance!(sym!("List")),
454                                Value::Dictionary(rule_fields) => {
455                                    instance!(sym!("Dictionary"), rule_fields.clone().fields)
456                                }
457                                _ => {
458                                    // TODO(gj): what type of value could this be? Will this get
459                                    // past the parser or is it unreachable? Prior to #1356 we
460                                    // could hit this branch with a `Value::Variable` if the
461                                    // specializer in the rule head was parenthesized.
462                                    return invalid_state(format!(
463                                        "Value variant {} cannot be a specializer",
464                                        rule_value
465                                    ));
466                                }
467                            };
468                            self.check_pattern_param(
469                                index,
470                                &Pattern::Instance(rule_spec),
471                                rule_type_spec,
472                            )?
473                        }
474                        // Rule type specializer is a dictionary pattern
475                        Pattern::Dictionary(rule_type_fields) => {
476                            if let Value::Dictionary(rule_fields) = rule_value {
477                                if self.param_fields_match(rule_type_fields, rule_fields) {
478                                    RuleParamMatch::True
479                                } else {
480                                    RuleParamMatch::False(format!("Invalid parameter {}. Rule type expected Dictionary with fields {}, got dictionary with fields {}.", index, rule_type_fields, rule_fields))
481                                }
482                            } else {
483                                RuleParamMatch::False(format!(
484                                    "Invalid parameter {}. Rule type expected Dictionary, got {}.",
485                                    index, rule_value
486                                ))
487                            }
488                        }
489                    }
490                }
491
492                // Rule type has no specializer
493                (Value::Variable(_), None, _, _) => RuleParamMatch::True,
494                // Rule has value or value specializer, rule type has value specializer |
495                // rule has value, rule type has value
496                (
497                    Value::Variable(_),
498                    Some(rule_type_value),
499                    Value::Variable(_),
500                    Some(rule_value),
501                )
502                | (Value::Variable(_), Some(rule_type_value), rule_value, None)
503                | (rule_type_value, None, rule_value, None) => {
504                    self.check_value_param(index, rule_value, rule_type_value, rule_type)?
505                }
506                _ => RuleParamMatch::False(format!(
507                    "Invalid parameter {}. Rule parameter {} does not match rule type parameter {}",
508                    index, rule_param, rule_type_param
509                )),
510            },
511        )
512    }
513
514    /// Determine whether a `rule` matches a `rule_type` based on its parameters.
515    fn rule_params_match(&self, rule: &Rule, rule_type: &Rule) -> PolarResult<RuleParamMatch> {
516        if rule.params.len() != rule_type.params.len() {
517            return Ok(RuleParamMatch::False(format!(
518                "Different number of parameters. Rule has {} parameter(s) but rule type has {}.",
519                rule.params.len(),
520                rule_type.params.len()
521            )));
522        }
523        let mut failure_message = "".to_owned();
524        rule.params
525            .iter()
526            .zip(rule_type.params.iter())
527            .enumerate()
528            .map(|(i, (rule_param, rule_type_param))| {
529                self.check_param(i + 1, rule_param, rule_type_param, rule_type)
530            })
531            .collect::<PolarResult<Vec<RuleParamMatch>>>()
532            .map(|results| {
533                // TODO(gj): all() is short-circuiting -- do we want to gather up *all* failure
534                // messages instead of just the first one?
535                results.iter().all(|r| {
536                    if let RuleParamMatch::False(msg) = r {
537                        failure_message = msg.to_owned();
538                        false
539                    } else {
540                        true
541                    }
542                })
543            })
544            .map(|matched| {
545                if matched {
546                    RuleParamMatch::True
547                } else {
548                    RuleParamMatch::False(failure_message)
549                }
550            })
551    }
552
553    pub fn get_rules(&self) -> &HashMap<Symbol, GenericRule> {
554        &self.rules
555    }
556
557    #[cfg(test)]
558    pub fn get_rule_types(&self, name: &Symbol) -> Option<&Vec<Rule>> {
559        self.rule_types.get(name)
560    }
561
562    pub fn get_generic_rule(&self, name: &Symbol) -> Option<&GenericRule> {
563        self.rules.get(name)
564    }
565
566    pub fn add_rule_type(&mut self, rule_type: Rule) {
567        self.rule_types.add(rule_type);
568    }
569
570    /// Define a constant variable.
571    ///
572    /// Error on attempts to register the "union" types (Actor & Resource) since those types have
573    /// special meaning in policies that use resource blocks.
574    pub fn register_constant(&mut self, name: Symbol, value: Term) -> PolarResult<()> {
575        if name.0 == ACTOR_UNION_NAME || name.0 == RESOURCE_UNION_NAME {
576            return Err(RuntimeError::InvalidRegistration {
577                msg: format!("'{}' is a built-in specializer.", name),
578                sym: name,
579            }
580            .into());
581        }
582
583        if let Value::ExternalInstance(ExternalInstance {
584            class_id,
585            instance_id,
586            ..
587        }) = *value.value()
588        {
589            if class_id.map_or(false, |id| id == instance_id) {
590                // ExternalInstance values with matching class_id & instance_id represent *classes*
591                // whose class_id we want to index for later type checking & MRO resolution
592                self.constants.insert_class(name, value, instance_id)
593            } else {
594                // ExternalInstance values with differing `class_id` and
595                // `instance_id` represent *instances* of classes whose class_id
596                // should not be registered
597                self.constants.insert(name, value)
598            }
599        } else {
600            self.constants.insert(name, value)
601        }
602        Ok(())
603    }
604
605    /// Return true if a constant with the given name has been defined.
606    pub fn is_constant(&self, name: &Symbol) -> bool {
607        self.constants.contains_key(name)
608    }
609
610    /// Getter for `constants` map without exposing it for mutation.
611    pub fn get_registered_constants(&self) -> &Bindings {
612        &self.constants.symbol_to_term
613    }
614
615    pub(crate) fn get_symbol_for_class_id(&self, id: &u64) -> Option<&Symbol> {
616        self.constants.get_symbol_for_class_id(id)
617    }
618
619    pub(crate) fn get_class_id_for_symbol(&self, symbol: &Symbol) -> Option<&u64> {
620        self.constants.get_class_id_for_symbol(symbol)
621    }
622
623    // TODO(gj): currently no way to distinguish classes from other registered constants in the
624    // core, so it's up to callers to ensure this is only called with terms we expect to be
625    // registered as a _class_.
626    pub fn get_registered_class(&self, class: &Term) -> PolarResult<&Term> {
627        self.constants.get(class.as_symbol()?).ok_or_else(|| {
628            ValidationError::UnregisteredClass {
629                term: class.clone(),
630            }
631            .into()
632        })
633    }
634
635    /// Add the Method Resolution Order (MRO) list for a registered class.
636    /// The `mro` argument is a list of the `instance_id` associated with a registered class.
637    pub fn add_mro(&mut self, name: Symbol, mro: Vec<u64>) -> PolarResult<()> {
638        // Confirm name is a registered class
639        if !self.is_constant(&name) {
640            return invalid_state(format!("Cannot add MRO for unregistered class {}", name));
641        }
642        self.mro.insert(name, mro);
643        Ok(())
644    }
645
646    pub fn clear_rules(&mut self) {
647        self.rules.clear();
648        self.rule_types.reset();
649        self.inline_queries.clear();
650        self.loaded_content.clear();
651        self.resource_blocks.clear();
652    }
653
654    // TODO(gj): Remove this fn & `FileLoading` error variant. These checks don't spark joy.
655    pub(crate) fn add_source(&mut self, filename: &str, contents: &str) -> PolarResult<()> {
656        let seen_filename = self.loaded_content.values().any(|name| name == filename);
657        match self.loaded_content.insert(contents.into(), filename.into()) {
658            Some(other_file) if other_file == filename => {
659                Err(format!("File {} has already been loaded.", filename))
660            }
661            Some(other_file) => Err(format!(
662                "A file with the same contents as {} named {} has already been loaded.",
663                filename, other_file
664            )),
665            _ if seen_filename => Err(format!(
666                "A file with the name {}, but different contents has already been loaded.",
667                filename
668            )),
669            _ => Ok(()),
670        }
671        .map_err(|msg| {
672            ValidationError::FileLoading {
673                filename: filename.into(),
674                contents: contents.into(),
675                msg,
676            }
677            .into()
678        })
679    }
680
681    /// Check that all relations declared across all resource blocks have been registered as
682    /// constants.
683    fn check_that_resource_block_relations_are_registered(&self) -> Vec<PolarError> {
684        self.resource_blocks
685            .relation_tuples()
686            .into_iter()
687            .filter_map(|(relation_type, _, _)| self.get_registered_class(relation_type).err())
688            .collect()
689    }
690
691    pub fn rewrite_shorthand_rules(&mut self) -> Vec<PolarError> {
692        let mut errors = vec![];
693
694        errors.append(&mut self.check_that_resource_block_relations_are_registered());
695
696        let mut rules = vec![];
697        for (resource_name, shorthand_rules) in &self.resource_blocks.shorthand_rules {
698            for shorthand_rule in shorthand_rules {
699                match shorthand_rule.as_rule(resource_name, &self.resource_blocks) {
700                    Ok(rule) => rules.push(rule),
701                    Err(error) => errors.push(error),
702                }
703            }
704        }
705
706        // Add the rewritten rules to the KB.
707        for rule in rules {
708            self.add_rule(rule);
709        }
710
711        errors
712    }
713
714    pub fn create_resource_specific_rule_types(&mut self) -> PolarResult<()> {
715        let mut rule_types_to_create = HashMap::new();
716
717        // TODO @patrickod refactor RuleTypes & split out
718        // RequiredRuleType struct to record the related
719        // shorthand rule and relation terms.
720
721        // Iterate through all resource block declarations and create
722        // non-required rule types for each relation declaration we observe.
723        //
724        // We create non-required rule types to gracefully account for the case
725        // where users have declared relations ahead of time that are used in
726        // rule or resource definitions.
727        for (subject, name, object) in self.resource_blocks.relation_tuples() {
728            rule_types_to_create.insert((subject, name, object), false);
729        }
730
731        // Iterate through resource block shorthand rules and create *required*
732        // rule types for each relation which is traversed in the rules.
733        for (object, shorthand_rules) in &self.resource_blocks.shorthand_rules {
734            for shorthand_rule in shorthand_rules {
735                // We create rule types from shorthand rules in the following scenarios...
736                match &shorthand_rule.body {
737                    // 1. When the the third "relation" term points to a related Resource. E.g.,
738                    //    `"admin" if "admin" on "parent";` where `relations = { parent: Org };`.
739                    (implier, Some((_, relation))) => {
740                        // First, create required rule type for relationship between `object` and
741                        // `subject`:
742                        //
743                        // resource Repo {
744                        //   roles = ["writer"];
745                        //   relations = { parent_org: Org };
746                        //
747                        //   "writer" if "admin" on "parent_org";
748                        // }
749                        //
750                        // (required) type has_relation(org: Org, "parent_org", repo: Repo);
751                        //
752                        // resource Org {
753                        //   roles = ["admin"];
754                        // }
755                        if let Ok(subject) = self
756                            .resource_blocks
757                            .get_relation_type_in_resource_block(relation, object)
758                        {
759                            rule_types_to_create.insert((subject, relation, object), true);
760
761                            // Then, if the "implier" term is declared as a relation on `subject`
762                            // (as opposed to a permission or role), create required rule type for
763                            // relationship between `related_subject` and `subject`:
764                            //
765                            // resource Repo {
766                            //   roles = ["writer"];
767                            //   relations = { parent_org: Org };
768                            //
769                            //   "writer" if "owner" on "parent_org";
770                            // }
771                            //
772                            // (required) type has_relation(org: Org, "parent_org", issue: Issue);
773                            //
774                            // resource Org {
775                            //   relations = { owner: User };
776                            // }
777                            //
778                            // (required) type has_relation(user: User, "owner", org: Org);
779                            if let Ok(related_subject) = self
780                                .resource_blocks
781                                .get_relation_type_in_resource_block(implier, subject)
782                            {
783                                rule_types_to_create
784                                    .insert((related_subject, implier, subject), true);
785                            }
786                        }
787                    }
788
789                    // 2. When the second "implier" term points to a related Actor. E.g., `"admin"
790                    //    if "owner";` where `relations = { owner: User };`. Technically, "implier"
791                    //    could be a related Resource, but that doesn't make much semantic sense.
792                    //    Related resources should be traversed via `"on"` clauses, which are
793                    //    captured in the above match arm.
794                    (implier, None) => {
795                        if let Ok(subject) = self
796                            .resource_blocks
797                            .get_relation_type_in_resource_block(implier, object)
798                        {
799                            rule_types_to_create.insert((subject, implier, object), true);
800                        }
801                    }
802                }
803            }
804        }
805
806        let mut rule_types = rule_types_to_create.into_iter().map(|((subject, relation, object), required)| {
807            let subject_specializer = pattern!(instance!(&subject.as_symbol()?.0));
808            let relation_name = relation.as_string()?;
809            let object_specializer = pattern!(instance!(&object.as_symbol()?.0));
810
811            let name = sym!("has_relation");
812            let mut params = args!("subject"; subject_specializer, relation_name, "object"; object_specializer);
813            params.reverse();
814            let body = term!(op!(And));
815            // Copy SourceInfo from implier or relation in shorthand rule.
816            let source_info = relation.source_info().clone();
817            Ok(Rule { name, params, body, source_info, required })
818        }).collect::<PolarResult<Vec<_>>>()?;
819
820        // If there are any Relation::Role declarations in *any* of our resource
821        // blocks then we want to add the `has_role` rule type.
822        if self.resource_blocks.has_roles() {
823            rule_types.push(
824                // TODO(gj): "Internal" SourceInfo variant.
825                // TODO(gj): Figure out if it's worth setting SourceInfo::Parser context for this
826                // `has_role` rule type we create. Best we could probably do at the moment is fetch
827                // a random role from self.resource_blocks.declarations and borrow its context.
828                rule!("has_role", ["actor"; instance!(ACTOR_UNION_NAME), "role"; instance!("String"), "resource"; instance!(RESOURCE_UNION_NAME)], true)
829            );
830        }
831
832        for rule_type in rule_types {
833            self.add_rule_type(rule_type.clone());
834        }
835        Ok(())
836    }
837
838    pub fn is_union(&self, maybe_union: &Term) -> bool {
839        (maybe_union.is_actor_union()) || (maybe_union.is_resource_union())
840    }
841
842    pub fn get_union_members(&self, union: &Term) -> &HashSet<Term> {
843        if union.is_actor_union() {
844            &self.resource_blocks.actors
845        } else if union.is_resource_union() {
846            &self.resource_blocks.resources
847        } else {
848            unreachable!()
849        }
850    }
851
852    pub fn has_rules(&self) -> bool {
853        !self.rules.is_empty()
854    }
855}
856
857#[cfg(test)]
858mod tests {
859    use super::*;
860
861    use crate::error::ValidationError::{FileLoading, InvalidRule};
862
863    #[test]
864    fn test_add_source_file_validation() {
865        fn expect_error(kb: &mut KnowledgeBase, name: &str, contents: &str, expected: &str) {
866            let err = kb.add_source(name, contents).unwrap_err();
867            let msg = match err.unwrap_validation() {
868                FileLoading { msg, .. } => msg,
869                e => panic!("Unexpected error: {}", e),
870            };
871            assert_eq!(msg, expected);
872        }
873
874        let mut kb = KnowledgeBase::new();
875        let contents1 = "f();";
876        let contents2 = "g();";
877        let filename1 = "f";
878        let filename2 = "g";
879
880        // Load source1.
881        kb.add_source(filename1, contents1).unwrap();
882
883        // Cannot load source1 a second time.
884        let expected = format!("File {} has already been loaded.", filename1);
885        expect_error(&mut kb, filename1, contents1, &expected);
886
887        // Cannot load source2 with the same name as source1 but different contents.
888        let expected = format!(
889            "A file with the name {}, but different contents has already been loaded.",
890            filename1
891        );
892        expect_error(&mut kb, filename1, contents2, &expected);
893
894        // Cannot load source3 with the same contents as source1 but a different name.
895        let expected = format!(
896            "A file with the same contents as {} named {} has already been loaded.",
897            filename2, filename1
898        );
899        expect_error(&mut kb, filename2, contents1, &expected);
900    }
901
902    #[test]
903    fn test_rule_params_match() {
904        let mut kb = KnowledgeBase::new();
905
906        let mut constant = |name: &str, instance_id: u64| {
907            kb.register_constant(
908                sym!(name),
909                term!(Value::ExternalInstance(ExternalInstance {
910                    instance_id,
911                    constructor: None,
912                    repr: None,
913                    class_repr: None,
914                    class_id: None,
915                })),
916            )
917            .unwrap();
918        };
919
920        constant("Fruit", 1);
921        constant("Citrus", 2);
922        constant("Orange", 3);
923        // NOTE: Foo doesn't need an MRO b/c it only appears as a rule type specializer; not a rule
924        // specializer.
925        constant("Foo", 4);
926
927        // NOTE: this is only required for these tests b/c we're bypassing the normal load process,
928        // where MROs are registered via FFI calls in the host language libraries.
929        // process.
930        constant("Integer", 5);
931        constant("Float", 6);
932        constant("String", 7);
933        constant("Boolean", 8);
934        constant("List", 9);
935        constant("Dictionary", 10);
936
937        kb.add_mro(sym!("Fruit"), vec![1]).unwrap();
938        // Citrus is a subclass of Fruit
939        kb.add_mro(sym!("Citrus"), vec![2, 1]).unwrap();
940        // Orange is a subclass of Citrus
941        kb.add_mro(sym!("Orange"), vec![3, 2, 1]).unwrap();
942
943        kb.add_mro(sym!("Integer"), vec![]).unwrap();
944        kb.add_mro(sym!("Float"), vec![]).unwrap();
945        kb.add_mro(sym!("String"), vec![]).unwrap();
946        kb.add_mro(sym!("Boolean"), vec![]).unwrap();
947        kb.add_mro(sym!("List"), vec![]).unwrap();
948        kb.add_mro(sym!("Dictionary"), vec![]).unwrap();
949
950        // BOTH PATTERN SPEC
951        // rule: f(x: Foo), rule_type: f(x: Foo) => PASS
952        assert!(kb
953            .rule_params_match(
954                &rule!("f", ["x"; instance!(sym!("Fruit"))]),
955                &rule!("f", ["x"; instance!(sym!("Fruit"))])
956            )
957            .unwrap()
958            .is_true());
959
960        // rule: f(x: Foo), rule_type: f(x: Bar) => FAIL if Foo is not subclass of Bar
961        assert!(!kb
962            .rule_params_match(
963                &rule!("f", ["x"; instance!(sym!("Fruit"))]),
964                &rule!("f", ["x"; instance!(sym!("Citrus"))])
965            )
966            .unwrap()
967            .is_true());
968
969        // rule: f(x: Foo), rule_type: f(x: Bar) => PASS if Foo is subclass of Bar
970        assert!(kb
971            .rule_params_match(
972                &rule!("f", ["x"; instance!(sym!("Citrus"))]),
973                &rule!("f", ["x"; instance!(sym!("Fruit"))])
974            )
975            .unwrap()
976            .is_true());
977        assert!(kb
978            .rule_params_match(
979                &rule!("f", ["x"; instance!(sym!("Orange"))]),
980                &rule!("f", ["x"; instance!(sym!("Fruit"))])
981            )
982            .unwrap()
983            .is_true());
984
985        // rule: f(x: Foo), rule_type: f(x: {id: 1}) => FAIL
986        assert!(!kb
987            .rule_params_match(
988                &rule!("f", ["x"; instance!(sym!("Foo"))]),
989                &rule!("f", ["x"; btreemap! {sym!("id") => term!(1)}])
990            )
991            .unwrap()
992            .is_true());
993        // rule: f(x: Foo{id: 1}), rule_type: f(x: {id: 1}) => PASS
994        assert!(kb
995            .rule_params_match(
996                &rule!(
997                    "f",
998                    ["x"; instance!(sym!("Foo"), btreemap! {sym!("id") => term!(1)})]
999                ),
1000                &rule!("f", ["x"; btreemap! {sym!("id") => term!(1)}])
1001            )
1002            .unwrap()
1003            .is_true());
1004        // rule: f(x: {id: 1}), rule_type: f(x: Foo{id: 1}) => FAIL
1005        assert!(!kb
1006            .rule_params_match(
1007                &rule!("f", ["x"; btreemap! {sym!("id") => term!(1)}]),
1008                &rule!(
1009                    "f",
1010                    ["x"; instance!(sym!("Foo"), btreemap! {sym!("id") => term!(1)})]
1011                )
1012            )
1013            .unwrap()
1014            .is_true());
1015        // rule: f(x: {id: 1}), rule_type: f(x: {id: 1}) => PASS
1016        assert!(kb
1017            .rule_params_match(
1018                &rule!("f", ["x"; btreemap! {sym!("id") => term!(1)}]),
1019                &rule!("f", ["x"; btreemap! {sym!("id") => term!(1)}])
1020            )
1021            .unwrap()
1022            .is_true());
1023
1024        // RULE VALUE SPEC, TEMPLATE PATTERN SPEC
1025        // rule: f(x: 6), rule_type: f(x: Integer) => PASS
1026        assert!(kb
1027            .rule_params_match(
1028                &rule!("f", ["x"; value!(6)]),
1029                &rule!("f", ["x"; instance!(sym!("Integer"))])
1030            )
1031            .unwrap()
1032            .is_true());
1033
1034        // rule: f(x: 6), rule_type: f(x: Foo) => FAIL
1035        assert!(!kb
1036            .rule_params_match(
1037                &rule!("f", ["x"; value!(6)]),
1038                &rule!("f", ["x"; instance!(sym!("Foo"))])
1039            )
1040            .unwrap()
1041            .is_true());
1042        // rule: f(x: "string"), rule_type: f(x: Integer) => FAIL
1043        assert!(!kb
1044            .rule_params_match(
1045                &rule!("f", ["x"; value!("string")]),
1046                &rule!("f", ["x"; instance!(sym!("Integer"))])
1047            )
1048            .unwrap()
1049            .is_true());
1050        // rule: f(x: 6.0), rule_type: f(x: Float) => PASS
1051        assert!(kb
1052            .rule_params_match(
1053                &rule!("f", ["x"; value!(6.0)]),
1054                &rule!("f", ["x"; instance!(sym!("Float"))])
1055            )
1056            .unwrap()
1057            .is_true());
1058        // rule: f(x: 6.0), rule_type: f(x: Foo) => FAIL
1059        assert!(!kb
1060            .rule_params_match(
1061                &rule!("f", ["x"; value!(6.0)]),
1062                &rule!("f", ["x"; instance!(sym!("Foo"))])
1063            )
1064            .unwrap()
1065            .is_true());
1066        // rule: f(x: 6), rule_type: f(x: Float) => FAIL
1067        assert!(!kb
1068            .rule_params_match(
1069                &rule!("f", ["x"; value!(6)]),
1070                &rule!("f", ["x"; instance!(sym!("Float"))])
1071            )
1072            .unwrap()
1073            .is_true());
1074        // rule: f(x: "hi"), rule_type: f(x: String) => PASS
1075        assert!(kb
1076            .rule_params_match(
1077                &rule!("f", ["x"; value!("hi")]),
1078                &rule!("f", ["x"; instance!(sym!("String"))])
1079            )
1080            .unwrap()
1081            .is_true());
1082        // rule: f(x: "hi"), rule_type: f(x: Foo) => FAIL
1083        assert!(!kb
1084            .rule_params_match(
1085                &rule!("f", ["x"; value!("hi")]),
1086                &rule!("f", ["x"; instance!(sym!("Foo"))])
1087            )
1088            .unwrap()
1089            .is_true());
1090        // rule: f(x: 6), rule_type: f(x: String) => FAIL
1091        assert!(!kb
1092            .rule_params_match(
1093                &rule!("f", ["x"; value!(6)]),
1094                &rule!("f", ["x"; instance!(sym!("String"))])
1095            )
1096            .unwrap()
1097            .is_true());
1098        // Ensure primitive types cannot have fields
1099        // rule: f(x: "hello"), rule_type: f(x: String{id: 1}) => FAIL
1100        assert!(!kb
1101            .rule_params_match(
1102                &rule!("f", ["x"; value!("hello")]),
1103                &rule!(
1104                    "f",
1105                    ["x"; instance!(sym!("String"), btreemap! {sym!("id") => term!(1)})]
1106                )
1107            )
1108            .unwrap()
1109            .is_true());
1110        // rule: f(x: true), rule_type: f(x: Boolean) => PASS
1111        assert!(kb
1112            .rule_params_match(
1113                &rule!("f", ["x"; value!(true)]),
1114                &rule!("f", ["x"; instance!(sym!("Boolean"))])
1115            )
1116            .unwrap()
1117            .is_true());
1118        // rule: f(x: true), rule_type: f(x: Foo) => FAIL
1119        assert!(!kb
1120            .rule_params_match(
1121                &rule!("f", ["x"; value!(true)]),
1122                &rule!("f", ["x"; instance!(sym!("Foo"))])
1123            )
1124            .unwrap()
1125            .is_true());
1126        // rule: f(x: 6), rule_type: f(x: Boolean) => FAIL
1127        assert!(!kb
1128            .rule_params_match(
1129                &rule!("f", ["x"; value!(6)]),
1130                &rule!("f", ["x"; instance!(sym!("Boolean"))])
1131            )
1132            .unwrap()
1133            .is_true());
1134        // rule: f(x: [1, 2]), rule_type: f(x: List) => PASS
1135        assert!(kb
1136            .rule_params_match(
1137                &rule!("f", ["x"; value!([1, 2])]),
1138                &rule!("f", ["x"; instance!(sym!("List"))])
1139            )
1140            .unwrap()
1141            .is_true());
1142        // rule: f(x: [1, 2]), rule_type: f(x: Foo) => FAIL
1143        assert!(!kb
1144            .rule_params_match(
1145                &rule!("f", ["x"; value!([1, 2])]),
1146                &rule!("f", ["x"; instance!(sym!("Foo"))])
1147            )
1148            .unwrap()
1149            .is_true());
1150        // rule: f(x: 6), rule_type: f(x: List) => FAIL
1151        assert!(!kb
1152            .rule_params_match(
1153                &rule!("f", ["x"; value!(6)]),
1154                &rule!("f", ["x"; instance!(sym!("List"))])
1155            )
1156            .unwrap()
1157            .is_true());
1158        // rule: f(x: {id: 1}), rule_type: f(x: Dictionary) => PASS
1159        assert!(kb
1160            .rule_params_match(
1161                &rule!("f", ["x"; btreemap! {sym!("id") => term!(1)}]),
1162                &rule!("f", ["x"; instance!(sym!("Dictionary"))])
1163            )
1164            .unwrap()
1165            .is_true());
1166        // rule: f(x: {id: 1}), rule_type: f(x: Foo) => FAIL
1167        assert!(!kb
1168            .rule_params_match(
1169                &rule!("f", ["x"; btreemap! {sym!("id") => term!(1)}]),
1170                &rule!("f", ["x"; instance!(sym!("Foo"))])
1171            )
1172            .unwrap()
1173            .is_true());
1174        // rule: f({id: 1}), rule_type: f(x: Foo) => FAIL
1175        assert!(!kb
1176            .rule_params_match(
1177                &rule!("f", [btreemap! {sym!("id") => term!(1)}]),
1178                &rule!("f", ["x"; instance!(sym!("Foo"))])
1179            )
1180            .unwrap()
1181            .is_true());
1182        // rule: f(x: 6), rule_type: f(x: Dictionary) => FAIL
1183        assert!(!kb
1184            .rule_params_match(
1185                &rule!("f", ["x"; value!(6)]),
1186                &rule!("f", ["x"; instance!(sym!("Dictionary"))])
1187            )
1188            .unwrap()
1189            .is_true());
1190        // rule: f(x: {id: 1}), rule_type: f(x: Dictionary{id: 1}) => PASS
1191        assert!(kb
1192            .rule_params_match(
1193                &rule!("f", ["x"; btreemap! {sym!("id") => term!(1)}]),
1194                &rule!(
1195                    "f",
1196                    ["x"; instance!(sym!("Dictionary"), btreemap! {sym!("id") => term!(1)})]
1197                )
1198            )
1199            .unwrap()
1200            .is_true());
1201
1202        // RULE PATTERN SPEC, TEMPLATE VALUE SPEC
1203        // always => FAIL
1204        assert!(!kb
1205            .rule_params_match(
1206                &rule!("f", ["x"; btreemap!(sym!("1") => term!(1))]),
1207                &rule!("f", ["x"; value!(1)])
1208            )
1209            .unwrap()
1210            .is_true());
1211
1212        // BOTH VALUE SPEC
1213        // Integer, String, Boolean: must be equal
1214        // rule: f(x: 1), rule_type: f(x: 1) => PASS
1215        assert!(kb
1216            .rule_params_match(&rule!("f", ["x"; value!(1)]), &rule!("f", ["x"; value!(1)]))
1217            .unwrap()
1218            .is_true());
1219        // rule: f(x: 1), rule_type: f(x: 2) => FAIL
1220        assert!(!kb
1221            .rule_params_match(&rule!("f", ["x"; value!(1)]), &rule!("f", ["x"; value!(2)]))
1222            .unwrap()
1223            .is_true());
1224        // rule: f(x: 1.0), rule_type: f(x: 1.0) => PASS
1225        assert!(kb
1226            .rule_params_match(
1227                &rule!("f", ["x"; value!(1.0)]),
1228                &rule!("f", ["x"; value!(1.0)])
1229            )
1230            .unwrap()
1231            .is_true());
1232        // rule: f(x: 1.0), rule_type: f(x: 2.0) => FAIL
1233        assert!(!kb
1234            .rule_params_match(
1235                &rule!("f", ["x"; value!(1.0)]),
1236                &rule!("f", ["x"; value!(2.0)])
1237            )
1238            .unwrap()
1239            .is_true());
1240        // rule: f(x: "hi"), rule_type: f(x: "hi") => PASS
1241        assert!(kb
1242            .rule_params_match(
1243                &rule!("f", ["x"; value!("hi")]),
1244                &rule!("f", ["x"; value!("hi")])
1245            )
1246            .unwrap()
1247            .is_true());
1248        // rule: f(x: "hi"), rule_type: f(x: "hello") => FAIL
1249        assert!(!kb
1250            .rule_params_match(
1251                &rule!("f", ["x"; value!("hi")]),
1252                &rule!("f", ["x"; value!("hello")])
1253            )
1254            .unwrap()
1255            .is_true());
1256        // rule: f(x: true), rule_type: f(x: true) => PASS
1257        assert!(kb
1258            .rule_params_match(
1259                &rule!("f", ["x"; value!(true)]),
1260                &rule!("f", ["x"; value!(true)])
1261            )
1262            .unwrap()
1263            .is_true());
1264        // rule: f(x: true), rule_type: f(x: false) => PASS
1265        assert!(!kb
1266            .rule_params_match(
1267                &rule!("f", ["x"; value!(true)]),
1268                &rule!("f", ["x"; value!(false)])
1269            )
1270            .unwrap()
1271            .is_true());
1272        // List: rule must be more specific than (superset of) rule_type
1273        // rule: f(x: [1,2,3]), rule_type: f(x: [1,2]) => PASS
1274        // TODO: I'm not sure this logic actually makes sense--it feels like
1275        // they should have to be an exact match
1276        assert!(kb
1277            .rule_params_match(
1278                &rule!("f", ["x"; value!([1, 2, 3])]),
1279                &rule!("f", ["x"; value!([1, 2])])
1280            )
1281            .unwrap()
1282            .is_true());
1283        // rule: f(x: [1,2]), rule_type: f(x: [1,2,3]) => FAIL
1284        assert!(!kb
1285            .rule_params_match(
1286                &rule!("f", ["x"; value!([1, 2])]),
1287                &rule!("f", ["x"; value!([1, 2, 3])])
1288            )
1289            .unwrap()
1290            .is_true());
1291        // test with *rest vars
1292        // rule: f(x: [1, 2, 3]), rule_type: f(x: [1, 2, *rest]) => PASS
1293        assert!(kb
1294            .rule_params_match(
1295                &rule!("f", ["x"; value!([1, 2])]),
1296                &rule!(
1297                    "f",
1298                    ["x"; value!([1, 2, Value::RestVariable(sym!("*_rest"))])]
1299                )
1300            )
1301            .is_err());
1302        // Dict: rule must be more specific than (superset of) rule_type
1303        // rule: f(x: {"id": 1, "name": "Dave"}), rule_type: f(x: {"id": 1}) => PASS
1304        assert!(kb
1305            .rule_params_match(
1306                &rule!(
1307                    "f",
1308                    ["x"; btreemap! {sym!("id") => term!(1), sym!("name") => term!(sym!("Dave"))}]
1309                ),
1310                &rule!("f", ["x"; btreemap! {sym!("id") => term!(1)}]),
1311            )
1312            .unwrap()
1313            .is_true());
1314        // rule: f(x: {"id": 1}), rule_type: f(x: {"id": 1, "name": "Dave"}) => FAIL
1315        assert!(!kb
1316            .rule_params_match(
1317                &rule!("f", ["x"; btreemap! {sym!("id") => term!(1)}]),
1318                &rule!(
1319                    "f",
1320                    ["x"; btreemap! {sym!("id") => term!(1), sym!("name") => term!(sym!("Dave"))}]
1321                )
1322            )
1323            .unwrap()
1324            .is_true());
1325
1326        // RULE None SPEC TEMPLATE Some SPEC
1327        // always => FAIL
1328        assert!(!kb
1329            .rule_params_match(
1330                &rule!("f", [sym!("x")]),
1331                &rule!("f", ["x"; instance!(sym!("Foo"))])
1332            )
1333            .unwrap()
1334            .is_true());
1335
1336        // RULE Some SPEC TEMPLATE None SPEC
1337        // always => PASS
1338        assert!(kb
1339            .rule_params_match(
1340                &rule!("f", ["x"; instance!(sym!("Foo"))]),
1341                &rule!("f", [sym!("x")]),
1342            )
1343            .unwrap()
1344            .is_true());
1345    }
1346
1347    #[test]
1348    fn test_validate_rules() {
1349        let mut kb = KnowledgeBase::new();
1350        kb.register_constant(
1351            sym!("Fruit"),
1352            term!(Value::ExternalInstance(ExternalInstance {
1353                instance_id: 1,
1354                constructor: None,
1355                repr: None,
1356                class_repr: None,
1357                class_id: None,
1358            })),
1359        )
1360        .unwrap();
1361        kb.register_constant(
1362            sym!("Citrus"),
1363            term!(Value::ExternalInstance(ExternalInstance {
1364                instance_id: 2,
1365                constructor: None,
1366                repr: None,
1367                class_repr: None,
1368                class_id: None,
1369            })),
1370        )
1371        .unwrap();
1372        kb.register_constant(
1373            sym!("Orange"),
1374            term!(Value::ExternalInstance(ExternalInstance {
1375                instance_id: 3,
1376                constructor: None,
1377                repr: None,
1378                class_repr: None,
1379                class_id: None,
1380            })),
1381        )
1382        .unwrap();
1383        kb.add_mro(sym!("Fruit"), vec![1]).unwrap();
1384        // Citrus is a subclass of Fruit
1385        kb.add_mro(sym!("Citrus"), vec![2, 1]).unwrap();
1386        // Orange is a subclass of Citrus
1387        kb.add_mro(sym!("Orange"), vec![3, 2, 1]).unwrap();
1388
1389        // Rule type applies if it has the same name as a rule
1390        kb.add_rule_type(rule!("f", ["x"; instance!(sym!("Orange"))]));
1391        kb.add_rule(rule!("f", ["x"; instance!(sym!("Orange"))]));
1392        kb.add_rule(rule!("f", ["x"; instance!(sym!("Fruit"))]));
1393
1394        let diagnostics = kb.validate_rules();
1395        assert_eq!(diagnostics.len(), 1);
1396        let diagnostic = diagnostics.into_iter().next().unwrap();
1397        let error = diagnostic.unwrap_error().unwrap_validation();
1398        assert!(matches!(error, InvalidRule { .. }));
1399
1400        // Rule type does not apply if it doesn't have the same name as a rule
1401        kb.clear_rules();
1402        kb.add_rule_type(rule!("f", ["x"; instance!(sym!("Orange"))]));
1403        kb.add_rule(rule!("f", ["x"; instance!(sym!("Orange"))]));
1404        kb.add_rule(rule!("g", ["x"; instance!(sym!("Fruit"))]));
1405        assert!(kb.validate_rules().is_empty());
1406
1407        // Rule type does apply if it has the same name as a rule even if different arity
1408        kb.clear_rules();
1409        kb.add_rule_type(rule!("f", ["x"; instance!(sym!("Orange")), value!(1)]));
1410        kb.add_rule(rule!("f", ["x"; instance!(sym!("Orange"))]));
1411
1412        let diagnostic = kb.validate_rules().into_iter().next().unwrap();
1413        let error = diagnostic.unwrap_error().unwrap_validation();
1414        assert!(matches!(error, InvalidRule { .. }));
1415
1416        // Multiple templates can exist for the same name but only one needs to match
1417        kb.clear_rules();
1418        kb.add_rule_type(rule!("f", ["x"; instance!(sym!("Orange"))]));
1419        kb.add_rule_type(rule!("f", ["x"; instance!(sym!("Orange")), value!(1)]));
1420        kb.add_rule_type(rule!("f", ["x"; instance!(sym!("Fruit"))]));
1421        kb.add_rule(rule!("f", ["x"; instance!(sym!("Fruit"))]));
1422        assert!(kb.validate_rules().is_empty());
1423    }
1424
1425    #[test]
1426    fn test_rule_type_validation_errors_for_non_class_specializers() {
1427        let mut kb = KnowledgeBase::new();
1428
1429        kb.register_constant(sym!("String1"), term!("not an external instance"))
1430            .unwrap();
1431        kb.register_constant(sym!("String2"), term!("also not an external instance"))
1432            .unwrap();
1433        kb.register_constant(
1434            sym!("ExternalInstanceWithoutMRO1"),
1435            term!(Value::ExternalInstance(ExternalInstance {
1436                instance_id: 1,
1437                constructor: None,
1438                repr: None,
1439                class_repr: None,
1440                class_id: None,
1441            })),
1442        )
1443        .unwrap();
1444        kb.register_constant(
1445            sym!("ExternalInstanceWithoutMRO2"),
1446            term!(Value::ExternalInstance(ExternalInstance {
1447                instance_id: 2,
1448                constructor: None,
1449                repr: None,
1450                class_repr: None,
1451                class_id: None,
1452            })),
1453        )
1454        .unwrap();
1455        kb.register_constant(
1456            sym!("Class1"),
1457            term!(Value::ExternalInstance(ExternalInstance {
1458                instance_id: 3,
1459                constructor: None,
1460                repr: None,
1461                class_repr: None,
1462                class_id: None,
1463            })),
1464        )
1465        .unwrap();
1466        kb.add_mro(sym!("Class1"), vec![3]).unwrap();
1467        kb.register_constant(
1468            sym!("Class2"),
1469            term!(Value::ExternalInstance(ExternalInstance {
1470                instance_id: 4,
1471                constructor: None,
1472                repr: None,
1473                class_repr: None,
1474                class_id: None,
1475            })),
1476        )
1477        .unwrap();
1478        kb.add_mro(sym!("Class2"), vec![4]).unwrap();
1479
1480        // Same unregistered specializer.
1481        kb.add_rule_type(rule!("f", ["_"; instance!("Unregistered")]));
1482        kb.add_rule(rule!("f", ["_"; instance!("Unregistered")]));
1483        assert!(kb.validate_rules().is_empty());
1484
1485        // Different unregistered specializers.
1486        kb.clear_rules();
1487        kb.add_rule_type(rule!("f", ["_"; instance!("Unregistered1")]));
1488        kb.add_rule(rule!("f", ["_"; instance!("Unregistered2")]));
1489        let diagnostics = kb.validate_rules();
1490        assert_eq!(diagnostics.len(), 1);
1491        let diagnostic = diagnostics.first().unwrap().to_string();
1492        assert_eq!(diagnostic, "Unregistered class: Unregistered1");
1493
1494        // Same specializer registered as a non-instance constant.
1495        kb.clear_rules();
1496        kb.add_rule_type(rule!("f", ["_"; instance!("String1")]));
1497        kb.add_rule(rule!("f", ["_"; instance!("String1")]));
1498        assert!(kb.validate_rules().is_empty());
1499
1500        // Different specializers registered as non-instance constants.
1501        kb.clear_rules();
1502        kb.add_rule_type(rule!("f", ["_"; instance!("String1")]));
1503        kb.add_rule(rule!("f", ["_"; instance!("String2")]));
1504        let diagnostics = kb.validate_rules();
1505        assert_eq!(diagnostics.len(), 1);
1506        let diagnostic = diagnostics.first().unwrap().to_string();
1507        let expected = "Rule type specializer String1 on parameter 1 should be a registered class, but instead it's registered as a constant with value: \"not an external instance\"";
1508        assert!(diagnostic.contains(expected), "{}", diagnostic);
1509
1510        // Same specializer registered as an external instance without an MRO.
1511        kb.clear_rules();
1512        kb.add_rule_type(rule!("f", ["_"; instance!("ExternalInstanceWithoutMRO1")]));
1513        kb.add_rule(rule!("f", ["_"; instance!("ExternalInstanceWithoutMRO1")]));
1514        assert!(kb.validate_rules().is_empty());
1515
1516        // Different specializers registered as external instances without MROs.
1517        kb.clear_rules();
1518        kb.add_rule_type(rule!("f", ["_"; instance!("ExternalInstanceWithoutMRO1")]));
1519        kb.add_rule(rule!("f", ["_"; instance!("ExternalInstanceWithoutMRO2")]));
1520        let diagnostics = kb.validate_rules();
1521        assert_eq!(diagnostics.len(), 1);
1522        let diagnostic = diagnostics.first().unwrap().to_string();
1523        let expected = "Rule specializer ExternalInstanceWithoutMRO2 on parameter 1 is not registered as a class.";
1524        assert!(diagnostic.contains(expected), "{}", diagnostic);
1525
1526        // Same specializer registered as a class.
1527        kb.clear_rules();
1528        kb.add_rule_type(rule!("f", ["_"; instance!("Class1")]));
1529        kb.add_rule(rule!("f", ["_"; instance!("Class1")]));
1530        assert!(kb.validate_rules().is_empty());
1531
1532        // Different specializers registered as classes.
1533        kb.clear_rules();
1534        kb.add_rule_type(rule!("f", ["_"; instance!("Class1")]));
1535        kb.add_rule(rule!("f", ["_"; instance!("Class2")]));
1536        let diagnostics = kb.validate_rules();
1537        assert_eq!(diagnostics.len(), 1);
1538        let diagnostic = diagnostics.first().unwrap().to_string();
1539        let expected =
1540            "Rule specializer Class2 on parameter 1 must match rule type specializer Class1";
1541        assert!(diagnostic.contains(expected), "{}", diagnostic);
1542
1543        // Rule type specializer: unregistered
1544        // Rule specializer: non-instance constant
1545        kb.clear_rules();
1546        kb.add_rule_type(rule!("f", ["_"; instance!("Unregistered")]));
1547        kb.add_rule(rule!("f", ["_"; instance!("String1")]));
1548        let diagnostics = kb.validate_rules();
1549        assert_eq!(diagnostics.len(), 1);
1550        let diagnostic = diagnostics.first().unwrap().to_string();
1551        assert_eq!(diagnostic, "Unregistered class: Unregistered");
1552
1553        // Rule type specializer: non-instance constant
1554        // Rule specializer: unregistered
1555        kb.clear_rules();
1556        kb.add_rule_type(rule!("f", ["_"; instance!("String1")]));
1557        kb.add_rule(rule!("f", ["_"; instance!("Unregistered")]));
1558        let diagnostics = kb.validate_rules();
1559        assert_eq!(diagnostics.len(), 1);
1560        let diagnostic = diagnostics.first().unwrap().to_string();
1561        let expected = "Rule type specializer String1 on parameter 1 should be a registered class, but instead it's registered as a constant with value: \"not an external instance\"";
1562        assert!(diagnostic.contains(expected), "{}", diagnostic);
1563
1564        // Rule type specializer: external instance w/o MRO
1565        // Rule specializer: unregistered
1566        kb.clear_rules();
1567        kb.add_rule_type(rule!("f", ["_"; instance!("ExternalInstanceWithoutMRO1")]));
1568        kb.add_rule(rule!("f", ["_"; instance!("Unregistered")]));
1569        let diagnostics = kb.validate_rules();
1570        assert_eq!(diagnostics.len(), 1);
1571        let diagnostic = diagnostics.first().unwrap().to_string();
1572        let expected = "Rule specializer Unregistered on parameter 1 is not registered as a class.";
1573        assert!(diagnostic.contains(expected), "{}", diagnostic);
1574
1575        // Rule type specializer: external instance w/o MRO
1576        // Rule specializer: class
1577        kb.clear_rules();
1578        kb.add_rule_type(rule!("f", ["_"; instance!("ExternalInstanceWithoutMRO1")]));
1579        kb.add_rule(rule!("f", ["_"; instance!("Class1")]));
1580        let diagnostics = kb.validate_rules();
1581        assert_eq!(diagnostics.len(), 1);
1582        let diagnostic = diagnostics.first().unwrap().to_string();
1583        let expected = "Rule specializer Class1 on parameter 1 must match rule type specializer ExternalInstanceWithoutMRO1";
1584        assert!(diagnostic.contains(expected), "{}", diagnostic);
1585    }
1586}