Skip to main content

radiate_pgm/factor/
discrete.rs

1use crate::var::{VarId, VarSpec};
2use crate::{log_normalize_in_place, logsumexp, prod_usize};
3
4#[derive(Clone, Debug)]
5pub struct DiscreteFactor {
6    scope: Vec<VarId>,   // axis order
7    dims: Vec<usize>,    // cardinalities by axis
8    strides: Vec<usize>, // row-major strides
9    logp: Vec<f32>,      // contiguous log-table
10}
11
12impl DiscreteFactor {
13    /// Build from explicit scope specs (order is preserved).
14    pub fn new(scope: Vec<VarSpec>, logp: Vec<f32>) -> Result<Self, String> {
15        let mut svars = Vec::with_capacity(scope.len());
16        let mut dims = Vec::with_capacity(scope.len());
17        for v in scope {
18            svars.push(v.id);
19            dims.push(v.card);
20        }
21
22        let strides = Self::compute_strides(&dims);
23        let expected = prod_usize(&dims);
24
25        if logp.len() != expected {
26            return Err(format!(
27                "logp length {} != expected {}",
28                logp.len(),
29                expected
30            ));
31        }
32
33        Ok(Self {
34            scope: svars,
35            dims,
36            strides,
37            logp,
38        })
39    }
40
41    /// Convenient: all zeros in prob-space => log(1) table (not normalized unless you normalize).
42    pub fn uniform(scope: Vec<VarSpec>) -> Result<Self, String> {
43        let dims: Vec<usize> = scope.iter().map(|v| v.card).collect();
44        let n = prod_usize(&dims);
45        Self::new(scope, vec![0.0; n])
46    }
47
48    pub fn scope(&self) -> &[VarId] {
49        &self.scope
50    }
51
52    pub fn dims(&self) -> &[usize] {
53        &self.dims
54    }
55
56    pub fn logp(&self) -> &[f32] {
57        &self.logp
58    }
59
60    #[inline]
61    fn compute_strides(dims: &[usize]) -> Vec<usize> {
62        let mut strides = vec![1usize; dims.len()];
63        let mut acc = 1usize;
64        for (i, &d) in dims.iter().enumerate() {
65            strides[i] = acc;
66            acc = acc.saturating_mul(d);
67        }
68        strides
69    }
70
71    /// Flatten an assignment (aligned with this factor's axis order) to an index.
72    #[inline]
73    pub fn index_of(&self, asg: &[usize]) -> usize {
74        debug_assert_eq!(asg.len(), self.scope.len());
75        let mut idx = 0usize;
76        for i in 0..asg.len() {
77            idx += (asg[i] as usize) * self.strides[i];
78        }
79        idx
80    }
81
82    /// Unflatten index -> assignment (aligned with axis order).
83    pub fn unflatten(&self, idx: usize) -> Vec<usize> {
84        let mut asg = vec![0usize; self.scope.len()];
85        for i in (0..self.scope.len()).rev() {
86            let d = self.dims[i] as usize;
87            let v = (idx / self.strides[i]) % d;
88            asg[i] = v as usize;
89        }
90        asg
91    }
92
93    #[inline]
94    pub fn log_value_aligned(&self, asg: &[usize]) -> f32 {
95        let idx = self.index_of(asg);
96        self.logp[idx]
97    }
98
99    /// Restrict (condition) the factor on evidence.
100    /// Evidence is a list of (VarId, state) pairs.
101    /// Returns a new factor with those variables fixed and removed from the scope.
102    pub fn restrict(&self, evidence: &[(VarId, usize)]) -> Result<Self, String> {
103        // Separate variables into fixed (evidence) and remaining (keep)
104        let mut keep = Vec::new();
105        for &v in &self.scope {
106            if !evidence.iter().any(|(ev, _)| *ev == v) {
107                keep.push(v);
108            }
109        }
110
111        // Build the new factor scope with only the kept variables
112        let keep_specs = keep
113            .iter()
114            .map(|&v| {
115                let ax = self.axis_of(v).unwrap();
116                VarSpec {
117                    id: v,
118                    card: self.dims[ax],
119                }
120            })
121            .collect::<Vec<_>>();
122
123        let mut out = DiscreteFactor::uniform(keep_specs)?;
124        out.logp.fill(f32::NEG_INFINITY);
125
126        // Build a base assignment with evidence values fixed
127        let mut base_asg = vec![0usize; self.scope.len()];
128        for (ev_var, ev_val) in evidence {
129            if let Some(ax) = self.axis_of(*ev_var) {
130                base_asg[ax] = *ev_val;
131            }
132        }
133
134        // For each assignment of the kept variables, copy the value from self
135        let out_len = out.logp.len();
136        for out_idx in 0..out_len {
137            let out_asg = out.unflatten(out_idx);
138
139            // Fill in the kept variable assignments
140            for (k, &v) in keep.iter().enumerate() {
141                let ax = self.axis_of(v).unwrap();
142                base_asg[ax] = out_asg[k];
143            }
144
145            out.logp[out_idx] = self.log_value_aligned(&base_asg);
146        }
147
148        Ok(out)
149    }
150
151    /// Reorder axes to `new_scope` (must be a permutation of current scope).
152    pub fn reorder(&self, new_scope: &[VarId]) -> Result<Self, String> {
153        if new_scope.len() != self.scope.len() {
154            return Err("new_scope length mismatch".into());
155        }
156        // map var -> old axis
157        let mut old_pos = std::collections::BTreeMap::<VarId, usize>::new();
158        for (i, &v) in self.scope.iter().enumerate() {
159            old_pos.insert(v, i);
160        }
161        let mut perm = Vec::with_capacity(new_scope.len());
162        let mut new_dims = Vec::with_capacity(new_scope.len());
163        for &v in new_scope {
164            let &p = old_pos
165                .get(&v)
166                .ok_or_else(|| "new_scope is not a permutation".to_string())?;
167            perm.push(p);
168            new_dims.push(self.dims[p]);
169        }
170
171        let new_strides = Self::compute_strides(&new_dims);
172        let new_len = prod_usize(&new_dims);
173        let mut new_logp = vec![f32::NEG_INFINITY; new_len];
174
175        // For each assignment in new space, map to old assignment and copy value.
176        for new_idx in 0..new_len {
177            // compute new assignment
178            let mut new_asg = vec![0usize; new_scope.len()];
179            for i in (0..new_scope.len()).rev() {
180                let d = new_dims[i] as usize;
181                let v = (new_idx / new_strides[i]) % d;
182                new_asg[i] = v as usize;
183            }
184            // old assignment in old axis order
185            let mut old_asg = vec![0usize; self.scope.len()];
186            for (new_axis, &old_axis) in perm.iter().enumerate() {
187                old_asg[old_axis] = new_asg[new_axis];
188            }
189            new_logp[new_idx] = self.log_value_aligned(&old_asg);
190        }
191
192        Ok(Self {
193            scope: new_scope.to_vec(),
194            dims: new_dims,
195            strides: new_strides,
196            logp: new_logp,
197        })
198    }
199
200    /// CPT-style normalization on the `child` axis: for every fixed parent assignment,
201    /// normalize over child states so that sum prob == 1.
202    pub fn normalize_rows(&mut self, child: VarId) -> Result<(), String> {
203        let axis = self
204            .axis_of(child)
205            .ok_or_else(|| "child not in scope".to_string())?;
206        let child_card = self.dims[axis] as usize;
207
208        // We will iterate all "rows" where row = varying child with parents fixed.
209        // For row-major strides: indices for a fixed parent assignment are spaced by stride[axis].
210        // But other axes also vary; easiest is: enumerate all assignments of non-child axes,
211        // then gather child slice.
212        let non_axes: Vec<usize> = (0..self.scope.len()).filter(|&i| i != axis).collect();
213        let non_dims: Vec<usize> = non_axes.iter().map(|&i| self.dims[i]).collect();
214        let non_strides = Self::compute_strides(&non_dims);
215        let rows = prod_usize(&non_dims);
216
217        let mut base_asg = vec![0usize; self.scope.len()];
218
219        for row_idx in 0..rows {
220            // decode non-child assignment into base_asg
221            for (k, &ax) in non_axes.iter().enumerate() {
222                let d = non_dims[k] as usize;
223                let v = (row_idx / non_strides[k]) % d;
224                base_asg[ax] = v;
225            }
226
227            // collect row over child
228            let mut row = vec![0.0f32; child_card];
229            for c in 0..child_card {
230                base_asg[axis] = c;
231                row[c] = self.log_value_aligned(&base_asg);
232            }
233
234            // normalize in log-space
235            log_normalize_in_place(&mut row);
236
237            // write back
238            for c in 0..child_card {
239                base_asg[axis] = c;
240                let idx = self.index_of(&base_asg);
241                self.logp[idx] = row[c];
242            }
243        }
244        Ok(())
245    }
246
247    /// Sum out the given vars.
248    pub fn marginalize(&self, elim: &[VarId]) -> Result<Self, String> {
249        // keep vars not eliminated
250        let mut keep = Vec::new();
251        for &v in &self.scope {
252            if !elim.contains(&v) {
253                keep.push(v);
254            }
255        }
256        // if nothing kept => scalar factor
257        let keep_specs = keep
258            .iter()
259            .map(|&v| {
260                let ax = self.axis_of(v).unwrap();
261                VarSpec {
262                    id: v,
263                    card: self.dims[ax],
264                }
265            })
266            .collect::<Vec<_>>();
267
268        let mut out = DiscreteFactor::uniform(keep_specs)?;
269        // out.logp will be overwritten by logsumexp accumulations; init -inf
270        out.logp.fill(f32::NEG_INFINITY);
271
272        // Map out assignment -> collect all matching self assignments across eliminated vars.
273        let out_len = out.logp.len();
274        for out_idx in 0..out_len {
275            let out_asg = out.unflatten(out_idx);
276
277            // Build a partial assignment in self-space for kept vars.
278            let mut base = vec![0usize; self.scope.len()];
279            for (k, &v) in keep.iter().enumerate() {
280                let ax = self.axis_of(v).unwrap();
281                base[ax] = out_asg[k];
282            }
283
284            // enumerate eliminated assignments
285            let elim_axes: Vec<usize> = self
286                .scope
287                .iter()
288                .enumerate()
289                .filter(|(_, v)| elim.contains(v))
290                .map(|(i, _)| i)
291                .collect();
292            let elim_dims: Vec<usize> = elim_axes.iter().map(|&i| self.dims[i]).collect();
293            let elim_strides = Self::compute_strides(&elim_dims);
294            let elim_len = prod_usize(&elim_dims);
295
296            let mut buf = Vec::with_capacity(elim_len.max(1));
297            if elim_axes.is_empty() {
298                buf.push(self.log_value_aligned(&base));
299            } else {
300                for eidx in 0..elim_len {
301                    for (k, &ax) in elim_axes.iter().enumerate() {
302                        let d = elim_dims[k];
303                        let v = (eidx / elim_strides[k]) % d;
304                        base[ax] = v;
305                    }
306                    buf.push(self.log_value_aligned(&base));
307                }
308            }
309
310            out.logp[out_idx] = logsumexp(&buf);
311        }
312
313        Ok(out)
314    }
315
316    /// Product of two discrete factors: join scope, broadcast, add in log-space.
317    pub fn product(
318        &self,
319        rhs: &DiscreteFactor,
320        cards: &impl Fn(VarId) -> usize,
321    ) -> Result<Self, String> {
322        // union scope: keep self order then append rhs vars not present
323        let mut out_scope = self.scope.clone();
324        for &v in rhs.scope.iter() {
325            if !out_scope.contains(&v) {
326                out_scope.push(v);
327            }
328        }
329        // build VarSpec using provided card lookup
330        let out_specs = out_scope
331            .iter()
332            .map(|&v| VarSpec {
333                id: v,
334                card: cards(v),
335            })
336            .collect::<Vec<_>>();
337        let mut out = DiscreteFactor::uniform(out_specs)?;
338        out.logp.fill(f32::NEG_INFINITY);
339
340        // precompute axis maps
341        let self_map = out_scope
342            .iter()
343            .map(|&v| self.axis_of(v))
344            .collect::<Vec<_>>();
345        let rhs_map = out_scope
346            .iter()
347            .map(|&v| rhs.axis_of(v))
348            .collect::<Vec<_>>();
349
350        // enumerate out assignments
351        let out_len = out.logp.len();
352        for out_idx in 0..out_len {
353            let out_asg = out.unflatten(out_idx);
354
355            // build aligned assignments for each factor
356            let mut asg_a = vec![0usize; self.scope.len()];
357            let mut asg_b = vec![0usize; rhs.scope.len()];
358
359            for (out_axis, &val) in out_asg.iter().enumerate() {
360                if let Some(ax) = self_map[out_axis] {
361                    asg_a[ax] = val;
362                }
363                if let Some(ax) = rhs_map[out_axis] {
364                    asg_b[ax] = val;
365                }
366            }
367
368            let la = self.log_value_aligned(&asg_a);
369            let lb = rhs.log_value_aligned(&asg_b);
370            out.logp[out_idx] = la + lb;
371        }
372
373        Ok(out)
374    }
375
376    #[inline]
377    pub fn axis_of(&self, v: VarId) -> Option<usize> {
378        self.scope.iter().position(|&x| x == v)
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use crate::var::VarSpec;
386
387    fn approx(a: f32, b: f32, eps: f32) -> bool {
388        (a - b).abs() <= eps
389    }
390
391    #[test]
392    fn indexing_roundtrip() {
393        let f = DiscreteFactor::uniform(vec![
394            VarSpec::new(0, 2),
395            VarSpec::new(1, 3),
396            VarSpec::new(2, 4),
397        ])
398        .unwrap();
399
400        for idx in 0..f.logp.len() {
401            let asg = f.unflatten(idx);
402            let idx2 = f.index_of(&asg);
403            assert_eq!(idx, idx2);
404        }
405    }
406
407    #[test]
408    fn reorder_preserves_values() {
409        // scope: [A,B], dims [2,3], fill logp = idx as f32
410        let scope = vec![VarSpec::new(0, 2), VarSpec::new(1, 3)];
411        let mut logp = vec![0.0; 6];
412        for i in 0..6 {
413            logp[i] = i as f32;
414        }
415        let f = DiscreteFactor::new(scope, logp).unwrap();
416
417        let a = VarId(0);
418        let b = VarId(1);
419
420        let g = f.reorder(&[b, a]).unwrap();
421
422        // For each assignment in g-space, value equals f with swapped assignment.
423        for bi in 0..3 {
424            for ai in 0..2 {
425                let gv = g.log_value_aligned(&[bi, ai]);
426                let fv = f.log_value_aligned(&[ai, bi]);
427                assert_eq!(gv, fv);
428            }
429        }
430    }
431
432    #[test]
433    fn normalize_rows_makes_rows_sum_to_1() {
434        // CPT factor P(C | A) with A card 2, C card 3. scope [A,C].
435        let scope = vec![VarSpec::new(0, 2), VarSpec::new(1, 3)];
436        // logits arbitrary
437        let logp = vec![
438            0.0, 1.0, 2.0, // A=0, C=0..2
439            2.0, 1.0, 0.0, // A=1
440        ];
441        let mut f = DiscreteFactor::new(scope, logp).unwrap();
442        f.normalize_rows(VarId(1)).unwrap();
443
444        // check each A row sums to 1
445        for a in 0..2usize {
446            let mut s = 0.0f32;
447            for c in 0..3usize {
448                let lp = f.log_value_aligned(&[a, c]);
449                s += lp.exp();
450            }
451            assert!(approx(s, 1.0, 1e-5), "sum={s}");
452        }
453    }
454
455    #[test]
456    fn marginalize_identity() {
457        // f(A,B) = logp = idx; marginalize B -> g(A) should be lse over B.
458        let scope = vec![VarSpec::new(0, 2), VarSpec::new(1, 3)];
459        let mut logp = vec![0.0; 6];
460        for i in 0..6 {
461            logp[i] = (i as f32) * 0.1;
462        }
463        let f = DiscreteFactor::new(scope, logp).unwrap();
464
465        let g = f.marginalize(&[VarId(1)]).unwrap(); // sum out B
466        assert_eq!(g.dims(), &[2]);
467
468        for a in 0..2usize {
469            let mut row = Vec::new();
470            for b in 0..3usize {
471                row.push(f.log_value_aligned(&[a, b]));
472            }
473            let want = crate::logsumexp(&row);
474            let got = g.log_value_aligned(&[a]);
475            assert!(approx(got, want, 1e-6), "a={a} got={got} want={want}");
476        }
477    }
478}