Skip to main content

radiate_pgm/
kernel.rs

1use crate::factor::DiscreteFactor;
2use crate::var::VarSpec;
3
4pub trait FactorKernel {
5    fn param_count(&self) -> usize;
6    fn build(&self, params: &[f32]) -> Result<DiscreteFactor, String>;
7}
8
9/// CPT kernel: builds P(child | parents...) as a discrete table with scope [parents..., child]
10/// Params are logits in row-major over that scope.
11pub struct CptKernel {
12    pub parents: Vec<VarSpec>,
13    pub child: VarSpec,
14}
15
16impl FactorKernel for CptKernel {
17    fn param_count(&self) -> usize {
18        let mut dims = self.parents.iter().map(|v| v.card).collect::<Vec<usize>>();
19        dims.push(self.child.card);
20        crate::prod_usize(&dims)
21    }
22
23    fn build(&self, params: &[f32]) -> Result<DiscreteFactor, String> {
24        if params.len() != self.param_count() {
25            return Err(format!(
26                "params len {} != {}",
27                params.len(),
28                self.param_count()
29            ));
30        }
31        let mut scope = self.parents.clone();
32        scope.push(self.child);
33
34        let mut f = DiscreteFactor::new(scope, params.to_vec())?;
35        f.normalize_rows(self.child.id)?;
36
37        Ok(f)
38    }
39}
40
41/// Ising pairwise kernel over two binary vars: scope [a,b], card 2 each.
42/// Params: [h_a0, h_a1, h_b0, h_b1, J_same, J_diff] (simple template)
43pub struct IsingKernel {
44    pub a: VarSpec,
45    pub b: VarSpec,
46}
47
48impl FactorKernel for IsingKernel {
49    fn param_count(&self) -> usize {
50        6
51    }
52
53    fn build(&self, params: &[f32]) -> Result<DiscreteFactor, String> {
54        if self.a.card != 2 || self.b.card != 2 {
55            return Err("IsingKernel expects binary vars".into());
56        }
57
58        if params.len() != 6 {
59            return Err("IsingKernel expects 6 params".into());
60        }
61
62        let ha0 = params[0];
63        let ha1 = params[1];
64        let hb0 = params[2];
65        let hb1 = params[3];
66        let j_same = params[4];
67        let j_diff = params[5];
68
69        // logp(a,b) = h_a[a] + h_b[b] + (a==b ? j_same : j_diff)
70        // axis order [a,b], row-major => idx = a + 2*b? (with strides [1,2]) => idx = a + b*2
71        let mut logp = vec![0.0f32; 4];
72        for a in 0..2u32 {
73            for b in 0..2u32 {
74                let ha = if a == 0 { ha0 } else { ha1 };
75                let hb = if b == 0 { hb0 } else { hb1 };
76                let j = if a == b { j_same } else { j_diff };
77                let idx = (a as usize) + (b as usize) * 2;
78                logp[idx] = ha + hb + j;
79            }
80        }
81
82        DiscreteFactor::new(vec![self.a, self.b], logp)
83    }
84}