Skip to main content

oxicuda_seq/memm/
memm.rs

1//! Maximum-Entropy Markov Model implementation.
2
3use crate::error::{SeqError, SeqResult};
4
5/// Maximum-entropy Markov model: per-current-label softmax conditioned on the
6/// previous label and the current feature vector.
7///
8/// * `weights[prev * n_labels * n_features + cur * n_features + f]` — weight of
9///   feature `f` for the transition `prev -> cur`.
10#[derive(Debug, Clone)]
11pub struct Memm {
12    pub n_labels: usize,
13    pub n_features: usize,
14    pub weights: Vec<f64>,
15    pub start_label: usize,
16}
17
18impl Memm {
19    /// Zero-initialised MEMM.
20    pub fn zeros(n_labels: usize, n_features: usize) -> SeqResult<Self> {
21        if n_labels == 0 || n_features == 0 {
22            return Err(SeqError::InvalidConfiguration(
23                "n_labels and n_features must be > 0".to_string(),
24            ));
25        }
26        Ok(Self {
27            n_labels,
28            n_features,
29            weights: vec![0.0; n_labels * n_labels * n_features],
30            start_label: 0,
31        })
32    }
33
34    /// Compute `p(cur | prev, x)` for all `cur` given `prev` and `x`.
35    pub fn class_probs(&self, prev: usize, x: &[f64]) -> SeqResult<Vec<f64>> {
36        if prev >= self.n_labels {
37            return Err(SeqError::IndexOutOfBounds {
38                index: prev,
39                len: self.n_labels,
40            });
41        }
42        if x.len() != self.n_features {
43            return Err(SeqError::ShapeMismatch {
44                expected: self.n_features,
45                got: x.len(),
46            });
47        }
48        let mut logits = vec![0.0; self.n_labels];
49        let base = prev * self.n_labels * self.n_features;
50        for cur in 0..self.n_labels {
51            let row =
52                &self.weights[base + cur * self.n_features..base + (cur + 1) * self.n_features];
53            let s: f64 = row.iter().zip(x.iter()).map(|(w, v)| w * v).sum();
54            logits[cur] = s;
55        }
56        let m = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
57        let exps: Vec<f64> = logits.iter().map(|l| (l - m).exp()).collect();
58        let z: f64 = exps.iter().sum();
59        Ok(exps.iter().map(|e| e / z).collect())
60    }
61
62    /// Greedy decode: pick argmax at each step.
63    pub fn decode_greedy(&self, x: &[f64]) -> SeqResult<Vec<usize>> {
64        if x.is_empty() {
65            return Err(SeqError::EmptyInput);
66        }
67        if x.len() % self.n_features != 0 {
68            return Err(SeqError::DimensionMismatch {
69                a: x.len(),
70                b: self.n_features,
71            });
72        }
73        let t_max = x.len() / self.n_features;
74        let mut path = Vec::with_capacity(t_max);
75        let mut prev = self.start_label;
76        for t in 0..t_max {
77            let probs =
78                self.class_probs(prev, &x[t * self.n_features..(t + 1) * self.n_features])?;
79            let (best, _) =
80                probs
81                    .iter()
82                    .enumerate()
83                    .fold((0usize, f64::NEG_INFINITY), |(bi, bv), (i, &v)| {
84                        if v > bv { (i, v) } else { (bi, bv) }
85                    });
86            path.push(best);
87            prev = best;
88        }
89        Ok(path)
90    }
91
92    /// Beam decode of width `beam`.
93    pub fn decode_beam(&self, x: &[f64], beam: usize) -> SeqResult<Vec<usize>> {
94        if beam == 0 {
95            return Err(SeqError::InvalidConfiguration(
96                "beam width must be > 0".to_string(),
97            ));
98        }
99        if x.is_empty() {
100            return Err(SeqError::EmptyInput);
101        }
102        if x.len() % self.n_features != 0 {
103            return Err(SeqError::DimensionMismatch {
104                a: x.len(),
105                b: self.n_features,
106            });
107        }
108        let t_max = x.len() / self.n_features;
109        // Each beam item: (log_prob, path)
110        let mut beam_items: Vec<(f64, Vec<usize>)> = vec![(0.0, Vec::new())];
111        for t in 0..t_max {
112            let mut new_items: Vec<(f64, Vec<usize>)> =
113                Vec::with_capacity(beam_items.len() * self.n_labels);
114            for (lp, path) in &beam_items {
115                let prev = path.last().copied().unwrap_or(self.start_label);
116                let probs =
117                    self.class_probs(prev, &x[t * self.n_features..(t + 1) * self.n_features])?;
118                for (cur, &p) in probs.iter().enumerate() {
119                    let logp = if p > 0.0 { p.ln() } else { f64::NEG_INFINITY };
120                    let mut new_path = path.clone();
121                    new_path.push(cur);
122                    new_items.push((lp + logp, new_path));
123                }
124            }
125            new_items.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
126            new_items.truncate(beam);
127            beam_items = new_items;
128        }
129        let best = beam_items
130            .into_iter()
131            .next()
132            .ok_or_else(|| SeqError::NumericalInstability("empty beam".to_string()))?;
133        Ok(best.1)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn class_probs_sum_to_one() {
143        let m = Memm::zeros(3, 2).expect("ok");
144        let p = m.class_probs(0, &[1.0, 1.0]).expect("ok");
145        let s: f64 = p.iter().sum();
146        assert!((s - 1.0).abs() < 1e-9);
147        for &v in &p {
148            assert!((v - 1.0 / 3.0).abs() < 1e-9);
149        }
150    }
151
152    #[test]
153    fn greedy_zero_weights() {
154        let m = Memm::zeros(2, 2).expect("ok");
155        let path = m.decode_greedy(&[1.0, 0.0, 0.0, 1.0]).expect("ok");
156        assert_eq!(path.len(), 2);
157    }
158
159    #[test]
160    fn beam_matches_greedy_zero_weights() {
161        let m = Memm::zeros(2, 2).expect("ok");
162        let g = m.decode_greedy(&[1.0, 0.0, 0.0, 1.0]).expect("ok");
163        let b = m.decode_beam(&[1.0, 0.0, 0.0, 1.0], 4).expect("ok");
164        assert_eq!(g.len(), b.len());
165    }
166}