1use crate::{FactorGene, FactorKind, PgmChromosome, factor};
2use radiate_core::{BatchFitnessFunction, Genotype, fitness::FitnessFunction};
3use radiate_utils::Value;
4
5#[derive(Clone)]
6pub struct PgmDataSet {
7 pub rows: Vec<Vec<Option<usize>>>,
8}
9
10impl PgmDataSet {
11 pub fn new(rows: Vec<Vec<Option<usize>>>) -> Self {
12 Self { rows }
13 }
14
15 pub fn len(&self) -> usize {
16 self.rows.len()
17 }
18}
19
20#[derive(Clone)]
21pub struct PgmNll {
22 pub data: Vec<Vec<Option<usize>>>,
23}
24
25impl FitnessFunction<PgmChromosome, f32> for PgmNll {
26 fn evaluate(&self, chrom: PgmChromosome) -> f32 {
27 factor::neg_mean_loglik(&chrom, &self.data).unwrap_or(f32::INFINITY) }
29}
30
31#[derive(Clone)]
32pub struct PgmLogLik {
33 data: PgmDataSet,
34}
35
36impl PgmLogLik {
37 pub fn new(data: PgmDataSet) -> Self {
38 Self { data }
39 }
40
41 fn eval_logp_factor(
42 &self,
43 chrom: &PgmChromosome,
44 factor: &FactorGene,
45 row: &[Option<usize>],
46 ) -> f32 {
47 match factor.kind {
48 FactorKind::Logp => {
49 let mut idxs = Vec::with_capacity(factor.scope.len());
50 for &var_spec in &factor.scope {
51 let Some(s) = row[var_spec.0 as usize] else {
52 return 0.0;
53 };
54 idxs.push(s.min(chrom.vars[var_spec.0 as usize].card.saturating_sub(1)));
55 }
56
57 logprob_table_eval(&factor.params, &idxs)
58 }
59 }
60 }
61
62 pub fn loglik(&self, chrom: &PgmChromosome) -> f32 {
63 let mut ll = 0.0;
64 for row in &self.data.rows {
65 for f in &chrom.factors {
66 ll += self.eval_logp_factor(chrom, f, row);
67 }
68 }
69 ll
70 }
71
72 pub fn neg_mean_loglik(&self, chrom: &PgmChromosome) -> f32 {
73 let n = self.data.len().max(1) as f32;
74 -(self.loglik(chrom) / n)
75 }
76}
77
78impl FitnessFunction<PgmChromosome, f32> for PgmLogLik {
79 #[inline]
80 fn evaluate(&self, input: PgmChromosome) -> f32 {
81 self.neg_mean_loglik(&input)
82 }
83}
84
85impl<'a> FitnessFunction<&'a Genotype<PgmChromosome>, f32> for PgmLogLik {
86 #[inline]
87 fn evaluate(&self, input: &'a Genotype<PgmChromosome>) -> f32 {
88 self.neg_mean_loglik(&input[0])
89 }
90}
91
92impl BatchFitnessFunction<PgmChromosome, f32> for PgmLogLik {
93 #[inline]
94 fn evaluate(&self, inputs: Vec<PgmChromosome>) -> Vec<f32> {
95 inputs
96 .into_iter()
97 .map(|c| self.neg_mean_loglik(&c))
98 .collect()
99 }
100}
101
102impl<'a> BatchFitnessFunction<&'a Genotype<PgmChromosome>, f32> for PgmLogLik {
103 #[inline]
104 fn evaluate(&self, inputs: Vec<&'a Genotype<PgmChromosome>>) -> Vec<f32> {
105 inputs
106 .into_iter()
107 .map(|g| self.neg_mean_loglik(&g[0]))
108 .collect()
109 }
110}
111
112fn logprob_table_eval(val: &Value<f32>, idxs: &[usize]) -> f32 {
113 let Value::Array {
114 values,
115 shape,
116 strides,
117 } = val
118 else {
119 return 0.0;
120 };
121
122 let rank = shape.rank();
123 if idxs.len() != rank {
124 return 0.0;
125 }
126 if rank == 0 {
127 return 0.0;
128 }
129
130 let child_axis = rank - 1;
131 let child = idxs[child_axis];
132 let child_states = shape.dim_at(child_axis).max(1);
133
134 let mut base = 0usize;
136 for i in 0..child_axis {
137 let dim = shape.dim_at(i).max(1);
138 let idx = idxs[i].min(dim - 1);
139 base = base.saturating_add(idx.saturating_mul(strides.stride_at(i)));
140 }
141
142 let mut max_logit = f32::NEG_INFINITY;
144 for k in 0..child_states {
145 let pos = base.saturating_add(k.saturating_mul(strides.stride_at(child_axis)));
146 let s = values.get(pos).copied().unwrap_or(0.0);
147 max_logit = max_logit.max(s);
148 }
149 if !max_logit.is_finite() {
150 return 0.0;
151 }
152
153 let mut sum_exp = 0.0f32;
155 for k in 0..child_states {
156 let pos = base.saturating_add(k.saturating_mul(strides.stride_at(child_axis)));
157 let s = values.get(pos).copied().unwrap_or(0.0);
158 sum_exp += (s - max_logit).exp();
159 }
160 let lse = max_logit + sum_exp.ln();
161
162 let child = child.min(child_states.saturating_sub(1));
163 let child_pos = base.saturating_add(child.saturating_mul(strides.stride_at(child_axis)));
164 let child_logit = values.get(child_pos).copied().unwrap_or(0.0);
165
166 child_logit - lse
167}