Skip to main content

oxicuda_seq/crf/
crf_train.rs

1//! CRF training: log-likelihood + gradient via forward-backward in score space,
2//! plus a limited-memory BFGS (L-BFGS) optimiser with backtracking line search.
3
4use super::linear_chain_crf::LinearChainCrf;
5use crate::error::{SeqError, SeqResult};
6use crate::hmm::forward_backward::logsumexp;
7
8/// L-BFGS configuration.
9#[derive(Debug, Clone)]
10pub struct LbfgsConfig {
11    /// Number of (s,y) history pairs retained.
12    pub memory: usize,
13    /// Maximum optimiser iterations.
14    pub max_iter: usize,
15    /// Gradient-norm convergence tolerance.
16    pub grad_tol: f64,
17    /// Line-search backtracking factor (multiplicative shrink).
18    pub backtrack: f64,
19    /// Maximum line-search trials per iteration.
20    pub max_line_search: usize,
21    /// L2 regularisation strength on parameters.
22    pub l2: f64,
23}
24
25impl Default for LbfgsConfig {
26    fn default() -> Self {
27        Self {
28            memory: 5,
29            max_iter: 50,
30            grad_tol: 1e-6,
31            backtrack: 0.5,
32            max_line_search: 30,
33            l2: 1e-3,
34        }
35    }
36}
37
38/// Forward in score space.  α_t(j) = logsumexp_i(α_{t-1}(i) + tr[i,j]) + emit_t(j).
39fn forward_scores(crf: &LinearChainCrf, emit: &[f64]) -> Vec<f64> {
40    let n = crf.n_labels;
41    let t_max = emit.len() / n;
42    let mut alpha = vec![f64::NEG_INFINITY; t_max * n];
43    alpha[..n].copy_from_slice(&emit[..n]);
44    let mut tmp = vec![0.0; n];
45    for t in 1..t_max {
46        for j in 0..n {
47            for i in 0..n {
48                tmp[i] = alpha[(t - 1) * n + i] + crf.transitions[i * n + j];
49            }
50            alpha[t * n + j] = logsumexp(&tmp) + emit[t * n + j];
51        }
52    }
53    alpha
54}
55
56/// Backward in score space.  β_{T-1}(i) = 0;
57/// β_t(i) = logsumexp_j(tr[i,j] + emit_{t+1}(j) + β_{t+1}(j)).
58fn backward_scores(crf: &LinearChainCrf, emit: &[f64]) -> Vec<f64> {
59    let n = crf.n_labels;
60    let t_max = emit.len() / n;
61    let mut beta = vec![0.0; t_max * n];
62    let mut tmp = vec![0.0; n];
63    for t in (0..t_max - 1).rev() {
64        for i in 0..n {
65            for j in 0..n {
66                tmp[j] = crf.transitions[i * n + j] + emit[(t + 1) * n + j] + beta[(t + 1) * n + j];
67            }
68            beta[t * n + i] = logsumexp(&tmp);
69        }
70    }
71    beta
72}
73
74/// Compute log-likelihood and gradient for a single (x, y) example.
75///
76/// Returns `(log_likelihood, grad_emissions, grad_transitions)` where the gradient
77/// has the **same sign as the objective** — i.e. an *ascent* direction.
78pub fn crf_log_likelihood_and_gradient(
79    crf: &LinearChainCrf,
80    x: &[f64],
81    y: &[usize],
82) -> SeqResult<(f64, Vec<f64>, Vec<f64>)> {
83    let n = crf.n_labels;
84    let k = crf.n_features;
85    if y.is_empty() {
86        return Err(SeqError::EmptyInput);
87    }
88    let t_max = y.len();
89    if x.len() != t_max * k {
90        return Err(SeqError::ShapeMismatch {
91            expected: t_max * k,
92            got: x.len(),
93        });
94    }
95
96    // Pre-compute emission scores for every (t, j).
97    let mut emit = vec![0.0; t_max * n];
98    for t in 0..t_max {
99        for j in 0..n {
100            emit[t * n + j] = crf.emit_score(j, &x[t * k..(t + 1) * k])?;
101        }
102    }
103
104    let alpha = forward_scores(crf, &emit);
105    let beta = backward_scores(crf, &emit);
106
107    // log Z(x) = logsumexp_j α_{T-1}(j)
108    let last_alpha = &alpha[(t_max - 1) * n..];
109    let log_z = logsumexp(last_alpha);
110
111    // score(y, x)
112    let true_score = crf.sequence_score(x, y)?;
113    let ll = true_score - log_z;
114
115    // Marginals: p_t(j) ∝ exp(α_t(j) + β_t(j) − log Z)
116    let mut p_node = vec![0.0; t_max * n];
117    for t in 0..t_max {
118        for j in 0..n {
119            p_node[t * n + j] = (alpha[t * n + j] + beta[t * n + j] - log_z).exp();
120        }
121        let s: f64 = p_node[t * n..t * n + n].iter().sum();
122        if s > 0.0 {
123            for v in p_node[t * n..t * n + n].iter_mut() {
124                *v /= s;
125            }
126        }
127    }
128    // Edge marginals: p_t(i,j) ∝ exp(α_t(i) + tr[i,j] + emit_{t+1}(j) + β_{t+1}(j) − log Z)
129    let mut p_edge = vec![0.0; t_max.saturating_sub(1) * n * n];
130    for t in 0..t_max.saturating_sub(1) {
131        let mut s = 0.0;
132        for i in 0..n {
133            for j in 0..n {
134                let v = (alpha[t * n + i]
135                    + crf.transitions[i * n + j]
136                    + emit[(t + 1) * n + j]
137                    + beta[(t + 1) * n + j]
138                    - log_z)
139                    .exp();
140                p_edge[t * n * n + i * n + j] = v;
141                s += v;
142            }
143        }
144        if s > 0.0 {
145            for v in p_edge[t * n * n..(t + 1) * n * n].iter_mut() {
146                *v /= s;
147            }
148        }
149    }
150
151    // Gradient = empirical features − expected features
152    let mut grad_emit = vec![0.0; n * k];
153    let mut grad_trans = vec![0.0; n * n];
154
155    // Empirical: +1 at observed pairs
156    for t in 0..t_max {
157        let yt = y[t];
158        for f in 0..k {
159            grad_emit[yt * k + f] += x[t * k + f];
160        }
161        if t > 0 {
162            grad_trans[y[t - 1] * n + y[t]] += 1.0;
163        }
164    }
165    // Expected: -p
166    for t in 0..t_max {
167        for j in 0..n {
168            let p = p_node[t * n + j];
169            for f in 0..k {
170                grad_emit[j * k + f] -= p * x[t * k + f];
171            }
172        }
173        if t < t_max - 1 {
174            for i in 0..n {
175                for j in 0..n {
176                    grad_trans[i * n + j] -= p_edge[t * n * n + i * n + j];
177                }
178            }
179        }
180    }
181
182    Ok((ll, grad_emit, grad_trans))
183}
184
185/// Aggregate log-likelihood and gradient over a dataset, with L2 regularisation.
186fn objective_and_grad(
187    crf: &LinearChainCrf,
188    examples: &[(Vec<f64>, Vec<usize>)],
189    l2: f64,
190) -> SeqResult<(f64, Vec<f64>)> {
191    let mut total_ll = 0.0;
192    let mut g_emit = vec![0.0; crf.emissions.len()];
193    let mut g_trans = vec![0.0; crf.transitions.len()];
194    for (x, y) in examples {
195        let (ll, ge, gt) = crf_log_likelihood_and_gradient(crf, x, y)?;
196        total_ll += ll;
197        for (a, b) in g_emit.iter_mut().zip(ge.iter()) {
198            *a += *b;
199        }
200        for (a, b) in g_trans.iter_mut().zip(gt.iter()) {
201            *a += *b;
202        }
203    }
204    // L2 regularisation: −0.5 λ ||w||² to objective; gradient −= λ w
205    let mut reg = 0.0;
206    for &e in &crf.emissions {
207        reg += e * e;
208    }
209    for &t in &crf.transitions {
210        reg += t * t;
211    }
212    total_ll -= 0.5 * l2 * reg;
213    for (g, w) in g_emit.iter_mut().zip(crf.emissions.iter()) {
214        *g -= l2 * *w;
215    }
216    for (g, w) in g_trans.iter_mut().zip(crf.transitions.iter()) {
217        *g -= l2 * *w;
218    }
219
220    let mut grad = Vec::with_capacity(g_emit.len() + g_trans.len());
221    grad.extend(g_emit);
222    grad.extend(g_trans);
223    Ok((total_ll, grad))
224}
225
226/// L-BFGS two-loop recursion direction computation.
227///
228/// Returns the search direction `d` (an *ascent* direction since we maximise
229/// the log-likelihood).
230fn lbfgs_direction(
231    grad: &[f64],
232    s_hist: &[Vec<f64>],
233    y_hist: &[Vec<f64>],
234    rho: &[f64],
235) -> Vec<f64> {
236    let m = s_hist.len();
237    let n = grad.len();
238    let mut q = grad.to_vec();
239    let mut alpha = vec![0.0; m];
240
241    // First loop: i = m-1 .. 0
242    for i in (0..m).rev() {
243        let r = rho[i];
244        let mut dot = 0.0;
245        for k in 0..n {
246            dot += s_hist[i][k] * q[k];
247        }
248        alpha[i] = r * dot;
249        for k in 0..n {
250            q[k] -= alpha[i] * y_hist[i][k];
251        }
252    }
253
254    // Initial Hessian approximation: H₀ = γ I with γ = (s·y)/(y·y)
255    let mut gamma = 1.0;
256    if m > 0 {
257        let last_s = &s_hist[m - 1];
258        let last_y = &y_hist[m - 1];
259        let mut sy = 0.0;
260        let mut yy = 0.0;
261        for k in 0..n {
262            sy += last_s[k] * last_y[k];
263            yy += last_y[k] * last_y[k];
264        }
265        if yy > 1e-30 {
266            gamma = sy / yy;
267        }
268    }
269    let mut r = q;
270    for v in r.iter_mut() {
271        *v *= gamma;
272    }
273
274    // Second loop: i = 0 .. m
275    for i in 0..m {
276        let mut dot = 0.0;
277        for k in 0..n {
278            dot += y_hist[i][k] * r[k];
279        }
280        let beta = rho[i] * dot;
281        for k in 0..n {
282            r[k] += s_hist[i][k] * (alpha[i] - beta);
283        }
284    }
285    r
286}
287
288/// Train a linear-chain CRF by maximising the log-likelihood with L-BFGS.
289///
290/// Returns the final log-likelihood and updates `crf` in place.
291pub fn train_crf_lbfgs(
292    crf: &mut LinearChainCrf,
293    examples: &[(Vec<f64>, Vec<usize>)],
294    cfg: &LbfgsConfig,
295) -> SeqResult<f64> {
296    if examples.is_empty() {
297        return Err(SeqError::EmptyInput);
298    }
299    let n_params = crf.param_count();
300    let mut s_hist: Vec<Vec<f64>> = Vec::with_capacity(cfg.memory);
301    let mut y_hist: Vec<Vec<f64>> = Vec::with_capacity(cfg.memory);
302    let mut rho_hist: Vec<f64> = Vec::with_capacity(cfg.memory);
303
304    let (mut f_val, mut grad) = objective_and_grad(crf, examples, cfg.l2)?;
305
306    for _it in 0..cfg.max_iter {
307        let grad_norm: f64 = grad.iter().map(|g| g * g).sum::<f64>().sqrt();
308        if grad_norm < cfg.grad_tol {
309            break;
310        }
311
312        // Compute direction.  Negative-gradient ascent search direction for the
313        // first iteration when no history yet.
314        let mut dir = if s_hist.is_empty() {
315            grad.clone()
316        } else {
317            lbfgs_direction(&grad, &s_hist, &y_hist, &rho_hist)
318        };
319
320        // Make sure dir is an ascent direction (g·d > 0); otherwise flip to +grad
321        let mut dot_gd: f64 = grad.iter().zip(dir.iter()).map(|(a, b)| a * b).sum();
322        if dot_gd <= 0.0 {
323            dir = grad.clone();
324            dot_gd = grad.iter().map(|g| g * g).sum();
325        }
326        // Normalise direction to a sane magnitude for the first step
327        let dir_norm: f64 = dir.iter().map(|d| d * d).sum::<f64>().sqrt();
328        if dir_norm > 0.0 && s_hist.is_empty() {
329            let scale = 1.0_f64 / dir_norm.max(1.0);
330            for v in dir.iter_mut() {
331                *v *= scale;
332            }
333        }
334
335        // Backtracking line search (Armijo).
336        let armijo = 1e-4_f64;
337        let mut step = 1.0_f64;
338        let p_old = crf.to_params();
339        let mut accepted = false;
340        let mut f_new = f_val;
341        let mut grad_new = grad.clone();
342        for _ls in 0..cfg.max_line_search {
343            let mut p_try = p_old.clone();
344            for k in 0..n_params {
345                p_try[k] = p_old[k] + step * dir[k];
346            }
347            crf.from_params(&p_try)?;
348            let (fc, gc) = objective_and_grad(crf, examples, cfg.l2)?;
349            if fc >= f_val + armijo * step * dot_gd {
350                f_new = fc;
351                grad_new = gc;
352                accepted = true;
353                break;
354            }
355            step *= cfg.backtrack;
356        }
357        if !accepted {
358            // Restore and stop
359            crf.from_params(&p_old)?;
360            return Ok(f_val);
361        }
362
363        // Update L-BFGS history
364        let p_new = crf.to_params();
365        let s_vec: Vec<f64> = p_new
366            .iter()
367            .zip(p_old.iter())
368            .map(|(a, b)| *a - *b)
369            .collect();
370        let y_vec: Vec<f64> = grad_new
371            .iter()
372            .zip(grad.iter())
373            .map(|(a, b)| *a - *b)
374            .collect();
375        let ys: f64 = s_vec.iter().zip(y_vec.iter()).map(|(a, b)| a * b).sum();
376        if ys.abs() > 1e-30 {
377            if s_hist.len() == cfg.memory {
378                s_hist.remove(0);
379                y_hist.remove(0);
380                rho_hist.remove(0);
381            }
382            s_hist.push(s_vec);
383            y_hist.push(y_vec);
384            rho_hist.push(1.0 / ys);
385        }
386        f_val = f_new;
387        grad = grad_new;
388    }
389
390    Ok(f_val)
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn gradient_finite_difference() {
399        let mut crf = LinearChainCrf::zeros(2, 2).expect("ok");
400        crf.emissions = vec![0.5, -0.3, 0.1, 0.4];
401        crf.transitions = vec![0.2, -0.1, -0.4, 0.3];
402
403        let x = vec![1.0, 0.5, 0.0, 1.0, 0.7, 0.2];
404        let y = vec![0usize, 1, 0];
405
406        let (ll0, ge, gt) = crf_log_likelihood_and_gradient(&crf, &x, &y).expect("ok");
407
408        let eps = 1e-5;
409        // Check emission gradient via finite difference.
410        for idx in 0..crf.emissions.len() {
411            let mut c2 = crf.clone();
412            c2.emissions[idx] += eps;
413            let (llp, _, _) = crf_log_likelihood_and_gradient(&c2, &x, &y).expect("ok");
414            let mut c3 = crf.clone();
415            c3.emissions[idx] -= eps;
416            let (llm, _, _) = crf_log_likelihood_and_gradient(&c3, &x, &y).expect("ok");
417            let num = (llp - llm) / (2.0 * eps);
418            assert!(
419                (num - ge[idx]).abs() < 1e-3,
420                "emit grad {idx}: num={num}, ana={}",
421                ge[idx]
422            );
423        }
424        for idx in 0..crf.transitions.len() {
425            let mut c2 = crf.clone();
426            c2.transitions[idx] += eps;
427            let (llp, _, _) = crf_log_likelihood_and_gradient(&c2, &x, &y).expect("ok");
428            let mut c3 = crf.clone();
429            c3.transitions[idx] -= eps;
430            let (llm, _, _) = crf_log_likelihood_and_gradient(&c3, &x, &y).expect("ok");
431            let num = (llp - llm) / (2.0 * eps);
432            assert!(
433                (num - gt[idx]).abs() < 1e-3,
434                "trans grad {idx}: num={num}, ana={}",
435                gt[idx]
436            );
437        }
438        let _ = ll0; // sanity used
439    }
440
441    #[test]
442    fn train_increases_likelihood() {
443        let mut crf = LinearChainCrf::zeros(2, 2).expect("ok");
444        let x1 = vec![1.0, 0.0, 1.0, 0.0, 0.0, 1.0];
445        let y1 = vec![0usize, 0, 1];
446        let x2 = vec![0.0, 1.0, 1.0, 0.0];
447        let y2 = vec![1usize, 0];
448        let examples = vec![(x1, y1), (x2, y2)];
449        let (ll0, _) = objective_and_grad(&crf, &examples, 0.0).expect("ok");
450        let cfg = LbfgsConfig {
451            memory: 3,
452            max_iter: 20,
453            grad_tol: 1e-8,
454            backtrack: 0.5,
455            max_line_search: 20,
456            l2: 0.0,
457        };
458        let ll_final = train_crf_lbfgs(&mut crf, &examples, &cfg).expect("ok");
459        assert!(ll_final >= ll0 - 1e-6, "ll0={ll0}, ll_final={ll_final}");
460    }
461}