1use core::{num::NonZeroU32, ops::Range};
2
3use alloc::vec::Vec;
4
5use std::sync::Mutex;
6use std::thread;
7
8use hashbrown::{HashMap, HashSet, hash_map::RawEntryMut};
9
10use crate::errors::{Result, RucrfError};
11use crate::feature::FeatureProvider;
12use crate::forward_backward;
13use crate::lattice::Lattice;
14use crate::model::RawModel;
15use crate::optimizers::lbfgs;
16use crate::utils::FromU32;
17
18pub struct LatticesLoss<'a> {
19 pub lattices: &'a [Lattice],
20 provider: &'a FeatureProvider,
21 unigram_weight_indices: &'a [Option<NonZeroU32>],
22 bigram_weight_indices: &'a [HashMap<u32, u32>],
23 n_threads: usize,
24 l2_lambda: Option<f64>,
25}
26
27impl<'a> LatticesLoss<'a> {
28 pub const fn new(
29 lattices: &'a [Lattice],
30 provider: &'a FeatureProvider,
31 unigram_weight_indices: &'a [Option<NonZeroU32>],
32 bigram_weight_indices: &'a [HashMap<u32, u32>],
33 n_threads: usize,
34 l2_lambda: Option<f64>,
35 ) -> Self {
36 Self {
37 lattices,
38 provider,
39 unigram_weight_indices,
40 bigram_weight_indices,
41 n_threads,
42 l2_lambda,
43 }
44 }
45
46 pub fn gradient_partial(&self, param: &[f64], range: Range<usize>) -> Vec<f64> {
47 let (s, r) = crossbeam_channel::unbounded();
48 for lattice in &self.lattices[range] {
49 s.send(lattice).unwrap();
50 }
51 let gradients = Mutex::new(vec![0.0; param.len()]);
52 thread::scope(|scope| {
53 for _ in 0..self.n_threads {
54 scope.spawn(|| {
55 let mut alphas = vec![];
56 let mut betas = vec![];
57 let mut local_gradients = vec![0.0; param.len()];
58 while let Ok(lattice) = r.try_recv() {
59 let z = forward_backward::calculate_alphas_betas(
60 lattice,
61 self.provider,
62 param,
63 self.unigram_weight_indices,
64 self.bigram_weight_indices,
65 &mut alphas,
66 &mut betas,
67 );
68 forward_backward::update_gradient(
69 lattice,
70 self.provider,
71 param,
72 self.unigram_weight_indices,
73 self.bigram_weight_indices,
74 &alphas,
75 &betas,
76 z,
77 &mut local_gradients,
78 );
79 }
80 #[allow(clippy::significant_drop_in_scrutinee)]
81 for (y, x) in gradients.lock().unwrap().iter_mut().zip(local_gradients) {
82 *y += x;
83 }
84 });
85 }
86 });
87 let mut gradients = gradients.into_inner().unwrap();
88
89 if let Some(lambda) = self.l2_lambda {
90 for (g, p) in gradients.iter_mut().zip(param) {
91 *g += lambda * *p;
92 }
93 }
94
95 gradients
96 }
97
98 pub fn cost(&self, param: &[f64]) -> f64 {
99 let (s, r) = crossbeam_channel::unbounded();
100 for lattice in self.lattices {
101 s.send(lattice).unwrap();
102 }
103 let mut loss_total = thread::scope(|scope| {
104 let mut threads = vec![];
105 for _ in 0..self.n_threads {
106 let t = scope.spawn(|| {
107 let mut alphas = vec![];
108 let mut betas = vec![];
109 let mut loss_total = 0.0;
110 while let Ok(lattice) = r.try_recv() {
111 let z = forward_backward::calculate_alphas_betas(
112 lattice,
113 self.provider,
114 param,
115 self.unigram_weight_indices,
116 self.bigram_weight_indices,
117 &mut alphas,
118 &mut betas,
119 );
120 let loss = forward_backward::calculate_loss(
121 lattice,
122 self.provider,
123 param,
124 self.unigram_weight_indices,
125 self.bigram_weight_indices,
126 z,
127 );
128 loss_total += loss;
129 }
130 loss_total
131 });
132 threads.push(t);
133 }
134 let mut loss_total = 0.0;
135 for t in threads {
136 let loss = t.join().unwrap();
137 loss_total += loss;
138 }
139 loss_total
140 });
141
142 if let Some(lambda) = self.l2_lambda {
143 let mut norm2 = 0.0;
144 for &p in param {
145 norm2 += p * p;
146 }
147 loss_total += lambda * norm2 * 0.5;
148 }
149
150 loss_total
151 }
152}
153
154#[cfg_attr(docsrs, doc(cfg(feature = "train")))]
156#[derive(Copy, Clone, PartialEq)]
157pub enum Regularization {
158 L1,
160
161 L2,
163
164 ElasticNet {
168 l1_ratio: f64,
170 },
171}
172
173#[cfg_attr(docsrs, doc(cfg(feature = "train")))]
175pub struct Trainer {
176 max_iter: u64,
177 n_threads: usize,
178 regularization: Regularization,
179 lambda: f64,
180}
181
182impl Trainer {
183 #[must_use]
185 pub const fn new() -> Self {
186 Self {
187 max_iter: 100,
188 n_threads: 1,
189 regularization: Regularization::L1,
190 lambda: 0.1,
191 }
192 }
193
194 pub const fn max_iter(mut self, max_iter: u64) -> Result<Self> {
200 if max_iter == 0 {
201 return Err(RucrfError::invalid_argument("max_iter must be >= 1"));
202 }
203 self.max_iter = max_iter;
204 Ok(self)
205 }
206
207 pub fn regularization(mut self, regularization: Regularization, lambda: f64) -> Result<Self> {
213 if lambda < 0.0 {
214 return Err(RucrfError::invalid_argument("lambda must be >= 0"));
215 }
216 if let Regularization::ElasticNet { l1_ratio } = regularization {
217 if !(0.0..=1.0).contains(&l1_ratio) {
218 return Err(RucrfError::invalid_argument(
219 "l1_ratio must be between 0.0 and 1.0",
220 ));
221 }
222 }
223 self.regularization = regularization;
224 self.lambda = lambda;
225 Ok(self)
226 }
227
228 pub const fn n_threads(mut self, n_threads: usize) -> Result<Self> {
234 if n_threads == 0 {
235 return Err(RucrfError::invalid_argument("n_thread must be >= 1"));
236 }
237 self.n_threads = n_threads;
238 Ok(self)
239 }
240
241 #[inline(always)]
242 fn update_unigram_feature(
243 provider: &FeatureProvider,
244 label: NonZeroU32,
245 unigram_weight_indices: &mut Vec<Option<NonZeroU32>>,
246 weights: &mut Vec<f64>,
247 ) {
248 if let Some(feature_set) = provider.get_feature_set(label) {
249 for &fid in feature_set.unigram() {
250 let fid = usize::from_u32(fid.get() - 1);
251 if unigram_weight_indices.len() <= fid + 1 {
252 unigram_weight_indices.resize(fid + 1, None);
253 }
254 if unigram_weight_indices[fid].is_none() {
255 unigram_weight_indices[fid] =
256 Some(NonZeroU32::new(u32::try_from(weights.len()).unwrap() + 1).unwrap());
257 weights.push(0.0);
258 }
259 }
260 }
261 }
262
263 #[inline(always)]
264 fn update_bigram_feature(
265 provider: &FeatureProvider,
266 left_label: Option<NonZeroU32>,
267 right_label: Option<NonZeroU32>,
268 bigram_weight_indices: &mut Vec<HashMap<u32, u32>>,
269 weights: &mut Vec<f64>,
270 ) {
271 match (left_label, right_label) {
272 (Some(left_label), Some(right_label)) => {
273 if let (Some(left_feature_set), Some(right_feature_set)) = (
274 provider.get_feature_set(left_label),
275 provider.get_feature_set(right_label),
276 ) {
277 let left_features = left_feature_set.bigram_left();
278 let right_features = right_feature_set.bigram_right();
279 for (left_fid, right_fid) in left_features.iter().zip(right_features) {
280 if let (Some(left_fid), Some(right_fid)) = (left_fid, right_fid) {
281 let left_fid = usize::try_from(left_fid.get()).unwrap();
282 let right_fid = right_fid.get();
283 if bigram_weight_indices.len() <= left_fid {
284 bigram_weight_indices.resize(left_fid + 1, HashMap::new());
285 }
286 let features = &mut bigram_weight_indices[left_fid];
287 if let RawEntryMut::Vacant(v) =
288 features.raw_entry_mut().from_key(&right_fid)
289 {
290 v.insert(right_fid, u32::try_from(weights.len()).unwrap());
291 weights.push(0.0);
292 }
293 }
294 }
295 }
296 }
297 (Some(left_label), None) => {
298 if let Some(feature_set) = provider.get_feature_set(left_label) {
299 for left_fid in feature_set.bigram_left().iter().flatten() {
300 let left_fid = usize::try_from(left_fid.get()).unwrap();
301 if bigram_weight_indices.len() <= left_fid {
302 bigram_weight_indices.resize(left_fid + 1, HashMap::new());
303 }
304 let features = &mut bigram_weight_indices[left_fid];
305 if let RawEntryMut::Vacant(v) = features.raw_entry_mut().from_key(&0) {
306 v.insert(0, u32::try_from(weights.len()).unwrap());
307 weights.push(0.0);
308 }
309 }
310 }
311 }
312 (None, Some(right_label)) => {
313 if let Some(feature_set) = provider.get_feature_set(right_label) {
314 for right_fid in feature_set.bigram_right().iter().flatten() {
315 let right_fid = right_fid.get();
316 if bigram_weight_indices.is_empty() {
317 bigram_weight_indices.resize(1, HashMap::new());
318 }
319 let features = &mut bigram_weight_indices[0];
320 if let RawEntryMut::Vacant(v) =
321 features.raw_entry_mut().from_key(&right_fid)
322 {
323 v.insert(right_fid, u32::try_from(weights.len()).unwrap());
324 weights.push(0.0);
325 }
326 }
327 }
328 }
329 _ => unreachable!(),
330 }
331 }
332
333 fn update_features(
334 lattice: &Lattice,
335 provider: &FeatureProvider,
336 unigram_weight_indices: &mut Vec<Option<NonZeroU32>>,
337 bigram_weight_indices: &mut Vec<HashMap<u32, u32>>,
338 weights: &mut Vec<f64>,
339 ) {
340 for (i, node) in lattice.nodes().iter().enumerate() {
341 if i == 0 {
342 for curr_edge in node.edges() {
343 Self::update_bigram_feature(
344 provider,
345 None,
346 Some(curr_edge.label),
347 bigram_weight_indices,
348 weights,
349 );
350 }
351 }
352 for curr_edge in node.edges() {
353 for next_edge in lattice.nodes()[curr_edge.target()].edges() {
354 Self::update_bigram_feature(
355 provider,
356 Some(curr_edge.label),
357 Some(next_edge.label),
358 bigram_weight_indices,
359 weights,
360 );
361 }
362 if curr_edge.target() == lattice.nodes().len() - 1 {
363 Self::update_bigram_feature(
364 provider,
365 Some(curr_edge.label),
366 None,
367 bigram_weight_indices,
368 weights,
369 );
370 }
371 Self::update_unigram_feature(
372 provider,
373 curr_edge.label,
374 unigram_weight_indices,
375 weights,
376 );
377 }
378 }
379 }
380
381 #[allow(clippy::missing_panics_doc)]
383 #[must_use]
384 pub fn train(&self, lattices: &[Lattice], mut provider: FeatureProvider) -> RawModel {
385 let mut unigram_weight_indices = vec![];
386 let mut bigram_weight_indices = vec![];
387 let mut weights_init = vec![];
388
389 for lattice in lattices {
390 Self::update_features(
391 lattice,
392 &provider,
393 &mut unigram_weight_indices,
394 &mut bigram_weight_indices,
395 &mut weights_init,
396 );
397 }
398
399 let weights = lbfgs::optimize(
400 lattices,
401 &provider,
402 &unigram_weight_indices,
403 &bigram_weight_indices,
404 &weights_init,
405 self.regularization,
406 self.lambda,
407 self.max_iter,
408 self.n_threads,
409 );
410
411 let mut weight_id_map = HashMap::new();
413 let mut new_weights = vec![];
414 for (i, w) in weights.into_iter().enumerate() {
415 if w.abs() < f64::EPSILON {
416 continue;
417 }
418 weight_id_map.insert(
419 u32::try_from(i).unwrap(),
420 u32::try_from(new_weights.len()).unwrap(),
421 );
422 new_weights.push(w);
423 }
424 let mut new_unigram_weight_indices = vec![];
425 for old_idx in unigram_weight_indices {
426 new_unigram_weight_indices.push(old_idx.and_then(|old_idx| {
427 weight_id_map
428 .get(&(old_idx.get() - 1))
429 .and_then(|&new_idx| NonZeroU32::new(new_idx + 1))
430 }));
431 }
432 let mut new_bigram_weight_indices = vec![];
433 let mut right_id_used = HashSet::new();
434 for fids in bigram_weight_indices {
435 let mut new_fids = HashMap::new();
436 for (k, v) in fids {
437 if let Some(&v) = weight_id_map.get(&v) {
438 new_fids.insert(k, v);
439 right_id_used.insert(k);
440 }
441 }
442 new_bigram_weight_indices.push(new_fids);
443 }
444
445 for feature_set in &mut provider.feature_sets {
446 let mut new_unigram = vec![];
447 for &fid in feature_set.unigram() {
448 if new_unigram_weight_indices
449 .get(usize::from_u32(fid.get() - 1))
450 .copied()
451 .flatten()
452 .is_some()
453 {
454 new_unigram.push(fid);
455 }
456 }
457 feature_set.unigram = new_unigram;
458 for fid in &mut feature_set.bigram_left {
459 *fid = fid.filter(|fid| {
460 !new_bigram_weight_indices
461 .get(usize::from_u32(fid.get()))
462 .is_none_or(HashMap::is_empty)
463 });
464 }
465 for fid in &mut feature_set.bigram_right {
466 *fid = fid.filter(|fid| right_id_used.contains(&fid.get()));
467 }
468 }
469
470 RawModel::new(
471 new_weights,
472 new_unigram_weight_indices,
473 new_bigram_weight_indices,
474 provider,
475 )
476 }
477}
478
479impl Default for Trainer {
480 fn default() -> Self {
481 Self::new()
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 use crate::test_utils::{self, hashmap, logsumexp};
490
491 #[test]
522 fn test_loss() {
523 let weights = vec![
524 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 46.0,
525 17.0, 18.0, 19.0, 20.0, 21.0, 42.0, 13.0, 24.0, 5.0, 26.0, 27.0, 6.0,
526 ];
527 let provider = test_utils::generate_test_feature_provider();
528 let lattices = vec![test_utils::generate_test_lattice()];
529 let unigram_weight_indices = vec![
530 NonZeroU32::new(2),
531 NonZeroU32::new(4),
532 NonZeroU32::new(6),
533 NonZeroU32::new(8),
534 ];
535 let bigram_weight_indices = vec![
536 hashmap![0 => 28, 1 => 0, 2 => 2, 3 => 4, 4 => 6],
537 hashmap![0 => 8, 1 => 9, 2 => 10, 3 => 11, 4 => 12],
538 hashmap![0 => 13, 1 => 14, 2 => 15, 3 => 16, 4 => 17],
539 hashmap![0 => 18, 1 => 19, 2 => 20, 3 => 21, 4 => 22],
540 hashmap![0 => 23, 1 => 24, 2 => 25, 3 => 26, 4 => 27],
541 ];
542 let loss_function = LatticesLoss::new(
543 &lattices,
544 &provider,
545 &unigram_weight_indices,
546 &bigram_weight_indices,
547 1,
548 None,
549 );
550
551 let expected = logsumexp!(184.0, 194.0, 186.0, 176.0) - 184.0;
552 let result = loss_function.cost(&weights);
553
554 assert!((expected - result).abs() < f64::EPSILON);
555 }
556
557 #[test]
558 fn test_gradient() {
559 let weights = vec![
560 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 46.0,
561 17.0, 18.0, 19.0, 20.0, 21.0, 42.0, 13.0, 24.0, 5.0, 26.0, 27.0, 6.0,
562 ];
563 let provider = test_utils::generate_test_feature_provider();
564 let lattices = vec![test_utils::generate_test_lattice()];
565 let unigram_weight_indices = vec![
566 NonZeroU32::new(2),
567 NonZeroU32::new(4),
568 NonZeroU32::new(6),
569 NonZeroU32::new(8),
570 ];
571 let bigram_weight_indices = vec![
572 hashmap![0 => 28, 1 => 0, 2 => 2, 3 => 4, 4 => 6],
573 hashmap![0 => 8, 1 => 9, 2 => 10, 3 => 11, 4 => 12],
574 hashmap![0 => 13, 1 => 14, 2 => 15, 3 => 16, 4 => 17],
575 hashmap![0 => 18, 1 => 19, 2 => 20, 3 => 21, 4 => 22],
576 hashmap![0 => 23, 1 => 24, 2 => 25, 3 => 26, 4 => 27],
577 ];
578 let loss_function = LatticesLoss::new(
579 &lattices,
580 &provider,
581 &unigram_weight_indices,
582 &bigram_weight_indices,
583 1,
584 None,
585 );
586
587 let z = logsumexp!(184.0, 194.0, 186.0, 176.0);
588 let prob1 = (184.0 - z).exp();
589 let prob2 = (194.0 - z).exp();
590 let prob3 = (186.0 - z).exp();
591 let prob4 = (176.0 - z).exp();
592
593 let mut expected = vec![0.0; 29];
594 for i in [1, 3, 5, 7, 1, 5, 7, 1] {
596 expected[i] -= 1.0;
597 }
598 for i in [1, 3, 5, 7, 1, 5, 7, 1] {
599 expected[i] += prob1;
600 }
601 for i in [1, 3, 5, 7, 1, 7, 3, 5, 7, 1] {
602 expected[i] += prob2;
603 }
604 for i in [3, 5, 1, 5, 7, 1] {
605 expected[i] += prob3;
606 }
607 for i in [3, 5, 1, 7, 3, 5, 7, 1] {
608 expected[i] += prob4;
609 }
610 for i in [0, 2, 12, 16, 20, 26, 10, 19, 8, 23] {
612 expected[i] -= 1.0;
613 }
614 for i in [0, 2, 12, 16, 20, 26, 10, 19, 8, 23] {
615 expected[i] += prob1;
616 }
617 for i in [0, 2, 12, 16, 22, 24, 16, 27, 25, 9, 8, 23] {
618 expected[i] += prob2;
619 }
620 for i in [2, 2, 15, 21, 10, 19, 8, 23] {
621 expected[i] += prob3;
622 }
623 for i in [2, 2, 17, 19, 16, 27, 25, 9, 8, 23] {
624 expected[i] += prob4;
625 }
626
627 let result = loss_function.gradient_partial(&weights, 0..lattices.len());
628
629 let norm = expected
630 .iter()
631 .zip(&result)
632 .fold(0.0, |acc, (a, b)| acc + (a - b).abs());
633
634 assert!(norm < 1e-12);
635 }
636}