Skip to main content

lindera_crf/
feature.rs

1use core::num::NonZeroU32;
2
3use alloc::vec::Vec;
4
5use hashbrown::HashMap;
6use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
7
8use crate::errors::{Result, RucrfError};
9use crate::utils::FromU32;
10
11#[inline(always)]
12pub fn apply_bigram<F>(
13    left_label: Option<NonZeroU32>,
14    right_label: Option<NonZeroU32>,
15    provider: &FeatureProvider,
16    bigram_weight_indices: &[HashMap<u32, u32>],
17    mut f: F,
18) where
19    F: FnMut(u32),
20{
21    match (left_label, right_label) {
22        (Some(left_label), Some(right_label)) => {
23            if let (Some(left_feature_set), Some(right_feature_set)) = (
24                provider.get_feature_set(left_label),
25                provider.get_feature_set(right_label),
26            ) {
27                let left_features = left_feature_set.bigram_left();
28                let right_features = right_feature_set.bigram_right();
29                for (&left_fid, &right_fid) in left_features.iter().zip(right_features) {
30                    if let (Some(left_fid), Some(right_fid)) = (left_fid, right_fid) {
31                        let left_fid = usize::from_u32(left_fid.get());
32                        let right_fid = right_fid.get();
33                        if let Some(&widx) = bigram_weight_indices
34                            .get(left_fid)
35                            .and_then(|hm| hm.get(&right_fid))
36                        {
37                            f(widx);
38                        }
39                    }
40                }
41            }
42        }
43        (Some(left_label), None) => {
44            if let Some(feature_set) = provider.get_feature_set(left_label) {
45                for &left_fid in feature_set.bigram_left() {
46                    if let Some(left_fid) = left_fid {
47                        let left_fid = usize::from_u32(left_fid.get());
48                        if let Some(&widx) = bigram_weight_indices[left_fid].get(&0) {
49                            f(widx);
50                        }
51                    }
52                }
53            }
54        }
55        (None, Some(right_label)) => {
56            if let Some(feature_set) = provider.get_feature_set(right_label) {
57                for &right_fid in feature_set.bigram_right() {
58                    if let Some(right_fid) = right_fid {
59                        let right_fid = right_fid.get();
60                        if let Some(&widx) = bigram_weight_indices[0].get(&right_fid) {
61                            f(widx);
62                        }
63                    }
64                }
65            }
66        }
67        _ => unreachable!(),
68    }
69}
70
71/// Manages a set of features for each label.
72#[derive(Debug, Default, Archive, RkyvSerialize, RkyvDeserialize)]
73pub struct FeatureSet {
74    pub(crate) unigram: Vec<NonZeroU32>,
75    pub(crate) bigram_right: Vec<Option<NonZeroU32>>,
76    pub(crate) bigram_left: Vec<Option<NonZeroU32>>,
77}
78
79impl FeatureSet {
80    /// Creates a new [`FeatureSet`].
81    #[inline(always)]
82    #[must_use]
83    pub fn new(
84        unigram: &[NonZeroU32],
85        bigram_right: &[Option<NonZeroU32>],
86        bigram_left: &[Option<NonZeroU32>],
87    ) -> Self {
88        Self {
89            unigram: unigram.to_vec(),
90            bigram_right: bigram_right.to_vec(),
91            bigram_left: bigram_left.to_vec(),
92        }
93    }
94
95    /// Gets uni-gram feature IDs.
96    #[inline(always)]
97    #[must_use]
98    pub fn unigram(&self) -> &[NonZeroU32] {
99        &self.unigram
100    }
101
102    /// Gets right bi-gram feature IDs.
103    #[inline(always)]
104    #[must_use]
105    pub fn bigram_right(&self) -> &[Option<NonZeroU32>] {
106        &self.bigram_right
107    }
108
109    /// Gets left bi-gram feature IDs
110    #[inline(always)]
111    #[must_use]
112    pub fn bigram_left(&self) -> &[Option<NonZeroU32>] {
113        &self.bigram_left
114    }
115}
116
117/// Manages the correspondence between edge labels and feature IDs.
118#[derive(Debug, Default, Archive, RkyvSerialize, RkyvDeserialize)]
119pub struct FeatureProvider {
120    pub(crate) feature_sets: Vec<FeatureSet>,
121}
122
123impl FeatureProvider {
124    /// Creates a new [`FeatureProvider`].
125    #[inline(always)]
126    #[must_use]
127    pub fn new() -> Self {
128        Self::default()
129    }
130
131    /// Returns `true` if the manager has no item.
132    #[inline(always)]
133    #[must_use]
134    pub fn is_empty(&self) -> bool {
135        self.feature_sets.is_empty()
136    }
137
138    /// Returns the number of items.
139    #[inline(always)]
140    #[must_use]
141    pub fn len(&self) -> usize {
142        self.feature_sets.len()
143    }
144
145    /// Adds a feature set and returns its ID.
146    ///
147    /// # Errors
148    ///
149    /// The number of features must be less than 2^32 - 1.
150    #[allow(clippy::missing_panics_doc)]
151    #[inline(always)]
152    pub fn add_feature_set(&mut self, feature_set: FeatureSet) -> Result<NonZeroU32> {
153        let new_id = u32::try_from(self.feature_sets.len() + 1)
154            .map_err(|_| RucrfError::model_scale("feature set too large"))?;
155        self.feature_sets.push(feature_set);
156        Ok(NonZeroU32::new(new_id).unwrap())
157    }
158
159    /// Returns the reference to the feature set corresponding to the given ID.
160    #[inline(always)]
161    pub(crate) fn get_feature_set(&self, label: NonZeroU32) -> Option<&FeatureSet> {
162        self.feature_sets
163            .get(usize::try_from(label.get() - 1).unwrap())
164    }
165}