mollendorff_forge/bayesian/
inference.rs1use super::config::{BayesianConfig, BayesianNode, NodeType};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct Factor {
12 pub variables: Vec<String>,
14 pub cardinalities: Vec<usize>,
16 pub values: Vec<f64>,
18}
19
20impl Factor {
21 #[must_use]
23 pub fn from_node(name: &str, node: &BayesianNode, config: &BayesianConfig) -> Self {
24 let mut variables = vec![name.to_string()];
25 let mut cardinalities = vec![node.states.len()];
26
27 for parent in &node.parents {
29 if let Some(parent_node) = config.nodes.get(parent) {
30 variables.push(parent.clone());
31 cardinalities.push(parent_node.states.len());
32 }
33 }
34
35 let total_size: usize = cardinalities.iter().product();
37 let mut values = vec![0.0; total_size];
38
39 if node.is_root() {
40 values.clone_from(&node.prior);
42 } else {
43 if let Some(parent_node) = config.nodes.get(&node.parents[0]) {
46 let parent_card = parent_node.states.len();
47
48 for (i, val) in values.iter_mut().enumerate().take(total_size) {
49 let parent_idx = i % parent_card;
51 let state_idx = i / parent_card;
52
53 if parent_idx < parent_node.states.len() {
54 let parent_state = &parent_node.states[parent_idx];
55 if let Some(probs) = node.cpt.get(parent_state) {
56 if state_idx < probs.len() {
57 *val = probs[state_idx];
58 }
59 }
60 }
61 }
62 }
63 }
64
65 Self {
66 variables,
67 cardinalities,
68 values,
69 }
70 }
71
72 #[must_use]
74 pub fn multiply(&self, other: &Self) -> Self {
75 let mut new_variables = self.variables.clone();
77 let mut new_cardinalities = self.cardinalities.clone();
78
79 let mut other_indices: Vec<Option<usize>> = vec![None; other.variables.len()];
80
81 for (i, var) in other.variables.iter().enumerate() {
82 if let Some(pos) = self.variables.iter().position(|v| v == var) {
83 other_indices[i] = Some(pos);
84 } else {
85 new_variables.push(var.clone());
86 new_cardinalities.push(other.cardinalities[i]);
87 other_indices[i] = Some(new_variables.len() - 1);
88 }
89 }
90
91 let total_size: usize = new_cardinalities.iter().product();
92 let mut new_values = vec![0.0; total_size];
93
94 for (i, val) in new_values.iter_mut().enumerate() {
96 let indices = Self::decode_index(i, &new_cardinalities);
97
98 let self_idx =
100 Self::encode_index(&indices[..self.variables.len()], &self.cardinalities);
101
102 let other_idx_vec: Vec<usize> = other_indices
104 .iter()
105 .filter_map(|&idx| idx.map(|j| indices[j]))
106 .collect();
107 let other_idx = Self::encode_index(&other_idx_vec, &other.cardinalities);
108
109 let self_val = self.values.get(self_idx).copied().unwrap_or(0.0);
110 let other_val = other.values.get(other_idx).copied().unwrap_or(0.0);
111
112 *val = self_val * other_val;
113 }
114
115 Self {
116 variables: new_variables,
117 cardinalities: new_cardinalities,
118 values: new_values,
119 }
120 }
121
122 #[must_use]
124 pub fn marginalize(&self, var: &str) -> Self {
125 let Some(var_idx) = self.variables.iter().position(|v| v == var) else {
126 return self.clone();
127 };
128
129 let new_variables: Vec<String> = self
130 .variables
131 .iter()
132 .enumerate()
133 .filter(|(i, _)| *i != var_idx)
134 .map(|(_, v)| v.clone())
135 .collect();
136
137 let new_cardinalities: Vec<usize> = self
138 .cardinalities
139 .iter()
140 .enumerate()
141 .filter(|(i, _)| *i != var_idx)
142 .map(|(_, c)| *c)
143 .collect();
144
145 if new_variables.is_empty() {
146 return Self {
148 variables: vec![],
149 cardinalities: vec![],
150 values: vec![self.values.iter().sum()],
151 };
152 }
153
154 let total_size: usize = new_cardinalities.iter().product();
155 let mut new_values = vec![0.0; total_size];
156
157 for i in 0..self.values.len() {
158 let indices = Self::decode_index(i, &self.cardinalities);
159
160 let new_idx_vec: Vec<usize> = indices
162 .iter()
163 .enumerate()
164 .filter(|(j, _)| *j != var_idx)
165 .map(|(_, idx)| *idx)
166 .collect();
167
168 let new_idx = if new_idx_vec.is_empty() {
169 0
170 } else {
171 Self::encode_index(&new_idx_vec, &new_cardinalities)
172 };
173
174 new_values[new_idx] += self.values[i];
175 }
176
177 Self {
178 variables: new_variables,
179 cardinalities: new_cardinalities,
180 values: new_values,
181 }
182 }
183
184 pub fn normalize(&mut self) {
186 let sum: f64 = self.values.iter().sum();
187 if sum > 0.0 {
188 for v in &mut self.values {
189 *v /= sum;
190 }
191 }
192 }
193
194 fn decode_index(mut idx: usize, cardinalities: &[usize]) -> Vec<usize> {
196 let mut indices = vec![0; cardinalities.len()];
197 for i in (0..cardinalities.len()).rev() {
198 indices[i] = idx % cardinalities[i];
199 idx /= cardinalities[i];
200 }
201 indices
202 }
203
204 fn encode_index(indices: &[usize], cardinalities: &[usize]) -> usize {
206 let mut idx = 0;
207 let mut multiplier = 1;
208 for i in (0..indices.len()).rev() {
209 idx += indices[i] * multiplier;
210 multiplier *= cardinalities.get(i).copied().unwrap_or(1);
211 }
212 idx
213 }
214
215 #[must_use]
217 pub fn get_probability(&self, assignment: &HashMap<String, usize>) -> f64 {
218 let indices: Vec<usize> = self
219 .variables
220 .iter()
221 .map(|v| assignment.get(v).copied().unwrap_or(0))
222 .collect();
223 let idx = Self::encode_index(&indices, &self.cardinalities);
224 self.values.get(idx).copied().unwrap_or(0.0)
225 }
226}
227
228pub struct BeliefPropagation {
230 config: BayesianConfig,
231 factors: Vec<Factor>,
232}
233
234impl BeliefPropagation {
235 pub fn new(config: BayesianConfig) -> Result<Self, String> {
241 config.validate()?;
242
243 let mut factors = Vec::new();
245 for (name, node) in &config.nodes {
246 if node.node_type == NodeType::Discrete {
247 factors.push(Factor::from_node(name, node, &config));
248 }
249 }
250
251 Ok(Self { config, factors })
252 }
253
254 pub fn query(&self, target: &str) -> Result<Vec<f64>, String> {
260 if !self.config.nodes.contains_key(target) {
261 return Err(format!("Variable '{target}' not found in network"));
262 }
263
264 let order = self.get_elimination_order(target);
266
267 let mut factors = self.factors.clone();
268
269 for var in order {
270 if var == target {
271 continue;
272 }
273
274 let (containing, remaining): (Vec<_>, Vec<_>) = factors
276 .into_iter()
277 .partition(|f| f.variables.contains(&var));
278
279 if containing.is_empty() {
280 factors = remaining;
281 continue;
282 }
283
284 let mut product = containing[0].clone();
286 for f in containing.iter().skip(1) {
287 product = product.multiply(f);
288 }
289
290 let marginal = product.marginalize(&var);
292
293 factors = remaining;
294 factors.push(marginal);
295 }
296
297 if factors.is_empty() {
299 return Err("No factors remaining".to_string());
300 }
301
302 let mut result = factors[0].clone();
303 for f in factors.iter().skip(1) {
304 result = result.multiply(f);
305 }
306
307 result.normalize();
309
310 if result.variables.len() == 1 && result.variables[0] == target {
313 let sum: f64 = result.values.iter().sum();
315 if sum > 0.0 {
316 Ok(result.values.iter().map(|v| v / sum).collect())
317 } else {
318 Ok(result.values.clone())
319 }
320 } else {
321 let mut final_result = result.clone();
323 for var in &result.variables {
324 if var != target {
325 final_result = final_result.marginalize(var);
326 }
327 }
328
329 let sum: f64 = final_result.values.iter().sum();
331 if sum > 0.0 {
332 Ok(final_result.values.iter().map(|v| v / sum).collect())
333 } else {
334 Ok(final_result.values)
335 }
336 }
337 }
338
339 pub fn query_with_evidence(
345 &self,
346 target: &str,
347 evidence: &HashMap<String, usize>,
348 ) -> Result<Vec<f64>, String> {
349 if !self.config.nodes.contains_key(target) {
350 return Err(format!("Variable '{target}' not found in network"));
351 }
352
353 let mut factors: Vec<Factor> = self
355 .factors
356 .iter()
357 .map(|f| Self::apply_evidence(f, evidence))
358 .collect();
359
360 let order = self.get_elimination_order(target);
362
363 for var in order {
364 if var == target || evidence.contains_key(&var) {
365 continue;
366 }
367
368 let (containing, remaining): (Vec<_>, Vec<_>) = factors
370 .into_iter()
371 .partition(|f| f.variables.contains(&var));
372
373 if containing.is_empty() {
374 factors = remaining;
375 continue;
376 }
377
378 let mut product = containing[0].clone();
380 for f in containing.iter().skip(1) {
381 product = product.multiply(f);
382 }
383
384 let marginal = product.marginalize(&var);
386
387 factors = remaining;
388 factors.push(marginal);
389 }
390
391 if factors.is_empty() {
393 return Err("No factors remaining".to_string());
394 }
395
396 let mut result = factors[0].clone();
397 for f in factors.iter().skip(1) {
398 result = result.multiply(f);
399 }
400
401 result.normalize();
403
404 if result.variables.len() == 1 && result.variables[0] == target {
407 let sum: f64 = result.values.iter().sum();
409 if sum > 0.0 {
410 Ok(result.values.iter().map(|v| v / sum).collect())
411 } else {
412 Ok(result.values.clone())
413 }
414 } else {
415 let mut final_result = result.clone();
417 for var in &result.variables {
418 if var != target {
419 final_result = final_result.marginalize(var);
420 }
421 }
422
423 let sum: f64 = final_result.values.iter().sum();
425 if sum > 0.0 {
426 Ok(final_result.values.iter().map(|v| v / sum).collect())
427 } else {
428 Ok(final_result.values)
429 }
430 }
431 }
432
433 fn apply_evidence(factor: &Factor, evidence: &HashMap<String, usize>) -> Factor {
435 let mut new_values = factor.values.clone();
436
437 for (i, val) in new_values.iter_mut().enumerate() {
438 let indices = Factor::decode_index(i, &factor.cardinalities);
439
440 for (var_idx, var) in factor.variables.iter().enumerate() {
441 if let Some(&ev_val) = evidence.get(var) {
442 if indices[var_idx] != ev_val {
443 *val = 0.0;
444 break;
445 }
446 }
447 }
448 }
449
450 Factor {
451 variables: factor.variables.clone(),
452 cardinalities: factor.cardinalities.clone(),
453 values: new_values,
454 }
455 }
456
457 fn get_elimination_order(&self, exclude: &str) -> Vec<String> {
459 let mut order = self.config.topological_order();
460 order.reverse();
461 order.retain(|v| v != exclude);
462 order
463 }
464
465 #[must_use]
467 pub const fn config(&self) -> &BayesianConfig {
468 &self.config
469 }
470}
471
472#[cfg(test)]
473mod inference_tests {
474 use super::*;
475
476 fn create_simple_network() -> BayesianConfig {
477 BayesianConfig::new("Sprinkler")
480 .with_node(
481 "rain",
482 BayesianNode::discrete(vec!["no", "yes"]).with_prior(vec![0.8, 0.2]),
483 )
484 .with_node(
485 "sprinkler",
486 BayesianNode::discrete(vec!["off", "on"])
487 .with_parents(vec!["rain"])
488 .with_cpt_entry("no", vec![0.6, 0.4])
489 .with_cpt_entry("yes", vec![0.99, 0.01]),
490 )
491 }
492
493 #[test]
494 fn test_prior_query() {
495 let config = create_simple_network();
496 let bp = BeliefPropagation::new(config).unwrap();
497
498 let rain_probs = bp.query("rain").unwrap();
499 assert!(
500 (rain_probs[0] - 0.8).abs() < 0.01,
501 "P(rain=no) should be 0.8"
502 );
503 assert!(
504 (rain_probs[1] - 0.2).abs() < 0.01,
505 "P(rain=yes) should be 0.2"
506 );
507 }
508
509 #[test]
510 fn test_marginal_query() {
511 let config = create_simple_network();
512 let bp = BeliefPropagation::new(config).unwrap();
513
514 let sprinkler_probs = bp.query("sprinkler").unwrap();
515
516 let expected_on = 0.4f64.mul_add(0.8, 0.01 * 0.2);
519 assert!(
520 (sprinkler_probs[1] - expected_on).abs() < 0.01,
521 "P(sprinkler=on) should be {}, got {}",
522 expected_on,
523 sprinkler_probs[1]
524 );
525 }
526
527 #[test]
528 fn test_evidence_query() {
529 let config = create_simple_network();
530 let bp = BeliefPropagation::new(config).unwrap();
531
532 let mut evidence = HashMap::new();
534 evidence.insert("rain".to_string(), 1); let probs = bp.query_with_evidence("sprinkler", &evidence).unwrap();
537
538 assert!(
540 (probs[1] - 0.01).abs() < 0.01,
541 "P(sprinkler=on | rain=yes) should be 0.01, got {}",
542 probs[1]
543 );
544 }
545}