1use core::num::NonZeroU32;
2
3use alloc::vec::Vec;
4
5use hashbrown::HashMap;
6use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
7
8use crate::errors::{Result, RucrfError};
9use crate::utils::FromU32;
10
11#[inline(always)]
12pub fn apply_bigram<F>(
13 left_label: Option<NonZeroU32>,
14 right_label: Option<NonZeroU32>,
15 provider: &FeatureProvider,
16 bigram_weight_indices: &[HashMap<u32, u32>],
17 mut f: F,
18) where
19 F: FnMut(u32),
20{
21 match (left_label, right_label) {
22 (Some(left_label), Some(right_label)) => {
23 if let (Some(left_feature_set), Some(right_feature_set)) = (
24 provider.get_feature_set(left_label),
25 provider.get_feature_set(right_label),
26 ) {
27 let left_features = left_feature_set.bigram_left();
28 let right_features = right_feature_set.bigram_right();
29 for (&left_fid, &right_fid) in left_features.iter().zip(right_features) {
30 if let (Some(left_fid), Some(right_fid)) = (left_fid, right_fid) {
31 let left_fid = usize::from_u32(left_fid.get());
32 let right_fid = right_fid.get();
33 if let Some(&widx) = bigram_weight_indices
34 .get(left_fid)
35 .and_then(|hm| hm.get(&right_fid))
36 {
37 f(widx);
38 }
39 }
40 }
41 }
42 }
43 (Some(left_label), None) => {
44 if let Some(feature_set) = provider.get_feature_set(left_label) {
45 for &left_fid in feature_set.bigram_left() {
46 if let Some(left_fid) = left_fid {
47 let left_fid = usize::from_u32(left_fid.get());
48 if let Some(&widx) = bigram_weight_indices[left_fid].get(&0) {
49 f(widx);
50 }
51 }
52 }
53 }
54 }
55 (None, Some(right_label)) => {
56 if let Some(feature_set) = provider.get_feature_set(right_label) {
57 for &right_fid in feature_set.bigram_right() {
58 if let Some(right_fid) = right_fid {
59 let right_fid = right_fid.get();
60 if let Some(&widx) = bigram_weight_indices[0].get(&right_fid) {
61 f(widx);
62 }
63 }
64 }
65 }
66 }
67 _ => unreachable!(),
68 }
69}
70
71#[derive(Debug, Default, Archive, RkyvSerialize, RkyvDeserialize)]
73pub struct FeatureSet {
74 pub(crate) unigram: Vec<NonZeroU32>,
75 pub(crate) bigram_right: Vec<Option<NonZeroU32>>,
76 pub(crate) bigram_left: Vec<Option<NonZeroU32>>,
77}
78
79impl FeatureSet {
80 #[inline(always)]
82 #[must_use]
83 pub fn new(
84 unigram: &[NonZeroU32],
85 bigram_right: &[Option<NonZeroU32>],
86 bigram_left: &[Option<NonZeroU32>],
87 ) -> Self {
88 Self {
89 unigram: unigram.to_vec(),
90 bigram_right: bigram_right.to_vec(),
91 bigram_left: bigram_left.to_vec(),
92 }
93 }
94
95 #[inline(always)]
97 #[must_use]
98 pub fn unigram(&self) -> &[NonZeroU32] {
99 &self.unigram
100 }
101
102 #[inline(always)]
104 #[must_use]
105 pub fn bigram_right(&self) -> &[Option<NonZeroU32>] {
106 &self.bigram_right
107 }
108
109 #[inline(always)]
111 #[must_use]
112 pub fn bigram_left(&self) -> &[Option<NonZeroU32>] {
113 &self.bigram_left
114 }
115}
116
117#[derive(Debug, Default, Archive, RkyvSerialize, RkyvDeserialize)]
119pub struct FeatureProvider {
120 pub(crate) feature_sets: Vec<FeatureSet>,
121}
122
123impl FeatureProvider {
124 #[inline(always)]
126 #[must_use]
127 pub fn new() -> Self {
128 Self::default()
129 }
130
131 #[inline(always)]
133 #[must_use]
134 pub fn is_empty(&self) -> bool {
135 self.feature_sets.is_empty()
136 }
137
138 #[inline(always)]
140 #[must_use]
141 pub fn len(&self) -> usize {
142 self.feature_sets.len()
143 }
144
145 #[allow(clippy::missing_panics_doc)]
151 #[inline(always)]
152 pub fn add_feature_set(&mut self, feature_set: FeatureSet) -> Result<NonZeroU32> {
153 let new_id = u32::try_from(self.feature_sets.len() + 1)
154 .map_err(|_| RucrfError::model_scale("feature set too large"))?;
155 self.feature_sets.push(feature_set);
156 Ok(NonZeroU32::new(new_id).unwrap())
157 }
158
159 #[inline(always)]
161 pub(crate) fn get_feature_set(&self, label: NonZeroU32) -> Option<&FeatureSet> {
162 self.feature_sets
163 .get(usize::try_from(label.get() - 1).unwrap())
164 }
165}