fuzzy_expert/
inference.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3
4use num::Float;
5
6use crate::dsl::Expr;
7use crate::inputs::Inputs;
8use crate::math::{meshgrid, Axis, CollectMatrix, Matrix};
9use crate::ops::*;
10use crate::outputs::Outputs;
11use crate::rules::Rules;
12use crate::variable::{VariableKey, Variables};
13
14pub struct DecompInference {
15    and_op: AndOp,
16    or_op: OrOp,
17    comp_op: CompositionOp,
18    imp_op: ImplicationOp,
19    prod_link: ProductionLink,
20    defuzz_op: DefuzzificationOp,
21}
22
23impl DecompInference {
24    pub fn new(
25        and_op: AndOp,
26        or_op: OrOp,
27        comp_op: CompositionOp,
28        imp_op: ImplicationOp,
29        prod_link: ProductionLink,
30        defuzz_op: DefuzzificationOp,
31    ) -> Self {
32        Self {
33            and_op,
34            or_op,
35            comp_op,
36            imp_op,
37            prod_link,
38            defuzz_op,
39        }
40    }
41
42    // TODO: Maybe can make vars and rules immutable if they're just starting state
43    // and everything else is calculated here
44    pub fn eval<T: Eq + Hash>(&self, vars: &mut Variables<T>, rules: &Rules<T>, inputs: &Inputs) -> Outputs {
45        // Convert Inputs to facts
46        // Converts input values to FIS facts
47        let mut fact_value = HashMap::with_capacity(inputs.0.len());
48        let mut fact_cf = HashMap::with_capacity(inputs.0.len());
49
50        for (key, input_value) in &inputs.0 {
51            // TODO: Fuzzy inputs
52            fact_value.insert(*key, *input_value);
53            fact_cf.insert(*key, 1.);
54        }
55
56        // #1 slowest section
57        // Fuzzificate Facts
58        // Convert crisp facts to membership functions
59        // let mut fact_types = HashMap::with_capacity(fact_value.len());
60        let mut fact_values = HashMap::with_capacity(inputs.0.len());
61
62        for (key, fact_value) in fact_value {
63            // py: (float, int?)
64            // This will eventually be an enum match
65            if true {
66                // Fuzzificate Crisp Fact
67                // TODO: Get rid of the mutability?
68                let var = &mut vars.0[key];
69
70                var.add_points_to_universe(Some(fact_value));
71                fact_values.insert(
72                    key,
73                    var.universe
74                        .iter()
75                        .map(|u| if *u == fact_value { 1.0f64 } else { 0. })
76                        .collect::<Vec<_>>(),
77                );
78                // py: list
79            } else {
80                // TODO: Fuzzificate fuzzy fact
81                unimplemented!();
82            }
83        }
84
85        // Compute Modified Premise Memberships
86        let mut modified_premise_memberships = HashMap::new();
87
88        for (i, rule) in rules.0.iter().enumerate() {
89            let premise = &rule.premise;
90
91            for (var_key, term, modifiers) in premise.propositions() {
92                let membership = vars.0[var_key].get_modified_membership(term, modifiers);
93                modified_premise_memberships.insert((i, var_key), membership);
94            }
95        }
96
97        // Compute Modified Consequence Memberships
98        let mut modified_consequence_memberships = HashMap::new();
99
100        // TODO: Can probably move the inner loop into the previous section's loop
101        for (i, rule) in rules.0.iter().enumerate() {
102            let consequence = &rule.consequence;
103
104            for (var_key, term, modifiers) in consequence.propositions() {
105                let membership = vars.0[var_key].get_modified_membership(term, modifiers);
106                modified_consequence_memberships.insert((i, var_key), membership);
107            }
108        }
109
110        // Compute Fuzzy Implication
111        let mut fuzzy_implications = HashMap::with_capacity(
112            rules.0.len() * modified_premise_memberships.len() * modified_consequence_memberships.len(),
113        );
114
115        for (i, premise_name) in modified_premise_memberships.keys() {
116            for (j, consequence_name) in modified_consequence_memberships.keys() {
117                // TODO: It'd be great if we didn't have to iterate through all other rules' consequences
118                // Maybe want to turn modified_consequence_memberships into Map<RuleId, Vec<VariableKey>>?
119                if *i != *j {
120                    continue;
121                }
122
123                let premise_membership = &modified_premise_memberships[&(*i, *premise_name)];
124                let consequence_membership = &modified_consequence_memberships[&(*i, *consequence_name)];
125                let (v, u) = meshgrid(
126                    consequence_membership.iter().copied(),
127                    premise_membership.iter().copied(),
128                );
129                let shape = v.shape();
130
131                fuzzy_implications.insert(
132                    (i, *premise_name, *consequence_name),
133                    self.imp_op.call(u, v).collect_matrix(shape),
134                );
135            }
136        }
137
138        // #2 slowest section
139        // Compute Fuzzy Composition
140        let mut fuzzy_compositions = HashMap::with_capacity(fuzzy_implications.len());
141
142        for (i, premise_name) in modified_premise_memberships.keys() {
143            for (j, consequence_name) in modified_consequence_memberships.keys() {
144                if *i != *j {
145                    continue;
146                }
147
148                let implication = &fuzzy_implications[&(i, *premise_name, *consequence_name)];
149                let fact_values = &fact_values[premise_name];
150                let n_dim = fact_values.len();
151                let fact_value = Matrix::new(fact_values.to_owned(), (n_dim, 1));
152                let fact_value = fact_value.tile((1, implication.shape().1));
153                let shape = fact_value.shape();
154
155                debug_assert_eq!(shape, implication.shape());
156
157                let composition = match self.comp_op {
158                    CompositionOp::MaxMin => ProductionLink::Min
159                        .call(fact_value, implication)
160                        .collect_matrix(shape)
161                        .max(Axis::Column),
162                    CompositionOp::MaxProd => unimplemented!("fact_value * implication"), // TODO: Matrix::from_mul(fact_value, implication)?
163                };
164
165                fuzzy_compositions.insert((*i, *premise_name, *consequence_name), composition);
166            }
167        }
168
169        // Combine Antecedents
170        let mut combined_compositions = HashMap::new();
171
172        for (i, rule) in rules.0.iter().enumerate() {
173            for (j, consequence_name) in modified_consequence_memberships.keys() {
174                if i != *j {
175                    continue;
176                }
177
178                // Originally tried to write this without collecting vecs at each layer, but
179                // recursive iterators are more or less impossible to write... Even with dyn trait.
180                // Not sure how to make this more performant
181                fn combine<T, F: Float>(
182                    expr: &Expr<T>,
183                    fuzzy_compositions: &HashMap<(usize, VariableKey, VariableKey), Vec<F>>,
184                    consequence_name: VariableKey,
185                    and_op: AndOp,
186                    or_op: OrOp,
187                    rule_id: usize,
188                ) -> Vec<F> {
189                    match expr {
190                        Expr::Is(var_key, _) => fuzzy_compositions[&(rule_id, *var_key, consequence_name)].clone(),
191                        Expr::And(expr, expr2) => {
192                            let left = combine(expr, fuzzy_compositions, consequence_name, and_op, or_op, rule_id);
193                            let right = combine(expr2, fuzzy_compositions, consequence_name, and_op, or_op, rule_id);
194
195                            and_op.call(left, right).into_iter().collect()
196                        },
197                        Expr::Or(expr, expr2) => {
198                            let left = combine(expr, fuzzy_compositions, consequence_name, and_op, or_op, rule_id);
199                            let right = combine(expr2, fuzzy_compositions, consequence_name, and_op, or_op, rule_id);
200
201                            or_op.call(left, right).into_iter().collect()
202                        },
203                    }
204                }
205
206                let combined_composition = combine(
207                    &rule.premise,
208                    &fuzzy_compositions,
209                    *consequence_name,
210                    self.and_op,
211                    self.or_op,
212                    i,
213                );
214
215                combined_compositions.insert((i, *consequence_name), combined_composition);
216            }
217        }
218
219        // Compute Rule Inferred CF
220        let mut inferred_cf = HashMap::new();
221
222        for (i, rule) in rules.0.iter().enumerate() {
223            fn calc_cf<T>(expr: &Expr<T>, fact_cf: &HashMap<VariableKey, f64>) -> f64 {
224                match expr {
225                    Expr::Is(var_key, _) => fact_cf[var_key],
226                    Expr::And(expr, expr2) => {
227                        let left = calc_cf(expr, fact_cf);
228                        let right = calc_cf(expr2, fact_cf);
229
230                        f64::min(left, right)
231                    },
232                    Expr::Or(expr, expr2) => {
233                        let left = calc_cf(expr, fact_cf);
234                        let right = calc_cf(expr2, fact_cf);
235
236                        f64::max(left, right)
237                    },
238                }
239            }
240
241            let aggregated_premise_cf = calc_cf(&rule.premise, &fact_cf);
242
243            inferred_cf.insert(i, aggregated_premise_cf * rule.cf);
244        }
245
246        // Collect Rule Memberships
247        let mut collected_rule_memberships = HashMap::new();
248
249        for (i, rule) in rules.0.iter().enumerate() {
250            for (j, var_key) in combined_compositions.keys() {
251                if i != *j {
252                    continue;
253                }
254
255                // REVIEW: Is this even necessary?
256                // if !collected_rule_memberships.contains_key(var_key) {
257                //     let universe = &vars.0[*var_key].universe;
258                //     let min = universe.iter().copied().reduce(f64::min).unwrap();
259                //     let max = universe.iter().copied().reduce(f64::max).unwrap();
260                //     let mut var = VariableContraints::<T>::new(min..=max, std::iter::empty());
261                //     var.universe = universe.clone();
262
263                //     collected_rule_memberships.insert(*var_key, var);
264                // }
265
266                if inferred_cf[&i] >= rule.threshold_cf {
267                    collected_rule_memberships
268                        .entry(*var_key)
269                        .or_insert_with(Vec::new)
270                        .push(&*combined_compositions[&(i, *var_key)]);
271                }
272            }
273        }
274
275        // Aggregate Collected Memberships
276        let mut aggregated_memberships = HashMap::new();
277
278        for (var_key, memberships) in collected_rule_memberships {
279            let mut agg = Vec::new();
280
281            for window in memberships.windows(2) {
282                match &agg[..] {
283                    [] => {
284                        agg = self
285                            .prod_link
286                            .call(window[0].iter().copied(), window[1].iter().copied())
287                            .into_iter()
288                            .collect();
289                    },
290                    [..] => {
291                        agg = self
292                            .prod_link
293                            .call(agg, window[1].iter().copied())
294                            .into_iter()
295                            .collect();
296                    },
297                }
298            }
299
300            aggregated_memberships.insert(var_key, agg);
301        }
302
303        let mut final_inferred_cf = 0.;
304
305        // Aggregate Production CF
306        for (i, _rule) in rules.0.iter().enumerate() {
307            final_inferred_cf = f64::max(final_inferred_cf, inferred_cf[&i]);
308        }
309
310        // Defuzzificate
311        let mut defuzzificated_inferred_memberships = HashMap::new();
312
313        for (var_key, aggregated_membership) in aggregated_memberships {
314            let var = &vars.0[var_key];
315
316            if aggregated_membership.iter().copied().sum::<f64>() == 0. {
317                let mean = var.universe.iter().copied().sum::<f64>() / var.universe.len() as f64;
318
319                defuzzificated_inferred_memberships.insert(var_key, mean);
320            } else {
321                let defuzzed = self.defuzz_op.call(&var.universe, &aggregated_membership);
322
323                defuzzificated_inferred_memberships.insert(var_key, defuzzed);
324            }
325        }
326
327        Outputs::new(defuzzificated_inferred_memberships, final_inferred_cf)
328    }
329}
330
331#[test]
332fn test_bank_loan() {
333    use crate::terms::Terms;
334    use fixed_map::Key;
335
336    #[derive(Clone, Copy, Debug, Eq, Hash, Key, Ord, PartialEq, PartialOrd)]
337    enum Score {
338        High,
339        Low,
340    }
341
342    #[derive(Clone, Copy, Debug, Eq, Hash, Key, Ord, PartialEq, PartialOrd)]
343    enum Ratio {
344        Good,
345        Bad,
346    }
347
348    #[derive(Clone, Copy, Debug, Eq, Hash, Key, Ord, PartialEq, PartialOrd)]
349    enum Credit {
350        Good,
351        Bad,
352    }
353
354    #[derive(Clone, Copy, Debug, Eq, Hash, Key, Ord, PartialEq, PartialOrd)]
355    enum Decision {
356        Approve,
357        Reject,
358    }
359
360    #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
361    enum VarTerms {
362        Score(Score),
363        Ratio(Ratio),
364        Credit(Credit),
365        Decision(Decision),
366    }
367
368    impl From<Decision> for VarTerms {
369        fn from(d: Decision) -> Self {
370            Self::Decision(d)
371        }
372    }
373
374    impl From<Score> for VarTerms {
375        fn from(s: Score) -> Self {
376            Self::Score(s)
377        }
378    }
379
380    impl From<Ratio> for VarTerms {
381        fn from(r: Ratio) -> Self {
382            Self::Ratio(r)
383        }
384    }
385
386    impl From<Credit> for VarTerms {
387        fn from(c: Credit) -> Self {
388            Self::Credit(c)
389        }
390    }
391
392    // TODO: The above lines should all be compressed into a macro_rules macro
393
394    let mut score_terms = Terms::new();
395    let mut ratio_terms = Terms::new();
396    let mut credit_terms = Terms::new();
397    let mut decision_terms = Terms::new();
398
399    score_terms.insert(Score::High, &[(175.0, 0.0), (180., 0.2), (185., 0.7), (190., 1.)]);
400    score_terms.insert(
401        Score::Low,
402        &[(155.0, 1.0), (160., 0.8), (165., 0.5), (170., 0.2), (175., 0.)],
403    );
404    ratio_terms.insert(Ratio::Good, &[(0.3, 1.0), (0.4, 0.7), (0.41, 0.3), (0.42, 0.)]);
405    ratio_terms.insert(Ratio::Bad, &[(0.44, 0.), (0.45, 0.3), (0.5, 0.7), (0.7, 1.)]);
406    credit_terms.insert(Credit::Good, &[(2.0, 1.0), (3., 0.7), (4., 0.3), (5., 0.)]);
407    credit_terms.insert(Credit::Bad, &[(5., 0.), (6., 0.3), (7., 0.7), (8., 1.)]);
408    decision_terms.insert(Decision::Approve, &[(5.0, 0.0), (6., 0.3), (7., 0.7), (8., 1.)]);
409    decision_terms.insert(Decision::Reject, &[(2., 1.), (3., 0.7), (4., 0.3), (5., 0.)]);
410
411    let mut vars = Variables::<VarTerms>::new();
412    let score = vars.add(150. ..=200., score_terms);
413    let ratio = vars.add(0.1..=1., ratio_terms);
414    let credit = vars.add(0. ..=10., credit_terms);
415    let decision = vars.add(0. ..=10., decision_terms);
416    let mut rules = Rules::new();
417
418    rules.add(
419        score
420            .is(Score::High)
421            .and2(ratio.is(Ratio::Good), credit.is(Credit::Good)),
422        decision.is(Decision::Approve),
423    );
424    rules.add(
425        score
426            .is(Score::Low)
427            .and(ratio.is(Ratio::Bad))
428            .or(credit.is(Credit::Bad)),
429        decision.is(Decision::Reject),
430    );
431
432    let mut inputs = Inputs::new();
433
434    inputs.add(score, 190.);
435    inputs.add(ratio, 0.39);
436    inputs.add(credit, 1.5);
437
438    let model = DecompInference::new(
439        AndOp::Min,
440        OrOp::Max,
441        CompositionOp::MaxMin,
442        ImplicationOp::Rc,
443        ProductionLink::Max,
444        DefuzzificationOp::Cog,
445    );
446
447    let outputs = model.eval(&mut vars, &rules, &inputs);
448
449    assert_eq!(outputs.get_inferred_membership(decision), Some(8.010492631084489));
450    assert_eq!(outputs.inferred_cf(), 1.);
451}