1use std::collections::{BTreeSet, HashMap};
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5
6use super::sources::{Context, Source, SourceInfo};
7use super::terms::*;
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
10pub struct Parameter {
11 pub parameter: Term,
12 pub specializer: Option<Term>,
13}
14
15impl Parameter {
16 pub fn is_ground(&self) -> bool {
17 self.specializer.is_none() && self.parameter.value().is_ground()
18 }
19}
20
21#[derive(Debug, Clone, Deserialize, Serialize)]
22pub struct Rule {
23 pub name: Symbol,
24 pub params: Vec<Parameter>,
25 pub body: Term,
26 #[serde(skip, default = "SourceInfo::ffi")]
27 pub source_info: SourceInfo,
28 pub required: bool,
31}
32
33impl PartialEq for Rule {
34 fn eq(&self, other: &Self) -> bool {
35 self.name == other.name
36 && self.params.len() == other.params.len()
37 && self.params == other.params
38 && self.body == other.body
39 }
40}
41
42impl Rule {
43 pub fn is_ground(&self) -> bool {
44 self.params.iter().all(|p| p.is_ground())
45 }
46
47 pub(crate) fn parsed_context(&self) -> Option<&Context> {
48 if let SourceInfo::Parser(context) = &self.source_info {
49 Some(context)
50 } else {
51 None
52 }
53 }
54
55 pub fn new_from_test(name: Symbol, params: Vec<Parameter>, body: Term) -> Self {
56 Self {
57 name,
58 params,
59 body,
60 source_info: SourceInfo::Test,
61 required: false,
62 }
63 }
64
65 pub fn new_from_parser(
67 source: Arc<Source>,
68 left: usize,
69 right: usize,
70 name: Symbol,
71 params: Vec<Parameter>,
72 body: Term,
73 ) -> Self {
74 Self {
75 name,
76 params,
77 body,
78 source_info: SourceInfo::parser(source, left, right),
79 required: false,
80 }
81 }
82}
83
84pub struct RuleTypes(HashMap<Symbol, Vec<Rule>>);
86
87impl Default for RuleTypes {
88 fn default() -> Self {
89 let mut rule_types = Self(HashMap::new());
90 rule_types.add_default_rule_types();
91 rule_types
92 }
93}
94
95impl RuleTypes {
96 fn add_default_rule_types(&mut self) {
97 self.add(rule!("has_permission", ["actor"; instance!(sym!("Actor")), "_permission"; instance!(sym!("String")), "resource"; instance!(sym!("Resource"))]));
99 self.add(rule!(
101 "allow",
102 [sym!("actor"), sym!("_action"), sym!("resource")]
103 ));
104 self.add(rule!(
106 "allow_field",
107 [
108 sym!("actor"),
109 sym!("action"),
110 sym!("resource"),
111 sym!("field")
112 ]
113 ));
114 self.add(rule!("allow_request", [sym!("actor"), sym!("request")]));
116 }
117
118 pub fn get(&self, name: &Symbol) -> Option<&Vec<Rule>> {
119 self.0.get(name)
120 }
121
122 pub fn add(&mut self, rule_type: Rule) {
123 let name = rule_type.name.clone();
124 let rule_types = self.0.entry(name).or_default();
126 rule_types.push(rule_type);
127 }
128
129 pub fn reset(&mut self) {
130 self.0.clear();
131 self.add_default_rule_types()
132 }
133
134 pub fn required_rule_types(&self) -> Vec<&Rule> {
135 self.0
136 .values()
137 .flatten()
138 .filter(|rule_type| rule_type.required)
139 .collect()
140 }
141}
142
143pub type Rules = Vec<Arc<Rule>>;
144
145type RuleSet = BTreeSet<u64>;
146
147#[derive(Clone, Default, Debug)]
148struct RuleIndex {
149 rules: RuleSet,
150 index: HashMap<Option<Value>, RuleIndex>,
151}
152
153impl RuleIndex {
154 pub fn index_rule(&mut self, rule_id: u64, params: &[Parameter], i: usize) {
155 if i < params.len() {
156 self.index
157 .entry({
158 if params[i].is_ground() {
159 Some(params[i].parameter.value().clone())
160 } else {
161 None
162 }
163 })
164 .or_default()
165 .index_rule(rule_id, params, i + 1);
166 } else {
167 self.rules.insert(rule_id);
168 }
169 }
170
171 #[allow(clippy::comparison_chain)]
172 pub fn get_applicable_rules(&self, args: &[Term], i: usize) -> RuleSet {
173 if i < args.len() {
174 let filter_next_args =
176 |index: &RuleIndex| -> RuleSet { index.get_applicable_rules(args, i + 1) };
177 let arg = args[i].value();
178 if arg.is_ground() {
179 let mut ruleset = self
181 .index
182 .get(&Some(arg.clone()))
183 .map(filter_next_args)
184 .unwrap_or_default();
185
186 if let Some(index) = self.index.get(&None) {
188 ruleset.extend(filter_next_args(index));
189 }
190 ruleset
191 } else {
192 self.index.values().fold(
194 RuleSet::default(),
195 |mut result: RuleSet, index: &RuleIndex| {
196 result.extend(filter_next_args(index));
197 result
198 },
199 )
200 }
201 } else {
202 self.rules.clone()
204 }
205 }
206}
207
208#[derive(Clone)]
209pub struct GenericRule {
210 pub name: Symbol,
211 pub rules: HashMap<u64, Arc<Rule>>,
212 index: RuleIndex,
213 next_rule_id: u64,
214}
215
216impl GenericRule {
217 pub fn new(name: Symbol, rules: Rules) -> Self {
218 let mut generic_rule = Self {
219 name,
220 rules: Default::default(),
221 index: Default::default(),
222 next_rule_id: 0,
223 };
224
225 for rule in rules {
226 generic_rule.add_rule(rule);
227 }
228
229 generic_rule
230 }
231
232 pub fn add_rule(&mut self, rule: Arc<Rule>) {
233 let rule_id = self.next_rule_id();
234
235 assert!(
236 self.rules.insert(rule_id, rule.clone()).is_none(),
237 "Rule id already used."
238 );
239 self.index.index_rule(rule_id, &rule.params[..], 0);
240 }
241
242 #[allow(clippy::ptr_arg)]
243 pub fn get_applicable_rules(&self, args: &TermList) -> Rules {
244 self.index
245 .get_applicable_rules(args, 0)
246 .iter()
247 .map(|id| self.rules.get(id).expect("Rule missing"))
248 .cloned()
249 .collect()
250 }
251
252 fn next_rule_id(&mut self) -> u64 {
253 let v = self.next_rule_id;
254 self.next_rule_id += 1;
255 v
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use std::collections::HashSet;
262
263 use super::*;
264 use crate::polar::Polar;
265
266 #[test]
267 fn test_rule_index() {
268 let polar = Polar::new();
269 polar
270 .load_str(
271 r#"
272 f(1, 1, "x");
273 f(1, 1, "y");
274 f(1, x, "y") if x = 2;
275 f(1, 2, {b: "y"});
276 f(1, 3, {c: "z"});
277 "#,
278 )
279 .unwrap();
280
281 let kb = polar.kb.read().unwrap();
282 let generic_rule = kb.get_generic_rule(&sym!("f")).unwrap();
283 let index = &generic_rule.index;
284 assert!(index.rules.is_empty());
285
286 fn keys(index: &RuleIndex) -> HashSet<Option<Value>> {
287 index.index.keys().cloned().collect()
288 }
289
290 let mut args = HashSet::<Option<Value>>::new();
291
292 args.clear();
293 args.insert(Some(value!(1)));
294 assert_eq!(args, keys(index));
295
296 args.clear();
297 args.insert(None); args.insert(Some(value!(1)));
299 args.insert(Some(value!(2)));
300 args.insert(Some(value!(3)));
301 let index1 = index.index.get(&Some(value!(1))).unwrap();
302 assert_eq!(args, keys(index1));
303
304 args.clear();
305 args.insert(Some(value!("x")));
306 args.insert(Some(value!("y")));
307 let index11 = index1.index.get(&Some(value!(1))).unwrap();
308 assert_eq!(args, keys(index11));
309
310 args.remove(&Some(value!("x")));
311 let index1_ = index1.index.get(&None).unwrap();
312 assert_eq!(args, keys(index1_));
313
314 args.clear();
315 args.insert(Some(value!(btreemap! {sym!("b") => term!("y")})));
316 let index12 = index1.index.get(&Some(value!(2))).unwrap();
317 assert_eq!(args, keys(index12));
318
319 args.clear();
320 args.insert(Some(value!(btreemap! {sym!("c") => term!("z")})));
321 let index13 = index1.index.get(&Some(value!(3))).unwrap();
322 assert_eq!(args, keys(index13));
323 }
324}