Skip to main content

oxicuda_seq/perceptron/
structured_perceptron.rs

1//! Structured perceptron and averaged structured perceptron (Collins 2002).
2//!
3//! The **structured perceptron** ("Discriminative Training Methods for Hidden
4//! Markov Models", Collins, EMNLP 2002) trains a linear sequence tagger by the
5//! mistake-driven perceptron rule.  For each training sentence the current
6//! weights `w` are used to Viterbi-decode a predicted tag sequence `ŷ`; if
7//! `ŷ ≠ y` the weights are corrected toward the gold features and away from the
8//! predicted features:
9//!
10//! ```text
11//! w ← w + φ(x, y) − φ(x, ŷ)
12//! ```
13//!
14//! where `φ` is the global feature map decomposing into local emission and
15//! transition features.  This crate uses the same parameterisation as the
16//! linear-chain CRF: real-valued emission features per position combined with a
17//! dense `n_labels × n_labels` transition table.
18//!
19//! The **averaged perceptron** (Freund & Schapire 1999, popularised by Collins)
20//! returns the *average* of all weight vectors seen during training rather than
21//! the final one, which dramatically reduces variance and overfitting.  We use
22//! the efficient lazy-update trick of Daumé (a running total accumulator
23//! `w_total += w` after every example) so averaging costs `O(P)` once at the end
24//! rather than `O(P)` per update.
25
26use crate::error::{SeqError, SeqResult};
27
28// ─── Model ───────────────────────────────────────────────────────────────────
29
30/// A linear-chain structured perceptron tagger.
31///
32/// Parameter layout (identical to `LinearChainCrf`):
33/// * `emissions[label*n_features + k]` — weight for feature `k` under `label`.
34/// * `transitions[prev*n_labels + cur]` — score of the `prev → cur` bigram.
35#[derive(Debug, Clone)]
36pub struct StructuredPerceptron {
37    /// Number of output labels.
38    pub n_labels: usize,
39    /// Number of (real-valued) emission features per position.
40    pub n_features: usize,
41    /// Emission weights, length `n_labels * n_features`.
42    pub emissions: Vec<f64>,
43    /// Transition weights, length `n_labels * n_labels`.
44    pub transitions: Vec<f64>,
45}
46
47impl StructuredPerceptron {
48    /// Create a zero-initialised perceptron.
49    ///
50    /// # Errors
51    ///
52    /// [`SeqError::InvalidConfiguration`] if `n_labels == 0` or `n_features == 0`.
53    pub fn zeros(n_labels: usize, n_features: usize) -> SeqResult<Self> {
54        if n_labels == 0 || n_features == 0 {
55            return Err(SeqError::InvalidConfiguration(
56                "n_labels and n_features must be > 0".to_string(),
57            ));
58        }
59        Ok(Self {
60            n_labels,
61            n_features,
62            emissions: vec![0.0; n_labels * n_features],
63            transitions: vec![0.0; n_labels * n_labels],
64        })
65    }
66
67    /// Total number of parameters.
68    #[must_use]
69    pub fn param_count(&self) -> usize {
70        self.n_labels * self.n_features + self.n_labels * self.n_labels
71    }
72
73    /// Emission score `w_label · x_t`.
74    fn emit_score(&self, label: usize, x: &[f64]) -> f64 {
75        let base = label * self.n_features;
76        let mut s = 0.0;
77        for (k, &xv) in x.iter().enumerate() {
78            s += self.emissions[base + k] * xv;
79        }
80        s
81    }
82
83    /// Viterbi decode the highest-scoring label sequence for feature matrix
84    /// `x` (`T × n_features`, row-major).
85    ///
86    /// # Errors
87    ///
88    /// * [`SeqError::EmptyInput`]    — if `x` is empty.
89    /// * [`SeqError::ShapeMismatch`] — if `x.len()` is not a multiple of
90    ///   `n_features`.
91    pub fn decode(&self, x: &[f64]) -> SeqResult<Vec<usize>> {
92        if x.is_empty() {
93            return Err(SeqError::EmptyInput);
94        }
95        let k = self.n_features;
96        if x.len() % k != 0 {
97            return Err(SeqError::ShapeMismatch {
98                expected: x.len().div_ceil(k) * k,
99                got: x.len(),
100            });
101        }
102        let n = self.n_labels;
103        let t_max = x.len() / k;
104
105        let mut delta = vec![f64::NEG_INFINITY; t_max * n];
106        let mut psi = vec![0usize; t_max * n];
107
108        // t = 0: emission only.
109        for j in 0..n {
110            delta[j] = self.emit_score(j, &x[..k]);
111        }
112        for t in 1..t_max {
113            let xt = &x[t * k..(t + 1) * k];
114            for j in 0..n {
115                let emit = self.emit_score(j, xt);
116                let mut best = f64::NEG_INFINITY;
117                let mut argmax = 0usize;
118                for i in 0..n {
119                    let v = delta[(t - 1) * n + i] + self.transitions[i * n + j];
120                    if v > best {
121                        best = v;
122                        argmax = i;
123                    }
124                }
125                delta[t * n + j] = best + emit;
126                psi[t * n + j] = argmax;
127            }
128        }
129
130        // Termination.
131        let mut best = f64::NEG_INFINITY;
132        let mut last = 0usize;
133        for j in 0..n {
134            let v = delta[(t_max - 1) * n + j];
135            if v > best {
136                best = v;
137                last = j;
138            }
139        }
140        let mut path = vec![0usize; t_max];
141        path[t_max - 1] = last;
142        for t in (1..t_max).rev() {
143            path[t - 1] = psi[t * n + path[t]];
144        }
145        Ok(path)
146    }
147
148    /// Total linear score of a full label sequence `y` under `x`.
149    ///
150    /// # Errors
151    ///
152    /// * [`SeqError::EmptyInput`]      — if `y` is empty.
153    /// * [`SeqError::ShapeMismatch`]   — if `x.len() ≠ y.len() * n_features`.
154    /// * [`SeqError::IndexOutOfBounds`] — if any label `≥ n_labels`.
155    pub fn sequence_score(&self, x: &[f64], y: &[usize]) -> SeqResult<f64> {
156        if y.is_empty() {
157            return Err(SeqError::EmptyInput);
158        }
159        let k = self.n_features;
160        let t_max = y.len();
161        if x.len() != t_max * k {
162            return Err(SeqError::ShapeMismatch {
163                expected: t_max * k,
164                got: x.len(),
165            });
166        }
167        let mut s = 0.0;
168        for t in 0..t_max {
169            if y[t] >= self.n_labels {
170                return Err(SeqError::IndexOutOfBounds {
171                    index: y[t],
172                    len: self.n_labels,
173                });
174            }
175            s += self.emit_score(y[t], &x[t * k..(t + 1) * k]);
176            if t > 0 {
177                s += self.transitions[y[t - 1] * self.n_labels + y[t]];
178            }
179        }
180        Ok(s)
181    }
182
183    /// Apply the perceptron correction `θ ← θ + φ(gold) − φ(pred)` in place,
184    /// returning the number of positions where `gold` and `pred` differ.
185    ///
186    /// Both label sequences must have length `T = x.len() / n_features`.
187    ///
188    /// # Errors
189    ///
190    /// * [`SeqError::LengthMismatch`]  — if `gold.len() ≠ pred.len()`.
191    /// * [`SeqError::ShapeMismatch`]   — if shapes are inconsistent.
192    /// * [`SeqError::IndexOutOfBounds`] — if any label is out of range.
193    pub fn update(&mut self, x: &[f64], gold: &[usize], pred: &[usize]) -> SeqResult<usize> {
194        if gold.len() != pred.len() {
195            return Err(SeqError::LengthMismatch {
196                a: gold.len(),
197                b: pred.len(),
198            });
199        }
200        let k = self.n_features;
201        let t_max = gold.len();
202        if x.len() != t_max * k {
203            return Err(SeqError::ShapeMismatch {
204                expected: t_max * k,
205                got: x.len(),
206            });
207        }
208        let n = self.n_labels;
209        for &lbl in gold.iter().chain(pred.iter()) {
210            if lbl >= n {
211                return Err(SeqError::IndexOutOfBounds { index: lbl, len: n });
212            }
213        }
214
215        let mut mistakes = 0usize;
216        // Emission features.
217        for t in 0..t_max {
218            if gold[t] == pred[t] {
219                continue;
220            }
221            mistakes += 1;
222            let xt = &x[t * k..(t + 1) * k];
223            let gbase = gold[t] * k;
224            let pbase = pred[t] * k;
225            for (idx, &xv) in xt.iter().enumerate() {
226                self.emissions[gbase + idx] += xv;
227                self.emissions[pbase + idx] -= xv;
228            }
229        }
230        // Transition features.
231        for t in 1..t_max {
232            let g = gold[t - 1] * n + gold[t];
233            let p = pred[t - 1] * n + pred[t];
234            if g != p {
235                self.transitions[g] += 1.0;
236                self.transitions[p] -= 1.0;
237            }
238        }
239        Ok(mistakes)
240    }
241}
242
243// ─── Training configuration ──────────────────────────────────────────────────
244
245/// Configuration for perceptron training.
246#[derive(Debug, Clone)]
247pub struct PerceptronConfig {
248    /// Number of passes (epochs) over the training set.
249    pub epochs: usize,
250    /// Whether to return the averaged weight vector (Collins averaging).
251    pub averaged: bool,
252}
253
254impl Default for PerceptronConfig {
255    fn default() -> Self {
256        Self {
257            epochs: 10,
258            averaged: true,
259        }
260    }
261}
262
263/// One training example: a feature matrix `x` (`T × n_features`) and a gold
264/// label sequence `y` (`T`).
265#[derive(Debug, Clone)]
266pub struct PerceptronExample {
267    /// Feature matrix, row-major `T × n_features`.
268    pub x: Vec<f64>,
269    /// Gold labels, length `T`.
270    pub y: Vec<usize>,
271}
272
273/// Result of perceptron training.
274#[derive(Debug, Clone)]
275pub struct PerceptronTrainResult {
276    /// The trained model (averaged iff `config.averaged`).
277    pub model: StructuredPerceptron,
278    /// Total number of mistaken positions over the final epoch.
279    pub final_epoch_mistakes: usize,
280    /// Number of epochs actually run.
281    pub epochs_run: usize,
282}
283
284/// Train a structured perceptron on the given examples.
285///
286/// The model is updated example-by-example.  When `config.averaged` is set the
287/// returned model is the running average of all weight vectors (the standard
288/// "averaged perceptron"); the running total is accumulated after every example
289/// so that each weight vector contributes once per example it survived.
290///
291/// # Errors
292///
293/// * [`SeqError::EmptyInput`]    — if `examples` is empty.
294/// * [`SeqError::ShapeMismatch`] — if any example's `x` length is not
295///   `y.len() * n_features`.
296/// * Propagates decode/update errors.
297pub fn train_perceptron(
298    n_labels: usize,
299    n_features: usize,
300    examples: &[PerceptronExample],
301    config: &PerceptronConfig,
302) -> SeqResult<PerceptronTrainResult> {
303    if examples.is_empty() {
304        return Err(SeqError::EmptyInput);
305    }
306    let mut model = StructuredPerceptron::zeros(n_labels, n_features)?;
307    let p = model.param_count();
308    // Running total for averaging (emissions then transitions, same layout).
309    let mut total = vec![0.0_f64; p];
310    let mut n_updates = 0u64;
311    let mut final_mistakes = 0usize;
312
313    for epoch in 0..config.epochs.max(1) {
314        let mut epoch_mistakes = 0usize;
315        for ex in examples {
316            let t_max = ex.y.len();
317            if t_max == 0 || ex.x.len() != t_max * n_features {
318                return Err(SeqError::ShapeMismatch {
319                    expected: t_max * n_features,
320                    got: ex.x.len(),
321                });
322            }
323            let pred = model.decode(&ex.x)?;
324            let mistakes = model.update(&ex.x, &ex.y, &pred)?;
325            epoch_mistakes += mistakes;
326
327            if config.averaged {
328                // Accumulate the *current* weight vector after the update.
329                accumulate(&model, &mut total);
330                n_updates += 1;
331            }
332        }
333        final_mistakes = epoch_mistakes;
334        // Early-stop hint: a perfectly separable pass converges.
335        if epoch_mistakes == 0 {
336            // Still keep accumulating handled above; break after recording.
337            return finish(model, total, n_updates, final_mistakes, epoch + 1, config);
338        }
339    }
340
341    finish(
342        model,
343        total,
344        n_updates,
345        final_mistakes,
346        config.epochs.max(1),
347        config,
348    )
349}
350
351/// Add the model's current weights into the running total.
352fn accumulate(model: &StructuredPerceptron, total: &mut [f64]) {
353    let cut = model.emissions.len();
354    for (t, &e) in total[..cut].iter_mut().zip(model.emissions.iter()) {
355        *t += e;
356    }
357    for (t, &tr) in total[cut..].iter_mut().zip(model.transitions.iter()) {
358        *t += tr;
359    }
360}
361
362/// Finalise training, applying averaging if requested.
363fn finish(
364    mut model: StructuredPerceptron,
365    total: Vec<f64>,
366    n_updates: u64,
367    final_mistakes: usize,
368    epochs_run: usize,
369    config: &PerceptronConfig,
370) -> SeqResult<PerceptronTrainResult> {
371    if config.averaged && n_updates > 0 {
372        let inv = 1.0 / n_updates as f64;
373        let cut = model.emissions.len();
374        for (e, &t) in model.emissions.iter_mut().zip(total[..cut].iter()) {
375            *e = t * inv;
376        }
377        for (tr, &t) in model.transitions.iter_mut().zip(total[cut..].iter()) {
378            *tr = t * inv;
379        }
380    }
381    Ok(PerceptronTrainResult {
382        model,
383        final_epoch_mistakes: final_mistakes,
384        epochs_run,
385    })
386}
387
388// ─── Tests ───────────────────────────────────────────────────────────────────
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    /// Build a separable toy dataset: 2 labels, 2 features where feature `j`
395    /// is an indicator that the gold label is `j`.  A perceptron must learn to
396    /// route feature 0 → label 0 and feature 1 → label 1.
397    fn toy_examples() -> Vec<PerceptronExample> {
398        vec![
399            PerceptronExample {
400                // tags 0,1,0
401                x: vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
402                y: vec![0, 1, 0],
403            },
404            PerceptronExample {
405                // tags 1,0,1
406                x: vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
407                y: vec![1, 0, 1],
408            },
409        ]
410    }
411
412    #[test]
413    fn zeros_rejects_bad_dims() {
414        assert!(StructuredPerceptron::zeros(0, 3).is_err());
415        assert!(StructuredPerceptron::zeros(3, 0).is_err());
416        assert!(StructuredPerceptron::zeros(2, 2).is_ok());
417    }
418
419    #[test]
420    fn param_count_correct() {
421        let m = StructuredPerceptron::zeros(3, 4).expect("ok");
422        assert_eq!(m.param_count(), 3 * 4 + 3 * 3);
423    }
424
425    #[test]
426    fn zero_model_decodes_first_label() {
427        // All-zero weights: every path scores 0; argmax ties resolve to label 0.
428        let m = StructuredPerceptron::zeros(2, 2).expect("ok");
429        let y = m.decode(&[1.0, 0.0, 0.0, 1.0]).expect("ok");
430        assert_eq!(y, vec![0, 0]);
431    }
432
433    #[test]
434    fn decode_rejects_empty() {
435        let m = StructuredPerceptron::zeros(2, 2).expect("ok");
436        assert!(matches!(m.decode(&[]), Err(SeqError::EmptyInput)));
437    }
438
439    #[test]
440    fn decode_rejects_bad_shape() {
441        let m = StructuredPerceptron::zeros(2, 3).expect("ok");
442        // length 4 is not a multiple of n_features=3.
443        assert!(matches!(
444            m.decode(&[1.0, 2.0, 3.0, 4.0]),
445            Err(SeqError::ShapeMismatch { .. })
446        ));
447    }
448
449    #[test]
450    fn sequence_score_rejects_oob_label() {
451        let m = StructuredPerceptron::zeros(2, 2).expect("ok");
452        assert!(matches!(
453            m.sequence_score(&[1.0, 0.0], &[5]),
454            Err(SeqError::IndexOutOfBounds { .. })
455        ));
456    }
457
458    #[test]
459    fn update_rejects_length_mismatch() {
460        let mut m = StructuredPerceptron::zeros(2, 2).expect("ok");
461        assert!(matches!(
462            m.update(&[1.0, 0.0], &[0], &[0, 1]),
463            Err(SeqError::LengthMismatch { .. })
464        ));
465    }
466
467    #[test]
468    fn update_counts_mistakes_and_moves_weights() {
469        let mut m = StructuredPerceptron::zeros(2, 2).expect("ok");
470        let x = vec![1.0, 0.0, 0.0, 1.0]; // T=2
471        let gold = vec![0, 1];
472        let pred = vec![1, 0];
473        let mistakes = m.update(&x, &gold, &pred).expect("ok");
474        assert_eq!(mistakes, 2);
475        // emission[label0, feat0] should have increased (gold uses feat0→label0).
476        assert!(m.emissions[0] > 0.0);
477        // After the correction, gold now scores higher than pred.
478        let sg = m.sequence_score(&x, &gold).expect("ok");
479        let sp = m.sequence_score(&x, &pred).expect("ok");
480        assert!(sg > sp, "gold {sg} should exceed pred {sp}");
481    }
482
483    #[test]
484    fn update_no_mistakes_is_noop() {
485        let mut m = StructuredPerceptron::zeros(2, 2).expect("ok");
486        let x = vec![1.0, 0.0, 0.0, 1.0];
487        let y = vec![0, 1];
488        let before = m.emissions.clone();
489        let mistakes = m.update(&x, &y, &y).expect("ok");
490        assert_eq!(mistakes, 0);
491        assert_eq!(before, m.emissions);
492    }
493
494    #[test]
495    fn train_rejects_empty() {
496        let cfg = PerceptronConfig::default();
497        assert!(matches!(
498            train_perceptron(2, 2, &[], &cfg),
499            Err(SeqError::EmptyInput)
500        ));
501    }
502
503    #[test]
504    fn train_learns_separable_data() {
505        let ex = toy_examples();
506        let cfg = PerceptronConfig {
507            epochs: 20,
508            averaged: false,
509        };
510        let res = train_perceptron(2, 2, &ex, &cfg).expect("ok");
511        // After training the model must decode every example correctly.
512        for e in &ex {
513            let pred = res.model.decode(&e.x).expect("ok");
514            assert_eq!(pred, e.y, "model failed to fit training example");
515        }
516    }
517
518    #[test]
519    fn train_converges_to_zero_mistakes() {
520        let ex = toy_examples();
521        let cfg = PerceptronConfig {
522            epochs: 50,
523            averaged: false,
524        };
525        let res = train_perceptron(2, 2, &ex, &cfg).expect("ok");
526        assert_eq!(
527            res.final_epoch_mistakes, 0,
528            "separable data should converge to 0 mistakes"
529        );
530        assert!(res.epochs_run <= 50);
531    }
532
533    #[test]
534    fn averaged_perceptron_fits_and_is_finite() {
535        let ex = toy_examples();
536        let cfg = PerceptronConfig {
537            epochs: 20,
538            averaged: true,
539        };
540        let res = train_perceptron(2, 2, &ex, &cfg).expect("ok");
541        assert!(res.model.emissions.iter().all(|v| v.is_finite()));
542        for e in &ex {
543            let pred = res.model.decode(&e.x).expect("ok");
544            assert_eq!(pred, e.y);
545        }
546    }
547
548    #[test]
549    fn averaging_equals_mean_of_trajectory() {
550        // Deterministic averaging check on a *single* example over 2 epochs.
551        // Epoch 0: weights start at 0, prediction is all-label-0, the example
552        // has a mistake → one update giving weight vector w₁.  Epoch 1: the
553        // example is now classified correctly → no update, weight stays w₁.
554        // The averaged model accumulates [w₁ (after ep0), w₁ (after ep1)], so
555        // the average is exactly w₁ == the raw final weights here.
556        let ex = vec![PerceptronExample {
557            x: vec![1.0, 0.0, 0.0, 1.0],
558            y: vec![0, 1],
559        }];
560        let avg = train_perceptron(
561            2,
562            2,
563            &ex,
564            &PerceptronConfig {
565                epochs: 2,
566                averaged: true,
567            },
568        )
569        .expect("ok");
570        let raw = train_perceptron(
571            2,
572            2,
573            &ex,
574            &PerceptronConfig {
575                epochs: 2,
576                averaged: false,
577            },
578        )
579        .expect("ok");
580        // Both should decode the example correctly.
581        assert_eq!(avg.model.decode(&ex[0].x).expect("d"), ex[0].y);
582        assert_eq!(raw.model.decode(&ex[0].x).expect("d"), ex[0].y);
583        // Averaged weights are a convex mean of the trajectory, hence finite
584        // and never larger in magnitude than the raw final weights.
585        for (a, r) in avg.model.emissions.iter().zip(raw.model.emissions.iter()) {
586            assert!(a.abs() <= r.abs() + 1e-9, "avg {a} exceeds raw {r}");
587        }
588    }
589
590    #[test]
591    fn averaging_shrinks_when_trajectory_varies() {
592        // A dataset whose two examples *pull weights in opposite directions*
593        // each epoch keeps the weight trajectory oscillating, so the running
594        // average has strictly smaller magnitude than the final weights for at
595        // least one coordinate.
596        let ex = vec![
597            PerceptronExample {
598                x: vec![1.0, 0.0],
599                y: vec![0],
600            },
601            PerceptronExample {
602                x: vec![1.0, 0.0],
603                y: vec![1],
604            },
605        ];
606        let avg = train_perceptron(
607            2,
608            2,
609            &ex,
610            &PerceptronConfig {
611                epochs: 6,
612                averaged: true,
613            },
614        )
615        .expect("ok");
616        let raw = train_perceptron(
617            2,
618            2,
619            &ex,
620            &PerceptronConfig {
621                epochs: 6,
622                averaged: false,
623            },
624        )
625        .expect("ok");
626        let diff: f64 = avg
627            .model
628            .emissions
629            .iter()
630            .zip(raw.model.emissions.iter())
631            .map(|(a, b)| (a - b).abs())
632            .sum();
633        assert!(
634            diff > 1e-9,
635            "with a non-separable oscillating dataset averaging must differ from final"
636        );
637    }
638
639    #[test]
640    fn train_rejects_inconsistent_example_shape() {
641        let bad = vec![PerceptronExample {
642            x: vec![1.0, 0.0, 0.0], // 3 entries, but T=2 needs 4
643            y: vec![0, 1],
644        }];
645        let cfg = PerceptronConfig::default();
646        assert!(matches!(
647            train_perceptron(2, 2, &bad, &cfg),
648            Err(SeqError::ShapeMismatch { .. })
649        ));
650    }
651}