Skip to main content

oxicuda_seq/crf/
linear_chain_crf.rs

1//! Linear-chain CRF parameterisation.
2
3use crate::error::{SeqError, SeqResult};
4
5/// Linear-chain CRF with linear emissions over real-valued features and a full
6/// dense transition matrix.
7///
8/// * `emissions[label*n_features + k]` — emission weight w_k for label
9/// * `transitions[prev*n_labels + cur]` — transition weight from `prev` to `cur`
10#[derive(Debug, Clone)]
11pub struct LinearChainCrf {
12    pub n_labels: usize,
13    pub n_features: usize,
14    pub emissions: Vec<f64>,
15    pub transitions: Vec<f64>,
16}
17
18impl LinearChainCrf {
19    /// Construct a zero-initialised CRF.
20    pub fn zeros(n_labels: usize, n_features: usize) -> SeqResult<Self> {
21        if n_labels == 0 || n_features == 0 {
22            return Err(SeqError::InvalidConfiguration(
23                "n_labels and n_features must be > 0".to_string(),
24            ));
25        }
26        Ok(Self {
27            n_labels,
28            n_features,
29            emissions: vec![0.0; n_labels * n_features],
30            transitions: vec![0.0; n_labels * n_labels],
31        })
32    }
33
34    /// Number of free parameters (emissions + transitions).
35    pub fn param_count(&self) -> usize {
36        self.n_labels * self.n_features + self.n_labels * self.n_labels
37    }
38
39    /// Pack `(emissions, transitions)` into a flat parameter vector.
40    pub fn to_params(&self) -> Vec<f64> {
41        let mut v = Vec::with_capacity(self.param_count());
42        v.extend_from_slice(&self.emissions);
43        v.extend_from_slice(&self.transitions);
44        v
45    }
46
47    /// Unpack a flat parameter vector into `(emissions, transitions)`.
48    pub fn from_params(&mut self, p: &[f64]) -> SeqResult<()> {
49        let expected = self.param_count();
50        if p.len() != expected {
51            return Err(SeqError::ShapeMismatch {
52                expected,
53                got: p.len(),
54            });
55        }
56        let cut = self.n_labels * self.n_features;
57        self.emissions.copy_from_slice(&p[..cut]);
58        self.transitions.copy_from_slice(&p[cut..]);
59        Ok(())
60    }
61
62    /// Compute the emission score for a single (label, feature_vec) pair.
63    pub fn emit_score(&self, label: usize, x: &[f64]) -> SeqResult<f64> {
64        if label >= self.n_labels {
65            return Err(SeqError::IndexOutOfBounds {
66                index: label,
67                len: self.n_labels,
68            });
69        }
70        if x.len() != self.n_features {
71            return Err(SeqError::ShapeMismatch {
72                expected: self.n_features,
73                got: x.len(),
74            });
75        }
76        let mut s = 0.0;
77        for (k, &xv) in x.iter().enumerate() {
78            s += self.emissions[label * self.n_features + k] * xv;
79        }
80        Ok(s)
81    }
82
83    /// Score for a single (prev_label, cur_label) transition.
84    pub fn trans_score(&self, prev: usize, cur: usize) -> SeqResult<f64> {
85        if prev >= self.n_labels || cur >= self.n_labels {
86            return Err(SeqError::IndexOutOfBounds {
87                index: prev.max(cur),
88                len: self.n_labels,
89            });
90        }
91        Ok(self.transitions[prev * self.n_labels + cur])
92    }
93
94    /// Score of a full label sequence given a feature matrix (T × n_features).
95    pub fn sequence_score(&self, x: &[f64], y: &[usize]) -> SeqResult<f64> {
96        if y.is_empty() {
97            return Err(SeqError::EmptyInput);
98        }
99        let t_max = y.len();
100        if x.len() != t_max * self.n_features {
101            return Err(SeqError::ShapeMismatch {
102                expected: t_max * self.n_features,
103                got: x.len(),
104            });
105        }
106        let mut s = 0.0;
107        for t in 0..t_max {
108            s += self.emit_score(y[t], &x[t * self.n_features..(t + 1) * self.n_features])?;
109            if t > 0 {
110                s += self.trans_score(y[t - 1], y[t])?;
111            }
112        }
113        Ok(s)
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn zero_crf_has_zero_score() {
123        let crf = LinearChainCrf::zeros(3, 4).expect("ok");
124        let s = crf.sequence_score(&[1.0; 8], &[0, 1]).expect("ok");
125        assert!((s).abs() < 1e-12);
126    }
127
128    #[test]
129    fn pack_unpack_roundtrip() {
130        let mut crf = LinearChainCrf::zeros(2, 3).expect("ok");
131        for v in crf.emissions.iter_mut() {
132            *v = 0.5;
133        }
134        for v in crf.transitions.iter_mut() {
135            *v = 1.5;
136        }
137        let p = crf.to_params();
138        let mut crf2 = LinearChainCrf::zeros(2, 3).expect("ok");
139        crf2.from_params(&p).expect("ok");
140        assert_eq!(crf.emissions, crf2.emissions);
141        assert_eq!(crf.transitions, crf2.transitions);
142    }
143}