Skip to main content

radiate_pgm/
chromosome.rs

1use crate::{VarId, VarSpec};
2use radiate_core::{Chromosome, Gene, Valid};
3use radiate_utils::Value;
4use std::sync::Arc;
5
6// You use DiscreteFactor as your ground-truth “math object” for inference and correctness, and keep your evolving genome
7// (FactorGene) as a parameter container. The bridge between the two is a single conversion step:
8// FactorGene (scope + params(Value<f32>))  -->  DiscreteFactor (scope + logp Vec<f32>)
9
10#[derive(Clone, Debug, PartialEq, Eq)]
11pub enum FactorKind {
12    Logp,
13}
14
15#[derive(Clone, Debug, PartialEq)]
16pub struct FactorGene {
17    pub scope: Vec<VarId>,
18    pub kind: FactorKind,
19    pub shape: Vec<usize>,
20    pub params: Value<f32>,
21}
22
23impl FactorGene {
24    #[inline]
25    pub fn resample_scope(&mut self, vars: &[VarSpec], max_scope: usize) {
26        let scope = super::sample_scope(vars.len(), max_scope);
27        let shape = super::logp_table_shape(vars, &scope);
28        let params = super::init_logp_table(&shape);
29
30        self.scope = scope;
31        self.shape = shape;
32        self.params = params;
33    }
34}
35
36impl Gene for FactorGene {
37    type Allele = Self;
38
39    fn allele(&self) -> &Self::Allele {
40        self
41    }
42
43    fn allele_mut(&mut self) -> &mut Self::Allele {
44        self
45    }
46
47    fn new_instance(&self) -> Self {
48        let params = match self.kind {
49            FactorKind::Logp => super::init_logp_table(&self.shape),
50        };
51
52        FactorGene {
53            scope: self.scope.clone(),
54            kind: self.kind.clone(),
55            shape: self.shape.clone(),
56            params,
57        }
58    }
59
60    fn with_allele(&self, allele: &Self::Allele) -> Self {
61        FactorGene {
62            scope: allele.scope.clone(),
63            kind: allele.kind.clone(),
64            shape: allele.shape.clone(),
65            params: allele.params.clone(),
66        }
67    }
68}
69
70impl Valid for FactorGene {
71    fn is_valid(&self) -> bool {
72        !self.scope.is_empty()
73            && self.shape.len() == self.scope.len()
74            && self.shape.iter().all(|&d| d >= 1)
75    }
76}
77
78#[derive(Clone, Debug, PartialEq)]
79pub struct PgmChromosome {
80    pub vars: Arc<[VarSpec]>,
81    pub factors: Vec<FactorGene>,
82}
83
84impl Chromosome for PgmChromosome {
85    type Gene = FactorGene;
86
87    fn as_slice(&self) -> &[FactorGene] {
88        &self.factors
89    }
90
91    fn as_mut_slice(&mut self) -> &mut [FactorGene] {
92        &mut self.factors
93    }
94}
95
96impl Valid for PgmChromosome {
97    fn is_valid(&self) -> bool {
98        let num_vars = self.vars.len();
99        for factor in &self.factors {
100            for &vid in &factor.scope {
101                let idx = vid.0 as usize;
102                if idx >= num_vars {
103                    return false;
104                }
105            }
106
107            if factor.shape.len() != factor.scope.len() {
108                return false;
109            }
110
111            for (i, &vid) in factor.scope.iter().enumerate() {
112                let idx = vid.0 as usize;
113                let expected = self.vars[idx].card.max(1) as usize;
114                if factor.shape[i].max(1) != expected {
115                    return false;
116                }
117            }
118        }
119
120        true
121    }
122}