Skip to main content

oxicuda_seq/crf/
sgd.rs

1//! CRF training via SGD / AdaGrad.
2//!
3//! Linear-chain CRF with parameter vector θ = [emissions | transitions]:
4//!   - emissions: `[n_tags × n_features]` (row-major: tag is the outer index)
5//!   - transitions: `[n_tags × n_tags]` (row-major: `tr[i,j]` = prev-tag i → curr-tag j)
6//!
7//! Score: Σ_t (`w_emit[y_t]` · x_t) + Σ_{t≥1} `w_tr[y_{t-1}, y_t]`
8//!
9//! Gradient is computed via log-space forward-backward.
10//! AdaGrad: G_i += g_i²;  θ_i -= lr / sqrt(G_i + ε) * g_i  (minimising NLL).
11
12use crate::error::{SeqError, SeqResult};
13use crate::handle::LcgRng;
14
15// ─── logsumexp helper ────────────────────────────────────────────────────────
16
17#[inline]
18fn logsumexp(xs: &[f64]) -> f64 {
19    let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
20    if m == f64::NEG_INFINITY {
21        return f64::NEG_INFINITY;
22    }
23    let s: f64 = xs.iter().map(|&x| (x - m).exp()).sum();
24    m + s.ln()
25}
26
27// ─── Configuration ───────────────────────────────────────────────────────────
28
29/// Configuration for SGD/AdaGrad CRF training.
30#[derive(Debug, Clone)]
31pub struct CrfSgdConfig {
32    /// Number of label tags.
33    pub n_tags: usize,
34    /// Number of emission features per position.
35    pub n_features: usize,
36    /// Number of training epochs.
37    pub n_epochs: usize,
38    /// Base learning rate.
39    pub lr: f64,
40    /// L2 regularisation strength.
41    pub l2_reg: f64,
42    /// Whether to use AdaGrad adaptive learning rates.
43    pub adagrad: bool,
44}
45
46// ─── CrfSgd ──────────────────────────────────────────────────────────────────
47
48/// Linear-chain CRF trained with SGD or AdaGrad.
49///
50/// Parameter layout:
51///   `weights[tag * n_features + feat]`  → emission weight for (tag, feat)
52///   `weights[n_tags * n_features + prev * n_tags + curr]`  → transition weight
53#[derive(Debug, Clone)]
54pub struct CrfSgd {
55    /// Full parameter vector `[n_tags × n_features + n_tags × n_tags]`.
56    pub weights: Vec<f64>,
57    /// Hyper-parameters.
58    config: CrfSgdConfig,
59    /// Accumulated squared gradients for AdaGrad.
60    adagrad_acc: Vec<f64>,
61}
62
63impl CrfSgd {
64    const ADAGRAD_EPS: f64 = 1e-8;
65
66    // ── Construction ─────────────────────────────────────────────────────────
67
68    /// Create and initialise a new CRF (weights ~ N(0, 0.1)).
69    pub fn new(config: CrfSgdConfig, rng: &mut LcgRng) -> SeqResult<Self> {
70        if config.n_tags == 0 {
71            return Err(SeqError::InvalidConfiguration("n_tags must be > 0".into()));
72        }
73        if config.n_features == 0 {
74            return Err(SeqError::InvalidConfiguration(
75                "n_features must be > 0".into(),
76            ));
77        }
78        let n_emit = config.n_tags * config.n_features;
79        let n_tr = config.n_tags * config.n_tags;
80        let n_params = n_emit + n_tr;
81        let weights: Vec<f64> = (0..n_params).map(|_| rng.next_normal() * 0.1).collect();
82        let adagrad_acc = vec![0.0f64; n_params];
83        Ok(Self {
84            weights,
85            config,
86            adagrad_acc,
87        })
88    }
89
90    // ── Index helpers ─────────────────────────────────────────────────────────
91
92    /// Index into the emission block.
93    #[inline]
94    fn emit_idx(&self, tag: usize, feat: usize) -> usize {
95        tag * self.config.n_features + feat
96    }
97
98    /// Index into the transition block.
99    #[inline]
100    fn tr_idx(&self, prev_tag: usize, curr_tag: usize) -> usize {
101        self.config.n_tags * self.config.n_features + prev_tag * self.config.n_tags + curr_tag
102    }
103
104    // ── Public accessors ──────────────────────────────────────────────────────
105
106    /// Emission weight for `(tag, feat)`.
107    pub fn emission_weight(&self, tag: usize, feat: usize) -> f64 {
108        self.weights[self.emit_idx(tag, feat)]
109    }
110
111    /// Transition weight for the `prev_tag → curr_tag` edge.
112    pub fn transition_weight(&self, prev_tag: usize, curr_tag: usize) -> f64 {
113        self.weights[self.tr_idx(prev_tag, curr_tag)]
114    }
115
116    // ── Score helpers ─────────────────────────────────────────────────────────
117
118    /// Emission score at position `t` for tag `j`.
119    #[inline]
120    fn emit_score(&self, j: usize, feat: &[f64]) -> f64 {
121        let base = j * self.config.n_features;
122        let mut s = 0.0;
123        for f in 0..self.config.n_features {
124            s += self.weights[base + f] * feat[f];
125        }
126        s
127    }
128
129    // ── Forward algorithm (log-partition) ─────────────────────────────────────
130
131    /// Compute the log-partition function Z for `features[seq_len][n_features]`.
132    pub fn log_partition(&self, features: &[Vec<f64>], seq_len: usize) -> SeqResult<f64> {
133        if seq_len == 0 {
134            return Err(SeqError::EmptyInput);
135        }
136        if features.len() < seq_len {
137            return Err(SeqError::ShapeMismatch {
138                expected: seq_len,
139                got: features.len(),
140            });
141        }
142        let n = self.config.n_tags;
143        let mut alpha = vec![f64::NEG_INFINITY; n];
144        // Initialise: α_0(j) = emit(j, x_0)
145        for j in 0..n {
146            alpha[j] = self.emit_score(j, &features[0]);
147        }
148        let mut tmp = vec![0.0f64; n];
149        for t in 1..seq_len {
150            let mut alpha_new = vec![f64::NEG_INFINITY; n];
151            for j in 0..n {
152                for i in 0..n {
153                    tmp[i] = alpha[i] + self.transition_weight(i, j);
154                }
155                alpha_new[j] = logsumexp(&tmp) + self.emit_score(j, &features[t]);
156            }
157            alpha = alpha_new;
158        }
159        Ok(logsumexp(&alpha))
160    }
161
162    // ── Forward-backward (all α, β tables) ───────────────────────────────────
163
164    fn forward_table(&self, features: &[Vec<f64>], seq_len: usize) -> Vec<Vec<f64>> {
165        let n = self.config.n_tags;
166        let mut table = vec![vec![f64::NEG_INFINITY; n]; seq_len];
167        for j in 0..n {
168            table[0][j] = self.emit_score(j, &features[0]);
169        }
170        let mut tmp = vec![0.0f64; n];
171        for t in 1..seq_len {
172            for j in 0..n {
173                for i in 0..n {
174                    tmp[i] = table[t - 1][i] + self.transition_weight(i, j);
175                }
176                table[t][j] = logsumexp(&tmp) + self.emit_score(j, &features[t]);
177            }
178        }
179        table
180    }
181
182    fn backward_table(&self, features: &[Vec<f64>], seq_len: usize) -> Vec<Vec<f64>> {
183        let n = self.config.n_tags;
184        let mut table = vec![vec![0.0f64; n]; seq_len]; // β_{T-1}(i) = 0 in log-space
185        let mut tmp = vec![0.0f64; n];
186        for t in (0..seq_len - 1).rev() {
187            for i in 0..n {
188                for j in 0..n {
189                    tmp[j] = self.transition_weight(i, j)
190                        + self.emit_score(j, &features[t + 1])
191                        + table[t + 1][j];
192                }
193                table[t][i] = logsumexp(&tmp);
194            }
195        }
196        table
197    }
198
199    // ── Gradient computation ──────────────────────────────────────────────────
200
201    /// Compute gradient (NLL direction, i.e. ascent on NLL = descent on LL) and
202    /// negative log-likelihood for one sequence.
203    ///
204    /// Returns `(nll, gradient)`.
205    fn gradient_one(&self, features: &[Vec<f64>], labels: &[usize]) -> SeqResult<(f64, Vec<f64>)> {
206        let seq_len = labels.len();
207        if seq_len == 0 {
208            return Err(SeqError::EmptyInput);
209        }
210        let n = self.config.n_tags;
211        let k = self.config.n_features;
212        let n_params = self.weights.len();
213
214        // Validate labels.
215        for (t, &y) in labels.iter().enumerate() {
216            if y >= n {
217                return Err(SeqError::IndexOutOfBounds { index: y, len: n });
218            }
219            if features[t].len() != k {
220                return Err(SeqError::ShapeMismatch {
221                    expected: k,
222                    got: features[t].len(),
223                });
224            }
225        }
226
227        // ── Log-partition (forward only) ──────────────────────────────────────
228        let alpha = self.forward_table(features, seq_len);
229        let log_z = logsumexp(&alpha[seq_len - 1]);
230
231        // ── Backward ──────────────────────────────────────────────────────────
232        let beta = self.backward_table(features, seq_len);
233
234        // ── Score of true path ────────────────────────────────────────────────
235        let mut score_true = self.emit_score(labels[0], &features[0]);
236        for t in 1..seq_len {
237            score_true += self.transition_weight(labels[t - 1], labels[t])
238                + self.emit_score(labels[t], &features[t]);
239        }
240        let nll = log_z - score_true;
241
242        // ── Gradient ──────────────────────────────────────────────────────────
243        // grad[θ] = E_model[f(x,y)] - f(x, y_true)
244        let mut grad = vec![0.0f64; n_params];
245
246        // Emission expected counts via γ_t(j) = α_t(j) + β_t(j) - log_z
247        for t in 0..seq_len {
248            let feat = &features[t];
249            for j in 0..n {
250                let log_gamma = alpha[t][j] + beta[t][j] - log_z;
251                let gamma = log_gamma.exp();
252                let base = self.emit_idx(j, 0);
253                for f in 0..k {
254                    grad[base + f] += gamma * feat[f];
255                }
256            }
257        }
258        // Subtract true emission counts
259        for t in 0..seq_len {
260            let feat = &features[t];
261            let j = labels[t];
262            let base = self.emit_idx(j, 0);
263            for f in 0..k {
264                grad[base + f] -= feat[f];
265            }
266        }
267
268        // Transition expected counts via ξ_{t}(i,j):
269        // log ξ_t(i,j) = α_t(i) + tr(i,j) + emit_{t+1}(j) + β_{t+1}(j) - log_z
270        for t in 0..seq_len - 1 {
271            for i in 0..n {
272                for j in 0..n {
273                    let log_xi = alpha[t][i]
274                        + self.transition_weight(i, j)
275                        + self.emit_score(j, &features[t + 1])
276                        + beta[t + 1][j]
277                        - log_z;
278                    let xi = log_xi.exp();
279                    grad[self.tr_idx(i, j)] += xi;
280                }
281            }
282        }
283        // Subtract true transition counts
284        for t in 1..seq_len {
285            let (i, j) = (labels[t - 1], labels[t]);
286            grad[self.tr_idx(i, j)] -= 1.0;
287        }
288
289        Ok((nll, grad))
290    }
291
292    // ── Parameter update ──────────────────────────────────────────────────────
293
294    /// Apply one SGD / AdaGrad step given a pre-computed gradient vector.
295    fn apply_update(&mut self, grad: &[f64]) {
296        let lr = self.config.lr;
297        let eps = Self::ADAGRAD_EPS;
298        let n_params = self.weights.len();
299
300        if self.config.adagrad {
301            for i in 0..n_params {
302                self.adagrad_acc[i] += grad[i] * grad[i];
303                let eff_lr = lr / (self.adagrad_acc[i] + eps).sqrt();
304                self.weights[i] -= eff_lr * grad[i];
305            }
306        } else {
307            for i in 0..n_params {
308                self.weights[i] -= lr * grad[i];
309            }
310        }
311    }
312
313    // ── Public training API ───────────────────────────────────────────────────
314
315    /// Compute gradient for one sequence and update weights in-place.
316    ///
317    /// Returns the negative log-likelihood for this sample.
318    pub fn update_one(&mut self, features: &[Vec<f64>], labels: &[usize]) -> SeqResult<f64> {
319        let (nll, mut grad) = self.gradient_one(features, labels)?;
320        // Add L2 regularisation gradient
321        let l2 = self.config.l2_reg;
322        if l2 > 0.0 {
323            for i in 0..self.weights.len() {
324                grad[i] += l2 * self.weights[i];
325            }
326        }
327        self.apply_update(&grad);
328        Ok(nll)
329    }
330
331    /// Train for `n_epochs`, returning average NLL per epoch.
332    pub fn fit(
333        &mut self,
334        all_features: &[Vec<Vec<f64>>],
335        all_labels: &[Vec<usize>],
336    ) -> SeqResult<Vec<f64>> {
337        if all_features.len() != all_labels.len() {
338            return Err(SeqError::LengthMismatch {
339                a: all_features.len(),
340                b: all_labels.len(),
341            });
342        }
343        let n_samples = all_features.len();
344        if n_samples == 0 {
345            return Err(SeqError::EmptyInput);
346        }
347        let mut epoch_losses = Vec::with_capacity(self.config.n_epochs);
348        for _epoch in 0..self.config.n_epochs {
349            let mut total_nll = 0.0;
350            for s in 0..n_samples {
351                total_nll += self.update_one(&all_features[s], &all_labels[s])?;
352            }
353            epoch_losses.push(total_nll / n_samples as f64);
354        }
355        Ok(epoch_losses)
356    }
357
358    // ── Decoding ──────────────────────────────────────────────────────────────
359
360    /// Viterbi decode — returns the best tag sequence for `features`.
361    pub fn decode(&self, features: &[Vec<f64>], seq_len: usize) -> SeqResult<Vec<usize>> {
362        if seq_len == 0 {
363            return Err(SeqError::EmptyInput);
364        }
365        if features.len() < seq_len {
366            return Err(SeqError::ShapeMismatch {
367                expected: seq_len,
368                got: features.len(),
369            });
370        }
371        let n = self.config.n_tags;
372        let mut viterbi = vec![f64::NEG_INFINITY; n];
373        let mut backptr = vec![vec![0usize; n]; seq_len];
374
375        // Initialise
376        for j in 0..n {
377            viterbi[j] = self.emit_score(j, &features[0]);
378        }
379
380        // Fill DP
381        for t in 1..seq_len {
382            let mut viterbi_new = vec![f64::NEG_INFINITY; n];
383            for j in 0..n {
384                let mut best_score = f64::NEG_INFINITY;
385                let mut best_prev = 0;
386                for i in 0..n {
387                    let s = viterbi[i] + self.transition_weight(i, j);
388                    if s > best_score {
389                        best_score = s;
390                        best_prev = i;
391                    }
392                }
393                viterbi_new[j] = best_score + self.emit_score(j, &features[t]);
394                backptr[t][j] = best_prev;
395            }
396            viterbi = viterbi_new;
397        }
398
399        // Find best last tag
400        let mut best_last = 0;
401        let mut best_val = f64::NEG_INFINITY;
402        for j in 0..n {
403            if viterbi[j] > best_val {
404                best_val = viterbi[j];
405                best_last = j;
406            }
407        }
408
409        // Backtrace
410        let mut path = vec![0usize; seq_len];
411        path[seq_len - 1] = best_last;
412        for t in (0..seq_len - 1).rev() {
413            path[t] = backptr[t + 1][path[t + 1]];
414        }
415        Ok(path)
416    }
417}
418
419// ─── Tests ───────────────────────────────────────────────────────────────────
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    fn make_config(adagrad: bool) -> CrfSgdConfig {
426        CrfSgdConfig {
427            n_tags: 3,
428            n_features: 4,
429            n_epochs: 5,
430            lr: 0.05,
431            l2_reg: 1e-4,
432            adagrad,
433        }
434    }
435
436    fn make_crf(adagrad: bool) -> CrfSgd {
437        let mut rng = LcgRng::new(42);
438        CrfSgd::new(make_config(adagrad), &mut rng).expect("construction failed")
439    }
440
441    fn simple_data(n_tags: usize, n_features: usize) -> (Vec<Vec<f64>>, Vec<usize>) {
442        // Sequence of length 3 with deterministic features
443        let features = vec![
444            vec![1.0, 0.0, 0.5, -0.5],
445            vec![0.0, 1.0, -0.5, 0.5],
446            vec![0.5, 0.5, 0.0, 1.0],
447        ];
448        // Keep only as many features as needed
449        let features: Vec<Vec<f64>> = features
450            .into_iter()
451            .map(|f| f.into_iter().take(n_features).collect())
452            .collect();
453        let labels = vec![0, 1 % n_tags, 2 % n_tags];
454        (features, labels)
455    }
456
457    #[test]
458    fn weights_shape() {
459        let crf = make_crf(false);
460        assert_eq!(
461            crf.weights.len(),
462            3 * 4 + 3 * 3,
463            "weights.len() should be n_tags*n_features + n_tags*n_tags"
464        );
465    }
466
467    #[test]
468    fn decode_output_len() {
469        let crf = make_crf(false);
470        let (features, _) = simple_data(3, 4);
471        let seq_len = features.len();
472        let path = crf.decode(&features, seq_len).expect("decode failed");
473        assert_eq!(path.len(), seq_len);
474    }
475
476    #[test]
477    fn decode_valid_tags() {
478        let crf = make_crf(false);
479        let (features, _) = simple_data(3, 4);
480        let seq_len = features.len();
481        let path = crf.decode(&features, seq_len).expect("decode failed");
482        for &tag in &path {
483            assert!(tag < 3, "decoded tag {tag} >= n_tags=3");
484        }
485    }
486
487    #[test]
488    fn log_partition_finite() {
489        let crf = make_crf(false);
490        let (features, _) = simple_data(3, 4);
491        let lz = crf
492            .log_partition(&features, features.len())
493            .expect("lz failed");
494        assert!(lz.is_finite(), "log_partition should be finite, got {lz}");
495    }
496
497    #[test]
498    fn update_decreases_loss() {
499        let mut rng = LcgRng::new(7);
500        let mut config = make_config(true);
501        config.n_epochs = 30;
502        config.lr = 0.1;
503        config.n_features = 4;
504        config.n_tags = 3;
505        let mut crf = CrfSgd::new(config, &mut rng).expect("new failed");
506
507        // Build a simple dataset: 4 samples, each of seq_len=3
508        let all_feats: Vec<Vec<Vec<f64>>> = (0..4)
509            .map(|seed| {
510                let mut r = LcgRng::new(seed as u64 + 1);
511                (0..3)
512                    .map(|_| (0..4).map(|_| r.next_normal()).collect())
513                    .collect()
514            })
515            .collect();
516        let all_labels: Vec<Vec<usize>> =
517            vec![vec![0, 1, 2], vec![2, 0, 1], vec![1, 2, 0], vec![0, 0, 1]];
518        let losses = crf.fit(&all_feats, &all_labels).expect("fit failed");
519        assert!(!losses.is_empty());
520        // Average of last 5 epochs < average of first 5
521        let first =
522            losses[..5.min(losses.len())].iter().sum::<f64>() / 5.0_f64.min(losses.len() as f64);
523        let last_start = losses.len().saturating_sub(5);
524        let last = losses[last_start..].iter().sum::<f64>() / (losses.len() - last_start) as f64;
525        assert!(
526            last < first,
527            "loss did not decrease: first={first:.4}, last={last:.4}"
528        );
529    }
530
531    #[test]
532    fn adagrad_different_from_sgd() {
533        let mut rng_sgd = LcgRng::new(42);
534        let mut rng_ada = LcgRng::new(42);
535        let mut config_sgd = make_config(false);
536        let mut config_ada = make_config(true);
537        config_sgd.n_epochs = 5;
538        config_ada.n_epochs = 5;
539        let mut crf_sgd = CrfSgd::new(config_sgd, &mut rng_sgd).expect("new failed");
540        let mut crf_ada = CrfSgd::new(config_ada, &mut rng_ada).expect("new failed");
541
542        let (features, labels) = simple_data(3, 4);
543        let all_feats = vec![features.clone()];
544        let all_labels = vec![labels.clone()];
545        crf_sgd.fit(&all_feats, &all_labels).expect("fit sgd");
546        crf_ada.fit(&all_feats, &all_labels).expect("fit ada");
547        let diff: f64 = crf_sgd
548            .weights
549            .iter()
550            .zip(&crf_ada.weights)
551            .map(|(a, b)| (a - b).abs())
552            .sum();
553        assert!(diff > 1e-12, "adagrad and sgd produced identical weights");
554    }
555
556    #[test]
557    fn viterbi_agrees_with_exhaustive() {
558        // For seq_len=2, n_tags=2: exhaustively check all 4 paths.
559        let mut rng = LcgRng::new(99);
560        let config = CrfSgdConfig {
561            n_tags: 2,
562            n_features: 3,
563            n_epochs: 1,
564            lr: 0.01,
565            l2_reg: 0.0,
566            adagrad: false,
567        };
568        let crf = CrfSgd::new(config, &mut rng).expect("new");
569        let features = vec![vec![1.0, -1.0, 0.5], vec![-0.5, 0.5, 1.0]];
570        let path = crf.decode(&features, 2).expect("decode");
571        // Enumerate all 4 paths
572        let score_path = |y0: usize, y1: usize| -> f64 {
573            crf.emit_score(y0, &features[0])
574                + crf.transition_weight(y0, y1)
575                + crf.emit_score(y1, &features[1])
576        };
577        let mut best_score = f64::NEG_INFINITY;
578        let mut best_path = (0, 0);
579        for y0 in 0..2 {
580            for y1 in 0..2 {
581                let s = score_path(y0, y1);
582                if s > best_score {
583                    best_score = s;
584                    best_path = (y0, y1);
585                }
586            }
587        }
588        assert_eq!(path[0], best_path.0, "Viterbi y0 mismatch");
589        assert_eq!(path[1], best_path.1, "Viterbi y1 mismatch");
590    }
591
592    #[test]
593    fn emission_weight_correct() {
594        let crf = make_crf(false);
595        for tag in 0..3 {
596            for feat in 0..4 {
597                let expected = crf.weights[tag * 4 + feat];
598                assert_eq!(
599                    crf.emission_weight(tag, feat),
600                    expected,
601                    "emission_weight({tag},{feat}) mismatch"
602                );
603            }
604        }
605    }
606
607    #[test]
608    fn n_tags_zero_error() {
609        let mut rng = LcgRng::new(1);
610        let config = CrfSgdConfig {
611            n_tags: 0,
612            n_features: 4,
613            n_epochs: 1,
614            lr: 0.01,
615            l2_reg: 0.0,
616            adagrad: false,
617        };
618        assert!(
619            CrfSgd::new(config, &mut rng).is_err(),
620            "n_tags=0 should fail"
621        );
622    }
623
624    #[test]
625    fn empty_sequence_error() {
626        let crf = make_crf(false);
627        let result = crf.decode(&[], 0);
628        assert!(result.is_err(), "decode on empty should fail");
629    }
630}