1use super::linear_chain_crf::LinearChainCrf;
5use crate::error::{SeqError, SeqResult};
6use crate::hmm::forward_backward::logsumexp;
7
8#[derive(Debug, Clone)]
10pub struct LbfgsConfig {
11 pub memory: usize,
13 pub max_iter: usize,
15 pub grad_tol: f64,
17 pub backtrack: f64,
19 pub max_line_search: usize,
21 pub l2: f64,
23}
24
25impl Default for LbfgsConfig {
26 fn default() -> Self {
27 Self {
28 memory: 5,
29 max_iter: 50,
30 grad_tol: 1e-6,
31 backtrack: 0.5,
32 max_line_search: 30,
33 l2: 1e-3,
34 }
35 }
36}
37
38fn forward_scores(crf: &LinearChainCrf, emit: &[f64]) -> Vec<f64> {
40 let n = crf.n_labels;
41 let t_max = emit.len() / n;
42 let mut alpha = vec![f64::NEG_INFINITY; t_max * n];
43 alpha[..n].copy_from_slice(&emit[..n]);
44 let mut tmp = vec![0.0; n];
45 for t in 1..t_max {
46 for j in 0..n {
47 for i in 0..n {
48 tmp[i] = alpha[(t - 1) * n + i] + crf.transitions[i * n + j];
49 }
50 alpha[t * n + j] = logsumexp(&tmp) + emit[t * n + j];
51 }
52 }
53 alpha
54}
55
56fn backward_scores(crf: &LinearChainCrf, emit: &[f64]) -> Vec<f64> {
59 let n = crf.n_labels;
60 let t_max = emit.len() / n;
61 let mut beta = vec![0.0; t_max * n];
62 let mut tmp = vec![0.0; n];
63 for t in (0..t_max - 1).rev() {
64 for i in 0..n {
65 for j in 0..n {
66 tmp[j] = crf.transitions[i * n + j] + emit[(t + 1) * n + j] + beta[(t + 1) * n + j];
67 }
68 beta[t * n + i] = logsumexp(&tmp);
69 }
70 }
71 beta
72}
73
74pub fn crf_log_likelihood_and_gradient(
79 crf: &LinearChainCrf,
80 x: &[f64],
81 y: &[usize],
82) -> SeqResult<(f64, Vec<f64>, Vec<f64>)> {
83 let n = crf.n_labels;
84 let k = crf.n_features;
85 if y.is_empty() {
86 return Err(SeqError::EmptyInput);
87 }
88 let t_max = y.len();
89 if x.len() != t_max * k {
90 return Err(SeqError::ShapeMismatch {
91 expected: t_max * k,
92 got: x.len(),
93 });
94 }
95
96 let mut emit = vec![0.0; t_max * n];
98 for t in 0..t_max {
99 for j in 0..n {
100 emit[t * n + j] = crf.emit_score(j, &x[t * k..(t + 1) * k])?;
101 }
102 }
103
104 let alpha = forward_scores(crf, &emit);
105 let beta = backward_scores(crf, &emit);
106
107 let last_alpha = &alpha[(t_max - 1) * n..];
109 let log_z = logsumexp(last_alpha);
110
111 let true_score = crf.sequence_score(x, y)?;
113 let ll = true_score - log_z;
114
115 let mut p_node = vec![0.0; t_max * n];
117 for t in 0..t_max {
118 for j in 0..n {
119 p_node[t * n + j] = (alpha[t * n + j] + beta[t * n + j] - log_z).exp();
120 }
121 let s: f64 = p_node[t * n..t * n + n].iter().sum();
122 if s > 0.0 {
123 for v in p_node[t * n..t * n + n].iter_mut() {
124 *v /= s;
125 }
126 }
127 }
128 let mut p_edge = vec![0.0; t_max.saturating_sub(1) * n * n];
130 for t in 0..t_max.saturating_sub(1) {
131 let mut s = 0.0;
132 for i in 0..n {
133 for j in 0..n {
134 let v = (alpha[t * n + i]
135 + crf.transitions[i * n + j]
136 + emit[(t + 1) * n + j]
137 + beta[(t + 1) * n + j]
138 - log_z)
139 .exp();
140 p_edge[t * n * n + i * n + j] = v;
141 s += v;
142 }
143 }
144 if s > 0.0 {
145 for v in p_edge[t * n * n..(t + 1) * n * n].iter_mut() {
146 *v /= s;
147 }
148 }
149 }
150
151 let mut grad_emit = vec![0.0; n * k];
153 let mut grad_trans = vec![0.0; n * n];
154
155 for t in 0..t_max {
157 let yt = y[t];
158 for f in 0..k {
159 grad_emit[yt * k + f] += x[t * k + f];
160 }
161 if t > 0 {
162 grad_trans[y[t - 1] * n + y[t]] += 1.0;
163 }
164 }
165 for t in 0..t_max {
167 for j in 0..n {
168 let p = p_node[t * n + j];
169 for f in 0..k {
170 grad_emit[j * k + f] -= p * x[t * k + f];
171 }
172 }
173 if t < t_max - 1 {
174 for i in 0..n {
175 for j in 0..n {
176 grad_trans[i * n + j] -= p_edge[t * n * n + i * n + j];
177 }
178 }
179 }
180 }
181
182 Ok((ll, grad_emit, grad_trans))
183}
184
185fn objective_and_grad(
187 crf: &LinearChainCrf,
188 examples: &[(Vec<f64>, Vec<usize>)],
189 l2: f64,
190) -> SeqResult<(f64, Vec<f64>)> {
191 let mut total_ll = 0.0;
192 let mut g_emit = vec![0.0; crf.emissions.len()];
193 let mut g_trans = vec![0.0; crf.transitions.len()];
194 for (x, y) in examples {
195 let (ll, ge, gt) = crf_log_likelihood_and_gradient(crf, x, y)?;
196 total_ll += ll;
197 for (a, b) in g_emit.iter_mut().zip(ge.iter()) {
198 *a += *b;
199 }
200 for (a, b) in g_trans.iter_mut().zip(gt.iter()) {
201 *a += *b;
202 }
203 }
204 let mut reg = 0.0;
206 for &e in &crf.emissions {
207 reg += e * e;
208 }
209 for &t in &crf.transitions {
210 reg += t * t;
211 }
212 total_ll -= 0.5 * l2 * reg;
213 for (g, w) in g_emit.iter_mut().zip(crf.emissions.iter()) {
214 *g -= l2 * *w;
215 }
216 for (g, w) in g_trans.iter_mut().zip(crf.transitions.iter()) {
217 *g -= l2 * *w;
218 }
219
220 let mut grad = Vec::with_capacity(g_emit.len() + g_trans.len());
221 grad.extend(g_emit);
222 grad.extend(g_trans);
223 Ok((total_ll, grad))
224}
225
226fn lbfgs_direction(
231 grad: &[f64],
232 s_hist: &[Vec<f64>],
233 y_hist: &[Vec<f64>],
234 rho: &[f64],
235) -> Vec<f64> {
236 let m = s_hist.len();
237 let n = grad.len();
238 let mut q = grad.to_vec();
239 let mut alpha = vec![0.0; m];
240
241 for i in (0..m).rev() {
243 let r = rho[i];
244 let mut dot = 0.0;
245 for k in 0..n {
246 dot += s_hist[i][k] * q[k];
247 }
248 alpha[i] = r * dot;
249 for k in 0..n {
250 q[k] -= alpha[i] * y_hist[i][k];
251 }
252 }
253
254 let mut gamma = 1.0;
256 if m > 0 {
257 let last_s = &s_hist[m - 1];
258 let last_y = &y_hist[m - 1];
259 let mut sy = 0.0;
260 let mut yy = 0.0;
261 for k in 0..n {
262 sy += last_s[k] * last_y[k];
263 yy += last_y[k] * last_y[k];
264 }
265 if yy > 1e-30 {
266 gamma = sy / yy;
267 }
268 }
269 let mut r = q;
270 for v in r.iter_mut() {
271 *v *= gamma;
272 }
273
274 for i in 0..m {
276 let mut dot = 0.0;
277 for k in 0..n {
278 dot += y_hist[i][k] * r[k];
279 }
280 let beta = rho[i] * dot;
281 for k in 0..n {
282 r[k] += s_hist[i][k] * (alpha[i] - beta);
283 }
284 }
285 r
286}
287
288pub fn train_crf_lbfgs(
292 crf: &mut LinearChainCrf,
293 examples: &[(Vec<f64>, Vec<usize>)],
294 cfg: &LbfgsConfig,
295) -> SeqResult<f64> {
296 if examples.is_empty() {
297 return Err(SeqError::EmptyInput);
298 }
299 let n_params = crf.param_count();
300 let mut s_hist: Vec<Vec<f64>> = Vec::with_capacity(cfg.memory);
301 let mut y_hist: Vec<Vec<f64>> = Vec::with_capacity(cfg.memory);
302 let mut rho_hist: Vec<f64> = Vec::with_capacity(cfg.memory);
303
304 let (mut f_val, mut grad) = objective_and_grad(crf, examples, cfg.l2)?;
305
306 for _it in 0..cfg.max_iter {
307 let grad_norm: f64 = grad.iter().map(|g| g * g).sum::<f64>().sqrt();
308 if grad_norm < cfg.grad_tol {
309 break;
310 }
311
312 let mut dir = if s_hist.is_empty() {
315 grad.clone()
316 } else {
317 lbfgs_direction(&grad, &s_hist, &y_hist, &rho_hist)
318 };
319
320 let mut dot_gd: f64 = grad.iter().zip(dir.iter()).map(|(a, b)| a * b).sum();
322 if dot_gd <= 0.0 {
323 dir = grad.clone();
324 dot_gd = grad.iter().map(|g| g * g).sum();
325 }
326 let dir_norm: f64 = dir.iter().map(|d| d * d).sum::<f64>().sqrt();
328 if dir_norm > 0.0 && s_hist.is_empty() {
329 let scale = 1.0_f64 / dir_norm.max(1.0);
330 for v in dir.iter_mut() {
331 *v *= scale;
332 }
333 }
334
335 let armijo = 1e-4_f64;
337 let mut step = 1.0_f64;
338 let p_old = crf.to_params();
339 let mut accepted = false;
340 let mut f_new = f_val;
341 let mut grad_new = grad.clone();
342 for _ls in 0..cfg.max_line_search {
343 let mut p_try = p_old.clone();
344 for k in 0..n_params {
345 p_try[k] = p_old[k] + step * dir[k];
346 }
347 crf.from_params(&p_try)?;
348 let (fc, gc) = objective_and_grad(crf, examples, cfg.l2)?;
349 if fc >= f_val + armijo * step * dot_gd {
350 f_new = fc;
351 grad_new = gc;
352 accepted = true;
353 break;
354 }
355 step *= cfg.backtrack;
356 }
357 if !accepted {
358 crf.from_params(&p_old)?;
360 return Ok(f_val);
361 }
362
363 let p_new = crf.to_params();
365 let s_vec: Vec<f64> = p_new
366 .iter()
367 .zip(p_old.iter())
368 .map(|(a, b)| *a - *b)
369 .collect();
370 let y_vec: Vec<f64> = grad_new
371 .iter()
372 .zip(grad.iter())
373 .map(|(a, b)| *a - *b)
374 .collect();
375 let ys: f64 = s_vec.iter().zip(y_vec.iter()).map(|(a, b)| a * b).sum();
376 if ys.abs() > 1e-30 {
377 if s_hist.len() == cfg.memory {
378 s_hist.remove(0);
379 y_hist.remove(0);
380 rho_hist.remove(0);
381 }
382 s_hist.push(s_vec);
383 y_hist.push(y_vec);
384 rho_hist.push(1.0 / ys);
385 }
386 f_val = f_new;
387 grad = grad_new;
388 }
389
390 Ok(f_val)
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn gradient_finite_difference() {
399 let mut crf = LinearChainCrf::zeros(2, 2).expect("ok");
400 crf.emissions = vec![0.5, -0.3, 0.1, 0.4];
401 crf.transitions = vec![0.2, -0.1, -0.4, 0.3];
402
403 let x = vec![1.0, 0.5, 0.0, 1.0, 0.7, 0.2];
404 let y = vec![0usize, 1, 0];
405
406 let (ll0, ge, gt) = crf_log_likelihood_and_gradient(&crf, &x, &y).expect("ok");
407
408 let eps = 1e-5;
409 for idx in 0..crf.emissions.len() {
411 let mut c2 = crf.clone();
412 c2.emissions[idx] += eps;
413 let (llp, _, _) = crf_log_likelihood_and_gradient(&c2, &x, &y).expect("ok");
414 let mut c3 = crf.clone();
415 c3.emissions[idx] -= eps;
416 let (llm, _, _) = crf_log_likelihood_and_gradient(&c3, &x, &y).expect("ok");
417 let num = (llp - llm) / (2.0 * eps);
418 assert!(
419 (num - ge[idx]).abs() < 1e-3,
420 "emit grad {idx}: num={num}, ana={}",
421 ge[idx]
422 );
423 }
424 for idx in 0..crf.transitions.len() {
425 let mut c2 = crf.clone();
426 c2.transitions[idx] += eps;
427 let (llp, _, _) = crf_log_likelihood_and_gradient(&c2, &x, &y).expect("ok");
428 let mut c3 = crf.clone();
429 c3.transitions[idx] -= eps;
430 let (llm, _, _) = crf_log_likelihood_and_gradient(&c3, &x, &y).expect("ok");
431 let num = (llp - llm) / (2.0 * eps);
432 assert!(
433 (num - gt[idx]).abs() < 1e-3,
434 "trans grad {idx}: num={num}, ana={}",
435 gt[idx]
436 );
437 }
438 let _ = ll0; }
440
441 #[test]
442 fn train_increases_likelihood() {
443 let mut crf = LinearChainCrf::zeros(2, 2).expect("ok");
444 let x1 = vec![1.0, 0.0, 1.0, 0.0, 0.0, 1.0];
445 let y1 = vec![0usize, 0, 1];
446 let x2 = vec![0.0, 1.0, 1.0, 0.0];
447 let y2 = vec![1usize, 0];
448 let examples = vec![(x1, y1), (x2, y2)];
449 let (ll0, _) = objective_and_grad(&crf, &examples, 0.0).expect("ok");
450 let cfg = LbfgsConfig {
451 memory: 3,
452 max_iter: 20,
453 grad_tol: 1e-8,
454 backtrack: 0.5,
455 max_line_search: 20,
456 l2: 0.0,
457 };
458 let ll_final = train_crf_lbfgs(&mut crf, &examples, &cfg).expect("ok");
459 assert!(ll_final >= ll0 - 1e-6, "ll0={ll0}, ll_final={ll_final}");
460 }
461}