Skip to main content

radiate_pgm/
lib.rs

1use radiate_core::random_provider;
2use radiate_utils::Value;
3
4mod alter;
5mod chromosome;
6mod codec;
7mod factor;
8mod fitness;
9mod kernel;
10mod var;
11
12pub use alter::{PgmParamMutator, PgmScopeMutator};
13pub use chromosome::{FactorGene, FactorKind, PgmChromosome};
14pub use codec::PgmCodec;
15pub use factor::{
16    DiscreteFactor, chromosome_factors, gene_to_discrete, joint_factor, loglik_evidence,
17    logp_evidence, logz, marginal_joint, marginal_ve, neg_mean_loglik,
18};
19pub use fitness::{PgmDataSet, PgmLogLik, PgmNll};
20pub use kernel::{CptKernel, FactorKernel, IsingKernel};
21pub use var::{VarId, VarSpec};
22
23pub(crate) fn sample_scope(num_vars: usize, max_scope: usize) -> Vec<VarId> {
24    let k = random_provider::range(1..max_scope.min(num_vars).max(1) + 1);
25
26    let mut picked = Vec::with_capacity(k);
27    while picked.len() < k {
28        let v = VarId(random_provider::range(0..num_vars) as u32);
29        if !picked.contains(&v) {
30            picked.push(v);
31        }
32    }
33
34    // enforce "child last" convention for Logp when len>1
35    if picked.len() > 1 {
36        let child_pos = random_provider::range(0..picked.len());
37        let child = picked.remove(child_pos);
38        picked.push(child);
39    }
40
41    picked
42}
43
44pub(crate) fn logp_table_shape(vars: &[VarSpec], scope: &[VarId]) -> Vec<usize> {
45    scope
46        .iter()
47        .map(|&vid| vars[vid.0 as usize].card.max(1) as usize)
48        .collect()
49}
50
51pub(crate) fn init_logp_table(shape: &[usize]) -> Value<f32> {
52    Value::from((shape.to_vec(), |_| random_provider::range(-1.0..1.0)))
53}
54
55pub fn clamp_f32(x: f32, lo: f32, hi: f32) -> f32 {
56    if x.is_nan() { 0.0 } else { x.clamp(lo, hi) }
57}
58
59#[inline]
60pub fn logsumexp(xs: &[f32]) -> f32 {
61    if xs.is_empty() {
62        return f32::NEG_INFINITY;
63    }
64    let mut m = f32::NEG_INFINITY;
65    for &x in xs {
66        if x > m {
67            m = x;
68        }
69    }
70    if m.is_infinite() {
71        return m;
72    }
73    let mut s = 0.0f32;
74    for &x in xs {
75        s += (x - m).exp();
76    }
77    m + s.ln()
78}
79
80/// log-space normalize a slice in-place so that exp(slice).sum() == 1.
81#[inline]
82pub fn log_normalize_in_place(row: &mut [f32]) {
83    if row.is_empty() {
84        return;
85    }
86
87    let mut m = f32::NEG_INFINITY;
88    for &x in row.iter() {
89        if x > m {
90            m = x;
91        }
92    }
93
94    if m.is_infinite() {
95        // all -inf; keep as-is
96        return;
97    }
98
99    let mut s = 0.0f32;
100    for &x in row.iter() {
101        s += (x - m).exp();
102    }
103
104    let lz = m + s.ln();
105    for x in row.iter_mut() {
106        *x -= lz;
107    }
108}
109
110#[inline]
111pub fn prod_usize(xs: &[usize]) -> usize {
112    xs.iter().fold(1usize, |acc, &v| acc.saturating_mul(v))
113}
114
115/// A very small discrete-only VE evaluator:
116/// - multiply all factors that mention elim var
117/// - marginalize elim var
118/// - put result back
119///
120/// For correctness validation and N <= ~10ish factors/vars.
121pub fn variable_elimination(
122    mut factors: Vec<DiscreteFactor>,
123    elim_order: &[VarId],
124    card: &impl Fn(VarId) -> usize,
125) -> Result<DiscreteFactor, String> {
126    for &z in elim_order {
127        // split factors into those that contain z and those that don't
128        let mut with = Vec::new();
129        let mut without = Vec::new();
130
131        for f in factors.into_iter() {
132            if f.scope().contains(&z) {
133                with.push(f);
134            } else {
135                without.push(f);
136            }
137        }
138
139        // if no factor includes z, keep going
140        if with.is_empty() {
141            factors = without;
142            continue;
143        }
144
145        // multiply them all
146        let mut joint = with[0].clone();
147        for f in with.iter().skip(1) {
148            joint = joint.product(f, card)?;
149        }
150
151        // marginalize z out
152        let reduced = joint.marginalize(&[z])?;
153
154        without.push(reduced);
155        factors = without;
156    }
157
158    // multiply remaining factors into one joint
159    if factors.is_empty() {
160        return Err("no factors".into());
161    }
162    let mut joint = factors[0].clone();
163    for f in factors.iter().skip(1) {
164        joint = joint.product(f, card)?;
165    }
166    Ok(joint)
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use crate::var::VarSpec;
173
174    fn approx(a: f32, b: f32, eps: f32) -> bool {
175        (a - b).abs() <= eps
176    }
177
178    #[test]
179    fn ve_matches_bruteforce_small_chain() {
180        // A(2) -> B(2) -> C(2)
181        let a = VarSpec::new(0, 2);
182        let b = VarSpec::new(1, 2);
183        let c = VarSpec::new(2, 2);
184
185        let card = |v: VarId| match v.0 {
186            0 => 2,
187            1 => 2,
188            2 => 2,
189            _ => 0,
190        };
191
192        // Prior P(A): logits [0, 1] then normalized (as CPT with "child" A and no parents)
193        let mut p_a = DiscreteFactor::new(vec![a], vec![0.0, 1.0]).unwrap();
194        // normalize across its only axis by treating it as "child"
195        p_a.normalize_rows(VarId(0)).unwrap();
196
197        // CPT P(B|A) scope [A,B]
198        let mut p_ba = DiscreteFactor::new(
199            vec![a, b],
200            vec![
201                2.0, 0.0, // A=0: B=0..1
202                0.0, 2.0, // A=1
203            ],
204        )
205        .unwrap();
206        p_ba.normalize_rows(VarId(1)).unwrap();
207
208        // CPT P(C|B) scope [B,C]
209        let mut p_cb = DiscreteFactor::new(
210            vec![b, c],
211            vec![
212                2.0, 0.0, // B=0
213                0.0, 2.0, // B=1
214            ],
215        )
216        .unwrap();
217        p_cb.normalize_rows(VarId(2)).unwrap();
218
219        // We want marginal P(C) by eliminating A,B.
220        let ve = variable_elimination(
221            vec![p_a.clone(), p_ba.clone(), p_cb.clone()],
222            &[VarId(0), VarId(1)],
223            &card,
224        )
225        .unwrap();
226        assert_eq!(ve.scope(), &[VarId(2)]);
227
228        // brute force: sum_{a,b} P(a)P(b|a)P(c|b)
229        for ci in 0..2 {
230            let mut acc = Vec::new();
231            for ai in 0..2 {
232                for bi in 0..2 {
233                    let lp = p_a.log_value_aligned(&[ai])
234                        + p_ba.log_value_aligned(&[ai, bi])
235                        + p_cb.log_value_aligned(&[bi, ci]);
236                    acc.push(lp);
237                }
238            }
239            let want = crate::logsumexp(&acc);
240            let got = ve.log_value_aligned(&[ci]);
241            assert!(approx(got, want, 1e-5), "c={ci} got={got} want={want}");
242        }
243    }
244}
245
246// use radiate_pgm_skeleton::kernels::{CptKernel, IsingKernel, FactorKernel};
247// use radiate_pgm_skeleton::ve::variable_elimination;
248// use radiate_pgm_skeleton::var::{VarSpec, VarId};
249
250// fn main() -> Result<(), String> {
251//     let a = VarSpec::new(0, 2);
252//     let b = VarSpec::new(1, 2);
253
254//     let card = |v: VarId| match v.0 { 0 => 2, 1 => 2, _ => 0 };
255
256//     // Build an Ising factor on (A,B)
257//     let ising = IsingKernel { a, b };
258//     let f_ab = ising.build(&[0.0, 0.1, 0.0, 0.2, 0.5, -0.5])?;
259
260//     // Build a CPT P(A) as a CPT with no parents (scope [A])
261//     let prior = CptKernel { parents: vec![], child: a };
262//     let p_a = prior.build(&[0.0, 1.0])?;
263
264//     // Compute marginal P(B) by eliminating A from P(A)*f(A,B)
265//     let out = variable_elimination(vec![p_a, f_ab], &[VarId(0)], &card)?;
266//     println!("scope={:?}, logp={:?}", out.scope(), out.logp());
267//     Ok(())
268// }