ai_kit/infer/
mod.rs

1//! The infer module implements basic forward chaining inference by applying any applicable Operations to a vector of Unifys.
2
3use constraints::ConstraintValue;
4use core::{Operation, Bindings, BindingsValue, Unify};
5use pedigree::{Origin, Pedigree, RenderType};
6use planner::{Goal, ConjunctivePlanner, PlanningConfig};
7use serde_json;
8use std;
9use std::collections::{BTreeMap, BTreeSet};
10use std::collections::HashMap;
11use std::marker::PhantomData;
12use utils;
13
14#[derive(Clone, Debug, Deserialize, PartialEq, PartialOrd, Serialize)]
15pub struct Negatable<B: BindingsValue, U: Unify<B>> {
16    content: U,
17    #[serde(default)]
18    is_negative: bool,
19    #[serde(default)]
20    _marker: PhantomData<B>,
21}
22
23impl<B, U> Eq for Negatable<B, U>
24    where B: BindingsValue,
25          U: Unify<B>
26{
27}
28
29impl<B, U> std::fmt::Display for Negatable<B, U>
30    where B: BindingsValue,
31          U: Unify<B>
32{
33    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34        write!(f, "{}", serde_json::to_string(&self).unwrap())
35    }
36}
37
38impl<B, U> Unify<B> for Negatable<B, U>
39    where B: BindingsValue,
40          U: Unify<B>
41{
42    fn unify(&self, other: &Self, bindings: &Bindings<B>) -> Option<Bindings<B>> {
43        self.content.unify(&other.content, bindings)
44    }
45    fn apply_bindings(&self, bindings: &Bindings<B>) -> Option<Self> {
46        self.content.apply_bindings(bindings).and_then(|bound_content| {
47            Some(Negatable {
48                content: bound_content,
49                is_negative: self.is_negative,
50                _marker: PhantomData,
51            })
52        })
53    }
54    fn variables(&self) -> Vec<String> {
55        self.content.variables()
56    }
57    fn rename_variables(&self, renamed_variables: &HashMap<String, String>) -> Self {
58        Negatable {
59            content: self.content.rename_variables(renamed_variables),
60            is_negative: self.is_negative,
61            _marker: PhantomData,
62        }
63    }
64    fn nil() -> Self {
65        Negatable {
66            content: U::nil(),
67            is_negative: false,
68            _marker: PhantomData,
69        }
70    }
71}
72
73#[derive(Clone, Debug, PartialEq)]
74pub struct OriginCache {
75    items: BTreeSet<Origin>,
76}
77
78impl OriginCache {
79    pub fn new() -> Self {
80        OriginCache { items: BTreeSet::new() }
81    }
82
83    pub fn has_item(&self, item: &Origin) -> bool {
84        self.items.contains(item)
85    }
86
87    pub fn insert_item_mut(&mut self, item: Origin) {
88        self.items.insert(item);
89    }
90}
91
92#[derive(Clone, Debug, PartialEq)]
93pub struct InferenceEngine<'a, T, U, A>
94    where T: 'a + ConstraintValue,
95          U: 'a + Unify<T>,
96          A: 'a + Operation<T, U>
97{
98    pub rules: Vec<(&'a String, &'a A)>,
99    pub facts: Vec<(&'a String, &'a U)>,
100    // Facts derived from this inference process
101    pub derived_facts: Vec<(String, U)>,
102    pub pedigree: Pedigree,
103    pub prefix: String,
104    // Used to check if an inference has already been performed,
105    // allowing us to short-circuit a potentially expensive unification process.
106    pub origin_cache: OriginCache,
107    _marker: PhantomData<T>,
108}
109
110impl<'a, T, U, A> InferenceEngine<'a, T, U, A>
111    where T: 'a + ConstraintValue,
112          U: 'a + Unify<T>,
113          A: 'a + Operation<T, U>
114{
115    pub fn new(prefix: String, rules: Vec<(&'a String, &'a A)>, facts: Vec<(&'a String, &'a U)>) -> Self {
116        InferenceEngine {
117            rules: rules,
118            facts: facts,
119            derived_facts: Vec::new(),
120            pedigree: Pedigree::new(),
121            prefix: prefix,
122            origin_cache: OriginCache::new(),
123            _marker: PhantomData,
124        }
125    }
126
127    pub fn all_facts(&'a self) -> Vec<(&'a String, &'a U)> {
128        self.derived_facts
129            .iter()
130            .map(|&(ref id, ref f)| (id, f))
131            .chain(self.facts.iter().map(|&(id, f)| (id, f)))
132            .collect()
133    }
134
135    pub fn chain_until_match(&self, max_iterations: usize, goal: &U) -> (Option<(U, String)>, Self) {
136        self.chain_until(max_iterations,
137                         &|f| goal.unify(f, &Bindings::new()).is_some())
138    }
139
140    pub fn chain_until(&self, max_iterations: usize, satisfied: &Fn(&U) -> bool) -> (Option<(U, String)>, Self) {
141        let mut engine = self.clone();
142        let mut target: Option<(U, String)> = None;
143        for _idx in 0..max_iterations {
144            for (fact, _bindings, origin) in engine.chain_forward().into_iter() {
145                let id = engine.construct_id(&fact);
146
147                if satisfied(&fact) {
148                    target = Some((fact.clone(), id.clone()));
149                }
150
151                engine.pedigree.insert_mut(id.clone(), origin);
152                engine.derived_facts.push((id, fact));
153            }
154            if target.is_some() {
155                break;
156            }
157        }
158        (target, engine)
159    }
160
161    pub fn chain_forward(&mut self) -> Vec<(U, Bindings<T>, Origin)> {
162        let mut origin_cache = self.origin_cache.clone();
163        let results = chain_forward(self.all_facts(), self.rules.clone(), &mut origin_cache);
164        self.origin_cache = origin_cache;
165        results
166    }
167
168    fn construct_id(&self, _fact: &U) -> String {
169        format!("{}-{}", self.prefix, self.derived_facts.len())
170    }
171
172    pub fn render_inference_tree(&'a self, id: &String, render_type: RenderType) -> String {
173        let all_facts_map: BTreeMap<&'a String, &'a U> = self.all_facts().into_iter().collect();
174        let rule_map: BTreeMap<&'a String, &'a A> = self.rules.clone().into_iter().collect();
175
176        let node_renderer = |x| {
177            all_facts_map.get(&x)
178                .and_then(|y| Some(format!("{}", y)))
179                .or_else(|| rule_map.get(&x).and_then(|y| Some(format!("{}", y))))
180                .unwrap_or(format!("{}?", x))
181        };
182
183        self.pedigree.render_inference_tree(id,
184                                            &node_renderer,
185                                            &node_renderer,
186                                            &|x, _y| x.clone(),
187                                            render_type)
188    }
189}
190
191pub fn chain_forward<T, U, A>(facts: Vec<(&String, &U)>, rules: Vec<(&String, &A)>, origin_cache: &mut OriginCache) -> Vec<(U, Bindings<T>, Origin)>
192    where T: ConstraintValue,
193          U: Unify<T>,
194          A: Operation<T, U>
195{
196    let mut derived_facts: Vec<(U, Bindings<T>, Origin)> = Vec::new();
197    let just_the_facts: Vec<&U> = facts.iter().map(|&(_id, u)| u).collect();
198
199    for (ref rule_id, ref rule) in rules.into_iter() {
200        let planner: ConjunctivePlanner<T, U, A> = ConjunctivePlanner::new(rule.input_patterns().into_iter().map(Goal::with_pattern).collect(),
201                                                                           &Bindings::new(),
202                                                                           &PlanningConfig::default(),
203                                                                           just_the_facts.clone(),
204                                                                           Vec::new());
205        let application_successful =
206            |(input_goals, bindings): (Vec<Goal<T, U, A>>, Bindings<T>)| -> Option<(Vec<Goal<T, U, A>>, Vec<U>, Bindings<T>)> {
207                let bound_input_goals: Vec<Goal<T, U, A>> =
208                    input_goals.iter().map(|input_goal| input_goal.apply_bindings(&bindings).expect("Should be applicable")).collect();
209                rule.apply_match(&bindings).and_then(|new_facts| Some((bound_input_goals, new_facts, bindings)))
210            };
211
212        for (matched_inputs, new_facts, bindings) in planner.filter_map(application_successful) {
213            let fact_ids: Vec<String> = extract_datum_indexes(&matched_inputs).iter().map(|idx| facts[*idx].0.clone()).collect();
214            let origin = Origin {
215                source_id: (*rule_id).clone(),
216                args: fact_ids,
217            };
218            if origin_cache.has_item(&origin) {
219                continue;
220            } else {
221                origin_cache.insert_item_mut(origin.clone());
222            }
223            for new_fact in new_facts {
224                if is_new_fact(&new_fact, &facts) {
225                    derived_facts.push((new_fact, bindings.clone(), origin.clone()))
226                }
227            }
228        }
229    }
230    derived_facts
231}
232
233pub fn chain_forward_with_negative_goals<T, IU, A>(facts: Vec<(&String, &Negatable<T, IU>)>,
234                                                   rules: Vec<(&String, &A)>,
235                                                   origin_cache: &mut OriginCache)
236                                                   -> Vec<(Negatable<T, IU>, Bindings<T>, Origin)>
237    where T: ConstraintValue,
238          IU: Unify<T>,
239          A: Operation<T, Negatable<T, IU>>
240{
241    let mut derived_facts: Vec<(Negatable<T, IU>, Bindings<T>, Origin)> = Vec::new();
242    let just_the_facts: Vec<&Negatable<T, IU>> = facts.iter().map(|&(_id, u)| u).collect();
243
244    for (ref rule_id, ref rule) in rules.into_iter() {
245        let (negative_inputs, positive_inputs): (Vec<Negatable<T, IU>>, Vec<Negatable<T, IU>>) =
246            rule.input_patterns().into_iter().partition(|input| input.is_negative);
247        let planner: ConjunctivePlanner<T, Negatable<T, IU>, A> =
248            ConjunctivePlanner::new(positive_inputs.into_iter().map(Goal::with_pattern).collect(),
249                                    &Bindings::new(),
250                                    &PlanningConfig::default(),
251                                    just_the_facts.clone().into_iter().collect(),
252                                    Vec::new());
253
254        let negative_patterns_are_satisfied = |(input_goals, bindings)| {
255            utils::map_while_some(&mut negative_inputs.iter(),
256                                  &|pattern| pattern.apply_bindings(&bindings))
257                .and_then(|bound_negative_patterns| if any_patterns_match(&bound_negative_patterns.iter().collect(), &just_the_facts) {
258                    None
259                } else {
260                    Some((input_goals, bindings))
261                })
262        };
263        let application_successful =
264            |(input_goals, bindings)| rule.apply_match(&bindings).and_then(|new_facts| Some((input_goals, new_facts, bindings)));
265
266        for (matched_inputs, new_facts, bindings) in planner.filter_map(negative_patterns_are_satisfied).filter_map(application_successful) {
267            let fact_ids: Vec<String> = extract_datum_indexes(&matched_inputs).iter().map(|idx| facts[*idx].0.clone()).collect();
268            let origin = Origin {
269                source_id: (*rule_id).clone(),
270                args: fact_ids,
271            };
272            if origin_cache.has_item(&origin) {
273                continue;
274            } else {
275                origin_cache.insert_item_mut(origin.clone());
276            }
277            for new_fact in new_facts {
278                if is_new_fact(&new_fact, &facts) {
279                    derived_facts.push((new_fact, bindings.clone(), origin.clone()))
280                }
281            }
282        }
283    }
284    derived_facts
285}
286
287fn any_patterns_match<B, U>(patterns: &Vec<&U>, patterns2: &Vec<&U>) -> bool
288    where B: BindingsValue,
289          U: Unify<B>
290{
291    let empty_bindings: Bindings<B> = Bindings::new();
292    patterns.iter().any(|patt| patterns2.iter().any(|f| f.unify(patt, &empty_bindings).is_some()))
293}
294
295fn extract_datum_indexes<T, U, A>(goals: &Vec<Goal<T, U, A>>) -> Vec<usize>
296    where T: ConstraintValue,
297          U: Unify<T>,
298          A: Operation<T, U>
299{
300    goals.iter().map(|goal| goal.unification_index.datum_idx().expect("Only datum idx should be here!")).collect()
301}
302
303fn is_new_fact<T, U>(f: &U, facts: &Vec<(&String, &U)>) -> bool
304    where T: ConstraintValue,
305          U: Unify<T>
306{
307    for &(_id, fact) in facts.iter() {
308        if fact.unify(f, &Bindings::new()).is_some() {
309            return false;
310        }
311    }
312    return true;
313}
314
315#[cfg(test)]
316mod tests;