1use crate::error::{SeqError, SeqResult};
13use crate::handle::LcgRng;
14
15#[inline]
18fn logsumexp(xs: &[f64]) -> f64 {
19 let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
20 if m == f64::NEG_INFINITY {
21 return f64::NEG_INFINITY;
22 }
23 let s: f64 = xs.iter().map(|&x| (x - m).exp()).sum();
24 m + s.ln()
25}
26
27#[derive(Debug, Clone)]
31pub struct CrfSgdConfig {
32 pub n_tags: usize,
34 pub n_features: usize,
36 pub n_epochs: usize,
38 pub lr: f64,
40 pub l2_reg: f64,
42 pub adagrad: bool,
44}
45
46#[derive(Debug, Clone)]
54pub struct CrfSgd {
55 pub weights: Vec<f64>,
57 config: CrfSgdConfig,
59 adagrad_acc: Vec<f64>,
61}
62
63impl CrfSgd {
64 const ADAGRAD_EPS: f64 = 1e-8;
65
66 pub fn new(config: CrfSgdConfig, rng: &mut LcgRng) -> SeqResult<Self> {
70 if config.n_tags == 0 {
71 return Err(SeqError::InvalidConfiguration("n_tags must be > 0".into()));
72 }
73 if config.n_features == 0 {
74 return Err(SeqError::InvalidConfiguration(
75 "n_features must be > 0".into(),
76 ));
77 }
78 let n_emit = config.n_tags * config.n_features;
79 let n_tr = config.n_tags * config.n_tags;
80 let n_params = n_emit + n_tr;
81 let weights: Vec<f64> = (0..n_params).map(|_| rng.next_normal() * 0.1).collect();
82 let adagrad_acc = vec![0.0f64; n_params];
83 Ok(Self {
84 weights,
85 config,
86 adagrad_acc,
87 })
88 }
89
90 #[inline]
94 fn emit_idx(&self, tag: usize, feat: usize) -> usize {
95 tag * self.config.n_features + feat
96 }
97
98 #[inline]
100 fn tr_idx(&self, prev_tag: usize, curr_tag: usize) -> usize {
101 self.config.n_tags * self.config.n_features + prev_tag * self.config.n_tags + curr_tag
102 }
103
104 pub fn emission_weight(&self, tag: usize, feat: usize) -> f64 {
108 self.weights[self.emit_idx(tag, feat)]
109 }
110
111 pub fn transition_weight(&self, prev_tag: usize, curr_tag: usize) -> f64 {
113 self.weights[self.tr_idx(prev_tag, curr_tag)]
114 }
115
116 #[inline]
120 fn emit_score(&self, j: usize, feat: &[f64]) -> f64 {
121 let base = j * self.config.n_features;
122 let mut s = 0.0;
123 for f in 0..self.config.n_features {
124 s += self.weights[base + f] * feat[f];
125 }
126 s
127 }
128
129 pub fn log_partition(&self, features: &[Vec<f64>], seq_len: usize) -> SeqResult<f64> {
133 if seq_len == 0 {
134 return Err(SeqError::EmptyInput);
135 }
136 if features.len() < seq_len {
137 return Err(SeqError::ShapeMismatch {
138 expected: seq_len,
139 got: features.len(),
140 });
141 }
142 let n = self.config.n_tags;
143 let mut alpha = vec![f64::NEG_INFINITY; n];
144 for j in 0..n {
146 alpha[j] = self.emit_score(j, &features[0]);
147 }
148 let mut tmp = vec![0.0f64; n];
149 for t in 1..seq_len {
150 let mut alpha_new = vec![f64::NEG_INFINITY; n];
151 for j in 0..n {
152 for i in 0..n {
153 tmp[i] = alpha[i] + self.transition_weight(i, j);
154 }
155 alpha_new[j] = logsumexp(&tmp) + self.emit_score(j, &features[t]);
156 }
157 alpha = alpha_new;
158 }
159 Ok(logsumexp(&alpha))
160 }
161
162 fn forward_table(&self, features: &[Vec<f64>], seq_len: usize) -> Vec<Vec<f64>> {
165 let n = self.config.n_tags;
166 let mut table = vec![vec![f64::NEG_INFINITY; n]; seq_len];
167 for j in 0..n {
168 table[0][j] = self.emit_score(j, &features[0]);
169 }
170 let mut tmp = vec![0.0f64; n];
171 for t in 1..seq_len {
172 for j in 0..n {
173 for i in 0..n {
174 tmp[i] = table[t - 1][i] + self.transition_weight(i, j);
175 }
176 table[t][j] = logsumexp(&tmp) + self.emit_score(j, &features[t]);
177 }
178 }
179 table
180 }
181
182 fn backward_table(&self, features: &[Vec<f64>], seq_len: usize) -> Vec<Vec<f64>> {
183 let n = self.config.n_tags;
184 let mut table = vec![vec![0.0f64; n]; seq_len]; let mut tmp = vec![0.0f64; n];
186 for t in (0..seq_len - 1).rev() {
187 for i in 0..n {
188 for j in 0..n {
189 tmp[j] = self.transition_weight(i, j)
190 + self.emit_score(j, &features[t + 1])
191 + table[t + 1][j];
192 }
193 table[t][i] = logsumexp(&tmp);
194 }
195 }
196 table
197 }
198
199 fn gradient_one(&self, features: &[Vec<f64>], labels: &[usize]) -> SeqResult<(f64, Vec<f64>)> {
206 let seq_len = labels.len();
207 if seq_len == 0 {
208 return Err(SeqError::EmptyInput);
209 }
210 let n = self.config.n_tags;
211 let k = self.config.n_features;
212 let n_params = self.weights.len();
213
214 for (t, &y) in labels.iter().enumerate() {
216 if y >= n {
217 return Err(SeqError::IndexOutOfBounds { index: y, len: n });
218 }
219 if features[t].len() != k {
220 return Err(SeqError::ShapeMismatch {
221 expected: k,
222 got: features[t].len(),
223 });
224 }
225 }
226
227 let alpha = self.forward_table(features, seq_len);
229 let log_z = logsumexp(&alpha[seq_len - 1]);
230
231 let beta = self.backward_table(features, seq_len);
233
234 let mut score_true = self.emit_score(labels[0], &features[0]);
236 for t in 1..seq_len {
237 score_true += self.transition_weight(labels[t - 1], labels[t])
238 + self.emit_score(labels[t], &features[t]);
239 }
240 let nll = log_z - score_true;
241
242 let mut grad = vec![0.0f64; n_params];
245
246 for t in 0..seq_len {
248 let feat = &features[t];
249 for j in 0..n {
250 let log_gamma = alpha[t][j] + beta[t][j] - log_z;
251 let gamma = log_gamma.exp();
252 let base = self.emit_idx(j, 0);
253 for f in 0..k {
254 grad[base + f] += gamma * feat[f];
255 }
256 }
257 }
258 for t in 0..seq_len {
260 let feat = &features[t];
261 let j = labels[t];
262 let base = self.emit_idx(j, 0);
263 for f in 0..k {
264 grad[base + f] -= feat[f];
265 }
266 }
267
268 for t in 0..seq_len - 1 {
271 for i in 0..n {
272 for j in 0..n {
273 let log_xi = alpha[t][i]
274 + self.transition_weight(i, j)
275 + self.emit_score(j, &features[t + 1])
276 + beta[t + 1][j]
277 - log_z;
278 let xi = log_xi.exp();
279 grad[self.tr_idx(i, j)] += xi;
280 }
281 }
282 }
283 for t in 1..seq_len {
285 let (i, j) = (labels[t - 1], labels[t]);
286 grad[self.tr_idx(i, j)] -= 1.0;
287 }
288
289 Ok((nll, grad))
290 }
291
292 fn apply_update(&mut self, grad: &[f64]) {
296 let lr = self.config.lr;
297 let eps = Self::ADAGRAD_EPS;
298 let n_params = self.weights.len();
299
300 if self.config.adagrad {
301 for i in 0..n_params {
302 self.adagrad_acc[i] += grad[i] * grad[i];
303 let eff_lr = lr / (self.adagrad_acc[i] + eps).sqrt();
304 self.weights[i] -= eff_lr * grad[i];
305 }
306 } else {
307 for i in 0..n_params {
308 self.weights[i] -= lr * grad[i];
309 }
310 }
311 }
312
313 pub fn update_one(&mut self, features: &[Vec<f64>], labels: &[usize]) -> SeqResult<f64> {
319 let (nll, mut grad) = self.gradient_one(features, labels)?;
320 let l2 = self.config.l2_reg;
322 if l2 > 0.0 {
323 for i in 0..self.weights.len() {
324 grad[i] += l2 * self.weights[i];
325 }
326 }
327 self.apply_update(&grad);
328 Ok(nll)
329 }
330
331 pub fn fit(
333 &mut self,
334 all_features: &[Vec<Vec<f64>>],
335 all_labels: &[Vec<usize>],
336 ) -> SeqResult<Vec<f64>> {
337 if all_features.len() != all_labels.len() {
338 return Err(SeqError::LengthMismatch {
339 a: all_features.len(),
340 b: all_labels.len(),
341 });
342 }
343 let n_samples = all_features.len();
344 if n_samples == 0 {
345 return Err(SeqError::EmptyInput);
346 }
347 let mut epoch_losses = Vec::with_capacity(self.config.n_epochs);
348 for _epoch in 0..self.config.n_epochs {
349 let mut total_nll = 0.0;
350 for s in 0..n_samples {
351 total_nll += self.update_one(&all_features[s], &all_labels[s])?;
352 }
353 epoch_losses.push(total_nll / n_samples as f64);
354 }
355 Ok(epoch_losses)
356 }
357
358 pub fn decode(&self, features: &[Vec<f64>], seq_len: usize) -> SeqResult<Vec<usize>> {
362 if seq_len == 0 {
363 return Err(SeqError::EmptyInput);
364 }
365 if features.len() < seq_len {
366 return Err(SeqError::ShapeMismatch {
367 expected: seq_len,
368 got: features.len(),
369 });
370 }
371 let n = self.config.n_tags;
372 let mut viterbi = vec![f64::NEG_INFINITY; n];
373 let mut backptr = vec![vec![0usize; n]; seq_len];
374
375 for j in 0..n {
377 viterbi[j] = self.emit_score(j, &features[0]);
378 }
379
380 for t in 1..seq_len {
382 let mut viterbi_new = vec![f64::NEG_INFINITY; n];
383 for j in 0..n {
384 let mut best_score = f64::NEG_INFINITY;
385 let mut best_prev = 0;
386 for i in 0..n {
387 let s = viterbi[i] + self.transition_weight(i, j);
388 if s > best_score {
389 best_score = s;
390 best_prev = i;
391 }
392 }
393 viterbi_new[j] = best_score + self.emit_score(j, &features[t]);
394 backptr[t][j] = best_prev;
395 }
396 viterbi = viterbi_new;
397 }
398
399 let mut best_last = 0;
401 let mut best_val = f64::NEG_INFINITY;
402 for j in 0..n {
403 if viterbi[j] > best_val {
404 best_val = viterbi[j];
405 best_last = j;
406 }
407 }
408
409 let mut path = vec![0usize; seq_len];
411 path[seq_len - 1] = best_last;
412 for t in (0..seq_len - 1).rev() {
413 path[t] = backptr[t + 1][path[t + 1]];
414 }
415 Ok(path)
416 }
417}
418
419#[cfg(test)]
422mod tests {
423 use super::*;
424
425 fn make_config(adagrad: bool) -> CrfSgdConfig {
426 CrfSgdConfig {
427 n_tags: 3,
428 n_features: 4,
429 n_epochs: 5,
430 lr: 0.05,
431 l2_reg: 1e-4,
432 adagrad,
433 }
434 }
435
436 fn make_crf(adagrad: bool) -> CrfSgd {
437 let mut rng = LcgRng::new(42);
438 CrfSgd::new(make_config(adagrad), &mut rng).expect("construction failed")
439 }
440
441 fn simple_data(n_tags: usize, n_features: usize) -> (Vec<Vec<f64>>, Vec<usize>) {
442 let features = vec![
444 vec![1.0, 0.0, 0.5, -0.5],
445 vec![0.0, 1.0, -0.5, 0.5],
446 vec![0.5, 0.5, 0.0, 1.0],
447 ];
448 let features: Vec<Vec<f64>> = features
450 .into_iter()
451 .map(|f| f.into_iter().take(n_features).collect())
452 .collect();
453 let labels = vec![0, 1 % n_tags, 2 % n_tags];
454 (features, labels)
455 }
456
457 #[test]
458 fn weights_shape() {
459 let crf = make_crf(false);
460 assert_eq!(
461 crf.weights.len(),
462 3 * 4 + 3 * 3,
463 "weights.len() should be n_tags*n_features + n_tags*n_tags"
464 );
465 }
466
467 #[test]
468 fn decode_output_len() {
469 let crf = make_crf(false);
470 let (features, _) = simple_data(3, 4);
471 let seq_len = features.len();
472 let path = crf.decode(&features, seq_len).expect("decode failed");
473 assert_eq!(path.len(), seq_len);
474 }
475
476 #[test]
477 fn decode_valid_tags() {
478 let crf = make_crf(false);
479 let (features, _) = simple_data(3, 4);
480 let seq_len = features.len();
481 let path = crf.decode(&features, seq_len).expect("decode failed");
482 for &tag in &path {
483 assert!(tag < 3, "decoded tag {tag} >= n_tags=3");
484 }
485 }
486
487 #[test]
488 fn log_partition_finite() {
489 let crf = make_crf(false);
490 let (features, _) = simple_data(3, 4);
491 let lz = crf
492 .log_partition(&features, features.len())
493 .expect("lz failed");
494 assert!(lz.is_finite(), "log_partition should be finite, got {lz}");
495 }
496
497 #[test]
498 fn update_decreases_loss() {
499 let mut rng = LcgRng::new(7);
500 let mut config = make_config(true);
501 config.n_epochs = 30;
502 config.lr = 0.1;
503 config.n_features = 4;
504 config.n_tags = 3;
505 let mut crf = CrfSgd::new(config, &mut rng).expect("new failed");
506
507 let all_feats: Vec<Vec<Vec<f64>>> = (0..4)
509 .map(|seed| {
510 let mut r = LcgRng::new(seed as u64 + 1);
511 (0..3)
512 .map(|_| (0..4).map(|_| r.next_normal()).collect())
513 .collect()
514 })
515 .collect();
516 let all_labels: Vec<Vec<usize>> =
517 vec![vec![0, 1, 2], vec![2, 0, 1], vec![1, 2, 0], vec![0, 0, 1]];
518 let losses = crf.fit(&all_feats, &all_labels).expect("fit failed");
519 assert!(!losses.is_empty());
520 let first =
522 losses[..5.min(losses.len())].iter().sum::<f64>() / 5.0_f64.min(losses.len() as f64);
523 let last_start = losses.len().saturating_sub(5);
524 let last = losses[last_start..].iter().sum::<f64>() / (losses.len() - last_start) as f64;
525 assert!(
526 last < first,
527 "loss did not decrease: first={first:.4}, last={last:.4}"
528 );
529 }
530
531 #[test]
532 fn adagrad_different_from_sgd() {
533 let mut rng_sgd = LcgRng::new(42);
534 let mut rng_ada = LcgRng::new(42);
535 let mut config_sgd = make_config(false);
536 let mut config_ada = make_config(true);
537 config_sgd.n_epochs = 5;
538 config_ada.n_epochs = 5;
539 let mut crf_sgd = CrfSgd::new(config_sgd, &mut rng_sgd).expect("new failed");
540 let mut crf_ada = CrfSgd::new(config_ada, &mut rng_ada).expect("new failed");
541
542 let (features, labels) = simple_data(3, 4);
543 let all_feats = vec![features.clone()];
544 let all_labels = vec![labels.clone()];
545 crf_sgd.fit(&all_feats, &all_labels).expect("fit sgd");
546 crf_ada.fit(&all_feats, &all_labels).expect("fit ada");
547 let diff: f64 = crf_sgd
548 .weights
549 .iter()
550 .zip(&crf_ada.weights)
551 .map(|(a, b)| (a - b).abs())
552 .sum();
553 assert!(diff > 1e-12, "adagrad and sgd produced identical weights");
554 }
555
556 #[test]
557 fn viterbi_agrees_with_exhaustive() {
558 let mut rng = LcgRng::new(99);
560 let config = CrfSgdConfig {
561 n_tags: 2,
562 n_features: 3,
563 n_epochs: 1,
564 lr: 0.01,
565 l2_reg: 0.0,
566 adagrad: false,
567 };
568 let crf = CrfSgd::new(config, &mut rng).expect("new");
569 let features = vec![vec![1.0, -1.0, 0.5], vec![-0.5, 0.5, 1.0]];
570 let path = crf.decode(&features, 2).expect("decode");
571 let score_path = |y0: usize, y1: usize| -> f64 {
573 crf.emit_score(y0, &features[0])
574 + crf.transition_weight(y0, y1)
575 + crf.emit_score(y1, &features[1])
576 };
577 let mut best_score = f64::NEG_INFINITY;
578 let mut best_path = (0, 0);
579 for y0 in 0..2 {
580 for y1 in 0..2 {
581 let s = score_path(y0, y1);
582 if s > best_score {
583 best_score = s;
584 best_path = (y0, y1);
585 }
586 }
587 }
588 assert_eq!(path[0], best_path.0, "Viterbi y0 mismatch");
589 assert_eq!(path[1], best_path.1, "Viterbi y1 mismatch");
590 }
591
592 #[test]
593 fn emission_weight_correct() {
594 let crf = make_crf(false);
595 for tag in 0..3 {
596 for feat in 0..4 {
597 let expected = crf.weights[tag * 4 + feat];
598 assert_eq!(
599 crf.emission_weight(tag, feat),
600 expected,
601 "emission_weight({tag},{feat}) mismatch"
602 );
603 }
604 }
605 }
606
607 #[test]
608 fn n_tags_zero_error() {
609 let mut rng = LcgRng::new(1);
610 let config = CrfSgdConfig {
611 n_tags: 0,
612 n_features: 4,
613 n_epochs: 1,
614 lr: 0.01,
615 l2_reg: 0.0,
616 adagrad: false,
617 };
618 assert!(
619 CrfSgd::new(config, &mut rng).is_err(),
620 "n_tags=0 should fail"
621 );
622 }
623
624 #[test]
625 fn empty_sequence_error() {
626 let crf = make_crf(false);
627 let result = crf.decode(&[], 0);
628 assert!(result.is_err(), "decode on empty should fail");
629 }
630}