1use crate::var::{VarId, VarSpec};
2use crate::{log_normalize_in_place, logsumexp, prod_usize};
3
4#[derive(Clone, Debug)]
5pub struct DiscreteFactor {
6 scope: Vec<VarId>, dims: Vec<usize>, strides: Vec<usize>, logp: Vec<f32>, }
11
12impl DiscreteFactor {
13 pub fn new(scope: Vec<VarSpec>, logp: Vec<f32>) -> Result<Self, String> {
15 let mut svars = Vec::with_capacity(scope.len());
16 let mut dims = Vec::with_capacity(scope.len());
17 for v in scope {
18 svars.push(v.id);
19 dims.push(v.card);
20 }
21
22 let strides = Self::compute_strides(&dims);
23 let expected = prod_usize(&dims);
24
25 if logp.len() != expected {
26 return Err(format!(
27 "logp length {} != expected {}",
28 logp.len(),
29 expected
30 ));
31 }
32
33 Ok(Self {
34 scope: svars,
35 dims,
36 strides,
37 logp,
38 })
39 }
40
41 pub fn uniform(scope: Vec<VarSpec>) -> Result<Self, String> {
43 let dims: Vec<usize> = scope.iter().map(|v| v.card).collect();
44 let n = prod_usize(&dims);
45 Self::new(scope, vec![0.0; n])
46 }
47
48 pub fn scope(&self) -> &[VarId] {
49 &self.scope
50 }
51
52 pub fn dims(&self) -> &[usize] {
53 &self.dims
54 }
55
56 pub fn logp(&self) -> &[f32] {
57 &self.logp
58 }
59
60 #[inline]
61 fn compute_strides(dims: &[usize]) -> Vec<usize> {
62 let mut strides = vec![1usize; dims.len()];
63 let mut acc = 1usize;
64 for (i, &d) in dims.iter().enumerate() {
65 strides[i] = acc;
66 acc = acc.saturating_mul(d);
67 }
68 strides
69 }
70
71 #[inline]
73 pub fn index_of(&self, asg: &[usize]) -> usize {
74 debug_assert_eq!(asg.len(), self.scope.len());
75 let mut idx = 0usize;
76 for i in 0..asg.len() {
77 idx += (asg[i] as usize) * self.strides[i];
78 }
79 idx
80 }
81
82 pub fn unflatten(&self, idx: usize) -> Vec<usize> {
84 let mut asg = vec![0usize; self.scope.len()];
85 for i in (0..self.scope.len()).rev() {
86 let d = self.dims[i] as usize;
87 let v = (idx / self.strides[i]) % d;
88 asg[i] = v as usize;
89 }
90 asg
91 }
92
93 #[inline]
94 pub fn log_value_aligned(&self, asg: &[usize]) -> f32 {
95 let idx = self.index_of(asg);
96 self.logp[idx]
97 }
98
99 pub fn restrict(&self, evidence: &[(VarId, usize)]) -> Result<Self, String> {
103 let mut keep = Vec::new();
105 for &v in &self.scope {
106 if !evidence.iter().any(|(ev, _)| *ev == v) {
107 keep.push(v);
108 }
109 }
110
111 let keep_specs = keep
113 .iter()
114 .map(|&v| {
115 let ax = self.axis_of(v).unwrap();
116 VarSpec {
117 id: v,
118 card: self.dims[ax],
119 }
120 })
121 .collect::<Vec<_>>();
122
123 let mut out = DiscreteFactor::uniform(keep_specs)?;
124 out.logp.fill(f32::NEG_INFINITY);
125
126 let mut base_asg = vec![0usize; self.scope.len()];
128 for (ev_var, ev_val) in evidence {
129 if let Some(ax) = self.axis_of(*ev_var) {
130 base_asg[ax] = *ev_val;
131 }
132 }
133
134 let out_len = out.logp.len();
136 for out_idx in 0..out_len {
137 let out_asg = out.unflatten(out_idx);
138
139 for (k, &v) in keep.iter().enumerate() {
141 let ax = self.axis_of(v).unwrap();
142 base_asg[ax] = out_asg[k];
143 }
144
145 out.logp[out_idx] = self.log_value_aligned(&base_asg);
146 }
147
148 Ok(out)
149 }
150
151 pub fn reorder(&self, new_scope: &[VarId]) -> Result<Self, String> {
153 if new_scope.len() != self.scope.len() {
154 return Err("new_scope length mismatch".into());
155 }
156 let mut old_pos = std::collections::BTreeMap::<VarId, usize>::new();
158 for (i, &v) in self.scope.iter().enumerate() {
159 old_pos.insert(v, i);
160 }
161 let mut perm = Vec::with_capacity(new_scope.len());
162 let mut new_dims = Vec::with_capacity(new_scope.len());
163 for &v in new_scope {
164 let &p = old_pos
165 .get(&v)
166 .ok_or_else(|| "new_scope is not a permutation".to_string())?;
167 perm.push(p);
168 new_dims.push(self.dims[p]);
169 }
170
171 let new_strides = Self::compute_strides(&new_dims);
172 let new_len = prod_usize(&new_dims);
173 let mut new_logp = vec![f32::NEG_INFINITY; new_len];
174
175 for new_idx in 0..new_len {
177 let mut new_asg = vec![0usize; new_scope.len()];
179 for i in (0..new_scope.len()).rev() {
180 let d = new_dims[i] as usize;
181 let v = (new_idx / new_strides[i]) % d;
182 new_asg[i] = v as usize;
183 }
184 let mut old_asg = vec![0usize; self.scope.len()];
186 for (new_axis, &old_axis) in perm.iter().enumerate() {
187 old_asg[old_axis] = new_asg[new_axis];
188 }
189 new_logp[new_idx] = self.log_value_aligned(&old_asg);
190 }
191
192 Ok(Self {
193 scope: new_scope.to_vec(),
194 dims: new_dims,
195 strides: new_strides,
196 logp: new_logp,
197 })
198 }
199
200 pub fn normalize_rows(&mut self, child: VarId) -> Result<(), String> {
203 let axis = self
204 .axis_of(child)
205 .ok_or_else(|| "child not in scope".to_string())?;
206 let child_card = self.dims[axis] as usize;
207
208 let non_axes: Vec<usize> = (0..self.scope.len()).filter(|&i| i != axis).collect();
213 let non_dims: Vec<usize> = non_axes.iter().map(|&i| self.dims[i]).collect();
214 let non_strides = Self::compute_strides(&non_dims);
215 let rows = prod_usize(&non_dims);
216
217 let mut base_asg = vec![0usize; self.scope.len()];
218
219 for row_idx in 0..rows {
220 for (k, &ax) in non_axes.iter().enumerate() {
222 let d = non_dims[k] as usize;
223 let v = (row_idx / non_strides[k]) % d;
224 base_asg[ax] = v;
225 }
226
227 let mut row = vec![0.0f32; child_card];
229 for c in 0..child_card {
230 base_asg[axis] = c;
231 row[c] = self.log_value_aligned(&base_asg);
232 }
233
234 log_normalize_in_place(&mut row);
236
237 for c in 0..child_card {
239 base_asg[axis] = c;
240 let idx = self.index_of(&base_asg);
241 self.logp[idx] = row[c];
242 }
243 }
244 Ok(())
245 }
246
247 pub fn marginalize(&self, elim: &[VarId]) -> Result<Self, String> {
249 let mut keep = Vec::new();
251 for &v in &self.scope {
252 if !elim.contains(&v) {
253 keep.push(v);
254 }
255 }
256 let keep_specs = keep
258 .iter()
259 .map(|&v| {
260 let ax = self.axis_of(v).unwrap();
261 VarSpec {
262 id: v,
263 card: self.dims[ax],
264 }
265 })
266 .collect::<Vec<_>>();
267
268 let mut out = DiscreteFactor::uniform(keep_specs)?;
269 out.logp.fill(f32::NEG_INFINITY);
271
272 let out_len = out.logp.len();
274 for out_idx in 0..out_len {
275 let out_asg = out.unflatten(out_idx);
276
277 let mut base = vec![0usize; self.scope.len()];
279 for (k, &v) in keep.iter().enumerate() {
280 let ax = self.axis_of(v).unwrap();
281 base[ax] = out_asg[k];
282 }
283
284 let elim_axes: Vec<usize> = self
286 .scope
287 .iter()
288 .enumerate()
289 .filter(|(_, v)| elim.contains(v))
290 .map(|(i, _)| i)
291 .collect();
292 let elim_dims: Vec<usize> = elim_axes.iter().map(|&i| self.dims[i]).collect();
293 let elim_strides = Self::compute_strides(&elim_dims);
294 let elim_len = prod_usize(&elim_dims);
295
296 let mut buf = Vec::with_capacity(elim_len.max(1));
297 if elim_axes.is_empty() {
298 buf.push(self.log_value_aligned(&base));
299 } else {
300 for eidx in 0..elim_len {
301 for (k, &ax) in elim_axes.iter().enumerate() {
302 let d = elim_dims[k];
303 let v = (eidx / elim_strides[k]) % d;
304 base[ax] = v;
305 }
306 buf.push(self.log_value_aligned(&base));
307 }
308 }
309
310 out.logp[out_idx] = logsumexp(&buf);
311 }
312
313 Ok(out)
314 }
315
316 pub fn product(
318 &self,
319 rhs: &DiscreteFactor,
320 cards: &impl Fn(VarId) -> usize,
321 ) -> Result<Self, String> {
322 let mut out_scope = self.scope.clone();
324 for &v in rhs.scope.iter() {
325 if !out_scope.contains(&v) {
326 out_scope.push(v);
327 }
328 }
329 let out_specs = out_scope
331 .iter()
332 .map(|&v| VarSpec {
333 id: v,
334 card: cards(v),
335 })
336 .collect::<Vec<_>>();
337 let mut out = DiscreteFactor::uniform(out_specs)?;
338 out.logp.fill(f32::NEG_INFINITY);
339
340 let self_map = out_scope
342 .iter()
343 .map(|&v| self.axis_of(v))
344 .collect::<Vec<_>>();
345 let rhs_map = out_scope
346 .iter()
347 .map(|&v| rhs.axis_of(v))
348 .collect::<Vec<_>>();
349
350 let out_len = out.logp.len();
352 for out_idx in 0..out_len {
353 let out_asg = out.unflatten(out_idx);
354
355 let mut asg_a = vec![0usize; self.scope.len()];
357 let mut asg_b = vec![0usize; rhs.scope.len()];
358
359 for (out_axis, &val) in out_asg.iter().enumerate() {
360 if let Some(ax) = self_map[out_axis] {
361 asg_a[ax] = val;
362 }
363 if let Some(ax) = rhs_map[out_axis] {
364 asg_b[ax] = val;
365 }
366 }
367
368 let la = self.log_value_aligned(&asg_a);
369 let lb = rhs.log_value_aligned(&asg_b);
370 out.logp[out_idx] = la + lb;
371 }
372
373 Ok(out)
374 }
375
376 #[inline]
377 pub fn axis_of(&self, v: VarId) -> Option<usize> {
378 self.scope.iter().position(|&x| x == v)
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::var::VarSpec;
386
387 fn approx(a: f32, b: f32, eps: f32) -> bool {
388 (a - b).abs() <= eps
389 }
390
391 #[test]
392 fn indexing_roundtrip() {
393 let f = DiscreteFactor::uniform(vec![
394 VarSpec::new(0, 2),
395 VarSpec::new(1, 3),
396 VarSpec::new(2, 4),
397 ])
398 .unwrap();
399
400 for idx in 0..f.logp.len() {
401 let asg = f.unflatten(idx);
402 let idx2 = f.index_of(&asg);
403 assert_eq!(idx, idx2);
404 }
405 }
406
407 #[test]
408 fn reorder_preserves_values() {
409 let scope = vec![VarSpec::new(0, 2), VarSpec::new(1, 3)];
411 let mut logp = vec![0.0; 6];
412 for i in 0..6 {
413 logp[i] = i as f32;
414 }
415 let f = DiscreteFactor::new(scope, logp).unwrap();
416
417 let a = VarId(0);
418 let b = VarId(1);
419
420 let g = f.reorder(&[b, a]).unwrap();
421
422 for bi in 0..3 {
424 for ai in 0..2 {
425 let gv = g.log_value_aligned(&[bi, ai]);
426 let fv = f.log_value_aligned(&[ai, bi]);
427 assert_eq!(gv, fv);
428 }
429 }
430 }
431
432 #[test]
433 fn normalize_rows_makes_rows_sum_to_1() {
434 let scope = vec![VarSpec::new(0, 2), VarSpec::new(1, 3)];
436 let logp = vec![
438 0.0, 1.0, 2.0, 2.0, 1.0, 0.0, ];
441 let mut f = DiscreteFactor::new(scope, logp).unwrap();
442 f.normalize_rows(VarId(1)).unwrap();
443
444 for a in 0..2usize {
446 let mut s = 0.0f32;
447 for c in 0..3usize {
448 let lp = f.log_value_aligned(&[a, c]);
449 s += lp.exp();
450 }
451 assert!(approx(s, 1.0, 1e-5), "sum={s}");
452 }
453 }
454
455 #[test]
456 fn marginalize_identity() {
457 let scope = vec![VarSpec::new(0, 2), VarSpec::new(1, 3)];
459 let mut logp = vec![0.0; 6];
460 for i in 0..6 {
461 logp[i] = (i as f32) * 0.1;
462 }
463 let f = DiscreteFactor::new(scope, logp).unwrap();
464
465 let g = f.marginalize(&[VarId(1)]).unwrap(); assert_eq!(g.dims(), &[2]);
467
468 for a in 0..2usize {
469 let mut row = Vec::new();
470 for b in 0..3usize {
471 row.push(f.log_value_aligned(&[a, b]));
472 }
473 let want = crate::logsumexp(&row);
474 let got = g.log_value_aligned(&[a]);
475 assert!(approx(got, want, 1e-6), "a={a} got={got} want={want}");
476 }
477 }
478}