1use core::num::NonZeroU32;
2
3use alloc::vec::Vec;
4
5use hashbrown::HashMap;
6use rkyv::rancor::Fallible;
7use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
8use rkyv::{Archive, Deserialize as RkyvDeserialize, Place, Serialize as RkyvSerialize};
9
10use crate::errors::{Result, RucrfError};
11use crate::feature::{self, FeatureProvider};
12use crate::lattice::{Edge, Lattice};
13use crate::utils::FromU32;
14
15pub(crate) struct VecHashMapAsVec;
21
22impl<K, V> ArchiveWith<Vec<HashMap<K, V>>> for VecHashMapAsVec
23where
24 K: Copy + core::hash::Hash + Eq,
25 V: Copy,
26 Vec<Vec<(K, V)>>: Archive,
27{
28 type Archived = <Vec<Vec<(K, V)>> as Archive>::Archived;
29 type Resolver = <Vec<Vec<(K, V)>> as Archive>::Resolver;
30
31 fn resolve_with(
32 field: &Vec<HashMap<K, V>>,
33 resolver: Self::Resolver,
34 out: Place<Self::Archived>,
35 ) {
36 let vec: Vec<Vec<(K, V)>> = field
37 .iter()
38 .map(|hm| hm.iter().map(|(&k, &v)| (k, v)).collect())
39 .collect();
40 Archive::resolve(&vec, resolver, out);
41 }
42}
43
44impl<K, V, S> SerializeWith<Vec<HashMap<K, V>>, S> for VecHashMapAsVec
45where
46 K: Copy + core::hash::Hash + Eq,
47 V: Copy,
48 S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized,
49 Vec<Vec<(K, V)>>: RkyvSerialize<S>,
50{
51 fn serialize_with(
52 field: &Vec<HashMap<K, V>>,
53 serializer: &mut S,
54 ) -> core::result::Result<Self::Resolver, S::Error> {
55 let vec: Vec<Vec<(K, V)>> = field
56 .iter()
57 .map(|hm| hm.iter().map(|(&k, &v)| (k, v)).collect())
58 .collect();
59 RkyvSerialize::serialize(&vec, serializer)
60 }
61}
62
63impl<K, V, D> DeserializeWith<<Vec<Vec<(K, V)>> as Archive>::Archived, Vec<HashMap<K, V>>, D>
64 for VecHashMapAsVec
65where
66 K: Copy + core::hash::Hash + Eq + Archive,
67 V: Copy + Archive,
68 D: Fallible + ?Sized,
69 <Vec<Vec<(K, V)>> as Archive>::Archived: RkyvDeserialize<Vec<Vec<(K, V)>>, D>,
70{
71 fn deserialize_with(
72 archived: &<Vec<Vec<(K, V)>> as Archive>::Archived,
73 deserializer: &mut D,
74 ) -> core::result::Result<Vec<HashMap<K, V>>, D::Error> {
75 let vec: Vec<Vec<(K, V)>> = RkyvDeserialize::deserialize(archived, deserializer)?;
76 Ok(vec.into_iter().map(|v| v.into_iter().collect()).collect())
77 }
78}
79
80pub trait Model {
82 #[must_use]
84 fn search_best_path(&self, lattice: &Lattice) -> (Vec<Edge>, f64);
85}
86
87#[derive(Archive, RkyvSerialize, RkyvDeserialize)]
89pub struct RawModel {
90 weights: Vec<f64>,
91 unigram_weight_indices: Vec<Option<NonZeroU32>>,
92 #[rkyv(with = VecHashMapAsVec)]
93 bigram_weight_indices: Vec<HashMap<u32, u32>>,
94 provider: FeatureProvider,
95}
96
97impl RawModel {
98 #[cfg(feature = "train")]
99 pub(crate) const fn new(
100 weights: Vec<f64>,
101 unigram_weight_indices: Vec<Option<NonZeroU32>>,
102 bigram_weight_indices: Vec<HashMap<u32, u32>>,
103 provider: FeatureProvider,
104 ) -> Self {
105 Self {
106 weights,
107 unigram_weight_indices,
108 bigram_weight_indices,
109 provider,
110 }
111 }
112
113 pub fn feature_provider(&mut self) -> &mut FeatureProvider {
115 &mut self.provider
116 }
117
118 #[allow(clippy::missing_panics_doc)]
127 pub fn merge(&self) -> Result<MergedModel> {
128 let mut left_conn_ids = HashMap::new();
129 let mut right_conn_ids = HashMap::new();
130 let mut left_conn_to_right_feats = vec![];
131 let mut right_conn_to_left_feats = vec![];
132 let mut new_feature_sets = vec![];
133 for feature_set in &self.provider.feature_sets {
134 let mut weight = 0.0;
135 for fid in feature_set.unigram() {
136 let fid = usize::from_u32(fid.get() - 1);
137 if let Some(widx) = self.unigram_weight_indices.get(fid).copied().flatten() {
138 weight += self.weights[usize::from_u32(widx.get() - 1)];
139 }
140 }
141 let left_id = {
142 let new_id = u32::try_from(left_conn_to_right_feats.len() + 1)
143 .map_err(|_| RucrfError::model_scale("connection ID too large"))?;
144 *left_conn_ids
145 .raw_entry_mut()
146 .from_key(feature_set.bigram_right())
147 .or_insert_with(|| {
148 let features = feature_set.bigram_right().to_vec();
149 left_conn_to_right_feats.push(features.clone());
150 (features, NonZeroU32::new(new_id).unwrap())
151 })
152 .1
153 };
154 let right_id = {
155 let new_id = u32::try_from(right_conn_to_left_feats.len() + 1)
156 .map_err(|_| RucrfError::model_scale("connection ID too large"))?;
157 *right_conn_ids
158 .raw_entry_mut()
159 .from_key(feature_set.bigram_left())
160 .or_insert_with(|| {
161 let features = feature_set.bigram_left().to_vec();
162 right_conn_to_left_feats.push(features.clone());
163 (features, NonZeroU32::new(new_id).unwrap())
164 })
165 .1
166 };
167 new_feature_sets.push(MergedFeatureSet {
168 weight,
169 left_id,
170 right_id,
171 });
172 }
173 let mut matrix = vec![];
174
175 let mut m = HashMap::new();
177 for (i, left_ids) in left_conn_to_right_feats.iter().enumerate() {
178 let mut weight = 0.0;
179 for fid in left_ids.iter().flatten() {
180 if let Some(&widx) = self.bigram_weight_indices[0].get(&fid.get()) {
181 weight += self.weights[usize::from_u32(widx)];
182 }
183 }
184 if weight.abs() >= f64::EPSILON {
185 m.insert(
186 u32::try_from(i + 1)
187 .map_err(|_| RucrfError::model_scale("connection ID too large"))?,
188 weight,
189 );
190 }
191 }
192 matrix.push(m);
193
194 for right_ids in &right_conn_to_left_feats {
195 let mut m = HashMap::new();
196
197 let mut weight = 0.0;
199 for fid in right_ids.iter().flatten() {
200 let right_id = usize::from_u32(fid.get());
201 if let Some(&widx) = self
202 .bigram_weight_indices
203 .get(right_id)
204 .and_then(|hm| hm.get(&0))
205 {
206 weight += self.weights[usize::from_u32(widx)];
207 }
208 }
209 if weight.abs() >= f64::EPSILON {
210 m.insert(0, weight);
211 }
212
213 for (i, left_ids) in left_conn_to_right_feats.iter().enumerate() {
214 let mut weight = 0.0;
215 for (right_id, left_id) in right_ids.iter().zip(left_ids) {
216 if let (Some(right_id), Some(left_id)) = (right_id, left_id) {
217 let right_id = usize::from_u32(right_id.get());
218 let left_id = left_id.get();
219 if let Some(&widx) = self
220 .bigram_weight_indices
221 .get(right_id)
222 .and_then(|hm| hm.get(&left_id))
223 {
224 weight += self.weights[usize::from_u32(widx)];
225 }
226 }
227 }
228 if weight.abs() >= f64::EPSILON {
229 m.insert(
230 u32::try_from(i + 1)
231 .map_err(|_| RucrfError::model_scale("connection ID too large"))?,
232 weight,
233 );
234 }
235 }
236
237 matrix.push(m);
238 }
239
240 Ok(MergedModel {
241 feature_sets: new_feature_sets,
242 matrix,
243 left_conn_to_right_feats,
244 right_conn_to_left_feats,
245 })
246 }
247
248 #[must_use]
250 pub fn unigram_weight_indices(&self) -> &[Option<NonZeroU32>] {
251 &self.unigram_weight_indices
252 }
253
254 #[must_use]
256 pub fn bigram_weight_indices(&self) -> &[HashMap<u32, u32>] {
257 &self.bigram_weight_indices
258 }
259
260 #[must_use]
262 pub fn weights(&self) -> &[f64] {
263 &self.weights
264 }
265}
266
267impl Model for RawModel {
268 fn search_best_path(&self, lattice: &Lattice) -> (Vec<Edge>, f64) {
269 let mut best_scores = vec![vec![]; lattice.nodes().len()];
270 best_scores[lattice.nodes().len() - 1].push((0, 0, None, 0.0));
271 for (i, node) in lattice.nodes().iter().enumerate() {
272 for edge in node.edges() {
273 let mut score = 0.0;
274 if let Some(feature_set) = self.provider.get_feature_set(edge.label) {
275 for &fid in feature_set.unigram() {
276 let fid = usize::from_u32(fid.get() - 1);
277 if let Some(widx) = self.unigram_weight_indices[fid] {
278 score += self.weights[usize::from_u32(widx.get() - 1)];
279 }
280 }
281 }
282 best_scores[i].push((edge.target(), 0, Some(edge.label), score));
283 }
284 }
285 for i in (0..lattice.nodes().len() - 1).rev() {
286 for j in 0..best_scores[i].len() {
287 let (k, _, curr_label, _) = best_scores[i][j];
288 let mut best_score = f64::NEG_INFINITY;
289 let mut best_idx = 0;
290 for (p, &(_, _, next_label, mut score)) in best_scores[k].iter().enumerate() {
291 feature::apply_bigram(
292 curr_label,
293 next_label,
294 &self.provider,
295 &self.bigram_weight_indices,
296 |widx| {
297 score += self.weights[usize::from_u32(widx)];
298 },
299 );
300 if score > best_score {
301 best_score = score;
302 best_idx = p;
303 }
304 }
305 best_scores[i][j].1 = best_idx;
306 best_scores[i][j].3 += best_score;
307 }
308 }
309 let mut best_score = f64::NEG_INFINITY;
310 let mut idx = 0;
311 for (p, &(_, _, next_label, mut score)) in best_scores[0].iter().enumerate() {
312 feature::apply_bigram(
313 None,
314 next_label,
315 &self.provider,
316 &self.bigram_weight_indices,
317 |widx| {
318 score += self.weights[usize::from_u32(widx)];
319 },
320 );
321 if score > best_score {
322 best_score = score;
323 idx = p;
324 }
325 }
326 let mut pos = 0;
327 let mut best_path = vec![];
328 while pos < lattice.nodes().len() - 1 {
329 let edge = &lattice.nodes()[pos].edges()[idx];
330 idx = best_scores[pos][idx].1;
331 pos = edge.target();
332 best_path.push(Edge::new(pos, edge.label()));
333 }
334 (best_path, best_score)
335 }
336}
337
338#[derive(Clone, Copy, Debug, Archive, RkyvSerialize, RkyvDeserialize)]
340pub struct MergedFeatureSet {
341 pub weight: f64,
343 pub left_id: NonZeroU32,
345 pub right_id: NonZeroU32,
347}
348
349#[derive(Archive, RkyvSerialize, RkyvDeserialize)]
351pub struct MergedModel {
352 pub feature_sets: Vec<MergedFeatureSet>,
354 #[rkyv(with = VecHashMapAsVec)]
356 pub matrix: Vec<HashMap<u32, f64>>,
357 pub left_conn_to_right_feats: Vec<Vec<Option<NonZeroU32>>>,
359 pub right_conn_to_left_feats: Vec<Vec<Option<NonZeroU32>>>,
361}
362
363impl Model for MergedModel {
364 fn search_best_path(&self, lattice: &Lattice) -> (Vec<Edge>, f64) {
365 let mut best_scores = vec![vec![]; lattice.nodes().len()];
366 best_scores[lattice.nodes().len() - 1].push((0, 0, None, 0.0));
367 for (i, node) in lattice.nodes().iter().enumerate() {
368 for edge in node.edges() {
369 let label = usize::from_u32(edge.label.get() - 1);
370 let score = self.feature_sets.get(label).map_or(0.0, |s| s.weight);
371 best_scores[i].push((edge.target(), 0, Some(edge.label), score));
372 }
373 }
374 for i in (0..lattice.nodes().len() - 1).rev() {
375 for j in 0..best_scores[i].len() {
376 let (k, _, curr_label, _) = best_scores[i][j];
377 let mut best_score = f64::NEG_INFINITY;
378 let mut best_idx = 0;
379 let curr_id = curr_label.map_or(Some(0), |label| {
380 self.feature_sets
381 .get(usize::from_u32(label.get() - 1))
382 .map(|s| s.right_id.get())
383 });
384 for (p, &(_, _, next_label, mut score)) in best_scores[k].iter().enumerate() {
385 let next_id = next_label.map_or(Some(0), |label| {
386 self.feature_sets
387 .get(usize::from_u32(label.get() - 1))
388 .map(|s| s.left_id.get())
389 });
390 if let (Some(curr_id), Some(next_id)) = (curr_id, next_id) {
391 score += self
392 .matrix
393 .get(usize::from_u32(curr_id))
394 .and_then(|hm| hm.get(&next_id))
395 .unwrap_or(&0.0);
396 }
397 if score > best_score {
398 best_score = score;
399 best_idx = p;
400 }
401 }
402 best_scores[i][j].1 = best_idx;
403 best_scores[i][j].3 += best_score;
404 }
405 }
406 let mut best_score = f64::NEG_INFINITY;
407 let mut idx = 0;
408 for (p, &(_, _, next_label, mut score)) in best_scores[0].iter().enumerate() {
409 let next_id = next_label.map_or(Some(0), |label| {
410 self.feature_sets
411 .get(usize::from_u32(label.get() - 1))
412 .map(|s| s.right_id.get())
413 });
414 if let Some(next_id) = next_id {
415 score += self
416 .matrix
417 .first()
418 .and_then(|hm| hm.get(&next_id))
419 .unwrap_or(&0.0);
420 }
421 if score > best_score {
422 best_score = score;
423 idx = p;
424 }
425 }
426 let mut pos = 0;
427 let mut best_path = vec![];
428 while pos < lattice.nodes().len() - 1 {
429 let edge = &lattice.nodes()[pos].edges()[idx];
430 idx = best_scores[pos][idx].1;
431 pos = edge.target();
432 best_path.push(Edge::new(pos, edge.label()));
433 }
434 (best_path, best_score)
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 use core::num::NonZeroU32;
443
444 use crate::lattice::Edge;
445 use crate::test_utils::{self, hashmap};
446
447 #[test]
448 fn test_search_best_path() {
449 let model = RawModel {
473 weights: vec![
474 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
475 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 13.0, 24.0, 5.0, 26.0, 27.0, 6.0,
476 ],
477 unigram_weight_indices: vec![
478 NonZeroU32::new(2),
479 NonZeroU32::new(4),
480 NonZeroU32::new(6),
481 NonZeroU32::new(8),
482 ],
483 bigram_weight_indices: vec![
484 hashmap![0 => 28, 1 => 0, 2 => 2, 3 => 4, 4 => 6],
485 hashmap![0 => 8, 1 => 9, 2 => 10, 3 => 11, 4 => 12],
486 hashmap![0 => 13, 1 => 14, 2 => 15, 3 => 16, 4 => 17],
487 hashmap![0 => 18, 1 => 19, 2 => 20, 3 => 21, 4 => 22],
488 hashmap![0 => 23, 1 => 24, 2 => 25, 3 => 26, 4 => 27],
489 ],
490 provider: test_utils::generate_test_feature_provider(),
491 };
492 let lattice = test_utils::generate_test_lattice();
493
494 let (path, score) = model.search_best_path(&lattice);
495
496 assert_eq!(
497 vec![
498 Edge::new(1, NonZeroU32::new(1).unwrap()),
499 Edge::new(2, NonZeroU32::new(2).unwrap()),
500 Edge::new(3, NonZeroU32::new(6).unwrap()),
501 Edge::new(4, NonZeroU32::new(7).unwrap()),
502 Edge::new(5, NonZeroU32::new(4).unwrap()),
503 ],
504 path,
505 );
506 assert!((194.0 - score).abs() < f64::EPSILON);
507 }
508
509 #[test]
510 fn test_hashed_search_best_path() {
511 let model = RawModel {
512 weights: vec![
513 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
514 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 13.0, 24.0, 5.0, 26.0, 27.0, 6.0,
515 ],
516 unigram_weight_indices: vec![
517 NonZeroU32::new(2),
518 NonZeroU32::new(4),
519 NonZeroU32::new(6),
520 NonZeroU32::new(8),
521 ],
522 bigram_weight_indices: vec![
523 hashmap![0 => 28, 1 => 0, 2 => 2, 3 => 4, 4 => 6],
524 hashmap![0 => 8, 1 => 9, 2 => 10, 3 => 11, 4 => 12],
525 hashmap![0 => 13, 1 => 14, 2 => 15, 3 => 16, 4 => 17],
526 hashmap![0 => 18, 1 => 19, 2 => 20, 3 => 21, 4 => 22],
527 hashmap![0 => 23, 1 => 24, 2 => 25, 3 => 26, 4 => 27],
528 ],
529 provider: test_utils::generate_test_feature_provider(),
530 };
531 let merged_model = model.merge().unwrap();
532
533 let lattice = test_utils::generate_test_lattice();
534
535 let (path, score) = merged_model.search_best_path(&lattice);
536
537 assert_eq!(
538 vec![
539 Edge::new(1, NonZeroU32::new(1).unwrap()),
540 Edge::new(2, NonZeroU32::new(2).unwrap()),
541 Edge::new(3, NonZeroU32::new(6).unwrap()),
542 Edge::new(4, NonZeroU32::new(7).unwrap()),
543 Edge::new(5, NonZeroU32::new(4).unwrap()),
544 ],
545 path,
546 );
547 assert!((194.0 - score).abs() < f64::EPSILON);
548 }
549}