1use std::collections::HashMap;
2use std::hash::Hash;
3
4use num::Float;
5
6use crate::dsl::Expr;
7use crate::inputs::Inputs;
8use crate::math::{meshgrid, Axis, CollectMatrix, Matrix};
9use crate::ops::*;
10use crate::outputs::Outputs;
11use crate::rules::Rules;
12use crate::variable::{VariableKey, Variables};
13
14pub struct DecompInference {
15 and_op: AndOp,
16 or_op: OrOp,
17 comp_op: CompositionOp,
18 imp_op: ImplicationOp,
19 prod_link: ProductionLink,
20 defuzz_op: DefuzzificationOp,
21}
22
23impl DecompInference {
24 pub fn new(
25 and_op: AndOp,
26 or_op: OrOp,
27 comp_op: CompositionOp,
28 imp_op: ImplicationOp,
29 prod_link: ProductionLink,
30 defuzz_op: DefuzzificationOp,
31 ) -> Self {
32 Self {
33 and_op,
34 or_op,
35 comp_op,
36 imp_op,
37 prod_link,
38 defuzz_op,
39 }
40 }
41
42 pub fn eval<T: Eq + Hash>(&self, vars: &mut Variables<T>, rules: &Rules<T>, inputs: &Inputs) -> Outputs {
45 let mut fact_value = HashMap::with_capacity(inputs.0.len());
48 let mut fact_cf = HashMap::with_capacity(inputs.0.len());
49
50 for (key, input_value) in &inputs.0 {
51 fact_value.insert(*key, *input_value);
53 fact_cf.insert(*key, 1.);
54 }
55
56 let mut fact_values = HashMap::with_capacity(inputs.0.len());
61
62 for (key, fact_value) in fact_value {
63 if true {
66 let var = &mut vars.0[key];
69
70 var.add_points_to_universe(Some(fact_value));
71 fact_values.insert(
72 key,
73 var.universe
74 .iter()
75 .map(|u| if *u == fact_value { 1.0f64 } else { 0. })
76 .collect::<Vec<_>>(),
77 );
78 } else {
80 unimplemented!();
82 }
83 }
84
85 let mut modified_premise_memberships = HashMap::new();
87
88 for (i, rule) in rules.0.iter().enumerate() {
89 let premise = &rule.premise;
90
91 for (var_key, term, modifiers) in premise.propositions() {
92 let membership = vars.0[var_key].get_modified_membership(term, modifiers);
93 modified_premise_memberships.insert((i, var_key), membership);
94 }
95 }
96
97 let mut modified_consequence_memberships = HashMap::new();
99
100 for (i, rule) in rules.0.iter().enumerate() {
102 let consequence = &rule.consequence;
103
104 for (var_key, term, modifiers) in consequence.propositions() {
105 let membership = vars.0[var_key].get_modified_membership(term, modifiers);
106 modified_consequence_memberships.insert((i, var_key), membership);
107 }
108 }
109
110 let mut fuzzy_implications = HashMap::with_capacity(
112 rules.0.len() * modified_premise_memberships.len() * modified_consequence_memberships.len(),
113 );
114
115 for (i, premise_name) in modified_premise_memberships.keys() {
116 for (j, consequence_name) in modified_consequence_memberships.keys() {
117 if *i != *j {
120 continue;
121 }
122
123 let premise_membership = &modified_premise_memberships[&(*i, *premise_name)];
124 let consequence_membership = &modified_consequence_memberships[&(*i, *consequence_name)];
125 let (v, u) = meshgrid(
126 consequence_membership.iter().copied(),
127 premise_membership.iter().copied(),
128 );
129 let shape = v.shape();
130
131 fuzzy_implications.insert(
132 (i, *premise_name, *consequence_name),
133 self.imp_op.call(u, v).collect_matrix(shape),
134 );
135 }
136 }
137
138 let mut fuzzy_compositions = HashMap::with_capacity(fuzzy_implications.len());
141
142 for (i, premise_name) in modified_premise_memberships.keys() {
143 for (j, consequence_name) in modified_consequence_memberships.keys() {
144 if *i != *j {
145 continue;
146 }
147
148 let implication = &fuzzy_implications[&(i, *premise_name, *consequence_name)];
149 let fact_values = &fact_values[premise_name];
150 let n_dim = fact_values.len();
151 let fact_value = Matrix::new(fact_values.to_owned(), (n_dim, 1));
152 let fact_value = fact_value.tile((1, implication.shape().1));
153 let shape = fact_value.shape();
154
155 debug_assert_eq!(shape, implication.shape());
156
157 let composition = match self.comp_op {
158 CompositionOp::MaxMin => ProductionLink::Min
159 .call(fact_value, implication)
160 .collect_matrix(shape)
161 .max(Axis::Column),
162 CompositionOp::MaxProd => unimplemented!("fact_value * implication"), };
164
165 fuzzy_compositions.insert((*i, *premise_name, *consequence_name), composition);
166 }
167 }
168
169 let mut combined_compositions = HashMap::new();
171
172 for (i, rule) in rules.0.iter().enumerate() {
173 for (j, consequence_name) in modified_consequence_memberships.keys() {
174 if i != *j {
175 continue;
176 }
177
178 fn combine<T, F: Float>(
182 expr: &Expr<T>,
183 fuzzy_compositions: &HashMap<(usize, VariableKey, VariableKey), Vec<F>>,
184 consequence_name: VariableKey,
185 and_op: AndOp,
186 or_op: OrOp,
187 rule_id: usize,
188 ) -> Vec<F> {
189 match expr {
190 Expr::Is(var_key, _) => fuzzy_compositions[&(rule_id, *var_key, consequence_name)].clone(),
191 Expr::And(expr, expr2) => {
192 let left = combine(expr, fuzzy_compositions, consequence_name, and_op, or_op, rule_id);
193 let right = combine(expr2, fuzzy_compositions, consequence_name, and_op, or_op, rule_id);
194
195 and_op.call(left, right).into_iter().collect()
196 },
197 Expr::Or(expr, expr2) => {
198 let left = combine(expr, fuzzy_compositions, consequence_name, and_op, or_op, rule_id);
199 let right = combine(expr2, fuzzy_compositions, consequence_name, and_op, or_op, rule_id);
200
201 or_op.call(left, right).into_iter().collect()
202 },
203 }
204 }
205
206 let combined_composition = combine(
207 &rule.premise,
208 &fuzzy_compositions,
209 *consequence_name,
210 self.and_op,
211 self.or_op,
212 i,
213 );
214
215 combined_compositions.insert((i, *consequence_name), combined_composition);
216 }
217 }
218
219 let mut inferred_cf = HashMap::new();
221
222 for (i, rule) in rules.0.iter().enumerate() {
223 fn calc_cf<T>(expr: &Expr<T>, fact_cf: &HashMap<VariableKey, f64>) -> f64 {
224 match expr {
225 Expr::Is(var_key, _) => fact_cf[var_key],
226 Expr::And(expr, expr2) => {
227 let left = calc_cf(expr, fact_cf);
228 let right = calc_cf(expr2, fact_cf);
229
230 f64::min(left, right)
231 },
232 Expr::Or(expr, expr2) => {
233 let left = calc_cf(expr, fact_cf);
234 let right = calc_cf(expr2, fact_cf);
235
236 f64::max(left, right)
237 },
238 }
239 }
240
241 let aggregated_premise_cf = calc_cf(&rule.premise, &fact_cf);
242
243 inferred_cf.insert(i, aggregated_premise_cf * rule.cf);
244 }
245
246 let mut collected_rule_memberships = HashMap::new();
248
249 for (i, rule) in rules.0.iter().enumerate() {
250 for (j, var_key) in combined_compositions.keys() {
251 if i != *j {
252 continue;
253 }
254
255 if inferred_cf[&i] >= rule.threshold_cf {
267 collected_rule_memberships
268 .entry(*var_key)
269 .or_insert_with(Vec::new)
270 .push(&*combined_compositions[&(i, *var_key)]);
271 }
272 }
273 }
274
275 let mut aggregated_memberships = HashMap::new();
277
278 for (var_key, memberships) in collected_rule_memberships {
279 let mut agg = Vec::new();
280
281 for window in memberships.windows(2) {
282 match &agg[..] {
283 [] => {
284 agg = self
285 .prod_link
286 .call(window[0].iter().copied(), window[1].iter().copied())
287 .into_iter()
288 .collect();
289 },
290 [..] => {
291 agg = self
292 .prod_link
293 .call(agg, window[1].iter().copied())
294 .into_iter()
295 .collect();
296 },
297 }
298 }
299
300 aggregated_memberships.insert(var_key, agg);
301 }
302
303 let mut final_inferred_cf = 0.;
304
305 for (i, _rule) in rules.0.iter().enumerate() {
307 final_inferred_cf = f64::max(final_inferred_cf, inferred_cf[&i]);
308 }
309
310 let mut defuzzificated_inferred_memberships = HashMap::new();
312
313 for (var_key, aggregated_membership) in aggregated_memberships {
314 let var = &vars.0[var_key];
315
316 if aggregated_membership.iter().copied().sum::<f64>() == 0. {
317 let mean = var.universe.iter().copied().sum::<f64>() / var.universe.len() as f64;
318
319 defuzzificated_inferred_memberships.insert(var_key, mean);
320 } else {
321 let defuzzed = self.defuzz_op.call(&var.universe, &aggregated_membership);
322
323 defuzzificated_inferred_memberships.insert(var_key, defuzzed);
324 }
325 }
326
327 Outputs::new(defuzzificated_inferred_memberships, final_inferred_cf)
328 }
329}
330
331#[test]
332fn test_bank_loan() {
333 use crate::terms::Terms;
334 use fixed_map::Key;
335
336 #[derive(Clone, Copy, Debug, Eq, Hash, Key, Ord, PartialEq, PartialOrd)]
337 enum Score {
338 High,
339 Low,
340 }
341
342 #[derive(Clone, Copy, Debug, Eq, Hash, Key, Ord, PartialEq, PartialOrd)]
343 enum Ratio {
344 Good,
345 Bad,
346 }
347
348 #[derive(Clone, Copy, Debug, Eq, Hash, Key, Ord, PartialEq, PartialOrd)]
349 enum Credit {
350 Good,
351 Bad,
352 }
353
354 #[derive(Clone, Copy, Debug, Eq, Hash, Key, Ord, PartialEq, PartialOrd)]
355 enum Decision {
356 Approve,
357 Reject,
358 }
359
360 #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
361 enum VarTerms {
362 Score(Score),
363 Ratio(Ratio),
364 Credit(Credit),
365 Decision(Decision),
366 }
367
368 impl From<Decision> for VarTerms {
369 fn from(d: Decision) -> Self {
370 Self::Decision(d)
371 }
372 }
373
374 impl From<Score> for VarTerms {
375 fn from(s: Score) -> Self {
376 Self::Score(s)
377 }
378 }
379
380 impl From<Ratio> for VarTerms {
381 fn from(r: Ratio) -> Self {
382 Self::Ratio(r)
383 }
384 }
385
386 impl From<Credit> for VarTerms {
387 fn from(c: Credit) -> Self {
388 Self::Credit(c)
389 }
390 }
391
392 let mut score_terms = Terms::new();
395 let mut ratio_terms = Terms::new();
396 let mut credit_terms = Terms::new();
397 let mut decision_terms = Terms::new();
398
399 score_terms.insert(Score::High, &[(175.0, 0.0), (180., 0.2), (185., 0.7), (190., 1.)]);
400 score_terms.insert(
401 Score::Low,
402 &[(155.0, 1.0), (160., 0.8), (165., 0.5), (170., 0.2), (175., 0.)],
403 );
404 ratio_terms.insert(Ratio::Good, &[(0.3, 1.0), (0.4, 0.7), (0.41, 0.3), (0.42, 0.)]);
405 ratio_terms.insert(Ratio::Bad, &[(0.44, 0.), (0.45, 0.3), (0.5, 0.7), (0.7, 1.)]);
406 credit_terms.insert(Credit::Good, &[(2.0, 1.0), (3., 0.7), (4., 0.3), (5., 0.)]);
407 credit_terms.insert(Credit::Bad, &[(5., 0.), (6., 0.3), (7., 0.7), (8., 1.)]);
408 decision_terms.insert(Decision::Approve, &[(5.0, 0.0), (6., 0.3), (7., 0.7), (8., 1.)]);
409 decision_terms.insert(Decision::Reject, &[(2., 1.), (3., 0.7), (4., 0.3), (5., 0.)]);
410
411 let mut vars = Variables::<VarTerms>::new();
412 let score = vars.add(150. ..=200., score_terms);
413 let ratio = vars.add(0.1..=1., ratio_terms);
414 let credit = vars.add(0. ..=10., credit_terms);
415 let decision = vars.add(0. ..=10., decision_terms);
416 let mut rules = Rules::new();
417
418 rules.add(
419 score
420 .is(Score::High)
421 .and2(ratio.is(Ratio::Good), credit.is(Credit::Good)),
422 decision.is(Decision::Approve),
423 );
424 rules.add(
425 score
426 .is(Score::Low)
427 .and(ratio.is(Ratio::Bad))
428 .or(credit.is(Credit::Bad)),
429 decision.is(Decision::Reject),
430 );
431
432 let mut inputs = Inputs::new();
433
434 inputs.add(score, 190.);
435 inputs.add(ratio, 0.39);
436 inputs.add(credit, 1.5);
437
438 let model = DecompInference::new(
439 AndOp::Min,
440 OrOp::Max,
441 CompositionOp::MaxMin,
442 ImplicationOp::Rc,
443 ProductionLink::Max,
444 DefuzzificationOp::Cog,
445 );
446
447 let outputs = model.eval(&mut vars, &rules, &inputs);
448
449 assert_eq!(outputs.get_inferred_membership(decision), Some(8.010492631084489));
450 assert_eq!(outputs.inferred_cf(), 1.);
451}