Skip to main content

lindera_crf/
trainer.rs

1use core::{num::NonZeroU32, ops::Range};
2
3use alloc::vec::Vec;
4
5use std::sync::Mutex;
6use std::thread;
7
8use hashbrown::{HashMap, HashSet, hash_map::RawEntryMut};
9
10use crate::errors::{Result, RucrfError};
11use crate::feature::FeatureProvider;
12use crate::forward_backward;
13use crate::lattice::Lattice;
14use crate::model::RawModel;
15use crate::optimizers::lbfgs;
16use crate::utils::FromU32;
17
18pub struct LatticesLoss<'a> {
19    pub lattices: &'a [Lattice],
20    provider: &'a FeatureProvider,
21    unigram_weight_indices: &'a [Option<NonZeroU32>],
22    bigram_weight_indices: &'a [HashMap<u32, u32>],
23    n_threads: usize,
24    l2_lambda: Option<f64>,
25}
26
27impl<'a> LatticesLoss<'a> {
28    pub const fn new(
29        lattices: &'a [Lattice],
30        provider: &'a FeatureProvider,
31        unigram_weight_indices: &'a [Option<NonZeroU32>],
32        bigram_weight_indices: &'a [HashMap<u32, u32>],
33        n_threads: usize,
34        l2_lambda: Option<f64>,
35    ) -> Self {
36        Self {
37            lattices,
38            provider,
39            unigram_weight_indices,
40            bigram_weight_indices,
41            n_threads,
42            l2_lambda,
43        }
44    }
45
46    pub fn gradient_partial(&self, param: &[f64], range: Range<usize>) -> Vec<f64> {
47        let (s, r) = crossbeam_channel::unbounded();
48        for lattice in &self.lattices[range] {
49            s.send(lattice).unwrap();
50        }
51        let gradients = Mutex::new(vec![0.0; param.len()]);
52        thread::scope(|scope| {
53            for _ in 0..self.n_threads {
54                scope.spawn(|| {
55                    let mut alphas = vec![];
56                    let mut betas = vec![];
57                    let mut local_gradients = vec![0.0; param.len()];
58                    while let Ok(lattice) = r.try_recv() {
59                        let z = forward_backward::calculate_alphas_betas(
60                            lattice,
61                            self.provider,
62                            param,
63                            self.unigram_weight_indices,
64                            self.bigram_weight_indices,
65                            &mut alphas,
66                            &mut betas,
67                        );
68                        forward_backward::update_gradient(
69                            lattice,
70                            self.provider,
71                            param,
72                            self.unigram_weight_indices,
73                            self.bigram_weight_indices,
74                            &alphas,
75                            &betas,
76                            z,
77                            &mut local_gradients,
78                        );
79                    }
80                    #[allow(clippy::significant_drop_in_scrutinee)]
81                    for (y, x) in gradients.lock().unwrap().iter_mut().zip(local_gradients) {
82                        *y += x;
83                    }
84                });
85            }
86        });
87        let mut gradients = gradients.into_inner().unwrap();
88
89        if let Some(lambda) = self.l2_lambda {
90            for (g, p) in gradients.iter_mut().zip(param) {
91                *g += lambda * *p;
92            }
93        }
94
95        gradients
96    }
97
98    pub fn cost(&self, param: &[f64]) -> f64 {
99        let (s, r) = crossbeam_channel::unbounded();
100        for lattice in self.lattices {
101            s.send(lattice).unwrap();
102        }
103        let mut loss_total = thread::scope(|scope| {
104            let mut threads = vec![];
105            for _ in 0..self.n_threads {
106                let t = scope.spawn(|| {
107                    let mut alphas = vec![];
108                    let mut betas = vec![];
109                    let mut loss_total = 0.0;
110                    while let Ok(lattice) = r.try_recv() {
111                        let z = forward_backward::calculate_alphas_betas(
112                            lattice,
113                            self.provider,
114                            param,
115                            self.unigram_weight_indices,
116                            self.bigram_weight_indices,
117                            &mut alphas,
118                            &mut betas,
119                        );
120                        let loss = forward_backward::calculate_loss(
121                            lattice,
122                            self.provider,
123                            param,
124                            self.unigram_weight_indices,
125                            self.bigram_weight_indices,
126                            z,
127                        );
128                        loss_total += loss;
129                    }
130                    loss_total
131                });
132                threads.push(t);
133            }
134            let mut loss_total = 0.0;
135            for t in threads {
136                let loss = t.join().unwrap();
137                loss_total += loss;
138            }
139            loss_total
140        });
141
142        if let Some(lambda) = self.l2_lambda {
143            let mut norm2 = 0.0;
144            for &p in param {
145                norm2 += p * p;
146            }
147            loss_total += lambda * norm2 * 0.5;
148        }
149
150        loss_total
151    }
152}
153
154/// L1- or L2- regularization settings
155#[cfg_attr(docsrs, doc(cfg(feature = "train")))]
156#[derive(Copy, Clone, PartialEq)]
157pub enum Regularization {
158    /// Performs L1-regularization.
159    L1,
160
161    /// Performs L2-regularization.
162    L2,
163
164    /// Performs Elastic Net regularization (L1 + L2 combination).
165    /// The parameter `l1_ratio` controls the mix: 1.0 = pure L1, 0.0 = pure L2.
166    /// L1 penalty = lambda * l1_ratio, L2 penalty = lambda * (1 - l1_ratio).
167    ElasticNet {
168        /// Ratio of L1 vs L2 penalty (0.0 to 1.0).
169        l1_ratio: f64,
170    },
171}
172
173/// CRF trainer.
174#[cfg_attr(docsrs, doc(cfg(feature = "train")))]
175pub struct Trainer {
176    max_iter: u64,
177    n_threads: usize,
178    regularization: Regularization,
179    lambda: f64,
180}
181
182impl Trainer {
183    /// Creates a new trainer.
184    #[must_use]
185    pub const fn new() -> Self {
186        Self {
187            max_iter: 100,
188            n_threads: 1,
189            regularization: Regularization::L1,
190            lambda: 0.1,
191        }
192    }
193
194    /// Sets the maximum number of iterations.
195    ///
196    /// # Errors
197    ///
198    /// `max_iter` must be >= 1.
199    pub const fn max_iter(mut self, max_iter: u64) -> Result<Self> {
200        if max_iter == 0 {
201            return Err(RucrfError::invalid_argument("max_iter must be >= 1"));
202        }
203        self.max_iter = max_iter;
204        Ok(self)
205    }
206
207    /// Sets regularization settings.
208    ///
209    /// # Errors
210    ///
211    /// `lambda` must be >= 0. For `ElasticNet`, `l1_ratio` must be in [0, 1].
212    pub fn regularization(mut self, regularization: Regularization, lambda: f64) -> Result<Self> {
213        if lambda < 0.0 {
214            return Err(RucrfError::invalid_argument("lambda must be >= 0"));
215        }
216        if let Regularization::ElasticNet { l1_ratio } = regularization {
217            if !(0.0..=1.0).contains(&l1_ratio) {
218                return Err(RucrfError::invalid_argument(
219                    "l1_ratio must be between 0.0 and 1.0",
220                ));
221            }
222        }
223        self.regularization = regularization;
224        self.lambda = lambda;
225        Ok(self)
226    }
227
228    /// Sets the number of threads.
229    ///
230    /// # Errors
231    ///
232    /// `n_threads` must be >= 1.
233    pub const fn n_threads(mut self, n_threads: usize) -> Result<Self> {
234        if n_threads == 0 {
235            return Err(RucrfError::invalid_argument("n_thread must be >= 1"));
236        }
237        self.n_threads = n_threads;
238        Ok(self)
239    }
240
241    #[inline(always)]
242    fn update_unigram_feature(
243        provider: &FeatureProvider,
244        label: NonZeroU32,
245        unigram_weight_indices: &mut Vec<Option<NonZeroU32>>,
246        weights: &mut Vec<f64>,
247    ) {
248        if let Some(feature_set) = provider.get_feature_set(label) {
249            for &fid in feature_set.unigram() {
250                let fid = usize::from_u32(fid.get() - 1);
251                if unigram_weight_indices.len() <= fid + 1 {
252                    unigram_weight_indices.resize(fid + 1, None);
253                }
254                if unigram_weight_indices[fid].is_none() {
255                    unigram_weight_indices[fid] =
256                        Some(NonZeroU32::new(u32::try_from(weights.len()).unwrap() + 1).unwrap());
257                    weights.push(0.0);
258                }
259            }
260        }
261    }
262
263    #[inline(always)]
264    fn update_bigram_feature(
265        provider: &FeatureProvider,
266        left_label: Option<NonZeroU32>,
267        right_label: Option<NonZeroU32>,
268        bigram_weight_indices: &mut Vec<HashMap<u32, u32>>,
269        weights: &mut Vec<f64>,
270    ) {
271        match (left_label, right_label) {
272            (Some(left_label), Some(right_label)) => {
273                if let (Some(left_feature_set), Some(right_feature_set)) = (
274                    provider.get_feature_set(left_label),
275                    provider.get_feature_set(right_label),
276                ) {
277                    let left_features = left_feature_set.bigram_left();
278                    let right_features = right_feature_set.bigram_right();
279                    for (left_fid, right_fid) in left_features.iter().zip(right_features) {
280                        if let (Some(left_fid), Some(right_fid)) = (left_fid, right_fid) {
281                            let left_fid = usize::try_from(left_fid.get()).unwrap();
282                            let right_fid = right_fid.get();
283                            if bigram_weight_indices.len() <= left_fid {
284                                bigram_weight_indices.resize(left_fid + 1, HashMap::new());
285                            }
286                            let features = &mut bigram_weight_indices[left_fid];
287                            if let RawEntryMut::Vacant(v) =
288                                features.raw_entry_mut().from_key(&right_fid)
289                            {
290                                v.insert(right_fid, u32::try_from(weights.len()).unwrap());
291                                weights.push(0.0);
292                            }
293                        }
294                    }
295                }
296            }
297            (Some(left_label), None) => {
298                if let Some(feature_set) = provider.get_feature_set(left_label) {
299                    for left_fid in feature_set.bigram_left().iter().flatten() {
300                        let left_fid = usize::try_from(left_fid.get()).unwrap();
301                        if bigram_weight_indices.len() <= left_fid {
302                            bigram_weight_indices.resize(left_fid + 1, HashMap::new());
303                        }
304                        let features = &mut bigram_weight_indices[left_fid];
305                        if let RawEntryMut::Vacant(v) = features.raw_entry_mut().from_key(&0) {
306                            v.insert(0, u32::try_from(weights.len()).unwrap());
307                            weights.push(0.0);
308                        }
309                    }
310                }
311            }
312            (None, Some(right_label)) => {
313                if let Some(feature_set) = provider.get_feature_set(right_label) {
314                    for right_fid in feature_set.bigram_right().iter().flatten() {
315                        let right_fid = right_fid.get();
316                        if bigram_weight_indices.is_empty() {
317                            bigram_weight_indices.resize(1, HashMap::new());
318                        }
319                        let features = &mut bigram_weight_indices[0];
320                        if let RawEntryMut::Vacant(v) =
321                            features.raw_entry_mut().from_key(&right_fid)
322                        {
323                            v.insert(right_fid, u32::try_from(weights.len()).unwrap());
324                            weights.push(0.0);
325                        }
326                    }
327                }
328            }
329            _ => unreachable!(),
330        }
331    }
332
333    fn update_features(
334        lattice: &Lattice,
335        provider: &FeatureProvider,
336        unigram_weight_indices: &mut Vec<Option<NonZeroU32>>,
337        bigram_weight_indices: &mut Vec<HashMap<u32, u32>>,
338        weights: &mut Vec<f64>,
339    ) {
340        for (i, node) in lattice.nodes().iter().enumerate() {
341            if i == 0 {
342                for curr_edge in node.edges() {
343                    Self::update_bigram_feature(
344                        provider,
345                        None,
346                        Some(curr_edge.label),
347                        bigram_weight_indices,
348                        weights,
349                    );
350                }
351            }
352            for curr_edge in node.edges() {
353                for next_edge in lattice.nodes()[curr_edge.target()].edges() {
354                    Self::update_bigram_feature(
355                        provider,
356                        Some(curr_edge.label),
357                        Some(next_edge.label),
358                        bigram_weight_indices,
359                        weights,
360                    );
361                }
362                if curr_edge.target() == lattice.nodes().len() - 1 {
363                    Self::update_bigram_feature(
364                        provider,
365                        Some(curr_edge.label),
366                        None,
367                        bigram_weight_indices,
368                        weights,
369                    );
370                }
371                Self::update_unigram_feature(
372                    provider,
373                    curr_edge.label,
374                    unigram_weight_indices,
375                    weights,
376                );
377            }
378        }
379    }
380
381    /// Trains a model from the given dataset.
382    #[allow(clippy::missing_panics_doc)]
383    #[must_use]
384    pub fn train(&self, lattices: &[Lattice], mut provider: FeatureProvider) -> RawModel {
385        let mut unigram_weight_indices = vec![];
386        let mut bigram_weight_indices = vec![];
387        let mut weights_init = vec![];
388
389        for lattice in lattices {
390            Self::update_features(
391                lattice,
392                &provider,
393                &mut unigram_weight_indices,
394                &mut bigram_weight_indices,
395                &mut weights_init,
396            );
397        }
398
399        let weights = lbfgs::optimize(
400            lattices,
401            &provider,
402            &unigram_weight_indices,
403            &bigram_weight_indices,
404            &weights_init,
405            self.regularization,
406            self.lambda,
407            self.max_iter,
408            self.n_threads,
409        );
410
411        // Removes zero weighted features
412        let mut weight_id_map = HashMap::new();
413        let mut new_weights = vec![];
414        for (i, w) in weights.into_iter().enumerate() {
415            if w.abs() < f64::EPSILON {
416                continue;
417            }
418            weight_id_map.insert(
419                u32::try_from(i).unwrap(),
420                u32::try_from(new_weights.len()).unwrap(),
421            );
422            new_weights.push(w);
423        }
424        let mut new_unigram_weight_indices = vec![];
425        for old_idx in unigram_weight_indices {
426            new_unigram_weight_indices.push(old_idx.and_then(|old_idx| {
427                weight_id_map
428                    .get(&(old_idx.get() - 1))
429                    .and_then(|&new_idx| NonZeroU32::new(new_idx + 1))
430            }));
431        }
432        let mut new_bigram_weight_indices = vec![];
433        let mut right_id_used = HashSet::new();
434        for fids in bigram_weight_indices {
435            let mut new_fids = HashMap::new();
436            for (k, v) in fids {
437                if let Some(&v) = weight_id_map.get(&v) {
438                    new_fids.insert(k, v);
439                    right_id_used.insert(k);
440                }
441            }
442            new_bigram_weight_indices.push(new_fids);
443        }
444
445        for feature_set in &mut provider.feature_sets {
446            let mut new_unigram = vec![];
447            for &fid in feature_set.unigram() {
448                if new_unigram_weight_indices
449                    .get(usize::from_u32(fid.get() - 1))
450                    .copied()
451                    .flatten()
452                    .is_some()
453                {
454                    new_unigram.push(fid);
455                }
456            }
457            feature_set.unigram = new_unigram;
458            for fid in &mut feature_set.bigram_left {
459                *fid = fid.filter(|fid| {
460                    !new_bigram_weight_indices
461                        .get(usize::from_u32(fid.get()))
462                        .is_none_or(HashMap::is_empty)
463                });
464            }
465            for fid in &mut feature_set.bigram_right {
466                *fid = fid.filter(|fid| right_id_used.contains(&fid.get()));
467            }
468        }
469
470        RawModel::new(
471            new_weights,
472            new_unigram_weight_indices,
473            new_bigram_weight_indices,
474            provider,
475        )
476    }
477}
478
479impl Default for Trainer {
480    fn default() -> Self {
481        Self::new()
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488
489    use crate::test_utils::{self, hashmap, logsumexp};
490
491    // 0     1     2     3     4     5
492    //  /-1-\ /-2-\ /----3----\ /-4-\
493    // *     *     *     *     *     *
494    //  \----5----/ \-6-/ \-7-/
495    // weights:
496    // 0->1: 4 (0-1:1 0-2:3)
497    // 0->5: 6 (0-2:3 0-2:3)
498    // 1->2: 30 (1-4:13 2-3:17)
499    // 2->3: 48 (3-2:21 4-3:27)
500    // 2->6: 18 (3-4:13 4-1:5)
501    // 5->3: 88 (2-2:46 3-3:42)
502    // 5->6: 38 (2-4:18 3-1:20)
503    // 6->7: 45 (2-3:17 4-4:6)
504    // 3->4: 31 (1-2:11 3-1:20)
505    // 7->4: 36 (4-2:26 1-1:10)
506    // 4->0: 33 (1-0:9 4-0:24)
507    // 1: 6
508    // 2: 14
509    // 3: 8
510    // 4: 10
511    // 5: 10
512    // 6: 10
513    // 7: 10
514    //
515    // 1-2-3-4: 184 *
516    // 1-2-6-7-4: 194
517    // 5-3-4: 186
518    // 5-6-7-4: 176
519    //
520    // loss = logsumexp(184,194,186,176) - 184
521    #[test]
522    fn test_loss() {
523        let weights = vec![
524            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 46.0,
525            17.0, 18.0, 19.0, 20.0, 21.0, 42.0, 13.0, 24.0, 5.0, 26.0, 27.0, 6.0,
526        ];
527        let provider = test_utils::generate_test_feature_provider();
528        let lattices = vec![test_utils::generate_test_lattice()];
529        let unigram_weight_indices = vec![
530            NonZeroU32::new(2),
531            NonZeroU32::new(4),
532            NonZeroU32::new(6),
533            NonZeroU32::new(8),
534        ];
535        let bigram_weight_indices = vec![
536            hashmap![0 => 28, 1 => 0, 2 => 2, 3 => 4, 4 => 6],
537            hashmap![0 => 8, 1 => 9, 2 => 10, 3 => 11, 4 => 12],
538            hashmap![0 => 13, 1 => 14, 2 => 15, 3 => 16, 4 => 17],
539            hashmap![0 => 18, 1 => 19, 2 => 20, 3 => 21, 4 => 22],
540            hashmap![0 => 23, 1 => 24, 2 => 25, 3 => 26, 4 => 27],
541        ];
542        let loss_function = LatticesLoss::new(
543            &lattices,
544            &provider,
545            &unigram_weight_indices,
546            &bigram_weight_indices,
547            1,
548            None,
549        );
550
551        let expected = logsumexp!(184.0, 194.0, 186.0, 176.0) - 184.0;
552        let result = loss_function.cost(&weights);
553
554        assert!((expected - result).abs() < f64::EPSILON);
555    }
556
557    #[test]
558    fn test_gradient() {
559        let weights = vec![
560            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 46.0,
561            17.0, 18.0, 19.0, 20.0, 21.0, 42.0, 13.0, 24.0, 5.0, 26.0, 27.0, 6.0,
562        ];
563        let provider = test_utils::generate_test_feature_provider();
564        let lattices = vec![test_utils::generate_test_lattice()];
565        let unigram_weight_indices = vec![
566            NonZeroU32::new(2),
567            NonZeroU32::new(4),
568            NonZeroU32::new(6),
569            NonZeroU32::new(8),
570        ];
571        let bigram_weight_indices = vec![
572            hashmap![0 => 28, 1 => 0, 2 => 2, 3 => 4, 4 => 6],
573            hashmap![0 => 8, 1 => 9, 2 => 10, 3 => 11, 4 => 12],
574            hashmap![0 => 13, 1 => 14, 2 => 15, 3 => 16, 4 => 17],
575            hashmap![0 => 18, 1 => 19, 2 => 20, 3 => 21, 4 => 22],
576            hashmap![0 => 23, 1 => 24, 2 => 25, 3 => 26, 4 => 27],
577        ];
578        let loss_function = LatticesLoss::new(
579            &lattices,
580            &provider,
581            &unigram_weight_indices,
582            &bigram_weight_indices,
583            1,
584            None,
585        );
586
587        let z = logsumexp!(184.0, 194.0, 186.0, 176.0);
588        let prob1 = (184.0 - z).exp();
589        let prob2 = (194.0 - z).exp();
590        let prob3 = (186.0 - z).exp();
591        let prob4 = (176.0 - z).exp();
592
593        let mut expected = vec![0.0; 29];
594        // unigram gradients
595        for i in [1, 3, 5, 7, 1, 5, 7, 1] {
596            expected[i] -= 1.0;
597        }
598        for i in [1, 3, 5, 7, 1, 5, 7, 1] {
599            expected[i] += prob1;
600        }
601        for i in [1, 3, 5, 7, 1, 7, 3, 5, 7, 1] {
602            expected[i] += prob2;
603        }
604        for i in [3, 5, 1, 5, 7, 1] {
605            expected[i] += prob3;
606        }
607        for i in [3, 5, 1, 7, 3, 5, 7, 1] {
608            expected[i] += prob4;
609        }
610        // bigram gradients
611        for i in [0, 2, 12, 16, 20, 26, 10, 19, 8, 23] {
612            expected[i] -= 1.0;
613        }
614        for i in [0, 2, 12, 16, 20, 26, 10, 19, 8, 23] {
615            expected[i] += prob1;
616        }
617        for i in [0, 2, 12, 16, 22, 24, 16, 27, 25, 9, 8, 23] {
618            expected[i] += prob2;
619        }
620        for i in [2, 2, 15, 21, 10, 19, 8, 23] {
621            expected[i] += prob3;
622        }
623        for i in [2, 2, 17, 19, 16, 27, 25, 9, 8, 23] {
624            expected[i] += prob4;
625        }
626
627        let result = loss_function.gradient_partial(&weights, 0..lattices.len());
628
629        let norm = expected
630            .iter()
631            .zip(&result)
632            .fold(0.0, |acc, (a, b)| acc + (a - b).abs());
633
634        assert!(norm < 1e-12);
635    }
636}