1use crate::error::{SeqError, SeqResult};
27
28#[derive(Debug, Clone)]
36pub struct StructuredPerceptron {
37 pub n_labels: usize,
39 pub n_features: usize,
41 pub emissions: Vec<f64>,
43 pub transitions: Vec<f64>,
45}
46
47impl StructuredPerceptron {
48 pub fn zeros(n_labels: usize, n_features: usize) -> SeqResult<Self> {
54 if n_labels == 0 || n_features == 0 {
55 return Err(SeqError::InvalidConfiguration(
56 "n_labels and n_features must be > 0".to_string(),
57 ));
58 }
59 Ok(Self {
60 n_labels,
61 n_features,
62 emissions: vec![0.0; n_labels * n_features],
63 transitions: vec![0.0; n_labels * n_labels],
64 })
65 }
66
67 #[must_use]
69 pub fn param_count(&self) -> usize {
70 self.n_labels * self.n_features + self.n_labels * self.n_labels
71 }
72
73 fn emit_score(&self, label: usize, x: &[f64]) -> f64 {
75 let base = label * self.n_features;
76 let mut s = 0.0;
77 for (k, &xv) in x.iter().enumerate() {
78 s += self.emissions[base + k] * xv;
79 }
80 s
81 }
82
83 pub fn decode(&self, x: &[f64]) -> SeqResult<Vec<usize>> {
92 if x.is_empty() {
93 return Err(SeqError::EmptyInput);
94 }
95 let k = self.n_features;
96 if x.len() % k != 0 {
97 return Err(SeqError::ShapeMismatch {
98 expected: x.len().div_ceil(k) * k,
99 got: x.len(),
100 });
101 }
102 let n = self.n_labels;
103 let t_max = x.len() / k;
104
105 let mut delta = vec![f64::NEG_INFINITY; t_max * n];
106 let mut psi = vec![0usize; t_max * n];
107
108 for j in 0..n {
110 delta[j] = self.emit_score(j, &x[..k]);
111 }
112 for t in 1..t_max {
113 let xt = &x[t * k..(t + 1) * k];
114 for j in 0..n {
115 let emit = self.emit_score(j, xt);
116 let mut best = f64::NEG_INFINITY;
117 let mut argmax = 0usize;
118 for i in 0..n {
119 let v = delta[(t - 1) * n + i] + self.transitions[i * n + j];
120 if v > best {
121 best = v;
122 argmax = i;
123 }
124 }
125 delta[t * n + j] = best + emit;
126 psi[t * n + j] = argmax;
127 }
128 }
129
130 let mut best = f64::NEG_INFINITY;
132 let mut last = 0usize;
133 for j in 0..n {
134 let v = delta[(t_max - 1) * n + j];
135 if v > best {
136 best = v;
137 last = j;
138 }
139 }
140 let mut path = vec![0usize; t_max];
141 path[t_max - 1] = last;
142 for t in (1..t_max).rev() {
143 path[t - 1] = psi[t * n + path[t]];
144 }
145 Ok(path)
146 }
147
148 pub fn sequence_score(&self, x: &[f64], y: &[usize]) -> SeqResult<f64> {
156 if y.is_empty() {
157 return Err(SeqError::EmptyInput);
158 }
159 let k = self.n_features;
160 let t_max = y.len();
161 if x.len() != t_max * k {
162 return Err(SeqError::ShapeMismatch {
163 expected: t_max * k,
164 got: x.len(),
165 });
166 }
167 let mut s = 0.0;
168 for t in 0..t_max {
169 if y[t] >= self.n_labels {
170 return Err(SeqError::IndexOutOfBounds {
171 index: y[t],
172 len: self.n_labels,
173 });
174 }
175 s += self.emit_score(y[t], &x[t * k..(t + 1) * k]);
176 if t > 0 {
177 s += self.transitions[y[t - 1] * self.n_labels + y[t]];
178 }
179 }
180 Ok(s)
181 }
182
183 pub fn update(&mut self, x: &[f64], gold: &[usize], pred: &[usize]) -> SeqResult<usize> {
194 if gold.len() != pred.len() {
195 return Err(SeqError::LengthMismatch {
196 a: gold.len(),
197 b: pred.len(),
198 });
199 }
200 let k = self.n_features;
201 let t_max = gold.len();
202 if x.len() != t_max * k {
203 return Err(SeqError::ShapeMismatch {
204 expected: t_max * k,
205 got: x.len(),
206 });
207 }
208 let n = self.n_labels;
209 for &lbl in gold.iter().chain(pred.iter()) {
210 if lbl >= n {
211 return Err(SeqError::IndexOutOfBounds { index: lbl, len: n });
212 }
213 }
214
215 let mut mistakes = 0usize;
216 for t in 0..t_max {
218 if gold[t] == pred[t] {
219 continue;
220 }
221 mistakes += 1;
222 let xt = &x[t * k..(t + 1) * k];
223 let gbase = gold[t] * k;
224 let pbase = pred[t] * k;
225 for (idx, &xv) in xt.iter().enumerate() {
226 self.emissions[gbase + idx] += xv;
227 self.emissions[pbase + idx] -= xv;
228 }
229 }
230 for t in 1..t_max {
232 let g = gold[t - 1] * n + gold[t];
233 let p = pred[t - 1] * n + pred[t];
234 if g != p {
235 self.transitions[g] += 1.0;
236 self.transitions[p] -= 1.0;
237 }
238 }
239 Ok(mistakes)
240 }
241}
242
243#[derive(Debug, Clone)]
247pub struct PerceptronConfig {
248 pub epochs: usize,
250 pub averaged: bool,
252}
253
254impl Default for PerceptronConfig {
255 fn default() -> Self {
256 Self {
257 epochs: 10,
258 averaged: true,
259 }
260 }
261}
262
263#[derive(Debug, Clone)]
266pub struct PerceptronExample {
267 pub x: Vec<f64>,
269 pub y: Vec<usize>,
271}
272
273#[derive(Debug, Clone)]
275pub struct PerceptronTrainResult {
276 pub model: StructuredPerceptron,
278 pub final_epoch_mistakes: usize,
280 pub epochs_run: usize,
282}
283
284pub fn train_perceptron(
298 n_labels: usize,
299 n_features: usize,
300 examples: &[PerceptronExample],
301 config: &PerceptronConfig,
302) -> SeqResult<PerceptronTrainResult> {
303 if examples.is_empty() {
304 return Err(SeqError::EmptyInput);
305 }
306 let mut model = StructuredPerceptron::zeros(n_labels, n_features)?;
307 let p = model.param_count();
308 let mut total = vec![0.0_f64; p];
310 let mut n_updates = 0u64;
311 let mut final_mistakes = 0usize;
312
313 for epoch in 0..config.epochs.max(1) {
314 let mut epoch_mistakes = 0usize;
315 for ex in examples {
316 let t_max = ex.y.len();
317 if t_max == 0 || ex.x.len() != t_max * n_features {
318 return Err(SeqError::ShapeMismatch {
319 expected: t_max * n_features,
320 got: ex.x.len(),
321 });
322 }
323 let pred = model.decode(&ex.x)?;
324 let mistakes = model.update(&ex.x, &ex.y, &pred)?;
325 epoch_mistakes += mistakes;
326
327 if config.averaged {
328 accumulate(&model, &mut total);
330 n_updates += 1;
331 }
332 }
333 final_mistakes = epoch_mistakes;
334 if epoch_mistakes == 0 {
336 return finish(model, total, n_updates, final_mistakes, epoch + 1, config);
338 }
339 }
340
341 finish(
342 model,
343 total,
344 n_updates,
345 final_mistakes,
346 config.epochs.max(1),
347 config,
348 )
349}
350
351fn accumulate(model: &StructuredPerceptron, total: &mut [f64]) {
353 let cut = model.emissions.len();
354 for (t, &e) in total[..cut].iter_mut().zip(model.emissions.iter()) {
355 *t += e;
356 }
357 for (t, &tr) in total[cut..].iter_mut().zip(model.transitions.iter()) {
358 *t += tr;
359 }
360}
361
362fn finish(
364 mut model: StructuredPerceptron,
365 total: Vec<f64>,
366 n_updates: u64,
367 final_mistakes: usize,
368 epochs_run: usize,
369 config: &PerceptronConfig,
370) -> SeqResult<PerceptronTrainResult> {
371 if config.averaged && n_updates > 0 {
372 let inv = 1.0 / n_updates as f64;
373 let cut = model.emissions.len();
374 for (e, &t) in model.emissions.iter_mut().zip(total[..cut].iter()) {
375 *e = t * inv;
376 }
377 for (tr, &t) in model.transitions.iter_mut().zip(total[cut..].iter()) {
378 *tr = t * inv;
379 }
380 }
381 Ok(PerceptronTrainResult {
382 model,
383 final_epoch_mistakes: final_mistakes,
384 epochs_run,
385 })
386}
387
388#[cfg(test)]
391mod tests {
392 use super::*;
393
394 fn toy_examples() -> Vec<PerceptronExample> {
398 vec![
399 PerceptronExample {
400 x: vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
402 y: vec![0, 1, 0],
403 },
404 PerceptronExample {
405 x: vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
407 y: vec![1, 0, 1],
408 },
409 ]
410 }
411
412 #[test]
413 fn zeros_rejects_bad_dims() {
414 assert!(StructuredPerceptron::zeros(0, 3).is_err());
415 assert!(StructuredPerceptron::zeros(3, 0).is_err());
416 assert!(StructuredPerceptron::zeros(2, 2).is_ok());
417 }
418
419 #[test]
420 fn param_count_correct() {
421 let m = StructuredPerceptron::zeros(3, 4).expect("ok");
422 assert_eq!(m.param_count(), 3 * 4 + 3 * 3);
423 }
424
425 #[test]
426 fn zero_model_decodes_first_label() {
427 let m = StructuredPerceptron::zeros(2, 2).expect("ok");
429 let y = m.decode(&[1.0, 0.0, 0.0, 1.0]).expect("ok");
430 assert_eq!(y, vec![0, 0]);
431 }
432
433 #[test]
434 fn decode_rejects_empty() {
435 let m = StructuredPerceptron::zeros(2, 2).expect("ok");
436 assert!(matches!(m.decode(&[]), Err(SeqError::EmptyInput)));
437 }
438
439 #[test]
440 fn decode_rejects_bad_shape() {
441 let m = StructuredPerceptron::zeros(2, 3).expect("ok");
442 assert!(matches!(
444 m.decode(&[1.0, 2.0, 3.0, 4.0]),
445 Err(SeqError::ShapeMismatch { .. })
446 ));
447 }
448
449 #[test]
450 fn sequence_score_rejects_oob_label() {
451 let m = StructuredPerceptron::zeros(2, 2).expect("ok");
452 assert!(matches!(
453 m.sequence_score(&[1.0, 0.0], &[5]),
454 Err(SeqError::IndexOutOfBounds { .. })
455 ));
456 }
457
458 #[test]
459 fn update_rejects_length_mismatch() {
460 let mut m = StructuredPerceptron::zeros(2, 2).expect("ok");
461 assert!(matches!(
462 m.update(&[1.0, 0.0], &[0], &[0, 1]),
463 Err(SeqError::LengthMismatch { .. })
464 ));
465 }
466
467 #[test]
468 fn update_counts_mistakes_and_moves_weights() {
469 let mut m = StructuredPerceptron::zeros(2, 2).expect("ok");
470 let x = vec![1.0, 0.0, 0.0, 1.0]; let gold = vec![0, 1];
472 let pred = vec![1, 0];
473 let mistakes = m.update(&x, &gold, &pred).expect("ok");
474 assert_eq!(mistakes, 2);
475 assert!(m.emissions[0] > 0.0);
477 let sg = m.sequence_score(&x, &gold).expect("ok");
479 let sp = m.sequence_score(&x, &pred).expect("ok");
480 assert!(sg > sp, "gold {sg} should exceed pred {sp}");
481 }
482
483 #[test]
484 fn update_no_mistakes_is_noop() {
485 let mut m = StructuredPerceptron::zeros(2, 2).expect("ok");
486 let x = vec![1.0, 0.0, 0.0, 1.0];
487 let y = vec![0, 1];
488 let before = m.emissions.clone();
489 let mistakes = m.update(&x, &y, &y).expect("ok");
490 assert_eq!(mistakes, 0);
491 assert_eq!(before, m.emissions);
492 }
493
494 #[test]
495 fn train_rejects_empty() {
496 let cfg = PerceptronConfig::default();
497 assert!(matches!(
498 train_perceptron(2, 2, &[], &cfg),
499 Err(SeqError::EmptyInput)
500 ));
501 }
502
503 #[test]
504 fn train_learns_separable_data() {
505 let ex = toy_examples();
506 let cfg = PerceptronConfig {
507 epochs: 20,
508 averaged: false,
509 };
510 let res = train_perceptron(2, 2, &ex, &cfg).expect("ok");
511 for e in &ex {
513 let pred = res.model.decode(&e.x).expect("ok");
514 assert_eq!(pred, e.y, "model failed to fit training example");
515 }
516 }
517
518 #[test]
519 fn train_converges_to_zero_mistakes() {
520 let ex = toy_examples();
521 let cfg = PerceptronConfig {
522 epochs: 50,
523 averaged: false,
524 };
525 let res = train_perceptron(2, 2, &ex, &cfg).expect("ok");
526 assert_eq!(
527 res.final_epoch_mistakes, 0,
528 "separable data should converge to 0 mistakes"
529 );
530 assert!(res.epochs_run <= 50);
531 }
532
533 #[test]
534 fn averaged_perceptron_fits_and_is_finite() {
535 let ex = toy_examples();
536 let cfg = PerceptronConfig {
537 epochs: 20,
538 averaged: true,
539 };
540 let res = train_perceptron(2, 2, &ex, &cfg).expect("ok");
541 assert!(res.model.emissions.iter().all(|v| v.is_finite()));
542 for e in &ex {
543 let pred = res.model.decode(&e.x).expect("ok");
544 assert_eq!(pred, e.y);
545 }
546 }
547
548 #[test]
549 fn averaging_equals_mean_of_trajectory() {
550 let ex = vec![PerceptronExample {
557 x: vec![1.0, 0.0, 0.0, 1.0],
558 y: vec![0, 1],
559 }];
560 let avg = train_perceptron(
561 2,
562 2,
563 &ex,
564 &PerceptronConfig {
565 epochs: 2,
566 averaged: true,
567 },
568 )
569 .expect("ok");
570 let raw = train_perceptron(
571 2,
572 2,
573 &ex,
574 &PerceptronConfig {
575 epochs: 2,
576 averaged: false,
577 },
578 )
579 .expect("ok");
580 assert_eq!(avg.model.decode(&ex[0].x).expect("d"), ex[0].y);
582 assert_eq!(raw.model.decode(&ex[0].x).expect("d"), ex[0].y);
583 for (a, r) in avg.model.emissions.iter().zip(raw.model.emissions.iter()) {
586 assert!(a.abs() <= r.abs() + 1e-9, "avg {a} exceeds raw {r}");
587 }
588 }
589
590 #[test]
591 fn averaging_shrinks_when_trajectory_varies() {
592 let ex = vec![
597 PerceptronExample {
598 x: vec![1.0, 0.0],
599 y: vec![0],
600 },
601 PerceptronExample {
602 x: vec![1.0, 0.0],
603 y: vec![1],
604 },
605 ];
606 let avg = train_perceptron(
607 2,
608 2,
609 &ex,
610 &PerceptronConfig {
611 epochs: 6,
612 averaged: true,
613 },
614 )
615 .expect("ok");
616 let raw = train_perceptron(
617 2,
618 2,
619 &ex,
620 &PerceptronConfig {
621 epochs: 6,
622 averaged: false,
623 },
624 )
625 .expect("ok");
626 let diff: f64 = avg
627 .model
628 .emissions
629 .iter()
630 .zip(raw.model.emissions.iter())
631 .map(|(a, b)| (a - b).abs())
632 .sum();
633 assert!(
634 diff > 1e-9,
635 "with a non-separable oscillating dataset averaging must differ from final"
636 );
637 }
638
639 #[test]
640 fn train_rejects_inconsistent_example_shape() {
641 let bad = vec![PerceptronExample {
642 x: vec![1.0, 0.0, 0.0], y: vec![0, 1],
644 }];
645 let cfg = PerceptronConfig::default();
646 assert!(matches!(
647 train_perceptron(2, 2, &bad, &cfg),
648 Err(SeqError::ShapeMismatch { .. })
649 ));
650 }
651}