Skip to main content

radiate_pgm/
fitness.rs

1use crate::{FactorGene, FactorKind, PgmChromosome, factor};
2use radiate_core::{BatchFitnessFunction, Genotype, fitness::FitnessFunction};
3use radiate_utils::Value;
4
5#[derive(Clone)]
6pub struct PgmDataSet {
7    pub rows: Vec<Vec<Option<usize>>>,
8}
9
10impl PgmDataSet {
11    pub fn new(rows: Vec<Vec<Option<usize>>>) -> Self {
12        Self { rows }
13    }
14
15    pub fn len(&self) -> usize {
16        self.rows.len()
17    }
18}
19
20#[derive(Clone)]
21pub struct PgmNll {
22    pub data: Vec<Vec<Option<usize>>>,
23}
24
25impl FitnessFunction<PgmChromosome, f32> for PgmNll {
26    fn evaluate(&self, chrom: PgmChromosome) -> f32 {
27        factor::neg_mean_loglik(&chrom, &self.data).unwrap_or(f32::INFINITY) // if something goes off the rails
28    }
29}
30
31#[derive(Clone)]
32pub struct PgmLogLik {
33    data: PgmDataSet,
34}
35
36impl PgmLogLik {
37    pub fn new(data: PgmDataSet) -> Self {
38        Self { data }
39    }
40
41    fn eval_logp_factor(
42        &self,
43        chrom: &PgmChromosome,
44        factor: &FactorGene,
45        row: &[Option<usize>],
46    ) -> f32 {
47        match factor.kind {
48            FactorKind::Logp => {
49                let mut idxs = Vec::with_capacity(factor.scope.len());
50                for &var_spec in &factor.scope {
51                    let Some(s) = row[var_spec.0 as usize] else {
52                        return 0.0;
53                    };
54                    idxs.push(s.min(chrom.vars[var_spec.0 as usize].card.saturating_sub(1)));
55                }
56
57                logprob_table_eval(&factor.params, &idxs)
58            }
59        }
60    }
61
62    pub fn loglik(&self, chrom: &PgmChromosome) -> f32 {
63        let mut ll = 0.0;
64        for row in &self.data.rows {
65            for f in &chrom.factors {
66                ll += self.eval_logp_factor(chrom, f, row);
67            }
68        }
69        ll
70    }
71
72    pub fn neg_mean_loglik(&self, chrom: &PgmChromosome) -> f32 {
73        let n = self.data.len().max(1) as f32;
74        -(self.loglik(chrom) / n)
75    }
76}
77
78impl FitnessFunction<PgmChromosome, f32> for PgmLogLik {
79    #[inline]
80    fn evaluate(&self, input: PgmChromosome) -> f32 {
81        self.neg_mean_loglik(&input)
82    }
83}
84
85impl<'a> FitnessFunction<&'a Genotype<PgmChromosome>, f32> for PgmLogLik {
86    #[inline]
87    fn evaluate(&self, input: &'a Genotype<PgmChromosome>) -> f32 {
88        self.neg_mean_loglik(&input[0])
89    }
90}
91
92impl BatchFitnessFunction<PgmChromosome, f32> for PgmLogLik {
93    #[inline]
94    fn evaluate(&self, inputs: Vec<PgmChromosome>) -> Vec<f32> {
95        inputs
96            .into_iter()
97            .map(|c| self.neg_mean_loglik(&c))
98            .collect()
99    }
100}
101
102impl<'a> BatchFitnessFunction<&'a Genotype<PgmChromosome>, f32> for PgmLogLik {
103    #[inline]
104    fn evaluate(&self, inputs: Vec<&'a Genotype<PgmChromosome>>) -> Vec<f32> {
105        inputs
106            .into_iter()
107            .map(|g| self.neg_mean_loglik(&g[0]))
108            .collect()
109    }
110}
111
112fn logprob_table_eval(val: &Value<f32>, idxs: &[usize]) -> f32 {
113    let Value::Array {
114        values,
115        shape,
116        strides,
117    } = val
118    else {
119        return 0.0;
120    };
121
122    let rank = shape.rank();
123    if idxs.len() != rank {
124        return 0.0;
125    }
126    if rank == 0 {
127        return 0.0;
128    }
129
130    let child_axis = rank - 1;
131    let child = idxs[child_axis];
132    let child_states = shape.dim_at(child_axis).max(1);
133
134    // base offset with child fixed to 0
135    let mut base = 0usize;
136    for i in 0..child_axis {
137        let dim = shape.dim_at(i).max(1);
138        let idx = idxs[i].min(dim - 1);
139        base = base.saturating_add(idx.saturating_mul(strides.stride_at(i)));
140    }
141
142    // max logit
143    let mut max_logit = f32::NEG_INFINITY;
144    for k in 0..child_states {
145        let pos = base.saturating_add(k.saturating_mul(strides.stride_at(child_axis)));
146        let s = values.get(pos).copied().unwrap_or(0.0);
147        max_logit = max_logit.max(s);
148    }
149    if !max_logit.is_finite() {
150        return 0.0;
151    }
152
153    // logsumexp
154    let mut sum_exp = 0.0f32;
155    for k in 0..child_states {
156        let pos = base.saturating_add(k.saturating_mul(strides.stride_at(child_axis)));
157        let s = values.get(pos).copied().unwrap_or(0.0);
158        sum_exp += (s - max_logit).exp();
159    }
160    let lse = max_logit + sum_exp.ln();
161
162    let child = child.min(child_states.saturating_sub(1));
163    let child_pos = base.saturating_add(child.saturating_mul(strides.stride_at(child_axis)));
164    let child_logit = values.get(child_pos).copied().unwrap_or(0.0);
165
166    child_logit - lse
167}