1use crate::factor::DiscreteFactor;
2use crate::var::VarSpec;
3
4pub trait FactorKernel {
5 fn param_count(&self) -> usize;
6 fn build(&self, params: &[f32]) -> Result<DiscreteFactor, String>;
7}
8
9pub struct CptKernel {
12 pub parents: Vec<VarSpec>,
13 pub child: VarSpec,
14}
15
16impl FactorKernel for CptKernel {
17 fn param_count(&self) -> usize {
18 let mut dims = self.parents.iter().map(|v| v.card).collect::<Vec<usize>>();
19 dims.push(self.child.card);
20 crate::prod_usize(&dims)
21 }
22
23 fn build(&self, params: &[f32]) -> Result<DiscreteFactor, String> {
24 if params.len() != self.param_count() {
25 return Err(format!(
26 "params len {} != {}",
27 params.len(),
28 self.param_count()
29 ));
30 }
31 let mut scope = self.parents.clone();
32 scope.push(self.child);
33
34 let mut f = DiscreteFactor::new(scope, params.to_vec())?;
35 f.normalize_rows(self.child.id)?;
36
37 Ok(f)
38 }
39}
40
41pub struct IsingKernel {
44 pub a: VarSpec,
45 pub b: VarSpec,
46}
47
48impl FactorKernel for IsingKernel {
49 fn param_count(&self) -> usize {
50 6
51 }
52
53 fn build(&self, params: &[f32]) -> Result<DiscreteFactor, String> {
54 if self.a.card != 2 || self.b.card != 2 {
55 return Err("IsingKernel expects binary vars".into());
56 }
57
58 if params.len() != 6 {
59 return Err("IsingKernel expects 6 params".into());
60 }
61
62 let ha0 = params[0];
63 let ha1 = params[1];
64 let hb0 = params[2];
65 let hb1 = params[3];
66 let j_same = params[4];
67 let j_diff = params[5];
68
69 let mut logp = vec![0.0f32; 4];
72 for a in 0..2u32 {
73 for b in 0..2u32 {
74 let ha = if a == 0 { ha0 } else { ha1 };
75 let hb = if b == 0 { hb0 } else { hb1 };
76 let j = if a == b { j_same } else { j_diff };
77 let idx = (a as usize) + (b as usize) * 2;
78 logp[idx] = ha + hb + j;
79 }
80 }
81
82 DiscreteFactor::new(vec![self.a, self.b], logp)
83 }
84}