oxicuda_seq/crf/
linear_chain_crf.rs1use crate::error::{SeqError, SeqResult};
4
5#[derive(Debug, Clone)]
11pub struct LinearChainCrf {
12 pub n_labels: usize,
13 pub n_features: usize,
14 pub emissions: Vec<f64>,
15 pub transitions: Vec<f64>,
16}
17
18impl LinearChainCrf {
19 pub fn zeros(n_labels: usize, n_features: usize) -> SeqResult<Self> {
21 if n_labels == 0 || n_features == 0 {
22 return Err(SeqError::InvalidConfiguration(
23 "n_labels and n_features must be > 0".to_string(),
24 ));
25 }
26 Ok(Self {
27 n_labels,
28 n_features,
29 emissions: vec![0.0; n_labels * n_features],
30 transitions: vec![0.0; n_labels * n_labels],
31 })
32 }
33
34 pub fn param_count(&self) -> usize {
36 self.n_labels * self.n_features + self.n_labels * self.n_labels
37 }
38
39 pub fn to_params(&self) -> Vec<f64> {
41 let mut v = Vec::with_capacity(self.param_count());
42 v.extend_from_slice(&self.emissions);
43 v.extend_from_slice(&self.transitions);
44 v
45 }
46
47 pub fn from_params(&mut self, p: &[f64]) -> SeqResult<()> {
49 let expected = self.param_count();
50 if p.len() != expected {
51 return Err(SeqError::ShapeMismatch {
52 expected,
53 got: p.len(),
54 });
55 }
56 let cut = self.n_labels * self.n_features;
57 self.emissions.copy_from_slice(&p[..cut]);
58 self.transitions.copy_from_slice(&p[cut..]);
59 Ok(())
60 }
61
62 pub fn emit_score(&self, label: usize, x: &[f64]) -> SeqResult<f64> {
64 if label >= self.n_labels {
65 return Err(SeqError::IndexOutOfBounds {
66 index: label,
67 len: self.n_labels,
68 });
69 }
70 if x.len() != self.n_features {
71 return Err(SeqError::ShapeMismatch {
72 expected: self.n_features,
73 got: x.len(),
74 });
75 }
76 let mut s = 0.0;
77 for (k, &xv) in x.iter().enumerate() {
78 s += self.emissions[label * self.n_features + k] * xv;
79 }
80 Ok(s)
81 }
82
83 pub fn trans_score(&self, prev: usize, cur: usize) -> SeqResult<f64> {
85 if prev >= self.n_labels || cur >= self.n_labels {
86 return Err(SeqError::IndexOutOfBounds {
87 index: prev.max(cur),
88 len: self.n_labels,
89 });
90 }
91 Ok(self.transitions[prev * self.n_labels + cur])
92 }
93
94 pub fn sequence_score(&self, x: &[f64], y: &[usize]) -> SeqResult<f64> {
96 if y.is_empty() {
97 return Err(SeqError::EmptyInput);
98 }
99 let t_max = y.len();
100 if x.len() != t_max * self.n_features {
101 return Err(SeqError::ShapeMismatch {
102 expected: t_max * self.n_features,
103 got: x.len(),
104 });
105 }
106 let mut s = 0.0;
107 for t in 0..t_max {
108 s += self.emit_score(y[t], &x[t * self.n_features..(t + 1) * self.n_features])?;
109 if t > 0 {
110 s += self.trans_score(y[t - 1], y[t])?;
111 }
112 }
113 Ok(s)
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn zero_crf_has_zero_score() {
123 let crf = LinearChainCrf::zeros(3, 4).expect("ok");
124 let s = crf.sequence_score(&[1.0; 8], &[0, 1]).expect("ok");
125 assert!((s).abs() < 1e-12);
126 }
127
128 #[test]
129 fn pack_unpack_roundtrip() {
130 let mut crf = LinearChainCrf::zeros(2, 3).expect("ok");
131 for v in crf.emissions.iter_mut() {
132 *v = 0.5;
133 }
134 for v in crf.transitions.iter_mut() {
135 *v = 1.5;
136 }
137 let p = crf.to_params();
138 let mut crf2 = LinearChainCrf::zeros(2, 3).expect("ok");
139 crf2.from_params(&p).expect("ok");
140 assert_eq!(crf.emissions, crf2.emissions);
141 assert_eq!(crf.transitions, crf2.transitions);
142 }
143}