polar_core/
rules.rs

1use std::collections::{BTreeSet, HashMap};
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5
6use super::sources::{Context, Source, SourceInfo};
7use super::terms::*;
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
10pub struct Parameter {
11    pub parameter: Term,
12    pub specializer: Option<Term>,
13}
14
15impl Parameter {
16    pub fn is_ground(&self) -> bool {
17        self.specializer.is_none() && self.parameter.value().is_ground()
18    }
19}
20
21#[derive(Debug, Clone, Deserialize, Serialize)]
22pub struct Rule {
23    pub name: Symbol,
24    pub params: Vec<Parameter>,
25    pub body: Term,
26    #[serde(skip, default = "SourceInfo::ffi")]
27    pub source_info: SourceInfo,
28    // TODO @patrickod: refactor Rule into Rule & RuleType structs
29    // `required` is used exclusively with rule *types* and not normal rules.
30    pub required: bool,
31}
32
33impl PartialEq for Rule {
34    fn eq(&self, other: &Self) -> bool {
35        self.name == other.name
36            && self.params.len() == other.params.len()
37            && self.params == other.params
38            && self.body == other.body
39    }
40}
41
42impl Rule {
43    pub fn is_ground(&self) -> bool {
44        self.params.iter().all(|p| p.is_ground())
45    }
46
47    pub(crate) fn parsed_context(&self) -> Option<&Context> {
48        if let SourceInfo::Parser(context) = &self.source_info {
49            Some(context)
50        } else {
51            None
52        }
53    }
54
55    pub fn new_from_test(name: Symbol, params: Vec<Parameter>, body: Term) -> Self {
56        Self {
57            name,
58            params,
59            body,
60            source_info: SourceInfo::Test,
61            required: false,
62        }
63    }
64
65    /// Creates a new rule from the parser
66    pub fn new_from_parser(
67        source: Arc<Source>,
68        left: usize,
69        right: usize,
70        name: Symbol,
71        params: Vec<Parameter>,
72        body: Term,
73    ) -> Self {
74        Self {
75            name,
76            params,
77            body,
78            source_info: SourceInfo::parser(source, left, right),
79            required: false,
80        }
81    }
82}
83
84// TODO: should this be a Set of Rules? Do we currently check for duplicate rules?
85pub struct RuleTypes(HashMap<Symbol, Vec<Rule>>);
86
87impl Default for RuleTypes {
88    fn default() -> Self {
89        let mut rule_types = Self(HashMap::new());
90        rule_types.add_default_rule_types();
91        rule_types
92    }
93}
94
95impl RuleTypes {
96    fn add_default_rule_types(&mut self) {
97        // type has_permission(actor: Actor, permission: String, resource: Resource);
98        self.add(rule!("has_permission", ["actor"; instance!(sym!("Actor")), "_permission"; instance!(sym!("String")), "resource"; instance!(sym!("Resource"))]));
99        // type allow(actor, action, resource);
100        self.add(rule!(
101            "allow",
102            [sym!("actor"), sym!("_action"), sym!("resource")]
103        ));
104        // type allow_field(actor, action, resource, field);
105        self.add(rule!(
106            "allow_field",
107            [
108                sym!("actor"),
109                sym!("action"),
110                sym!("resource"),
111                sym!("field")
112            ]
113        ));
114        // type allow_request(actor, request);"#;
115        self.add(rule!("allow_request", [sym!("actor"), sym!("request")]));
116    }
117
118    pub fn get(&self, name: &Symbol) -> Option<&Vec<Rule>> {
119        self.0.get(name)
120    }
121
122    pub fn add(&mut self, rule_type: Rule) {
123        let name = rule_type.name.clone();
124        // get rule types with this rule name
125        let rule_types = self.0.entry(name).or_default();
126        rule_types.push(rule_type);
127    }
128
129    pub fn reset(&mut self) {
130        self.0.clear();
131        self.add_default_rule_types()
132    }
133
134    pub fn required_rule_types(&self) -> Vec<&Rule> {
135        self.0
136            .values()
137            .flatten()
138            .filter(|rule_type| rule_type.required)
139            .collect()
140    }
141}
142
143pub type Rules = Vec<Arc<Rule>>;
144
145type RuleSet = BTreeSet<u64>;
146
147#[derive(Clone, Default, Debug)]
148struct RuleIndex {
149    rules: RuleSet,
150    index: HashMap<Option<Value>, RuleIndex>,
151}
152
153impl RuleIndex {
154    pub fn index_rule(&mut self, rule_id: u64, params: &[Parameter], i: usize) {
155        if i < params.len() {
156            self.index
157                .entry({
158                    if params[i].is_ground() {
159                        Some(params[i].parameter.value().clone())
160                    } else {
161                        None
162                    }
163                })
164                .or_default()
165                .index_rule(rule_id, params, i + 1);
166        } else {
167            self.rules.insert(rule_id);
168        }
169    }
170
171    #[allow(clippy::comparison_chain)]
172    pub fn get_applicable_rules(&self, args: &[Term], i: usize) -> RuleSet {
173        if i < args.len() {
174            // Check this argument and recurse on the rest.
175            let filter_next_args =
176                |index: &RuleIndex| -> RuleSet { index.get_applicable_rules(args, i + 1) };
177            let arg = args[i].value();
178            if arg.is_ground() {
179                // Check the index for a ground argument.
180                let mut ruleset = self
181                    .index
182                    .get(&Some(arg.clone()))
183                    .map(filter_next_args)
184                    .unwrap_or_default();
185
186                // Extend for a variable parameter.
187                if let Some(index) = self.index.get(&None) {
188                    ruleset.extend(filter_next_args(index));
189                }
190                ruleset
191            } else {
192                // Accumulate all indexed arguments.
193                self.index.values().fold(
194                    RuleSet::default(),
195                    |mut result: RuleSet, index: &RuleIndex| {
196                        result.extend(filter_next_args(index));
197                        result
198                    },
199                )
200            }
201        } else {
202            // No more arguments.
203            self.rules.clone()
204        }
205    }
206}
207
208#[derive(Clone)]
209pub struct GenericRule {
210    pub name: Symbol,
211    pub rules: HashMap<u64, Arc<Rule>>,
212    index: RuleIndex,
213    next_rule_id: u64,
214}
215
216impl GenericRule {
217    pub fn new(name: Symbol, rules: Rules) -> Self {
218        let mut generic_rule = Self {
219            name,
220            rules: Default::default(),
221            index: Default::default(),
222            next_rule_id: 0,
223        };
224
225        for rule in rules {
226            generic_rule.add_rule(rule);
227        }
228
229        generic_rule
230    }
231
232    pub fn add_rule(&mut self, rule: Arc<Rule>) {
233        let rule_id = self.next_rule_id();
234
235        assert!(
236            self.rules.insert(rule_id, rule.clone()).is_none(),
237            "Rule id already used."
238        );
239        self.index.index_rule(rule_id, &rule.params[..], 0);
240    }
241
242    #[allow(clippy::ptr_arg)]
243    pub fn get_applicable_rules(&self, args: &TermList) -> Rules {
244        self.index
245            .get_applicable_rules(args, 0)
246            .iter()
247            .map(|id| self.rules.get(id).expect("Rule missing"))
248            .cloned()
249            .collect()
250    }
251
252    fn next_rule_id(&mut self) -> u64 {
253        let v = self.next_rule_id;
254        self.next_rule_id += 1;
255        v
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use std::collections::HashSet;
262
263    use super::*;
264    use crate::polar::Polar;
265
266    #[test]
267    fn test_rule_index() {
268        let polar = Polar::new();
269        polar
270            .load_str(
271                r#"
272            f(1, 1, "x");
273            f(1, 1, "y");
274            f(1, x, "y") if x = 2;
275            f(1, 2, {b: "y"});
276            f(1, 3, {c: "z"});
277        "#,
278            )
279            .unwrap();
280
281        let kb = polar.kb.read().unwrap();
282        let generic_rule = kb.get_generic_rule(&sym!("f")).unwrap();
283        let index = &generic_rule.index;
284        assert!(index.rules.is_empty());
285
286        fn keys(index: &RuleIndex) -> HashSet<Option<Value>> {
287            index.index.keys().cloned().collect()
288        }
289
290        let mut args = HashSet::<Option<Value>>::new();
291
292        args.clear();
293        args.insert(Some(value!(1)));
294        assert_eq!(args, keys(index));
295
296        args.clear();
297        args.insert(None); // x
298        args.insert(Some(value!(1)));
299        args.insert(Some(value!(2)));
300        args.insert(Some(value!(3)));
301        let index1 = index.index.get(&Some(value!(1))).unwrap();
302        assert_eq!(args, keys(index1));
303
304        args.clear();
305        args.insert(Some(value!("x")));
306        args.insert(Some(value!("y")));
307        let index11 = index1.index.get(&Some(value!(1))).unwrap();
308        assert_eq!(args, keys(index11));
309
310        args.remove(&Some(value!("x")));
311        let index1_ = index1.index.get(&None).unwrap();
312        assert_eq!(args, keys(index1_));
313
314        args.clear();
315        args.insert(Some(value!(btreemap! {sym!("b") => term!("y")})));
316        let index12 = index1.index.get(&Some(value!(2))).unwrap();
317        assert_eq!(args, keys(index12));
318
319        args.clear();
320        args.insert(Some(value!(btreemap! {sym!("c") => term!("z")})));
321        let index13 = index1.index.get(&Some(value!(3))).unwrap();
322        assert_eq!(args, keys(index13));
323    }
324}