1use ghostflow_core::Tensor;
7use std::collections::HashMap;
8
9pub struct LinearChainCRF {
17 pub n_labels: usize,
18 pub n_features: usize,
19 pub max_iter: usize,
20 pub learning_rate: f32,
21 pub l2_penalty: f32,
22 pub tol: f32,
23
24 weights: Vec<f32>, transitions: Vec<Vec<f32>>, converged: bool,
28}
29
30impl LinearChainCRF {
31 pub fn new(n_labels: usize, n_features: usize) -> Self {
32 Self {
33 n_labels,
34 n_features,
35 max_iter: 100,
36 learning_rate: 0.01,
37 l2_penalty: 0.1,
38 tol: 1e-3,
39 weights: vec![0.0; n_features * n_labels],
40 transitions: vec![vec![0.0; n_labels]; n_labels],
41 converged: false,
42 }
43 }
44
45 pub fn max_iter(mut self, iter: usize) -> Self {
46 self.max_iter = iter;
47 self
48 }
49
50 pub fn learning_rate(mut self, lr: f32) -> Self {
51 self.learning_rate = lr;
52 self
53 }
54
55 pub fn l2_penalty(mut self, penalty: f32) -> Self {
56 self.l2_penalty = penalty;
57 self
58 }
59
60 pub fn fit(&mut self, sequences: &[Tensor], labels: &[Tensor]) {
62 assert_eq!(sequences.len(), labels.len(), "Number of sequences and labels must match");
63
64 let mut prev_loss = f32::INFINITY;
65
66 for iteration in 0..self.max_iter {
67 let mut total_loss = 0.0;
68 let mut n_samples = 0;
69
70 for (seq_idx, (sequence, label_seq)) in sequences.iter().zip(labels.iter()).enumerate() {
72 let seq_data = sequence.data_f32();
73 let label_data = label_seq.data_f32();
74 let seq_len = sequence.dims()[0];
75
76 let (alpha, beta, z) = self.forward_backward(&seq_data, seq_len);
78
79 let (weight_grad, trans_grad) = self.compute_gradients(
81 &seq_data,
82 &label_data,
83 &alpha,
84 &beta,
85 z,
86 seq_len,
87 );
88
89 self.update_parameters(&weight_grad, &trans_grad);
91
92 let loss = self.compute_loss(&seq_data, &label_data, seq_len);
94 total_loss += loss;
95 n_samples += 1;
96 }
97
98 let avg_loss = total_loss / n_samples as f32;
99
100 if (prev_loss - avg_loss).abs() < self.tol {
102 self.converged = true;
103 println!("CRF converged at iteration {}", iteration);
104 break;
105 }
106
107 prev_loss = avg_loss;
108
109 if iteration % 10 == 0 {
110 println!("Iteration {}: Loss = {:.4}", iteration, avg_loss);
111 }
112 }
113 }
114
115 fn forward_backward(&self, seq_data: &[f32], seq_len: usize) -> (Vec<Vec<f32>>, Vec<Vec<f32>>, f32) {
117 let mut alpha = vec![vec![f32::NEG_INFINITY; self.n_labels]; seq_len];
119
120 for j in 0..self.n_labels {
122 alpha[0][j] = self.emission_score(&seq_data, 0, j);
123 }
124
125 for t in 1..seq_len {
127 for j in 0..self.n_labels {
128 let emission = self.emission_score(&seq_data, t, j);
129 let mut max_score = f32::NEG_INFINITY;
130
131 for i in 0..self.n_labels {
132 let score = alpha[t - 1][i] + self.transitions[i][j] + emission;
133 max_score = max_score.max(score);
134 }
135
136 let mut sum = 0.0;
138 for i in 0..self.n_labels {
139 let score = alpha[t - 1][i] + self.transitions[i][j] + emission;
140 sum += (score - max_score).exp();
141 }
142 alpha[t][j] = max_score + sum.ln();
143 }
144 }
145
146 let mut max_alpha = f32::NEG_INFINITY;
148 for j in 0..self.n_labels {
149 max_alpha = max_alpha.max(alpha[seq_len - 1][j]);
150 }
151
152 let mut z_sum = 0.0;
153 for j in 0..self.n_labels {
154 z_sum += (alpha[seq_len - 1][j] - max_alpha).exp();
155 }
156 let z = max_alpha + z_sum.ln();
157
158 let mut beta = vec![vec![f32::NEG_INFINITY; self.n_labels]; seq_len];
160
161 for j in 0..self.n_labels {
163 beta[seq_len - 1][j] = 0.0;
164 }
165
166 for t in (0..seq_len - 1).rev() {
168 for i in 0..self.n_labels {
169 let mut max_score = f32::NEG_INFINITY;
170
171 for j in 0..self.n_labels {
172 let emission = self.emission_score(&seq_data, t + 1, j);
173 let score = self.transitions[i][j] + emission + beta[t + 1][j];
174 max_score = max_score.max(score);
175 }
176
177 let mut sum = 0.0;
179 for j in 0..self.n_labels {
180 let emission = self.emission_score(&seq_data, t + 1, j);
181 let score = self.transitions[i][j] + emission + beta[t + 1][j];
182 sum += (score - max_score).exp();
183 }
184 beta[t][i] = max_score + sum.ln();
185 }
186 }
187
188 (alpha, beta, z)
189 }
190
191 fn emission_score(&self, seq_data: &[f32], position: usize, label: usize) -> f32 {
193 let features = &seq_data[position * self.n_features..(position + 1) * self.n_features];
194 let mut score = 0.0;
195
196 for (feat_idx, &feat_val) in features.iter().enumerate() {
197 let weight_idx = feat_idx * self.n_labels + label;
198 score += self.weights[weight_idx] * feat_val;
199 }
200
201 score
202 }
203
204 fn compute_gradients(
206 &self,
207 seq_data: &[f32],
208 label_data: &[f32],
209 alpha: &[Vec<f32>],
210 beta: &[Vec<f32>],
211 z: f32,
212 seq_len: usize,
213 ) -> (Vec<f32>, Vec<Vec<f32>>) {
214 let mut weight_grad = vec![0.0; self.n_features * self.n_labels];
215 let mut trans_grad = vec![vec![0.0; self.n_labels]; self.n_labels];
216
217 for t in 0..seq_len {
219 let features = &seq_data[t * self.n_features..(t + 1) * self.n_features];
220
221 for j in 0..self.n_labels {
222 let marginal = (alpha[t][j] + beta[t][j] - z).exp();
224
225 for (feat_idx, &feat_val) in features.iter().enumerate() {
227 let weight_idx = feat_idx * self.n_labels + j;
228 weight_grad[weight_idx] -= marginal * feat_val;
229 }
230 }
231 }
232
233 for t in 0..seq_len - 1 {
235 for i in 0..self.n_labels {
236 for j in 0..self.n_labels {
237 let emission = self.emission_score(&seq_data, t + 1, j);
238 let marginal = (alpha[t][i] + self.transitions[i][j] + emission + beta[t + 1][j] - z).exp();
239 trans_grad[i][j] -= marginal;
240 }
241 }
242 }
243
244 for t in 0..seq_len {
246 let label = label_data[t] as usize;
247 let features = &seq_data[t * self.n_features..(t + 1) * self.n_features];
248
249 for (feat_idx, &feat_val) in features.iter().enumerate() {
250 let weight_idx = feat_idx * self.n_labels + label;
251 weight_grad[weight_idx] += feat_val;
252 }
253 }
254
255 for t in 0..seq_len - 1 {
256 let prev_label = label_data[t] as usize;
257 let curr_label = label_data[t + 1] as usize;
258 trans_grad[prev_label][curr_label] += 1.0;
259 }
260
261 for i in 0..weight_grad.len() {
263 weight_grad[i] -= self.l2_penalty * self.weights[i];
264 }
265
266 for i in 0..self.n_labels {
267 for j in 0..self.n_labels {
268 trans_grad[i][j] -= self.l2_penalty * self.transitions[i][j];
269 }
270 }
271
272 (weight_grad, trans_grad)
273 }
274
275 fn update_parameters(&mut self, weight_grad: &[f32], trans_grad: &[Vec<f32>]) {
277 for i in 0..self.weights.len() {
279 self.weights[i] += self.learning_rate * weight_grad[i];
280 }
281
282 for i in 0..self.n_labels {
284 for j in 0..self.n_labels {
285 self.transitions[i][j] += self.learning_rate * trans_grad[i][j];
286 }
287 }
288 }
289
290 fn compute_loss(&self, seq_data: &[f32], label_data: &[f32], seq_len: usize) -> f32 {
292 let mut true_score = 0.0;
294
295 for t in 0..seq_len {
296 let label = label_data[t] as usize;
297 true_score += self.emission_score(&seq_data, t, label);
298 }
299
300 for t in 0..seq_len - 1 {
301 let prev_label = label_data[t] as usize;
302 let curr_label = label_data[t + 1] as usize;
303 true_score += self.transitions[prev_label][curr_label];
304 }
305
306 let (_, _, z) = self.forward_backward(&seq_data, seq_len);
308
309 let nll = z - true_score;
311
312 let mut reg_term = 0.0;
314 for &w in &self.weights {
315 reg_term += w * w;
316 }
317 for i in 0..self.n_labels {
318 for j in 0..self.n_labels {
319 reg_term += self.transitions[i][j] * self.transitions[i][j];
320 }
321 }
322 reg_term *= 0.5 * self.l2_penalty;
323
324 nll + reg_term
325 }
326
327 pub fn predict(&self, sequence: &Tensor) -> Tensor {
329 let seq_data = sequence.data_f32();
330 let seq_len = sequence.dims()[0];
331
332 let mut delta = vec![vec![f32::NEG_INFINITY; self.n_labels]; seq_len];
333 let mut psi = vec![vec![0; self.n_labels]; seq_len];
334
335 for j in 0..self.n_labels {
337 delta[0][j] = self.emission_score(&seq_data, 0, j);
338 }
339
340 for t in 1..seq_len {
342 for j in 0..self.n_labels {
343 let emission = self.emission_score(&seq_data, t, j);
344 let mut max_score = f32::NEG_INFINITY;
345 let mut max_idx = 0;
346
347 for i in 0..self.n_labels {
348 let score = delta[t - 1][i] + self.transitions[i][j] + emission;
349 if score > max_score {
350 max_score = score;
351 max_idx = i;
352 }
353 }
354
355 delta[t][j] = max_score;
356 psi[t][j] = max_idx;
357 }
358 }
359
360 let mut path = vec![0; seq_len];
362 path[seq_len - 1] = delta[seq_len - 1]
363 .iter()
364 .enumerate()
365 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
366 .map(|(idx, _)| idx)
367 .unwrap();
368
369 for t in (0..seq_len - 1).rev() {
370 path[t] = psi[t + 1][path[t + 1]];
371 }
372
373 let path_f32: Vec<f32> = path.iter().map(|&x| x as f32).collect();
374 Tensor::from_slice(&path_f32, &[seq_len]).unwrap()
375 }
376
377 pub fn predict_marginals(&self, sequence: &Tensor) -> Tensor {
379 let seq_data = sequence.data_f32();
380 let seq_len = sequence.dims()[0];
381
382 let (alpha, beta, z) = self.forward_backward(&seq_data, seq_len);
383
384 let mut marginals = Vec::with_capacity(seq_len * self.n_labels);
385
386 for t in 0..seq_len {
387 for j in 0..self.n_labels {
388 let marginal = (alpha[t][j] + beta[t][j] - z).exp();
389 marginals.push(marginal);
390 }
391 }
392
393 Tensor::from_slice(&marginals, &[seq_len, self.n_labels]).unwrap()
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn test_linear_chain_crf() {
403 let seq1 = Tensor::from_slice(
405 &[
406 1.0f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, ],
410 &[3, 3],
411 ).unwrap();
412
413 let labels1 = Tensor::from_slice(&[0.0f32, 1.0, 2.0], &[3]).unwrap();
414
415 let sequences = vec![seq1.clone()];
416 let labels = vec![labels1];
417
418 let mut crf = LinearChainCRF::new(3, 3)
419 .max_iter(50)
420 .learning_rate(0.1)
421 .l2_penalty(0.01);
422
423 crf.fit(&sequences, &labels);
424
425 let predictions = crf.predict(&seq1);
426 assert_eq!(predictions.dims(), &[3]);
427
428 let marginals = crf.predict_marginals(&seq1);
429 assert_eq!(marginals.dims(), &[3, 3]);
430 }
431}
432
433