1use radiate_core::random_provider;
2use radiate_utils::Value;
3
4mod alter;
5mod chromosome;
6mod codec;
7mod factor;
8mod fitness;
9mod kernel;
10mod var;
11
12pub use alter::{PgmParamMutator, PgmScopeMutator};
13pub use chromosome::{FactorGene, FactorKind, PgmChromosome};
14pub use codec::PgmCodec;
15pub use factor::{
16 DiscreteFactor, chromosome_factors, gene_to_discrete, joint_factor, loglik_evidence,
17 logp_evidence, logz, marginal_joint, marginal_ve, neg_mean_loglik,
18};
19pub use fitness::{PgmDataSet, PgmLogLik, PgmNll};
20pub use kernel::{CptKernel, FactorKernel, IsingKernel};
21pub use var::{VarId, VarSpec};
22
23pub(crate) fn sample_scope(num_vars: usize, max_scope: usize) -> Vec<VarId> {
24 let k = random_provider::range(1..max_scope.min(num_vars).max(1) + 1);
25
26 let mut picked = Vec::with_capacity(k);
27 while picked.len() < k {
28 let v = VarId(random_provider::range(0..num_vars) as u32);
29 if !picked.contains(&v) {
30 picked.push(v);
31 }
32 }
33
34 if picked.len() > 1 {
36 let child_pos = random_provider::range(0..picked.len());
37 let child = picked.remove(child_pos);
38 picked.push(child);
39 }
40
41 picked
42}
43
44pub(crate) fn logp_table_shape(vars: &[VarSpec], scope: &[VarId]) -> Vec<usize> {
45 scope
46 .iter()
47 .map(|&vid| vars[vid.0 as usize].card.max(1) as usize)
48 .collect()
49}
50
51pub(crate) fn init_logp_table(shape: &[usize]) -> Value<f32> {
52 Value::from((shape.to_vec(), |_| random_provider::range(-1.0..1.0)))
53}
54
55pub fn clamp_f32(x: f32, lo: f32, hi: f32) -> f32 {
56 if x.is_nan() { 0.0 } else { x.clamp(lo, hi) }
57}
58
59#[inline]
60pub fn logsumexp(xs: &[f32]) -> f32 {
61 if xs.is_empty() {
62 return f32::NEG_INFINITY;
63 }
64 let mut m = f32::NEG_INFINITY;
65 for &x in xs {
66 if x > m {
67 m = x;
68 }
69 }
70 if m.is_infinite() {
71 return m;
72 }
73 let mut s = 0.0f32;
74 for &x in xs {
75 s += (x - m).exp();
76 }
77 m + s.ln()
78}
79
80#[inline]
82pub fn log_normalize_in_place(row: &mut [f32]) {
83 if row.is_empty() {
84 return;
85 }
86
87 let mut m = f32::NEG_INFINITY;
88 for &x in row.iter() {
89 if x > m {
90 m = x;
91 }
92 }
93
94 if m.is_infinite() {
95 return;
97 }
98
99 let mut s = 0.0f32;
100 for &x in row.iter() {
101 s += (x - m).exp();
102 }
103
104 let lz = m + s.ln();
105 for x in row.iter_mut() {
106 *x -= lz;
107 }
108}
109
110#[inline]
111pub fn prod_usize(xs: &[usize]) -> usize {
112 xs.iter().fold(1usize, |acc, &v| acc.saturating_mul(v))
113}
114
115pub fn variable_elimination(
122 mut factors: Vec<DiscreteFactor>,
123 elim_order: &[VarId],
124 card: &impl Fn(VarId) -> usize,
125) -> Result<DiscreteFactor, String> {
126 for &z in elim_order {
127 let mut with = Vec::new();
129 let mut without = Vec::new();
130
131 for f in factors.into_iter() {
132 if f.scope().contains(&z) {
133 with.push(f);
134 } else {
135 without.push(f);
136 }
137 }
138
139 if with.is_empty() {
141 factors = without;
142 continue;
143 }
144
145 let mut joint = with[0].clone();
147 for f in with.iter().skip(1) {
148 joint = joint.product(f, card)?;
149 }
150
151 let reduced = joint.marginalize(&[z])?;
153
154 without.push(reduced);
155 factors = without;
156 }
157
158 if factors.is_empty() {
160 return Err("no factors".into());
161 }
162 let mut joint = factors[0].clone();
163 for f in factors.iter().skip(1) {
164 joint = joint.product(f, card)?;
165 }
166 Ok(joint)
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::var::VarSpec;
173
174 fn approx(a: f32, b: f32, eps: f32) -> bool {
175 (a - b).abs() <= eps
176 }
177
178 #[test]
179 fn ve_matches_bruteforce_small_chain() {
180 let a = VarSpec::new(0, 2);
182 let b = VarSpec::new(1, 2);
183 let c = VarSpec::new(2, 2);
184
185 let card = |v: VarId| match v.0 {
186 0 => 2,
187 1 => 2,
188 2 => 2,
189 _ => 0,
190 };
191
192 let mut p_a = DiscreteFactor::new(vec![a], vec![0.0, 1.0]).unwrap();
194 p_a.normalize_rows(VarId(0)).unwrap();
196
197 let mut p_ba = DiscreteFactor::new(
199 vec![a, b],
200 vec![
201 2.0, 0.0, 0.0, 2.0, ],
204 )
205 .unwrap();
206 p_ba.normalize_rows(VarId(1)).unwrap();
207
208 let mut p_cb = DiscreteFactor::new(
210 vec![b, c],
211 vec![
212 2.0, 0.0, 0.0, 2.0, ],
215 )
216 .unwrap();
217 p_cb.normalize_rows(VarId(2)).unwrap();
218
219 let ve = variable_elimination(
221 vec![p_a.clone(), p_ba.clone(), p_cb.clone()],
222 &[VarId(0), VarId(1)],
223 &card,
224 )
225 .unwrap();
226 assert_eq!(ve.scope(), &[VarId(2)]);
227
228 for ci in 0..2 {
230 let mut acc = Vec::new();
231 for ai in 0..2 {
232 for bi in 0..2 {
233 let lp = p_a.log_value_aligned(&[ai])
234 + p_ba.log_value_aligned(&[ai, bi])
235 + p_cb.log_value_aligned(&[bi, ci]);
236 acc.push(lp);
237 }
238 }
239 let want = crate::logsumexp(&acc);
240 let got = ve.log_value_aligned(&[ci]);
241 assert!(approx(got, want, 1e-5), "c={ci} got={got} want={want}");
242 }
243 }
244}
245
246