radiate_pgm/
chromosome.rs1use crate::{VarId, VarSpec};
2use radiate_core::{Chromosome, Gene, Valid};
3use radiate_utils::Value;
4use std::sync::Arc;
5
6#[derive(Clone, Debug, PartialEq, Eq)]
11pub enum FactorKind {
12 Logp,
13}
14
15#[derive(Clone, Debug, PartialEq)]
16pub struct FactorGene {
17 pub scope: Vec<VarId>,
18 pub kind: FactorKind,
19 pub shape: Vec<usize>,
20 pub params: Value<f32>,
21}
22
23impl FactorGene {
24 #[inline]
25 pub fn resample_scope(&mut self, vars: &[VarSpec], max_scope: usize) {
26 let scope = super::sample_scope(vars.len(), max_scope);
27 let shape = super::logp_table_shape(vars, &scope);
28 let params = super::init_logp_table(&shape);
29
30 self.scope = scope;
31 self.shape = shape;
32 self.params = params;
33 }
34}
35
36impl Gene for FactorGene {
37 type Allele = Self;
38
39 fn allele(&self) -> &Self::Allele {
40 self
41 }
42
43 fn allele_mut(&mut self) -> &mut Self::Allele {
44 self
45 }
46
47 fn new_instance(&self) -> Self {
48 let params = match self.kind {
49 FactorKind::Logp => super::init_logp_table(&self.shape),
50 };
51
52 FactorGene {
53 scope: self.scope.clone(),
54 kind: self.kind.clone(),
55 shape: self.shape.clone(),
56 params,
57 }
58 }
59
60 fn with_allele(&self, allele: &Self::Allele) -> Self {
61 FactorGene {
62 scope: allele.scope.clone(),
63 kind: allele.kind.clone(),
64 shape: allele.shape.clone(),
65 params: allele.params.clone(),
66 }
67 }
68}
69
70impl Valid for FactorGene {
71 fn is_valid(&self) -> bool {
72 !self.scope.is_empty()
73 && self.shape.len() == self.scope.len()
74 && self.shape.iter().all(|&d| d >= 1)
75 }
76}
77
78#[derive(Clone, Debug, PartialEq)]
79pub struct PgmChromosome {
80 pub vars: Arc<[VarSpec]>,
81 pub factors: Vec<FactorGene>,
82}
83
84impl Chromosome for PgmChromosome {
85 type Gene = FactorGene;
86
87 fn as_slice(&self) -> &[FactorGene] {
88 &self.factors
89 }
90
91 fn as_mut_slice(&mut self) -> &mut [FactorGene] {
92 &mut self.factors
93 }
94}
95
96impl Valid for PgmChromosome {
97 fn is_valid(&self) -> bool {
98 let num_vars = self.vars.len();
99 for factor in &self.factors {
100 for &vid in &factor.scope {
101 let idx = vid.0 as usize;
102 if idx >= num_vars {
103 return false;
104 }
105 }
106
107 if factor.shape.len() != factor.scope.len() {
108 return false;
109 }
110
111 for (i, &vid) in factor.scope.iter().enumerate() {
112 let idx = vid.0 as usize;
113 let expected = self.vars[idx].card.max(1) as usize;
114 if factor.shape[i].max(1) != expected {
115 return false;
116 }
117 }
118 }
119
120 true
121 }
122}