ghostflow_ml/
crf.rs

1//! Conditional Random Fields (CRF)
2//!
3//! Discriminative probabilistic models for structured prediction,
4//! particularly useful for sequence labeling tasks.
5
6use ghostflow_core::Tensor;
7use std::collections::HashMap;
8
9/// Linear-chain Conditional Random Field
10/// 
11/// Used for sequence labeling tasks like:
12/// - Named Entity Recognition (NER)
13/// - Part-of-Speech (POS) tagging
14/// - Chunking
15/// - Segmentation
16pub struct LinearChainCRF {
17    pub n_labels: usize,
18    pub n_features: usize,
19    pub max_iter: usize,
20    pub learning_rate: f32,
21    pub l2_penalty: f32,
22    pub tol: f32,
23    
24    // Model parameters
25    weights: Vec<f32>,              // Feature weights (n_features * n_labels)
26    transitions: Vec<Vec<f32>>,     // Transition scores (n_labels, n_labels)
27    converged: bool,
28}
29
30impl LinearChainCRF {
31    pub fn new(n_labels: usize, n_features: usize) -> Self {
32        Self {
33            n_labels,
34            n_features,
35            max_iter: 100,
36            learning_rate: 0.01,
37            l2_penalty: 0.1,
38            tol: 1e-3,
39            weights: vec![0.0; n_features * n_labels],
40            transitions: vec![vec![0.0; n_labels]; n_labels],
41            converged: false,
42        }
43    }
44
45    pub fn max_iter(mut self, iter: usize) -> Self {
46        self.max_iter = iter;
47        self
48    }
49
50    pub fn learning_rate(mut self, lr: f32) -> Self {
51        self.learning_rate = lr;
52        self
53    }
54
55    pub fn l2_penalty(mut self, penalty: f32) -> Self {
56        self.l2_penalty = penalty;
57        self
58    }
59
60    /// Fit the CRF using stochastic gradient descent
61    pub fn fit(&mut self, sequences: &[Tensor], labels: &[Tensor]) {
62        assert_eq!(sequences.len(), labels.len(), "Number of sequences and labels must match");
63
64        let mut prev_loss = f32::INFINITY;
65
66        for iteration in 0..self.max_iter {
67            let mut total_loss = 0.0;
68            let mut n_samples = 0;
69
70            // Process each sequence
71            for (seq_idx, (sequence, label_seq)) in sequences.iter().zip(labels.iter()).enumerate() {
72                let seq_data = sequence.data_f32();
73                let label_data = label_seq.data_f32();
74                let seq_len = sequence.dims()[0];
75
76                // Forward-backward to compute marginals
77                let (alpha, beta, z) = self.forward_backward(&seq_data, seq_len);
78
79                // Compute gradients
80                let (weight_grad, trans_grad) = self.compute_gradients(
81                    &seq_data,
82                    &label_data,
83                    &alpha,
84                    &beta,
85                    z,
86                    seq_len,
87                );
88
89                // Update parameters
90                self.update_parameters(&weight_grad, &trans_grad);
91
92                // Compute loss for this sequence
93                let loss = self.compute_loss(&seq_data, &label_data, seq_len);
94                total_loss += loss;
95                n_samples += 1;
96            }
97
98            let avg_loss = total_loss / n_samples as f32;
99
100            // Check convergence
101            if (prev_loss - avg_loss).abs() < self.tol {
102                self.converged = true;
103                println!("CRF converged at iteration {}", iteration);
104                break;
105            }
106
107            prev_loss = avg_loss;
108
109            if iteration % 10 == 0 {
110                println!("Iteration {}: Loss = {:.4}", iteration, avg_loss);
111            }
112        }
113    }
114
115    /// Forward-backward algorithm for CRF
116    fn forward_backward(&self, seq_data: &[f32], seq_len: usize) -> (Vec<Vec<f32>>, Vec<Vec<f32>>, f32) {
117        // Forward pass
118        let mut alpha = vec![vec![f32::NEG_INFINITY; self.n_labels]; seq_len];
119        
120        // Initialize first position
121        for j in 0..self.n_labels {
122            alpha[0][j] = self.emission_score(&seq_data, 0, j);
123        }
124
125        // Forward recursion
126        for t in 1..seq_len {
127            for j in 0..self.n_labels {
128                let emission = self.emission_score(&seq_data, t, j);
129                let mut max_score = f32::NEG_INFINITY;
130                
131                for i in 0..self.n_labels {
132                    let score = alpha[t - 1][i] + self.transitions[i][j] + emission;
133                    max_score = max_score.max(score);
134                }
135                
136                // Log-sum-exp for numerical stability
137                let mut sum = 0.0;
138                for i in 0..self.n_labels {
139                    let score = alpha[t - 1][i] + self.transitions[i][j] + emission;
140                    sum += (score - max_score).exp();
141                }
142                alpha[t][j] = max_score + sum.ln();
143            }
144        }
145
146        // Compute partition function Z
147        let mut max_alpha = f32::NEG_INFINITY;
148        for j in 0..self.n_labels {
149            max_alpha = max_alpha.max(alpha[seq_len - 1][j]);
150        }
151        
152        let mut z_sum = 0.0;
153        for j in 0..self.n_labels {
154            z_sum += (alpha[seq_len - 1][j] - max_alpha).exp();
155        }
156        let z = max_alpha + z_sum.ln();
157
158        // Backward pass
159        let mut beta = vec![vec![f32::NEG_INFINITY; self.n_labels]; seq_len];
160        
161        // Initialize last position
162        for j in 0..self.n_labels {
163            beta[seq_len - 1][j] = 0.0;
164        }
165
166        // Backward recursion
167        for t in (0..seq_len - 1).rev() {
168            for i in 0..self.n_labels {
169                let mut max_score = f32::NEG_INFINITY;
170                
171                for j in 0..self.n_labels {
172                    let emission = self.emission_score(&seq_data, t + 1, j);
173                    let score = self.transitions[i][j] + emission + beta[t + 1][j];
174                    max_score = max_score.max(score);
175                }
176                
177                // Log-sum-exp
178                let mut sum = 0.0;
179                for j in 0..self.n_labels {
180                    let emission = self.emission_score(&seq_data, t + 1, j);
181                    let score = self.transitions[i][j] + emission + beta[t + 1][j];
182                    sum += (score - max_score).exp();
183                }
184                beta[t][i] = max_score + sum.ln();
185            }
186        }
187
188        (alpha, beta, z)
189    }
190
191    /// Compute emission score for a position and label
192    fn emission_score(&self, seq_data: &[f32], position: usize, label: usize) -> f32 {
193        let features = &seq_data[position * self.n_features..(position + 1) * self.n_features];
194        let mut score = 0.0;
195        
196        for (feat_idx, &feat_val) in features.iter().enumerate() {
197            let weight_idx = feat_idx * self.n_labels + label;
198            score += self.weights[weight_idx] * feat_val;
199        }
200        
201        score
202    }
203
204    /// Compute gradients using forward-backward marginals
205    fn compute_gradients(
206        &self,
207        seq_data: &[f32],
208        label_data: &[f32],
209        alpha: &[Vec<f32>],
210        beta: &[Vec<f32>],
211        z: f32,
212        seq_len: usize,
213    ) -> (Vec<f32>, Vec<Vec<f32>>) {
214        let mut weight_grad = vec![0.0; self.n_features * self.n_labels];
215        let mut trans_grad = vec![vec![0.0; self.n_labels]; self.n_labels];
216
217        // Compute expected feature counts (model)
218        for t in 0..seq_len {
219            let features = &seq_data[t * self.n_features..(t + 1) * self.n_features];
220            
221            for j in 0..self.n_labels {
222                // Marginal probability of label j at position t
223                let marginal = (alpha[t][j] + beta[t][j] - z).exp();
224                
225                // Update weight gradients (expected - observed)
226                for (feat_idx, &feat_val) in features.iter().enumerate() {
227                    let weight_idx = feat_idx * self.n_labels + j;
228                    weight_grad[weight_idx] -= marginal * feat_val;
229                }
230            }
231        }
232
233        // Compute expected transition counts (model)
234        for t in 0..seq_len - 1 {
235            for i in 0..self.n_labels {
236                for j in 0..self.n_labels {
237                    let emission = self.emission_score(&seq_data, t + 1, j);
238                    let marginal = (alpha[t][i] + self.transitions[i][j] + emission + beta[t + 1][j] - z).exp();
239                    trans_grad[i][j] -= marginal;
240                }
241            }
242        }
243
244        // Add observed counts
245        for t in 0..seq_len {
246            let label = label_data[t] as usize;
247            let features = &seq_data[t * self.n_features..(t + 1) * self.n_features];
248            
249            for (feat_idx, &feat_val) in features.iter().enumerate() {
250                let weight_idx = feat_idx * self.n_labels + label;
251                weight_grad[weight_idx] += feat_val;
252            }
253        }
254
255        for t in 0..seq_len - 1 {
256            let prev_label = label_data[t] as usize;
257            let curr_label = label_data[t + 1] as usize;
258            trans_grad[prev_label][curr_label] += 1.0;
259        }
260
261        // Add L2 regularization
262        for i in 0..weight_grad.len() {
263            weight_grad[i] -= self.l2_penalty * self.weights[i];
264        }
265
266        for i in 0..self.n_labels {
267            for j in 0..self.n_labels {
268                trans_grad[i][j] -= self.l2_penalty * self.transitions[i][j];
269            }
270        }
271
272        (weight_grad, trans_grad)
273    }
274
275    /// Update parameters using gradients
276    fn update_parameters(&mut self, weight_grad: &[f32], trans_grad: &[Vec<f32>]) {
277        // Update weights
278        for i in 0..self.weights.len() {
279            self.weights[i] += self.learning_rate * weight_grad[i];
280        }
281
282        // Update transitions
283        for i in 0..self.n_labels {
284            for j in 0..self.n_labels {
285                self.transitions[i][j] += self.learning_rate * trans_grad[i][j];
286            }
287        }
288    }
289
290    /// Compute negative log-likelihood loss
291    fn compute_loss(&self, seq_data: &[f32], label_data: &[f32], seq_len: usize) -> f32 {
292        // Compute score of true sequence
293        let mut true_score = 0.0;
294        
295        for t in 0..seq_len {
296            let label = label_data[t] as usize;
297            true_score += self.emission_score(&seq_data, t, label);
298        }
299        
300        for t in 0..seq_len - 1 {
301            let prev_label = label_data[t] as usize;
302            let curr_label = label_data[t + 1] as usize;
303            true_score += self.transitions[prev_label][curr_label];
304        }
305
306        // Compute partition function
307        let (_, _, z) = self.forward_backward(&seq_data, seq_len);
308
309        // Negative log-likelihood
310        let nll = z - true_score;
311
312        // Add L2 regularization term
313        let mut reg_term = 0.0;
314        for &w in &self.weights {
315            reg_term += w * w;
316        }
317        for i in 0..self.n_labels {
318            for j in 0..self.n_labels {
319                reg_term += self.transitions[i][j] * self.transitions[i][j];
320            }
321        }
322        reg_term *= 0.5 * self.l2_penalty;
323
324        nll + reg_term
325    }
326
327    /// Predict label sequence using Viterbi algorithm
328    pub fn predict(&self, sequence: &Tensor) -> Tensor {
329        let seq_data = sequence.data_f32();
330        let seq_len = sequence.dims()[0];
331
332        let mut delta = vec![vec![f32::NEG_INFINITY; self.n_labels]; seq_len];
333        let mut psi = vec![vec![0; self.n_labels]; seq_len];
334
335        // Initialize
336        for j in 0..self.n_labels {
337            delta[0][j] = self.emission_score(&seq_data, 0, j);
338        }
339
340        // Viterbi recursion
341        for t in 1..seq_len {
342            for j in 0..self.n_labels {
343                let emission = self.emission_score(&seq_data, t, j);
344                let mut max_score = f32::NEG_INFINITY;
345                let mut max_idx = 0;
346
347                for i in 0..self.n_labels {
348                    let score = delta[t - 1][i] + self.transitions[i][j] + emission;
349                    if score > max_score {
350                        max_score = score;
351                        max_idx = i;
352                    }
353                }
354
355                delta[t][j] = max_score;
356                psi[t][j] = max_idx;
357            }
358        }
359
360        // Backtrack
361        let mut path = vec![0; seq_len];
362        path[seq_len - 1] = delta[seq_len - 1]
363            .iter()
364            .enumerate()
365            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
366            .map(|(idx, _)| idx)
367            .unwrap();
368
369        for t in (0..seq_len - 1).rev() {
370            path[t] = psi[t + 1][path[t + 1]];
371        }
372
373        let path_f32: Vec<f32> = path.iter().map(|&x| x as f32).collect();
374        Tensor::from_slice(&path_f32, &[seq_len]).unwrap()
375    }
376
377    /// Predict marginal probabilities for each position
378    pub fn predict_marginals(&self, sequence: &Tensor) -> Tensor {
379        let seq_data = sequence.data_f32();
380        let seq_len = sequence.dims()[0];
381
382        let (alpha, beta, z) = self.forward_backward(&seq_data, seq_len);
383
384        let mut marginals = Vec::with_capacity(seq_len * self.n_labels);
385
386        for t in 0..seq_len {
387            for j in 0..self.n_labels {
388                let marginal = (alpha[t][j] + beta[t][j] - z).exp();
389                marginals.push(marginal);
390            }
391        }
392
393        Tensor::from_slice(&marginals, &[seq_len, self.n_labels]).unwrap()
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn test_linear_chain_crf() {
403        // Simple sequence labeling task
404        let seq1 = Tensor::from_slice(
405            &[
406                1.0f32, 0.0, 0.0,  // Position 0
407                0.0, 1.0, 0.0,     // Position 1
408                0.0, 0.0, 1.0,     // Position 2
409            ],
410            &[3, 3],
411        ).unwrap();
412
413        let labels1 = Tensor::from_slice(&[0.0f32, 1.0, 2.0], &[3]).unwrap();
414
415        let sequences = vec![seq1.clone()];
416        let labels = vec![labels1];
417
418        let mut crf = LinearChainCRF::new(3, 3)
419            .max_iter(50)
420            .learning_rate(0.1)
421            .l2_penalty(0.01);
422
423        crf.fit(&sequences, &labels);
424
425        let predictions = crf.predict(&seq1);
426        assert_eq!(predictions.dims(), &[3]);
427
428        let marginals = crf.predict_marginals(&seq1);
429        assert_eq!(marginals.dims(), &[3, 3]);
430    }
431}
432
433