Skip to main content

oxicuda_seq/crf/
skip_chain.rs

1//! Skip-chain Conditional Random Fields.
2//!
3//! A skip-chain CRF (Sutton & McCallum 2004; Galley 2006) augments a linear-chain
4//! CRF with long-range "skip" edges that connect non-adjacent positions believed to
5//! be related (for example, repeated tokens that should receive the same label).
6//! The skip edges turn the chain into a loopy graph, so exact forward-backward no
7//! longer applies; inference is performed with **loopy belief propagation**
8//! (sum-product for marginals, max-product for decoding).
9//!
10//! ## Parameterisation (score / log-potential space)
11//!
12//! To match the rest of the `crf` module, every potential is a **log-potential**
13//! (additive score) and probabilities are proportional to `exp(score)`:
14//!
15//! * `unary[t * n_labels + l]` — log-potential of label `l` at position `t`.
16//! * `transition[prev * n_labels + cur]` — chain edge `(t, t+1)` log-potential.
17//! * `skip_potential[l_i * n_labels + l_j]` — skip edge `(i, j)` log-potential
18//!   (with `i < j`; the first index addresses position `i`, the second `j`).
19//!
20//! Messages are kept in the log domain and normalised every iteration (max
21//! subtracted) so they neither under- nor overflow on long, high-score chains.
22//! With **no** skip edges the factor graph is a tree and loopy BP reduces to exact
23//! forward-backward (sum-product) / Viterbi (max-product).
24
25use crate::error::{SeqError, SeqResult};
26
27/// Configuration for a skip-chain CRF.
28#[derive(Debug, Clone)]
29pub struct SkipChainConfig {
30    /// Number of labels per position.
31    pub n_labels: usize,
32    /// Maximum number of loopy-BP iterations.
33    pub max_bp_iters: usize,
34    /// Convergence tolerance on the max absolute message change (log domain).
35    pub bp_tol: f64,
36}
37
38/// A skip-chain CRF holding the (shared) chain and skip-edge log-potentials.
39#[derive(Debug, Clone)]
40pub struct SkipChainCrf {
41    cfg: SkipChainConfig,
42    /// Chain transition log-potentials, `n_labels × n_labels` row-major.
43    transition: Vec<f64>,
44    /// Skip-edge log-potentials, `n_labels × n_labels` row-major.
45    skip_potential: Vec<f64>,
46    /// Damping factor for message updates (fixed, in `(0, 1]`).
47    damping: f64,
48}
49
50/// Internal description of an undirected pairwise edge in the factor graph.
51#[derive(Debug, Clone, Copy)]
52struct Edge {
53    /// Lower-indexed endpoint position.
54    u: usize,
55    /// Higher-indexed endpoint position.
56    v: usize,
57    /// `true` if this is a chain edge (uses `transition`); `false` for a skip edge.
58    is_chain: bool,
59}
60
61/// Log-sum-exp of a slice (`-inf` for an all-`-inf` slice).
62fn log_sum_exp(xs: &[f64]) -> f64 {
63    let mut m = f64::NEG_INFINITY;
64    for &x in xs {
65        if x > m {
66            m = x;
67        }
68    }
69    if m == f64::NEG_INFINITY {
70        return f64::NEG_INFINITY;
71    }
72    let mut s = 0.0;
73    for &x in xs {
74        s += (x - m).exp();
75    }
76    m + s.ln()
77}
78
79/// Maximum of a slice (`-inf` for empty).
80fn max_of(xs: &[f64]) -> f64 {
81    let mut m = f64::NEG_INFINITY;
82    for &x in xs {
83        if x > m {
84            m = x;
85        }
86    }
87    m
88}
89
90impl SkipChainCrf {
91    /// Construct a skip-chain CRF, validating the potential shapes.
92    pub fn new(
93        cfg: SkipChainConfig,
94        transition: Vec<f64>,
95        skip_potential: Vec<f64>,
96    ) -> SeqResult<Self> {
97        if cfg.n_labels == 0 {
98            return Err(SeqError::InvalidConfiguration(
99                "n_labels must be >= 1".to_string(),
100            ));
101        }
102        if cfg.max_bp_iters == 0 {
103            return Err(SeqError::InvalidConfiguration(
104                "max_bp_iters must be >= 1".to_string(),
105            ));
106        }
107        if cfg.bp_tol <= 0.0 || cfg.bp_tol.is_nan() {
108            return Err(SeqError::InvalidParameter {
109                name: "bp_tol".to_string(),
110                value: cfg.bp_tol,
111            });
112        }
113        let l2 = cfg.n_labels * cfg.n_labels;
114        if transition.len() != l2 {
115            return Err(SeqError::ShapeMismatch {
116                expected: l2,
117                got: transition.len(),
118            });
119        }
120        if skip_potential.len() != l2 {
121            return Err(SeqError::ShapeMismatch {
122                expected: l2,
123                got: skip_potential.len(),
124            });
125        }
126        Ok(Self {
127            cfg,
128            transition,
129            skip_potential,
130            damping: 0.5,
131        })
132    }
133
134    /// Override the message-passing damping factor (must be in `(0, 1]`).
135    pub fn with_damping(mut self, damping: f64) -> SeqResult<Self> {
136        if damping <= 0.0 || damping > 1.0 || damping.is_nan() {
137            return Err(SeqError::InvalidParameter {
138                name: "damping".to_string(),
139                value: damping,
140            });
141        }
142        self.damping = damping;
143        Ok(self)
144    }
145
146    /// Number of labels.
147    pub fn n_labels(&self) -> usize {
148        self.cfg.n_labels
149    }
150
151    /// Validate the inference inputs and assemble the edge list (chain + skip).
152    fn prepare_edges(
153        &self,
154        unary: &[f64],
155        seq_len: usize,
156        skip_edges: &[(usize, usize)],
157    ) -> SeqResult<Vec<Edge>> {
158        let nl = self.cfg.n_labels;
159        if seq_len == 0 {
160            return Err(SeqError::EmptyInput);
161        }
162        if unary.len() != seq_len * nl {
163            return Err(SeqError::ShapeMismatch {
164                expected: seq_len * nl,
165                got: unary.len(),
166            });
167        }
168        let mut edges: Vec<Edge> = Vec::with_capacity(seq_len.saturating_sub(1) + skip_edges.len());
169        for t in 0..seq_len.saturating_sub(1) {
170            edges.push(Edge {
171                u: t,
172                v: t + 1,
173                is_chain: true,
174            });
175        }
176        for &(i, j) in skip_edges {
177            if i >= seq_len || j >= seq_len {
178                return Err(SeqError::IndexOutOfBounds {
179                    index: i.max(j),
180                    len: seq_len,
181                });
182            }
183            if i >= j {
184                return Err(SeqError::GraphInvariantViolated(format!(
185                    "skip edge ({i}, {j}) must have i < j"
186                )));
187            }
188            edges.push(Edge {
189                u: i,
190                v: j,
191                is_chain: false,
192            });
193        }
194        Ok(edges)
195    }
196
197    /// The oriented pairwise log-potential for `edge` between source label `l_src`
198    /// (at `src`) and destination label `l_dst` (at `dst`).
199    #[inline]
200    fn edge_log_potential(&self, edge: &Edge, src: usize, l_src: usize, l_dst: usize) -> f64 {
201        let nl = self.cfg.n_labels;
202        let table = if edge.is_chain {
203            &self.transition
204        } else {
205            &self.skip_potential
206        };
207        // Table is row-major over (position u, position v); orient by source.
208        if src == edge.u {
209            table[l_src * nl + l_dst]
210        } else {
211            table[l_dst * nl + l_src]
212        }
213    }
214
215    /// Run loopy BP with the supplied combine operator (`log_sum_exp` for
216    /// sum-product, `max_of` for max-product) and return the converged log-domain
217    /// directed messages plus the per-edge directions.
218    ///
219    /// Directed message layout: edge `e` direction `u→v` lives at slot `2*e`,
220    /// direction `v→u` at slot `2*e+1`; each slot holds `n_labels` log values.
221    fn run_bp(
222        &self,
223        unary: &[f64],
224        seq_len: usize,
225        edges: &[Edge],
226        combine: fn(&[f64]) -> f64,
227    ) -> (Vec<f64>, usize, bool) {
228        let nl = self.cfg.n_labels;
229        let n_slots = edges.len() * 2;
230        let mut log_msg = vec![0.0; n_slots * nl];
231        let mut new_log_msg = log_msg.clone();
232
233        // Precompute, for each position, the list of (edge_idx, incoming slot) so
234        // that gathering neighbour messages is cheap.
235        let mut incoming: Vec<Vec<(usize, usize)>> = vec![Vec::new(); seq_len];
236        for (e_idx, e) in edges.iter().enumerate() {
237            // Message arriving at u from v is stored in slot 2*e+1 (v→u).
238            incoming[e.u].push((e_idx, e_idx * 2 + 1));
239            // Message arriving at v from u is stored in slot 2*e (u→v).
240            incoming[e.v].push((e_idx, e_idx * 2));
241        }
242
243        let mut iters = 0;
244        let mut converged = false;
245        let mut terms = vec![0.0; nl];
246
247        for it in 0..self.cfg.max_bp_iters {
248            iters = it + 1;
249            for (e_idx, e) in edges.iter().enumerate() {
250                // Two directions: src=u→dst=v (slot 2*e) and src=v→dst=u (slot 2*e+1).
251                for &(src, dst, out_slot) in &[(e.u, e.v, e_idx * 2), (e.v, e.u, e_idx * 2 + 1)] {
252                    let _ = dst;
253                    let mut out = vec![f64::NEG_INFINITY; nl];
254                    for l_dst in 0..nl {
255                        for l_src in 0..nl {
256                            // Unary at src + oriented pairwise + product of all
257                            // incoming messages to src except along this edge.
258                            let mut acc = unary[src * nl + l_src]
259                                + self.edge_log_potential(e, src, l_src, l_dst);
260                            for &(k_edge, slot) in &incoming[src] {
261                                if k_edge == e_idx {
262                                    continue;
263                                }
264                                acc += log_msg[slot * nl + l_src];
265                            }
266                            terms[l_src] = acc;
267                        }
268                        out[l_dst] = combine(&terms);
269                    }
270                    // Normalise (subtract max) for stability.
271                    let m = max_of(&out);
272                    if m != f64::NEG_INFINITY {
273                        for v in out.iter_mut() {
274                            *v -= m;
275                        }
276                    }
277                    // Damped write.
278                    for l in 0..nl {
279                        new_log_msg[out_slot * nl + l] = (1.0 - self.damping)
280                            * log_msg[out_slot * nl + l]
281                            + self.damping * out[l];
282                    }
283                }
284            }
285            // Convergence on max absolute message change.
286            let mut max_diff = 0.0_f64;
287            for k in 0..log_msg.len() {
288                let d = (new_log_msg[k] - log_msg[k]).abs();
289                if d > max_diff {
290                    max_diff = d;
291                }
292            }
293            log_msg.copy_from_slice(&new_log_msg);
294            if max_diff < self.cfg.bp_tol {
295                converged = true;
296                break;
297            }
298        }
299        (log_msg, iters, converged)
300    }
301
302    /// Compute per-position belief = unary + sum of all incoming messages (log
303    /// domain) for position `pos`.
304    fn position_belief(
305        &self,
306        unary: &[f64],
307        edges: &[Edge],
308        log_msg: &[f64],
309        incoming: &[Vec<(usize, usize)>],
310        pos: usize,
311    ) -> Vec<f64> {
312        let nl = self.cfg.n_labels;
313        let mut belief = vec![0.0; nl];
314        for l in 0..nl {
315            belief[l] = unary[pos * nl + l];
316        }
317        for &(_e_idx, slot) in &incoming[pos] {
318            for l in 0..nl {
319                belief[l] += log_msg[slot * nl + l];
320            }
321        }
322        // `edges` is unused directly here but kept for signature symmetry.
323        let _ = edges;
324        belief
325    }
326
327    /// Build the per-position incoming-message index used by belief readout.
328    fn build_incoming(seq_len: usize, edges: &[Edge]) -> Vec<Vec<(usize, usize)>> {
329        let mut incoming: Vec<Vec<(usize, usize)>> = vec![Vec::new(); seq_len];
330        for (e_idx, e) in edges.iter().enumerate() {
331            incoming[e.u].push((e_idx, e_idx * 2 + 1));
332            incoming[e.v].push((e_idx, e_idx * 2));
333        }
334        incoming
335    }
336
337    /// Loopy sum-product BP returning per-position marginals (`seq_len × n_labels`,
338    /// each position normalised to sum to 1).
339    pub fn infer_marginals(
340        &self,
341        unary: &[f64],
342        seq_len: usize,
343        skip_edges: &[(usize, usize)],
344    ) -> SeqResult<Vec<f64>> {
345        let nl = self.cfg.n_labels;
346        let edges = self.prepare_edges(unary, seq_len, skip_edges)?;
347        let (log_msg, _iters, _converged) = self.run_bp(unary, seq_len, &edges, log_sum_exp);
348        let incoming = Self::build_incoming(seq_len, &edges);
349
350        let mut marginals = vec![0.0; seq_len * nl];
351        for pos in 0..seq_len {
352            let belief = self.position_belief(unary, &edges, &log_msg, &incoming, pos);
353            let logz = log_sum_exp(&belief);
354            if logz == f64::NEG_INFINITY {
355                let u = 1.0 / nl as f64;
356                for l in 0..nl {
357                    marginals[pos * nl + l] = u;
358                }
359            } else {
360                for l in 0..nl {
361                    marginals[pos * nl + l] = (belief[l] - logz).exp();
362                }
363            }
364        }
365        Ok(marginals)
366    }
367
368    /// Loopy sum-product BP returning per-position marginals together with whether
369    /// the message passing converged and the iteration count.
370    pub fn infer_marginals_with_status(
371        &self,
372        unary: &[f64],
373        seq_len: usize,
374        skip_edges: &[(usize, usize)],
375    ) -> SeqResult<(Vec<f64>, usize, bool)> {
376        let nl = self.cfg.n_labels;
377        let edges = self.prepare_edges(unary, seq_len, skip_edges)?;
378        let (log_msg, iters, converged) = self.run_bp(unary, seq_len, &edges, log_sum_exp);
379        let incoming = Self::build_incoming(seq_len, &edges);
380
381        let mut marginals = vec![0.0; seq_len * nl];
382        for pos in 0..seq_len {
383            let belief = self.position_belief(unary, &edges, &log_msg, &incoming, pos);
384            let logz = log_sum_exp(&belief);
385            if logz == f64::NEG_INFINITY {
386                let u = 1.0 / nl as f64;
387                for l in 0..nl {
388                    marginals[pos * nl + l] = u;
389                }
390            } else {
391                for l in 0..nl {
392                    marginals[pos * nl + l] = (belief[l] - logz).exp();
393                }
394            }
395        }
396        Ok((marginals, iters, converged))
397    }
398
399    /// Loopy max-product BP decoding, returning the best label per position.
400    pub fn decode(
401        &self,
402        unary: &[f64],
403        seq_len: usize,
404        skip_edges: &[(usize, usize)],
405    ) -> SeqResult<Vec<usize>> {
406        let nl = self.cfg.n_labels;
407        let edges = self.prepare_edges(unary, seq_len, skip_edges)?;
408        let (log_msg, _iters, _converged) = self.run_bp(unary, seq_len, &edges, max_of);
409        let incoming = Self::build_incoming(seq_len, &edges);
410
411        let mut labels = vec![0usize; seq_len];
412        for pos in 0..seq_len {
413            let belief = self.position_belief(unary, &edges, &log_msg, &incoming, pos);
414            let mut best_l = 0usize;
415            let mut best_v = f64::NEG_INFINITY;
416            for l in 0..nl {
417                if belief[l] > best_v {
418                    best_v = belief[l];
419                    best_l = l;
420                }
421            }
422            labels[pos] = best_l;
423        }
424        Ok(labels)
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use crate::crf::linear_chain_crf::LinearChainCrf;
432    use crate::crf::viterbi_decode::viterbi_decode;
433    use crate::hmm::forward_backward::logsumexp;
434
435    fn cfg(n_labels: usize) -> SkipChainConfig {
436        SkipChainConfig {
437            n_labels,
438            max_bp_iters: 200,
439            bp_tol: 1e-10,
440        }
441    }
442
443    /// Exact linear-chain forward-backward marginals in score space, given per
444    /// position emission log-potentials `emit` (seq_len × n_labels) and a transition
445    /// matrix.  Mirrors `crf_train::forward_scores`/`backward_scores`.
446    fn exact_chain_marginals(emit: &[f64], transition: &[f64], n: usize, t_max: usize) -> Vec<f64> {
447        let mut alpha = vec![f64::NEG_INFINITY; t_max * n];
448        alpha[..n].copy_from_slice(&emit[..n]);
449        let mut tmp = vec![0.0; n];
450        for t in 1..t_max {
451            for j in 0..n {
452                for i in 0..n {
453                    tmp[i] = alpha[(t - 1) * n + i] + transition[i * n + j];
454                }
455                alpha[t * n + j] = logsumexp(&tmp) + emit[t * n + j];
456            }
457        }
458        let mut beta = vec![0.0; t_max * n];
459        for t in (0..t_max - 1).rev() {
460            for i in 0..n {
461                for j in 0..n {
462                    tmp[j] = transition[i * n + j] + emit[(t + 1) * n + j] + beta[(t + 1) * n + j];
463                }
464                beta[t * n + i] = logsumexp(&tmp);
465            }
466        }
467        let log_z = logsumexp(&alpha[(t_max - 1) * n..]);
468        let mut marg = vec![0.0; t_max * n];
469        for t in 0..t_max {
470            for j in 0..n {
471                marg[t * n + j] = (alpha[t * n + j] + beta[t * n + j] - log_z).exp();
472            }
473        }
474        marg
475    }
476
477    #[test]
478    fn marginals_shape() {
479        let crf = SkipChainCrf::new(cfg(3), vec![0.0; 9], vec![0.0; 9]).expect("new");
480        let unary = vec![0.0; 4 * 3];
481        let m = crf.infer_marginals(&unary, 4, &[]).expect("marg");
482        assert_eq!(m.len(), 4 * 3);
483    }
484
485    #[test]
486    fn marginals_each_position_sums_to_one() {
487        let transition = vec![0.5, -0.2, 0.1, 0.3];
488        let crf = SkipChainCrf::new(cfg(2), transition, vec![0.0, 0.0, 0.0, 0.0]).expect("new");
489        let unary = vec![1.0, -0.5, 0.2, 0.7, -0.3, 0.4];
490        let m = crf.infer_marginals(&unary, 3, &[(0, 2)]).expect("marg");
491        for t in 0..3 {
492            let s: f64 = m[t * 2..t * 2 + 2].iter().sum();
493            assert!((s - 1.0).abs() < 1e-9, "pos {t} sum {s}");
494        }
495    }
496
497    #[test]
498    fn no_skip_marginals_equal_forward_backward() {
499        let n = 3;
500        let transition = vec![0.4, -0.1, 0.2, -0.3, 0.5, 0.0, 0.1, -0.2, 0.6];
501        let crf = SkipChainCrf::new(cfg(n), transition.clone(), vec![0.0; 9]).expect("new");
502        let t_max = 5;
503        // Arbitrary deterministic unary log-potentials.
504        let mut unary = vec![0.0; t_max * n];
505        for t in 0..t_max {
506            for l in 0..n {
507                unary[t * n + l] = ((t * 7 + l * 3) as f64 % 5.0) - 2.0 + 0.1 * (t as f64);
508            }
509        }
510        let bp = crf.infer_marginals(&unary, t_max, &[]).expect("bp");
511        let exact = exact_chain_marginals(&unary, &transition, n, t_max);
512        for k in 0..t_max * n {
513            assert!(
514                (bp[k] - exact[k]).abs() < 1e-5,
515                "idx {k}: bp={} exact={}",
516                bp[k],
517                exact[k]
518            );
519        }
520    }
521
522    #[test]
523    fn no_skip_marginals_equal_brute_force_short() {
524        // Brute-force enumeration on a length-3 chain, n_labels=2.
525        let n = 2;
526        let transition = vec![0.3, -0.4, 0.2, 0.5];
527        let crf = SkipChainCrf::new(cfg(n), transition.clone(), vec![0.0; 4]).expect("new");
528        let t_max = 3;
529        let unary = vec![0.5, -0.2, 0.1, 0.7, -0.3, 0.4];
530        let bp = crf.infer_marginals(&unary, t_max, &[]).expect("bp");
531        // Enumerate all 2^3 = 8 sequences.
532        let mut marg = vec![0.0; t_max * n];
533        let mut z = 0.0;
534        for a in 0..n {
535            for b in 0..n {
536                for c in 0..n {
537                    let y = [a, b, c];
538                    let mut score = 0.0;
539                    for (t, &yt) in y.iter().enumerate() {
540                        score += unary[t * n + yt];
541                        if t > 0 {
542                            score += transition[y[t - 1] * n + yt];
543                        }
544                    }
545                    let p = score.exp();
546                    z += p;
547                    for (t, &yt) in y.iter().enumerate() {
548                        marg[t * n + yt] += p;
549                    }
550                }
551            }
552        }
553        for v in marg.iter_mut() {
554            *v /= z;
555        }
556        for k in 0..t_max * n {
557            assert!(
558                (bp[k] - marg[k]).abs() < 1e-6,
559                "idx {k}: {} vs {}",
560                bp[k],
561                marg[k]
562            );
563        }
564    }
565
566    #[test]
567    fn no_skip_decode_equals_viterbi() {
568        // Build an equivalent LinearChainCrf whose emissions reproduce the unary.
569        let n = 3;
570        let k = n; // one-hot features per position -> emission = unary
571        let transition = vec![0.4, -0.1, 0.2, -0.3, 0.5, 0.0, 0.1, -0.2, 0.6];
572        let crf = SkipChainCrf::new(cfg(n), transition.clone(), vec![0.0; 9]).expect("new");
573        let t_max = 6;
574        let mut unary = vec![0.0; t_max * n];
575        for t in 0..t_max {
576            for l in 0..n {
577                unary[t * n + l] = ((t * 5 + l * 11) as f64 % 7.0) - 3.0;
578            }
579        }
580        let bp_labels = crf.decode(&unary, t_max, &[]).expect("decode");
581
582        // LinearChainCrf with identity emission weights and one-hot features so that
583        // emit_score(l, x_t) = unary[t, l].
584        let mut lc = LinearChainCrf::zeros(n, k).expect("lc");
585        lc.transitions = transition;
586        // emissions[l*k + f] = 1 if l==f else 0.
587        for l in 0..n {
588            for f in 0..k {
589                lc.emissions[l * k + f] = if l == f { 1.0 } else { 0.0 };
590            }
591        }
592        let mut x = vec![0.0; t_max * k];
593        for t in 0..t_max {
594            for f in 0..k {
595                x[t * k + f] = unary[t * n + f];
596            }
597        }
598        let vit = viterbi_decode(&lc, &x).expect("viterbi");
599        assert_eq!(bp_labels, vit);
600    }
601
602    #[test]
603    fn skip_edge_pulls_marginals_to_agreement() {
604        // Two positions with identical evidence and a strongly-attractive skip edge
605        // should agree more than without the edge.  Use a diagonal skip potential.
606        let n = 2;
607        let transition = vec![0.0, 0.0, 0.0, 0.0];
608        // Attractive: large on equal labels, small on disagreement.
609        let skip = vec![2.0, -2.0, -2.0, 2.0];
610        let crf = SkipChainCrf::new(cfg(n), transition, skip).expect("new");
611        let t_max = 3;
612        // Position 0 prefers label 0; position 2 prefers label 1 (conflicting),
613        // position 1 neutral.
614        let unary = vec![1.0, -1.0, 0.0, 0.0, -1.0, 1.0];
615        let no_skip = crf.infer_marginals(&unary, t_max, &[]).expect("ns");
616        let with_skip = crf.infer_marginals(&unary, t_max, &[(0, 2)]).expect("ws");
617        // Without the skip edge, position 0 favours label 0 and position 2 favours
618        // label 1, so they disagree.  The attractive skip edge should bring the
619        // marginals of position 0 and 2 closer together.
620        let dist_no = (no_skip[0] - no_skip[2 * n]).abs();
621        let dist_ws = (with_skip[0] - with_skip[2 * n]).abs();
622        assert!(
623            dist_ws < dist_no,
624            "skip edge should reduce disagreement: no={dist_no} ws={dist_ws}"
625        );
626    }
627
628    #[test]
629    fn decode_returns_valid_labels() {
630        let n = 4;
631        let crf = SkipChainCrf::new(cfg(n), vec![0.1; 16], vec![0.0; 16]).expect("new");
632        let t_max = 5;
633        let mut unary = vec![0.0; t_max * n];
634        for (i, v) in unary.iter_mut().enumerate() {
635            *v = (i as f64 % 3.0) - 1.0;
636        }
637        let labels = crf.decode(&unary, t_max, &[(0, 3), (1, 4)]).expect("dec");
638        assert_eq!(labels.len(), t_max);
639        for &l in &labels {
640            assert!(l < n);
641        }
642    }
643
644    #[test]
645    fn bp_converges_on_short_sequence() {
646        let n = 2;
647        let transition = vec![0.5, -0.2, 0.1, 0.3];
648        let skip = vec![0.4, -0.1, -0.1, 0.4];
649        let crf = SkipChainCrf::new(cfg(n), transition, skip).expect("new");
650        let unary = vec![0.3, -0.2, 0.1, 0.4, -0.3, 0.2, 0.0, 0.1];
651        let (_m, iters, converged) = crf
652            .infer_marginals_with_status(&unary, 4, &[(0, 3)])
653            .expect("bp");
654        assert!(converged, "BP should converge");
655        assert!(iters <= 200);
656    }
657
658    #[test]
659    fn uniform_unary_uniform_potentials_uniform_marginals() {
660        let n = 3;
661        let crf = SkipChainCrf::new(cfg(n), vec![0.0; 9], vec![0.0; 9]).expect("new");
662        let t_max = 4;
663        let unary = vec![0.0; t_max * n];
664        let m = crf.infer_marginals(&unary, t_max, &[(0, 2)]).expect("m");
665        for t in 0..t_max {
666            for l in 0..n {
667                assert!(
668                    (m[t * n + l] - 1.0 / n as f64).abs() < 1e-9,
669                    "pos {t} label {l}: {}",
670                    m[t * n + l]
671                );
672            }
673        }
674    }
675
676    #[test]
677    fn deterministic_inference() {
678        let n = 2;
679        let transition = vec![0.5, -0.2, 0.1, 0.3];
680        let skip = vec![0.4, -0.1, -0.1, 0.4];
681        let crf = SkipChainCrf::new(cfg(n), transition, skip).expect("new");
682        let unary = vec![0.3, -0.2, 0.1, 0.4, -0.3, 0.2];
683        let a = crf.infer_marginals(&unary, 3, &[(0, 2)]).expect("a");
684        let b = crf.infer_marginals(&unary, 3, &[(0, 2)]).expect("b");
685        assert_eq!(a, b);
686        let da = crf.decode(&unary, 3, &[(0, 2)]).expect("da");
687        let db = crf.decode(&unary, 3, &[(0, 2)]).expect("db");
688        assert_eq!(da, db);
689    }
690
691    #[test]
692    fn seq_len_one_marginal_is_softmax() {
693        let n = 3;
694        let crf = SkipChainCrf::new(cfg(n), vec![0.0; 9], vec![0.0; 9]).expect("new");
695        let unary = vec![1.0, 0.0, -1.0];
696        let m = crf.infer_marginals(&unary, 1, &[]).expect("m");
697        // Softmax of unary.
698        let mx = unary.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
699        let exps: Vec<f64> = unary.iter().map(|&u| (u - mx).exp()).collect();
700        let s: f64 = exps.iter().sum();
701        for l in 0..n {
702            assert!((m[l] - exps[l] / s).abs() < 1e-12, "label {l}");
703        }
704    }
705
706    #[test]
707    fn single_label_trivial() {
708        let crf = SkipChainCrf::new(cfg(1), vec![0.0], vec![0.0]).expect("new");
709        let unary = vec![3.0, -1.0, 0.5];
710        let m = crf.infer_marginals(&unary, 3, &[(0, 2)]).expect("m");
711        for v in &m {
712            assert!((v - 1.0).abs() < 1e-12);
713        }
714        let labels = crf.decode(&unary, 3, &[(0, 2)]).expect("dec");
715        assert_eq!(labels, vec![0, 0, 0]);
716    }
717
718    #[test]
719    fn err_transition_wrong_length() {
720        let r = SkipChainCrf::new(cfg(2), vec![0.0; 3], vec![0.0; 4]);
721        assert!(matches!(r, Err(SeqError::ShapeMismatch { .. })));
722    }
723
724    #[test]
725    fn err_skip_potential_wrong_length() {
726        let r = SkipChainCrf::new(cfg(2), vec![0.0; 4], vec![0.0; 5]);
727        assert!(matches!(r, Err(SeqError::ShapeMismatch { .. })));
728    }
729
730    #[test]
731    fn err_unary_wrong_length() {
732        let crf = SkipChainCrf::new(cfg(2), vec![0.0; 4], vec![0.0; 4]).expect("new");
733        let r = crf.infer_marginals(&[0.0, 0.0, 0.0], 3, &[]);
734        assert!(matches!(r, Err(SeqError::ShapeMismatch { .. })));
735    }
736
737    #[test]
738    fn err_skip_edge_out_of_range() {
739        let crf = SkipChainCrf::new(cfg(2), vec![0.0; 4], vec![0.0; 4]).expect("new");
740        let unary = vec![0.0; 6];
741        let r = crf.infer_marginals(&unary, 3, &[(0, 9)]);
742        assert!(matches!(r, Err(SeqError::IndexOutOfBounds { .. })));
743    }
744
745    #[test]
746    fn err_skip_edge_i_ge_j() {
747        let crf = SkipChainCrf::new(cfg(2), vec![0.0; 4], vec![0.0; 4]).expect("new");
748        let unary = vec![0.0; 6];
749        let r = crf.infer_marginals(&unary, 3, &[(2, 1)]);
750        assert!(matches!(r, Err(SeqError::GraphInvariantViolated(_))));
751        let r2 = crf.infer_marginals(&unary, 3, &[(1, 1)]);
752        assert!(matches!(r2, Err(SeqError::GraphInvariantViolated(_))));
753    }
754
755    #[test]
756    fn err_n_labels_zero() {
757        let r = SkipChainCrf::new(cfg(0), vec![], vec![]);
758        assert!(matches!(r, Err(SeqError::InvalidConfiguration(_))));
759    }
760
761    #[test]
762    fn err_max_bp_iters_zero() {
763        let c = SkipChainConfig {
764            n_labels: 2,
765            max_bp_iters: 0,
766            bp_tol: 1e-6,
767        };
768        let r = SkipChainCrf::new(c, vec![0.0; 4], vec![0.0; 4]);
769        assert!(matches!(r, Err(SeqError::InvalidConfiguration(_))));
770    }
771
772    #[test]
773    fn err_bp_tol_non_positive() {
774        let c = SkipChainConfig {
775            n_labels: 2,
776            max_bp_iters: 10,
777            bp_tol: 0.0,
778        };
779        let r = SkipChainCrf::new(c, vec![0.0; 4], vec![0.0; 4]);
780        assert!(matches!(r, Err(SeqError::InvalidParameter { .. })));
781        let c2 = SkipChainConfig {
782            n_labels: 2,
783            max_bp_iters: 10,
784            bp_tol: -1.0,
785        };
786        let r2 = SkipChainCrf::new(c2, vec![0.0; 4], vec![0.0; 4]);
787        assert!(matches!(r2, Err(SeqError::InvalidParameter { .. })));
788    }
789
790    #[test]
791    fn err_empty_input_seq_len_zero() {
792        let crf = SkipChainCrf::new(cfg(2), vec![0.0; 4], vec![0.0; 4]).expect("new");
793        let r = crf.infer_marginals(&[], 0, &[]);
794        assert!(matches!(r, Err(SeqError::EmptyInput)));
795    }
796
797    #[test]
798    fn n_labels_accessor() {
799        let crf = SkipChainCrf::new(cfg(5), vec![0.0; 25], vec![0.0; 25]).expect("new");
800        assert_eq!(crf.n_labels(), 5);
801    }
802
803    #[test]
804    fn with_damping_validates() {
805        let crf = SkipChainCrf::new(cfg(2), vec![0.0; 4], vec![0.0; 4]).expect("new");
806        assert!(crf.clone().with_damping(0.3).is_ok());
807        assert!(crf.clone().with_damping(1.0).is_ok());
808        assert!(crf.clone().with_damping(0.0).is_err());
809        assert!(crf.with_damping(1.5).is_err());
810    }
811
812    #[test]
813    fn no_skip_decode_equals_viterbi_two_labels() {
814        // A second decode/Viterbi agreement test with a unique (non-degenerate) MAP
815        // sequence 0,1,1,0 induced by asymmetric unary log-potentials.
816        let n = 2;
817        let transition = vec![0.8, -0.5, -0.3, 0.6];
818        let crf = SkipChainCrf::new(cfg(n), transition.clone(), vec![0.0; 4]).expect("new");
819        let t_max = 4;
820        let unary = vec![3.0, -1.0, -1.0, 3.0, -2.0, 2.0, 2.5, -1.5];
821        let bp_labels = crf.decode(&unary, t_max, &[]).expect("dec");
822        let mut lc = LinearChainCrf::zeros(n, n).expect("lc");
823        lc.transitions = transition;
824        for l in 0..n {
825            lc.emissions[l * n + l] = 1.0;
826        }
827        let x = unary.clone();
828        let vit = viterbi_decode(&lc, &x).expect("vit");
829        assert_eq!(bp_labels, vit);
830    }
831}