1use crate::error::{SeqError, SeqResult};
4
5#[derive(Debug, Clone)]
11pub struct Memm {
12 pub n_labels: usize,
13 pub n_features: usize,
14 pub weights: Vec<f64>,
15 pub start_label: usize,
16}
17
18impl Memm {
19 pub fn zeros(n_labels: usize, n_features: usize) -> SeqResult<Self> {
21 if n_labels == 0 || n_features == 0 {
22 return Err(SeqError::InvalidConfiguration(
23 "n_labels and n_features must be > 0".to_string(),
24 ));
25 }
26 Ok(Self {
27 n_labels,
28 n_features,
29 weights: vec![0.0; n_labels * n_labels * n_features],
30 start_label: 0,
31 })
32 }
33
34 pub fn class_probs(&self, prev: usize, x: &[f64]) -> SeqResult<Vec<f64>> {
36 if prev >= self.n_labels {
37 return Err(SeqError::IndexOutOfBounds {
38 index: prev,
39 len: self.n_labels,
40 });
41 }
42 if x.len() != self.n_features {
43 return Err(SeqError::ShapeMismatch {
44 expected: self.n_features,
45 got: x.len(),
46 });
47 }
48 let mut logits = vec![0.0; self.n_labels];
49 let base = prev * self.n_labels * self.n_features;
50 for cur in 0..self.n_labels {
51 let row =
52 &self.weights[base + cur * self.n_features..base + (cur + 1) * self.n_features];
53 let s: f64 = row.iter().zip(x.iter()).map(|(w, v)| w * v).sum();
54 logits[cur] = s;
55 }
56 let m = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
57 let exps: Vec<f64> = logits.iter().map(|l| (l - m).exp()).collect();
58 let z: f64 = exps.iter().sum();
59 Ok(exps.iter().map(|e| e / z).collect())
60 }
61
62 pub fn decode_greedy(&self, x: &[f64]) -> SeqResult<Vec<usize>> {
64 if x.is_empty() {
65 return Err(SeqError::EmptyInput);
66 }
67 if x.len() % self.n_features != 0 {
68 return Err(SeqError::DimensionMismatch {
69 a: x.len(),
70 b: self.n_features,
71 });
72 }
73 let t_max = x.len() / self.n_features;
74 let mut path = Vec::with_capacity(t_max);
75 let mut prev = self.start_label;
76 for t in 0..t_max {
77 let probs =
78 self.class_probs(prev, &x[t * self.n_features..(t + 1) * self.n_features])?;
79 let (best, _) =
80 probs
81 .iter()
82 .enumerate()
83 .fold((0usize, f64::NEG_INFINITY), |(bi, bv), (i, &v)| {
84 if v > bv { (i, v) } else { (bi, bv) }
85 });
86 path.push(best);
87 prev = best;
88 }
89 Ok(path)
90 }
91
92 pub fn decode_beam(&self, x: &[f64], beam: usize) -> SeqResult<Vec<usize>> {
94 if beam == 0 {
95 return Err(SeqError::InvalidConfiguration(
96 "beam width must be > 0".to_string(),
97 ));
98 }
99 if x.is_empty() {
100 return Err(SeqError::EmptyInput);
101 }
102 if x.len() % self.n_features != 0 {
103 return Err(SeqError::DimensionMismatch {
104 a: x.len(),
105 b: self.n_features,
106 });
107 }
108 let t_max = x.len() / self.n_features;
109 let mut beam_items: Vec<(f64, Vec<usize>)> = vec![(0.0, Vec::new())];
111 for t in 0..t_max {
112 let mut new_items: Vec<(f64, Vec<usize>)> =
113 Vec::with_capacity(beam_items.len() * self.n_labels);
114 for (lp, path) in &beam_items {
115 let prev = path.last().copied().unwrap_or(self.start_label);
116 let probs =
117 self.class_probs(prev, &x[t * self.n_features..(t + 1) * self.n_features])?;
118 for (cur, &p) in probs.iter().enumerate() {
119 let logp = if p > 0.0 { p.ln() } else { f64::NEG_INFINITY };
120 let mut new_path = path.clone();
121 new_path.push(cur);
122 new_items.push((lp + logp, new_path));
123 }
124 }
125 new_items.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
126 new_items.truncate(beam);
127 beam_items = new_items;
128 }
129 let best = beam_items
130 .into_iter()
131 .next()
132 .ok_or_else(|| SeqError::NumericalInstability("empty beam".to_string()))?;
133 Ok(best.1)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn class_probs_sum_to_one() {
143 let m = Memm::zeros(3, 2).expect("ok");
144 let p = m.class_probs(0, &[1.0, 1.0]).expect("ok");
145 let s: f64 = p.iter().sum();
146 assert!((s - 1.0).abs() < 1e-9);
147 for &v in &p {
148 assert!((v - 1.0 / 3.0).abs() < 1e-9);
149 }
150 }
151
152 #[test]
153 fn greedy_zero_weights() {
154 let m = Memm::zeros(2, 2).expect("ok");
155 let path = m.decode_greedy(&[1.0, 0.0, 0.0, 1.0]).expect("ok");
156 assert_eq!(path.len(), 2);
157 }
158
159 #[test]
160 fn beam_matches_greedy_zero_weights() {
161 let m = Memm::zeros(2, 2).expect("ok");
162 let g = m.decode_greedy(&[1.0, 0.0, 0.0, 1.0]).expect("ok");
163 let b = m.decode_beam(&[1.0, 0.0, 0.0, 1.0], 4).expect("ok");
164 assert_eq!(g.len(), b.len());
165 }
166}