1use crate::error::{SeqError, SeqResult};
26
27#[derive(Debug, Clone)]
29pub struct SkipChainConfig {
30 pub n_labels: usize,
32 pub max_bp_iters: usize,
34 pub bp_tol: f64,
36}
37
38#[derive(Debug, Clone)]
40pub struct SkipChainCrf {
41 cfg: SkipChainConfig,
42 transition: Vec<f64>,
44 skip_potential: Vec<f64>,
46 damping: f64,
48}
49
50#[derive(Debug, Clone, Copy)]
52struct Edge {
53 u: usize,
55 v: usize,
57 is_chain: bool,
59}
60
61fn log_sum_exp(xs: &[f64]) -> f64 {
63 let mut m = f64::NEG_INFINITY;
64 for &x in xs {
65 if x > m {
66 m = x;
67 }
68 }
69 if m == f64::NEG_INFINITY {
70 return f64::NEG_INFINITY;
71 }
72 let mut s = 0.0;
73 for &x in xs {
74 s += (x - m).exp();
75 }
76 m + s.ln()
77}
78
79fn max_of(xs: &[f64]) -> f64 {
81 let mut m = f64::NEG_INFINITY;
82 for &x in xs {
83 if x > m {
84 m = x;
85 }
86 }
87 m
88}
89
90impl SkipChainCrf {
91 pub fn new(
93 cfg: SkipChainConfig,
94 transition: Vec<f64>,
95 skip_potential: Vec<f64>,
96 ) -> SeqResult<Self> {
97 if cfg.n_labels == 0 {
98 return Err(SeqError::InvalidConfiguration(
99 "n_labels must be >= 1".to_string(),
100 ));
101 }
102 if cfg.max_bp_iters == 0 {
103 return Err(SeqError::InvalidConfiguration(
104 "max_bp_iters must be >= 1".to_string(),
105 ));
106 }
107 if cfg.bp_tol <= 0.0 || cfg.bp_tol.is_nan() {
108 return Err(SeqError::InvalidParameter {
109 name: "bp_tol".to_string(),
110 value: cfg.bp_tol,
111 });
112 }
113 let l2 = cfg.n_labels * cfg.n_labels;
114 if transition.len() != l2 {
115 return Err(SeqError::ShapeMismatch {
116 expected: l2,
117 got: transition.len(),
118 });
119 }
120 if skip_potential.len() != l2 {
121 return Err(SeqError::ShapeMismatch {
122 expected: l2,
123 got: skip_potential.len(),
124 });
125 }
126 Ok(Self {
127 cfg,
128 transition,
129 skip_potential,
130 damping: 0.5,
131 })
132 }
133
134 pub fn with_damping(mut self, damping: f64) -> SeqResult<Self> {
136 if damping <= 0.0 || damping > 1.0 || damping.is_nan() {
137 return Err(SeqError::InvalidParameter {
138 name: "damping".to_string(),
139 value: damping,
140 });
141 }
142 self.damping = damping;
143 Ok(self)
144 }
145
146 pub fn n_labels(&self) -> usize {
148 self.cfg.n_labels
149 }
150
151 fn prepare_edges(
153 &self,
154 unary: &[f64],
155 seq_len: usize,
156 skip_edges: &[(usize, usize)],
157 ) -> SeqResult<Vec<Edge>> {
158 let nl = self.cfg.n_labels;
159 if seq_len == 0 {
160 return Err(SeqError::EmptyInput);
161 }
162 if unary.len() != seq_len * nl {
163 return Err(SeqError::ShapeMismatch {
164 expected: seq_len * nl,
165 got: unary.len(),
166 });
167 }
168 let mut edges: Vec<Edge> = Vec::with_capacity(seq_len.saturating_sub(1) + skip_edges.len());
169 for t in 0..seq_len.saturating_sub(1) {
170 edges.push(Edge {
171 u: t,
172 v: t + 1,
173 is_chain: true,
174 });
175 }
176 for &(i, j) in skip_edges {
177 if i >= seq_len || j >= seq_len {
178 return Err(SeqError::IndexOutOfBounds {
179 index: i.max(j),
180 len: seq_len,
181 });
182 }
183 if i >= j {
184 return Err(SeqError::GraphInvariantViolated(format!(
185 "skip edge ({i}, {j}) must have i < j"
186 )));
187 }
188 edges.push(Edge {
189 u: i,
190 v: j,
191 is_chain: false,
192 });
193 }
194 Ok(edges)
195 }
196
197 #[inline]
200 fn edge_log_potential(&self, edge: &Edge, src: usize, l_src: usize, l_dst: usize) -> f64 {
201 let nl = self.cfg.n_labels;
202 let table = if edge.is_chain {
203 &self.transition
204 } else {
205 &self.skip_potential
206 };
207 if src == edge.u {
209 table[l_src * nl + l_dst]
210 } else {
211 table[l_dst * nl + l_src]
212 }
213 }
214
215 fn run_bp(
222 &self,
223 unary: &[f64],
224 seq_len: usize,
225 edges: &[Edge],
226 combine: fn(&[f64]) -> f64,
227 ) -> (Vec<f64>, usize, bool) {
228 let nl = self.cfg.n_labels;
229 let n_slots = edges.len() * 2;
230 let mut log_msg = vec![0.0; n_slots * nl];
231 let mut new_log_msg = log_msg.clone();
232
233 let mut incoming: Vec<Vec<(usize, usize)>> = vec![Vec::new(); seq_len];
236 for (e_idx, e) in edges.iter().enumerate() {
237 incoming[e.u].push((e_idx, e_idx * 2 + 1));
239 incoming[e.v].push((e_idx, e_idx * 2));
241 }
242
243 let mut iters = 0;
244 let mut converged = false;
245 let mut terms = vec![0.0; nl];
246
247 for it in 0..self.cfg.max_bp_iters {
248 iters = it + 1;
249 for (e_idx, e) in edges.iter().enumerate() {
250 for &(src, dst, out_slot) in &[(e.u, e.v, e_idx * 2), (e.v, e.u, e_idx * 2 + 1)] {
252 let _ = dst;
253 let mut out = vec![f64::NEG_INFINITY; nl];
254 for l_dst in 0..nl {
255 for l_src in 0..nl {
256 let mut acc = unary[src * nl + l_src]
259 + self.edge_log_potential(e, src, l_src, l_dst);
260 for &(k_edge, slot) in &incoming[src] {
261 if k_edge == e_idx {
262 continue;
263 }
264 acc += log_msg[slot * nl + l_src];
265 }
266 terms[l_src] = acc;
267 }
268 out[l_dst] = combine(&terms);
269 }
270 let m = max_of(&out);
272 if m != f64::NEG_INFINITY {
273 for v in out.iter_mut() {
274 *v -= m;
275 }
276 }
277 for l in 0..nl {
279 new_log_msg[out_slot * nl + l] = (1.0 - self.damping)
280 * log_msg[out_slot * nl + l]
281 + self.damping * out[l];
282 }
283 }
284 }
285 let mut max_diff = 0.0_f64;
287 for k in 0..log_msg.len() {
288 let d = (new_log_msg[k] - log_msg[k]).abs();
289 if d > max_diff {
290 max_diff = d;
291 }
292 }
293 log_msg.copy_from_slice(&new_log_msg);
294 if max_diff < self.cfg.bp_tol {
295 converged = true;
296 break;
297 }
298 }
299 (log_msg, iters, converged)
300 }
301
302 fn position_belief(
305 &self,
306 unary: &[f64],
307 edges: &[Edge],
308 log_msg: &[f64],
309 incoming: &[Vec<(usize, usize)>],
310 pos: usize,
311 ) -> Vec<f64> {
312 let nl = self.cfg.n_labels;
313 let mut belief = vec![0.0; nl];
314 for l in 0..nl {
315 belief[l] = unary[pos * nl + l];
316 }
317 for &(_e_idx, slot) in &incoming[pos] {
318 for l in 0..nl {
319 belief[l] += log_msg[slot * nl + l];
320 }
321 }
322 let _ = edges;
324 belief
325 }
326
327 fn build_incoming(seq_len: usize, edges: &[Edge]) -> Vec<Vec<(usize, usize)>> {
329 let mut incoming: Vec<Vec<(usize, usize)>> = vec![Vec::new(); seq_len];
330 for (e_idx, e) in edges.iter().enumerate() {
331 incoming[e.u].push((e_idx, e_idx * 2 + 1));
332 incoming[e.v].push((e_idx, e_idx * 2));
333 }
334 incoming
335 }
336
337 pub fn infer_marginals(
340 &self,
341 unary: &[f64],
342 seq_len: usize,
343 skip_edges: &[(usize, usize)],
344 ) -> SeqResult<Vec<f64>> {
345 let nl = self.cfg.n_labels;
346 let edges = self.prepare_edges(unary, seq_len, skip_edges)?;
347 let (log_msg, _iters, _converged) = self.run_bp(unary, seq_len, &edges, log_sum_exp);
348 let incoming = Self::build_incoming(seq_len, &edges);
349
350 let mut marginals = vec![0.0; seq_len * nl];
351 for pos in 0..seq_len {
352 let belief = self.position_belief(unary, &edges, &log_msg, &incoming, pos);
353 let logz = log_sum_exp(&belief);
354 if logz == f64::NEG_INFINITY {
355 let u = 1.0 / nl as f64;
356 for l in 0..nl {
357 marginals[pos * nl + l] = u;
358 }
359 } else {
360 for l in 0..nl {
361 marginals[pos * nl + l] = (belief[l] - logz).exp();
362 }
363 }
364 }
365 Ok(marginals)
366 }
367
368 pub fn infer_marginals_with_status(
371 &self,
372 unary: &[f64],
373 seq_len: usize,
374 skip_edges: &[(usize, usize)],
375 ) -> SeqResult<(Vec<f64>, usize, bool)> {
376 let nl = self.cfg.n_labels;
377 let edges = self.prepare_edges(unary, seq_len, skip_edges)?;
378 let (log_msg, iters, converged) = self.run_bp(unary, seq_len, &edges, log_sum_exp);
379 let incoming = Self::build_incoming(seq_len, &edges);
380
381 let mut marginals = vec![0.0; seq_len * nl];
382 for pos in 0..seq_len {
383 let belief = self.position_belief(unary, &edges, &log_msg, &incoming, pos);
384 let logz = log_sum_exp(&belief);
385 if logz == f64::NEG_INFINITY {
386 let u = 1.0 / nl as f64;
387 for l in 0..nl {
388 marginals[pos * nl + l] = u;
389 }
390 } else {
391 for l in 0..nl {
392 marginals[pos * nl + l] = (belief[l] - logz).exp();
393 }
394 }
395 }
396 Ok((marginals, iters, converged))
397 }
398
399 pub fn decode(
401 &self,
402 unary: &[f64],
403 seq_len: usize,
404 skip_edges: &[(usize, usize)],
405 ) -> SeqResult<Vec<usize>> {
406 let nl = self.cfg.n_labels;
407 let edges = self.prepare_edges(unary, seq_len, skip_edges)?;
408 let (log_msg, _iters, _converged) = self.run_bp(unary, seq_len, &edges, max_of);
409 let incoming = Self::build_incoming(seq_len, &edges);
410
411 let mut labels = vec![0usize; seq_len];
412 for pos in 0..seq_len {
413 let belief = self.position_belief(unary, &edges, &log_msg, &incoming, pos);
414 let mut best_l = 0usize;
415 let mut best_v = f64::NEG_INFINITY;
416 for l in 0..nl {
417 if belief[l] > best_v {
418 best_v = belief[l];
419 best_l = l;
420 }
421 }
422 labels[pos] = best_l;
423 }
424 Ok(labels)
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use crate::crf::linear_chain_crf::LinearChainCrf;
432 use crate::crf::viterbi_decode::viterbi_decode;
433 use crate::hmm::forward_backward::logsumexp;
434
435 fn cfg(n_labels: usize) -> SkipChainConfig {
436 SkipChainConfig {
437 n_labels,
438 max_bp_iters: 200,
439 bp_tol: 1e-10,
440 }
441 }
442
443 fn exact_chain_marginals(emit: &[f64], transition: &[f64], n: usize, t_max: usize) -> Vec<f64> {
447 let mut alpha = vec![f64::NEG_INFINITY; t_max * n];
448 alpha[..n].copy_from_slice(&emit[..n]);
449 let mut tmp = vec![0.0; n];
450 for t in 1..t_max {
451 for j in 0..n {
452 for i in 0..n {
453 tmp[i] = alpha[(t - 1) * n + i] + transition[i * n + j];
454 }
455 alpha[t * n + j] = logsumexp(&tmp) + emit[t * n + j];
456 }
457 }
458 let mut beta = vec![0.0; t_max * n];
459 for t in (0..t_max - 1).rev() {
460 for i in 0..n {
461 for j in 0..n {
462 tmp[j] = transition[i * n + j] + emit[(t + 1) * n + j] + beta[(t + 1) * n + j];
463 }
464 beta[t * n + i] = logsumexp(&tmp);
465 }
466 }
467 let log_z = logsumexp(&alpha[(t_max - 1) * n..]);
468 let mut marg = vec![0.0; t_max * n];
469 for t in 0..t_max {
470 for j in 0..n {
471 marg[t * n + j] = (alpha[t * n + j] + beta[t * n + j] - log_z).exp();
472 }
473 }
474 marg
475 }
476
477 #[test]
478 fn marginals_shape() {
479 let crf = SkipChainCrf::new(cfg(3), vec![0.0; 9], vec![0.0; 9]).expect("new");
480 let unary = vec![0.0; 4 * 3];
481 let m = crf.infer_marginals(&unary, 4, &[]).expect("marg");
482 assert_eq!(m.len(), 4 * 3);
483 }
484
485 #[test]
486 fn marginals_each_position_sums_to_one() {
487 let transition = vec![0.5, -0.2, 0.1, 0.3];
488 let crf = SkipChainCrf::new(cfg(2), transition, vec![0.0, 0.0, 0.0, 0.0]).expect("new");
489 let unary = vec![1.0, -0.5, 0.2, 0.7, -0.3, 0.4];
490 let m = crf.infer_marginals(&unary, 3, &[(0, 2)]).expect("marg");
491 for t in 0..3 {
492 let s: f64 = m[t * 2..t * 2 + 2].iter().sum();
493 assert!((s - 1.0).abs() < 1e-9, "pos {t} sum {s}");
494 }
495 }
496
497 #[test]
498 fn no_skip_marginals_equal_forward_backward() {
499 let n = 3;
500 let transition = vec![0.4, -0.1, 0.2, -0.3, 0.5, 0.0, 0.1, -0.2, 0.6];
501 let crf = SkipChainCrf::new(cfg(n), transition.clone(), vec![0.0; 9]).expect("new");
502 let t_max = 5;
503 let mut unary = vec![0.0; t_max * n];
505 for t in 0..t_max {
506 for l in 0..n {
507 unary[t * n + l] = ((t * 7 + l * 3) as f64 % 5.0) - 2.0 + 0.1 * (t as f64);
508 }
509 }
510 let bp = crf.infer_marginals(&unary, t_max, &[]).expect("bp");
511 let exact = exact_chain_marginals(&unary, &transition, n, t_max);
512 for k in 0..t_max * n {
513 assert!(
514 (bp[k] - exact[k]).abs() < 1e-5,
515 "idx {k}: bp={} exact={}",
516 bp[k],
517 exact[k]
518 );
519 }
520 }
521
522 #[test]
523 fn no_skip_marginals_equal_brute_force_short() {
524 let n = 2;
526 let transition = vec![0.3, -0.4, 0.2, 0.5];
527 let crf = SkipChainCrf::new(cfg(n), transition.clone(), vec![0.0; 4]).expect("new");
528 let t_max = 3;
529 let unary = vec![0.5, -0.2, 0.1, 0.7, -0.3, 0.4];
530 let bp = crf.infer_marginals(&unary, t_max, &[]).expect("bp");
531 let mut marg = vec![0.0; t_max * n];
533 let mut z = 0.0;
534 for a in 0..n {
535 for b in 0..n {
536 for c in 0..n {
537 let y = [a, b, c];
538 let mut score = 0.0;
539 for (t, &yt) in y.iter().enumerate() {
540 score += unary[t * n + yt];
541 if t > 0 {
542 score += transition[y[t - 1] * n + yt];
543 }
544 }
545 let p = score.exp();
546 z += p;
547 for (t, &yt) in y.iter().enumerate() {
548 marg[t * n + yt] += p;
549 }
550 }
551 }
552 }
553 for v in marg.iter_mut() {
554 *v /= z;
555 }
556 for k in 0..t_max * n {
557 assert!(
558 (bp[k] - marg[k]).abs() < 1e-6,
559 "idx {k}: {} vs {}",
560 bp[k],
561 marg[k]
562 );
563 }
564 }
565
566 #[test]
567 fn no_skip_decode_equals_viterbi() {
568 let n = 3;
570 let k = n; let transition = vec![0.4, -0.1, 0.2, -0.3, 0.5, 0.0, 0.1, -0.2, 0.6];
572 let crf = SkipChainCrf::new(cfg(n), transition.clone(), vec![0.0; 9]).expect("new");
573 let t_max = 6;
574 let mut unary = vec![0.0; t_max * n];
575 for t in 0..t_max {
576 for l in 0..n {
577 unary[t * n + l] = ((t * 5 + l * 11) as f64 % 7.0) - 3.0;
578 }
579 }
580 let bp_labels = crf.decode(&unary, t_max, &[]).expect("decode");
581
582 let mut lc = LinearChainCrf::zeros(n, k).expect("lc");
585 lc.transitions = transition;
586 for l in 0..n {
588 for f in 0..k {
589 lc.emissions[l * k + f] = if l == f { 1.0 } else { 0.0 };
590 }
591 }
592 let mut x = vec![0.0; t_max * k];
593 for t in 0..t_max {
594 for f in 0..k {
595 x[t * k + f] = unary[t * n + f];
596 }
597 }
598 let vit = viterbi_decode(&lc, &x).expect("viterbi");
599 assert_eq!(bp_labels, vit);
600 }
601
602 #[test]
603 fn skip_edge_pulls_marginals_to_agreement() {
604 let n = 2;
607 let transition = vec![0.0, 0.0, 0.0, 0.0];
608 let skip = vec![2.0, -2.0, -2.0, 2.0];
610 let crf = SkipChainCrf::new(cfg(n), transition, skip).expect("new");
611 let t_max = 3;
612 let unary = vec![1.0, -1.0, 0.0, 0.0, -1.0, 1.0];
615 let no_skip = crf.infer_marginals(&unary, t_max, &[]).expect("ns");
616 let with_skip = crf.infer_marginals(&unary, t_max, &[(0, 2)]).expect("ws");
617 let dist_no = (no_skip[0] - no_skip[2 * n]).abs();
621 let dist_ws = (with_skip[0] - with_skip[2 * n]).abs();
622 assert!(
623 dist_ws < dist_no,
624 "skip edge should reduce disagreement: no={dist_no} ws={dist_ws}"
625 );
626 }
627
628 #[test]
629 fn decode_returns_valid_labels() {
630 let n = 4;
631 let crf = SkipChainCrf::new(cfg(n), vec![0.1; 16], vec![0.0; 16]).expect("new");
632 let t_max = 5;
633 let mut unary = vec![0.0; t_max * n];
634 for (i, v) in unary.iter_mut().enumerate() {
635 *v = (i as f64 % 3.0) - 1.0;
636 }
637 let labels = crf.decode(&unary, t_max, &[(0, 3), (1, 4)]).expect("dec");
638 assert_eq!(labels.len(), t_max);
639 for &l in &labels {
640 assert!(l < n);
641 }
642 }
643
644 #[test]
645 fn bp_converges_on_short_sequence() {
646 let n = 2;
647 let transition = vec![0.5, -0.2, 0.1, 0.3];
648 let skip = vec![0.4, -0.1, -0.1, 0.4];
649 let crf = SkipChainCrf::new(cfg(n), transition, skip).expect("new");
650 let unary = vec![0.3, -0.2, 0.1, 0.4, -0.3, 0.2, 0.0, 0.1];
651 let (_m, iters, converged) = crf
652 .infer_marginals_with_status(&unary, 4, &[(0, 3)])
653 .expect("bp");
654 assert!(converged, "BP should converge");
655 assert!(iters <= 200);
656 }
657
658 #[test]
659 fn uniform_unary_uniform_potentials_uniform_marginals() {
660 let n = 3;
661 let crf = SkipChainCrf::new(cfg(n), vec![0.0; 9], vec![0.0; 9]).expect("new");
662 let t_max = 4;
663 let unary = vec![0.0; t_max * n];
664 let m = crf.infer_marginals(&unary, t_max, &[(0, 2)]).expect("m");
665 for t in 0..t_max {
666 for l in 0..n {
667 assert!(
668 (m[t * n + l] - 1.0 / n as f64).abs() < 1e-9,
669 "pos {t} label {l}: {}",
670 m[t * n + l]
671 );
672 }
673 }
674 }
675
676 #[test]
677 fn deterministic_inference() {
678 let n = 2;
679 let transition = vec![0.5, -0.2, 0.1, 0.3];
680 let skip = vec![0.4, -0.1, -0.1, 0.4];
681 let crf = SkipChainCrf::new(cfg(n), transition, skip).expect("new");
682 let unary = vec![0.3, -0.2, 0.1, 0.4, -0.3, 0.2];
683 let a = crf.infer_marginals(&unary, 3, &[(0, 2)]).expect("a");
684 let b = crf.infer_marginals(&unary, 3, &[(0, 2)]).expect("b");
685 assert_eq!(a, b);
686 let da = crf.decode(&unary, 3, &[(0, 2)]).expect("da");
687 let db = crf.decode(&unary, 3, &[(0, 2)]).expect("db");
688 assert_eq!(da, db);
689 }
690
691 #[test]
692 fn seq_len_one_marginal_is_softmax() {
693 let n = 3;
694 let crf = SkipChainCrf::new(cfg(n), vec![0.0; 9], vec![0.0; 9]).expect("new");
695 let unary = vec![1.0, 0.0, -1.0];
696 let m = crf.infer_marginals(&unary, 1, &[]).expect("m");
697 let mx = unary.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
699 let exps: Vec<f64> = unary.iter().map(|&u| (u - mx).exp()).collect();
700 let s: f64 = exps.iter().sum();
701 for l in 0..n {
702 assert!((m[l] - exps[l] / s).abs() < 1e-12, "label {l}");
703 }
704 }
705
706 #[test]
707 fn single_label_trivial() {
708 let crf = SkipChainCrf::new(cfg(1), vec![0.0], vec![0.0]).expect("new");
709 let unary = vec![3.0, -1.0, 0.5];
710 let m = crf.infer_marginals(&unary, 3, &[(0, 2)]).expect("m");
711 for v in &m {
712 assert!((v - 1.0).abs() < 1e-12);
713 }
714 let labels = crf.decode(&unary, 3, &[(0, 2)]).expect("dec");
715 assert_eq!(labels, vec![0, 0, 0]);
716 }
717
718 #[test]
719 fn err_transition_wrong_length() {
720 let r = SkipChainCrf::new(cfg(2), vec![0.0; 3], vec![0.0; 4]);
721 assert!(matches!(r, Err(SeqError::ShapeMismatch { .. })));
722 }
723
724 #[test]
725 fn err_skip_potential_wrong_length() {
726 let r = SkipChainCrf::new(cfg(2), vec![0.0; 4], vec![0.0; 5]);
727 assert!(matches!(r, Err(SeqError::ShapeMismatch { .. })));
728 }
729
730 #[test]
731 fn err_unary_wrong_length() {
732 let crf = SkipChainCrf::new(cfg(2), vec![0.0; 4], vec![0.0; 4]).expect("new");
733 let r = crf.infer_marginals(&[0.0, 0.0, 0.0], 3, &[]);
734 assert!(matches!(r, Err(SeqError::ShapeMismatch { .. })));
735 }
736
737 #[test]
738 fn err_skip_edge_out_of_range() {
739 let crf = SkipChainCrf::new(cfg(2), vec![0.0; 4], vec![0.0; 4]).expect("new");
740 let unary = vec![0.0; 6];
741 let r = crf.infer_marginals(&unary, 3, &[(0, 9)]);
742 assert!(matches!(r, Err(SeqError::IndexOutOfBounds { .. })));
743 }
744
745 #[test]
746 fn err_skip_edge_i_ge_j() {
747 let crf = SkipChainCrf::new(cfg(2), vec![0.0; 4], vec![0.0; 4]).expect("new");
748 let unary = vec![0.0; 6];
749 let r = crf.infer_marginals(&unary, 3, &[(2, 1)]);
750 assert!(matches!(r, Err(SeqError::GraphInvariantViolated(_))));
751 let r2 = crf.infer_marginals(&unary, 3, &[(1, 1)]);
752 assert!(matches!(r2, Err(SeqError::GraphInvariantViolated(_))));
753 }
754
755 #[test]
756 fn err_n_labels_zero() {
757 let r = SkipChainCrf::new(cfg(0), vec![], vec![]);
758 assert!(matches!(r, Err(SeqError::InvalidConfiguration(_))));
759 }
760
761 #[test]
762 fn err_max_bp_iters_zero() {
763 let c = SkipChainConfig {
764 n_labels: 2,
765 max_bp_iters: 0,
766 bp_tol: 1e-6,
767 };
768 let r = SkipChainCrf::new(c, vec![0.0; 4], vec![0.0; 4]);
769 assert!(matches!(r, Err(SeqError::InvalidConfiguration(_))));
770 }
771
772 #[test]
773 fn err_bp_tol_non_positive() {
774 let c = SkipChainConfig {
775 n_labels: 2,
776 max_bp_iters: 10,
777 bp_tol: 0.0,
778 };
779 let r = SkipChainCrf::new(c, vec![0.0; 4], vec![0.0; 4]);
780 assert!(matches!(r, Err(SeqError::InvalidParameter { .. })));
781 let c2 = SkipChainConfig {
782 n_labels: 2,
783 max_bp_iters: 10,
784 bp_tol: -1.0,
785 };
786 let r2 = SkipChainCrf::new(c2, vec![0.0; 4], vec![0.0; 4]);
787 assert!(matches!(r2, Err(SeqError::InvalidParameter { .. })));
788 }
789
790 #[test]
791 fn err_empty_input_seq_len_zero() {
792 let crf = SkipChainCrf::new(cfg(2), vec![0.0; 4], vec![0.0; 4]).expect("new");
793 let r = crf.infer_marginals(&[], 0, &[]);
794 assert!(matches!(r, Err(SeqError::EmptyInput)));
795 }
796
797 #[test]
798 fn n_labels_accessor() {
799 let crf = SkipChainCrf::new(cfg(5), vec![0.0; 25], vec![0.0; 25]).expect("new");
800 assert_eq!(crf.n_labels(), 5);
801 }
802
803 #[test]
804 fn with_damping_validates() {
805 let crf = SkipChainCrf::new(cfg(2), vec![0.0; 4], vec![0.0; 4]).expect("new");
806 assert!(crf.clone().with_damping(0.3).is_ok());
807 assert!(crf.clone().with_damping(1.0).is_ok());
808 assert!(crf.clone().with_damping(0.0).is_err());
809 assert!(crf.with_damping(1.5).is_err());
810 }
811
812 #[test]
813 fn no_skip_decode_equals_viterbi_two_labels() {
814 let n = 2;
817 let transition = vec![0.8, -0.5, -0.3, 0.6];
818 let crf = SkipChainCrf::new(cfg(n), transition.clone(), vec![0.0; 4]).expect("new");
819 let t_max = 4;
820 let unary = vec![3.0, -1.0, -1.0, 3.0, -2.0, 2.0, 2.5, -1.5];
821 let bp_labels = crf.decode(&unary, t_max, &[]).expect("dec");
822 let mut lc = LinearChainCrf::zeros(n, n).expect("lc");
823 lc.transitions = transition;
824 for l in 0..n {
825 lc.emissions[l * n + l] = 1.0;
826 }
827 let x = unary.clone();
828 let vit = viterbi_decode(&lc, &x).expect("vit");
829 assert_eq!(bp_labels, vit);
830 }
831}