Skip to main content

lindera_crf/
model.rs

1use core::num::NonZeroU32;
2
3use alloc::vec::Vec;
4
5use hashbrown::HashMap;
6use rkyv::rancor::Fallible;
7use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
8use rkyv::{Archive, Deserialize as RkyvDeserialize, Place, Serialize as RkyvSerialize};
9
10use crate::errors::{Result, RucrfError};
11use crate::feature::{self, FeatureProvider};
12use crate::lattice::{Edge, Lattice};
13use crate::utils::FromU32;
14
15/// Wrapper for serializing `Vec<HashMap<K, V>>` as `Vec<Vec<(K, V)>>` with rkyv.
16///
17/// `hashbrown::HashMap` does not implement `Archive` directly.
18/// This wrapper converts the HashMap to a Vec of key-value pairs for serialization,
19/// and converts back to HashMap during deserialization.
20pub(crate) struct VecHashMapAsVec;
21
22impl<K, V> ArchiveWith<Vec<HashMap<K, V>>> for VecHashMapAsVec
23where
24    K: Copy + core::hash::Hash + Eq,
25    V: Copy,
26    Vec<Vec<(K, V)>>: Archive,
27{
28    type Archived = <Vec<Vec<(K, V)>> as Archive>::Archived;
29    type Resolver = <Vec<Vec<(K, V)>> as Archive>::Resolver;
30
31    fn resolve_with(
32        field: &Vec<HashMap<K, V>>,
33        resolver: Self::Resolver,
34        out: Place<Self::Archived>,
35    ) {
36        let vec: Vec<Vec<(K, V)>> = field
37            .iter()
38            .map(|hm| hm.iter().map(|(&k, &v)| (k, v)).collect())
39            .collect();
40        Archive::resolve(&vec, resolver, out);
41    }
42}
43
44impl<K, V, S> SerializeWith<Vec<HashMap<K, V>>, S> for VecHashMapAsVec
45where
46    K: Copy + core::hash::Hash + Eq,
47    V: Copy,
48    S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized,
49    Vec<Vec<(K, V)>>: RkyvSerialize<S>,
50{
51    fn serialize_with(
52        field: &Vec<HashMap<K, V>>,
53        serializer: &mut S,
54    ) -> core::result::Result<Self::Resolver, S::Error> {
55        let vec: Vec<Vec<(K, V)>> = field
56            .iter()
57            .map(|hm| hm.iter().map(|(&k, &v)| (k, v)).collect())
58            .collect();
59        RkyvSerialize::serialize(&vec, serializer)
60    }
61}
62
63impl<K, V, D> DeserializeWith<<Vec<Vec<(K, V)>> as Archive>::Archived, Vec<HashMap<K, V>>, D>
64    for VecHashMapAsVec
65where
66    K: Copy + core::hash::Hash + Eq + Archive,
67    V: Copy + Archive,
68    D: Fallible + ?Sized,
69    <Vec<Vec<(K, V)>> as Archive>::Archived: RkyvDeserialize<Vec<Vec<(K, V)>>, D>,
70{
71    fn deserialize_with(
72        archived: &<Vec<Vec<(K, V)>> as Archive>::Archived,
73        deserializer: &mut D,
74    ) -> core::result::Result<Vec<HashMap<K, V>>, D::Error> {
75        let vec: Vec<Vec<(K, V)>> = RkyvDeserialize::deserialize(archived, deserializer)?;
76        Ok(vec.into_iter().map(|v| v.into_iter().collect()).collect())
77    }
78}
79
80/// The `Model` trait allows for searching the best path in the lattice.
81pub trait Model {
82    /// Searches the best path and returns the path and its score.
83    #[must_use]
84    fn search_best_path(&self, lattice: &Lattice) -> (Vec<Edge>, f64);
85}
86
87/// Represents a raw model.
88#[derive(Archive, RkyvSerialize, RkyvDeserialize)]
89pub struct RawModel {
90    weights: Vec<f64>,
91    unigram_weight_indices: Vec<Option<NonZeroU32>>,
92    #[rkyv(with = VecHashMapAsVec)]
93    bigram_weight_indices: Vec<HashMap<u32, u32>>,
94    provider: FeatureProvider,
95}
96
97impl RawModel {
98    #[cfg(feature = "train")]
99    pub(crate) const fn new(
100        weights: Vec<f64>,
101        unigram_weight_indices: Vec<Option<NonZeroU32>>,
102        bigram_weight_indices: Vec<HashMap<u32, u32>>,
103        provider: FeatureProvider,
104    ) -> Self {
105        Self {
106            weights,
107            unigram_weight_indices,
108            bigram_weight_indices,
109            provider,
110        }
111    }
112
113    /// Returns a mutable reference of the feature provider.
114    pub fn feature_provider(&mut self) -> &mut FeatureProvider {
115        &mut self.provider
116    }
117
118    /// Merges this model and returns [`MergedModel`].
119    ///
120    /// This process integrates the features, so that each edge has three items: a uni-gram cost,
121    /// a left-connection ID, and a right-connection ID.
122    ///
123    /// # Errors
124    ///
125    /// Generated left/right connection ID must be smaller than 2^32.
126    #[allow(clippy::missing_panics_doc)]
127    pub fn merge(&self) -> Result<MergedModel> {
128        let mut left_conn_ids = HashMap::new();
129        let mut right_conn_ids = HashMap::new();
130        let mut left_conn_to_right_feats = vec![];
131        let mut right_conn_to_left_feats = vec![];
132        let mut new_feature_sets = vec![];
133        for feature_set in &self.provider.feature_sets {
134            let mut weight = 0.0;
135            for fid in feature_set.unigram() {
136                let fid = usize::from_u32(fid.get() - 1);
137                if let Some(widx) = self.unigram_weight_indices.get(fid).copied().flatten() {
138                    weight += self.weights[usize::from_u32(widx.get() - 1)];
139                }
140            }
141            let left_id = {
142                let new_id = u32::try_from(left_conn_to_right_feats.len() + 1)
143                    .map_err(|_| RucrfError::model_scale("connection ID too large"))?;
144                *left_conn_ids
145                    .raw_entry_mut()
146                    .from_key(feature_set.bigram_right())
147                    .or_insert_with(|| {
148                        let features = feature_set.bigram_right().to_vec();
149                        left_conn_to_right_feats.push(features.clone());
150                        (features, NonZeroU32::new(new_id).unwrap())
151                    })
152                    .1
153            };
154            let right_id = {
155                let new_id = u32::try_from(right_conn_to_left_feats.len() + 1)
156                    .map_err(|_| RucrfError::model_scale("connection ID too large"))?;
157                *right_conn_ids
158                    .raw_entry_mut()
159                    .from_key(feature_set.bigram_left())
160                    .or_insert_with(|| {
161                        let features = feature_set.bigram_left().to_vec();
162                        right_conn_to_left_feats.push(features.clone());
163                        (features, NonZeroU32::new(new_id).unwrap())
164                    })
165                    .1
166            };
167            new_feature_sets.push(MergedFeatureSet {
168                weight,
169                left_id,
170                right_id,
171            });
172        }
173        let mut matrix = vec![];
174
175        // BOS
176        let mut m = HashMap::new();
177        for (i, left_ids) in left_conn_to_right_feats.iter().enumerate() {
178            let mut weight = 0.0;
179            for fid in left_ids.iter().flatten() {
180                if let Some(&widx) = self.bigram_weight_indices[0].get(&fid.get()) {
181                    weight += self.weights[usize::from_u32(widx)];
182                }
183            }
184            if weight.abs() >= f64::EPSILON {
185                m.insert(
186                    u32::try_from(i + 1)
187                        .map_err(|_| RucrfError::model_scale("connection ID too large"))?,
188                    weight,
189                );
190            }
191        }
192        matrix.push(m);
193
194        for right_ids in &right_conn_to_left_feats {
195            let mut m = HashMap::new();
196
197            // EOS
198            let mut weight = 0.0;
199            for fid in right_ids.iter().flatten() {
200                let right_id = usize::from_u32(fid.get());
201                if let Some(&widx) = self
202                    .bigram_weight_indices
203                    .get(right_id)
204                    .and_then(|hm| hm.get(&0))
205                {
206                    weight += self.weights[usize::from_u32(widx)];
207                }
208            }
209            if weight.abs() >= f64::EPSILON {
210                m.insert(0, weight);
211            }
212
213            for (i, left_ids) in left_conn_to_right_feats.iter().enumerate() {
214                let mut weight = 0.0;
215                for (right_id, left_id) in right_ids.iter().zip(left_ids) {
216                    if let (Some(right_id), Some(left_id)) = (right_id, left_id) {
217                        let right_id = usize::from_u32(right_id.get());
218                        let left_id = left_id.get();
219                        if let Some(&widx) = self
220                            .bigram_weight_indices
221                            .get(right_id)
222                            .and_then(|hm| hm.get(&left_id))
223                        {
224                            weight += self.weights[usize::from_u32(widx)];
225                        }
226                    }
227                }
228                if weight.abs() >= f64::EPSILON {
229                    m.insert(
230                        u32::try_from(i + 1)
231                            .map_err(|_| RucrfError::model_scale("connection ID too large"))?,
232                        weight,
233                    );
234                }
235            }
236
237            matrix.push(m);
238        }
239
240        Ok(MergedModel {
241            feature_sets: new_feature_sets,
242            matrix,
243            left_conn_to_right_feats,
244            right_conn_to_left_feats,
245        })
246    }
247
248    /// Returns the relation between uni-gram feature IDs and weight indices.
249    #[must_use]
250    pub fn unigram_weight_indices(&self) -> &[Option<NonZeroU32>] {
251        &self.unigram_weight_indices
252    }
253
254    /// Returns the relation between bi-gram feature IDs and weight indices.
255    #[must_use]
256    pub fn bigram_weight_indices(&self) -> &[HashMap<u32, u32>] {
257        &self.bigram_weight_indices
258    }
259
260    /// Returns weights.
261    #[must_use]
262    pub fn weights(&self) -> &[f64] {
263        &self.weights
264    }
265}
266
267impl Model for RawModel {
268    fn search_best_path(&self, lattice: &Lattice) -> (Vec<Edge>, f64) {
269        let mut best_scores = vec![vec![]; lattice.nodes().len()];
270        best_scores[lattice.nodes().len() - 1].push((0, 0, None, 0.0));
271        for (i, node) in lattice.nodes().iter().enumerate() {
272            for edge in node.edges() {
273                let mut score = 0.0;
274                if let Some(feature_set) = self.provider.get_feature_set(edge.label) {
275                    for &fid in feature_set.unigram() {
276                        let fid = usize::from_u32(fid.get() - 1);
277                        if let Some(widx) = self.unigram_weight_indices[fid] {
278                            score += self.weights[usize::from_u32(widx.get() - 1)];
279                        }
280                    }
281                }
282                best_scores[i].push((edge.target(), 0, Some(edge.label), score));
283            }
284        }
285        for i in (0..lattice.nodes().len() - 1).rev() {
286            for j in 0..best_scores[i].len() {
287                let (k, _, curr_label, _) = best_scores[i][j];
288                let mut best_score = f64::NEG_INFINITY;
289                let mut best_idx = 0;
290                for (p, &(_, _, next_label, mut score)) in best_scores[k].iter().enumerate() {
291                    feature::apply_bigram(
292                        curr_label,
293                        next_label,
294                        &self.provider,
295                        &self.bigram_weight_indices,
296                        |widx| {
297                            score += self.weights[usize::from_u32(widx)];
298                        },
299                    );
300                    if score > best_score {
301                        best_score = score;
302                        best_idx = p;
303                    }
304                }
305                best_scores[i][j].1 = best_idx;
306                best_scores[i][j].3 += best_score;
307            }
308        }
309        let mut best_score = f64::NEG_INFINITY;
310        let mut idx = 0;
311        for (p, &(_, _, next_label, mut score)) in best_scores[0].iter().enumerate() {
312            feature::apply_bigram(
313                None,
314                next_label,
315                &self.provider,
316                &self.bigram_weight_indices,
317                |widx| {
318                    score += self.weights[usize::from_u32(widx)];
319                },
320            );
321            if score > best_score {
322                best_score = score;
323                idx = p;
324            }
325        }
326        let mut pos = 0;
327        let mut best_path = vec![];
328        while pos < lattice.nodes().len() - 1 {
329            let edge = &lattice.nodes()[pos].edges()[idx];
330            idx = best_scores[pos][idx].1;
331            pos = edge.target();
332            best_path.push(Edge::new(pos, edge.label()));
333        }
334        (best_path, best_score)
335    }
336}
337
338/// Represents a merged feature set.
339#[derive(Clone, Copy, Debug, Archive, RkyvSerialize, RkyvDeserialize)]
340pub struct MergedFeatureSet {
341    /// Weight.
342    pub weight: f64,
343    /// Left bi-gram connection ID.
344    pub left_id: NonZeroU32,
345    /// Right bi-gram connection ID.
346    pub right_id: NonZeroU32,
347}
348
349/// Represents a merged model.
350#[derive(Archive, RkyvSerialize, RkyvDeserialize)]
351pub struct MergedModel {
352    /// Feature sets corresponding to label IDs.
353    pub feature_sets: Vec<MergedFeatureSet>,
354    /// Bi-gram weight matrix.
355    #[rkyv(with = VecHashMapAsVec)]
356    pub matrix: Vec<HashMap<u32, f64>>,
357    /// Relation between the left connection IDs and the right bi-gram feature IDs.
358    pub left_conn_to_right_feats: Vec<Vec<Option<NonZeroU32>>>,
359    /// Relation between the right connection IDs and the left bi-gram feature IDs.
360    pub right_conn_to_left_feats: Vec<Vec<Option<NonZeroU32>>>,
361}
362
363impl Model for MergedModel {
364    fn search_best_path(&self, lattice: &Lattice) -> (Vec<Edge>, f64) {
365        let mut best_scores = vec![vec![]; lattice.nodes().len()];
366        best_scores[lattice.nodes().len() - 1].push((0, 0, None, 0.0));
367        for (i, node) in lattice.nodes().iter().enumerate() {
368            for edge in node.edges() {
369                let label = usize::from_u32(edge.label.get() - 1);
370                let score = self.feature_sets.get(label).map_or(0.0, |s| s.weight);
371                best_scores[i].push((edge.target(), 0, Some(edge.label), score));
372            }
373        }
374        for i in (0..lattice.nodes().len() - 1).rev() {
375            for j in 0..best_scores[i].len() {
376                let (k, _, curr_label, _) = best_scores[i][j];
377                let mut best_score = f64::NEG_INFINITY;
378                let mut best_idx = 0;
379                let curr_id = curr_label.map_or(Some(0), |label| {
380                    self.feature_sets
381                        .get(usize::from_u32(label.get() - 1))
382                        .map(|s| s.right_id.get())
383                });
384                for (p, &(_, _, next_label, mut score)) in best_scores[k].iter().enumerate() {
385                    let next_id = next_label.map_or(Some(0), |label| {
386                        self.feature_sets
387                            .get(usize::from_u32(label.get() - 1))
388                            .map(|s| s.left_id.get())
389                    });
390                    if let (Some(curr_id), Some(next_id)) = (curr_id, next_id) {
391                        score += self
392                            .matrix
393                            .get(usize::from_u32(curr_id))
394                            .and_then(|hm| hm.get(&next_id))
395                            .unwrap_or(&0.0);
396                    }
397                    if score > best_score {
398                        best_score = score;
399                        best_idx = p;
400                    }
401                }
402                best_scores[i][j].1 = best_idx;
403                best_scores[i][j].3 += best_score;
404            }
405        }
406        let mut best_score = f64::NEG_INFINITY;
407        let mut idx = 0;
408        for (p, &(_, _, next_label, mut score)) in best_scores[0].iter().enumerate() {
409            let next_id = next_label.map_or(Some(0), |label| {
410                self.feature_sets
411                    .get(usize::from_u32(label.get() - 1))
412                    .map(|s| s.right_id.get())
413            });
414            if let Some(next_id) = next_id {
415                score += self
416                    .matrix
417                    .first()
418                    .and_then(|hm| hm.get(&next_id))
419                    .unwrap_or(&0.0);
420            }
421            if score > best_score {
422                best_score = score;
423                idx = p;
424            }
425        }
426        let mut pos = 0;
427        let mut best_path = vec![];
428        while pos < lattice.nodes().len() - 1 {
429            let edge = &lattice.nodes()[pos].edges()[idx];
430            idx = best_scores[pos][idx].1;
431            pos = edge.target();
432            best_path.push(Edge::new(pos, edge.label()));
433        }
434        (best_path, best_score)
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    use core::num::NonZeroU32;
443
444    use crate::lattice::Edge;
445    use crate::test_utils::{self, hashmap};
446
447    #[test]
448    fn test_search_best_path() {
449        // 0     1     2     3     4     5
450        //  /-1-\ /-2-\ /----3----\ /-4-\
451        // *     *     *     *     *     *
452        //  \----5----/ \-6-/ \-7-/
453        // weights:
454        // 0->1: 4 (0-1:1 0-2:3)
455        // 0->5: 6 (0-2:3 0-2:3)
456        // 1->2: 30 (1-4:13 2-3:17)
457        // 2->3: 48 (3-2:21 4-3:27)
458        // 2->6: 18 (3-4:13 4-1:5)
459        // 5->3: 38 (2-2:16 3-3:22)
460        // 5->6: 38 (2-4:18 3-1:20)
461        // 6->7: 45 (2-3:17 4-4:6)
462        // 3->4: 31 (1-2:11 3-1:20)
463        // 7->4: 36 (4-2:26 1-1:10)
464        // 4->0: 33 (1-0:9 4-0:24)
465        // 1: 6
466        // 2: 14
467        // 3: 8
468        // 4: 10
469        // 5: 10
470        // 6: 10
471        // 7: 10
472        let model = RawModel {
473            weights: vec![
474                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
475                16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 13.0, 24.0, 5.0, 26.0, 27.0, 6.0,
476            ],
477            unigram_weight_indices: vec![
478                NonZeroU32::new(2),
479                NonZeroU32::new(4),
480                NonZeroU32::new(6),
481                NonZeroU32::new(8),
482            ],
483            bigram_weight_indices: vec![
484                hashmap![0 => 28, 1 => 0, 2 => 2, 3 => 4, 4 => 6],
485                hashmap![0 => 8, 1 => 9, 2 => 10, 3 => 11, 4 => 12],
486                hashmap![0 => 13, 1 => 14, 2 => 15, 3 => 16, 4 => 17],
487                hashmap![0 => 18, 1 => 19, 2 => 20, 3 => 21, 4 => 22],
488                hashmap![0 => 23, 1 => 24, 2 => 25, 3 => 26, 4 => 27],
489            ],
490            provider: test_utils::generate_test_feature_provider(),
491        };
492        let lattice = test_utils::generate_test_lattice();
493
494        let (path, score) = model.search_best_path(&lattice);
495
496        assert_eq!(
497            vec![
498                Edge::new(1, NonZeroU32::new(1).unwrap()),
499                Edge::new(2, NonZeroU32::new(2).unwrap()),
500                Edge::new(3, NonZeroU32::new(6).unwrap()),
501                Edge::new(4, NonZeroU32::new(7).unwrap()),
502                Edge::new(5, NonZeroU32::new(4).unwrap()),
503            ],
504            path,
505        );
506        assert!((194.0 - score).abs() < f64::EPSILON);
507    }
508
509    #[test]
510    fn test_hashed_search_best_path() {
511        let model = RawModel {
512            weights: vec![
513                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
514                16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 13.0, 24.0, 5.0, 26.0, 27.0, 6.0,
515            ],
516            unigram_weight_indices: vec![
517                NonZeroU32::new(2),
518                NonZeroU32::new(4),
519                NonZeroU32::new(6),
520                NonZeroU32::new(8),
521            ],
522            bigram_weight_indices: vec![
523                hashmap![0 => 28, 1 => 0, 2 => 2, 3 => 4, 4 => 6],
524                hashmap![0 => 8, 1 => 9, 2 => 10, 3 => 11, 4 => 12],
525                hashmap![0 => 13, 1 => 14, 2 => 15, 3 => 16, 4 => 17],
526                hashmap![0 => 18, 1 => 19, 2 => 20, 3 => 21, 4 => 22],
527                hashmap![0 => 23, 1 => 24, 2 => 25, 3 => 26, 4 => 27],
528            ],
529            provider: test_utils::generate_test_feature_provider(),
530        };
531        let merged_model = model.merge().unwrap();
532
533        let lattice = test_utils::generate_test_lattice();
534
535        let (path, score) = merged_model.search_best_path(&lattice);
536
537        assert_eq!(
538            vec![
539                Edge::new(1, NonZeroU32::new(1).unwrap()),
540                Edge::new(2, NonZeroU32::new(2).unwrap()),
541                Edge::new(3, NonZeroU32::new(6).unwrap()),
542                Edge::new(4, NonZeroU32::new(7).unwrap()),
543                Edge::new(5, NonZeroU32::new(4).unwrap()),
544            ],
545            path,
546        );
547        assert!((194.0 - score).abs() < f64::EPSILON);
548    }
549}