lindera_dictionary/dictionary/
prefix_dictionary.rs1use std::ops::Deref;
2
3use rkyv::rancor::Fallible;
4use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
5use rkyv::{Archive, Deserialize as RkyvDeserialize, Place, Serialize as RkyvSerialize};
6use serde::{Deserialize, Serialize};
7use yada::DoubleArray;
8
9use crate::{util::Data, viterbi::WordEntry};
10
11#[derive(Debug, Clone)]
13pub struct Match {
14 pub word_idx: WordIdx,
15 pub end_char: usize,
16}
17
18#[derive(Debug, Clone, Copy)]
19pub struct WordIdx {
20 pub word_id: u32,
21}
22
23impl WordIdx {
24 pub fn new(word_id: u32) -> Self {
25 Self { word_id }
26 }
27}
28
29#[derive(Serialize, Deserialize)]
30#[serde(remote = "DoubleArray")]
31struct DoubleArrayDef<T>(pub T)
32where
33 T: Deref<Target = [u8]>;
34
35pub struct DoubleArrayArchiver;
36
37impl ArchiveWith<DoubleArray<Data>> for DoubleArrayArchiver {
38 type Archived = rkyv::vec::ArchivedVec<u8>;
39 type Resolver = rkyv::vec::VecResolver;
40
41 fn resolve_with(
42 field: &DoubleArray<Data>,
43 resolver: Self::Resolver,
44 out: Place<Self::Archived>,
45 ) {
46 rkyv::vec::ArchivedVec::resolve_from_slice(&field.0[..], resolver, out);
48 }
49}
50
51impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized>
52 SerializeWith<DoubleArray<Data>, S> for DoubleArrayArchiver
53{
54 fn serialize_with(
55 field: &DoubleArray<Data>,
56 serializer: &mut S,
57 ) -> Result<Self::Resolver, S::Error> {
58 rkyv::vec::ArchivedVec::serialize_from_slice(&field.0[..], serializer)
59 }
60}
61
62impl<D: Fallible + ?Sized> DeserializeWith<rkyv::vec::ArchivedVec<u8>, DoubleArray<Data>, D>
63 for DoubleArrayArchiver
64{
65 fn deserialize_with(
66 archived: &rkyv::vec::ArchivedVec<u8>,
67 _deserializer: &mut D,
68 ) -> Result<DoubleArray<Data>, D::Error> {
69 let mut vec = Vec::with_capacity(archived.len());
70 vec.extend_from_slice(archived.as_slice());
71 Ok(DoubleArray::new(Data::Vec(vec)))
72 }
73}
74
75#[derive(Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
76pub struct PrefixDictionary {
77 #[serde(with = "DoubleArrayDef")]
78 #[rkyv(with = DoubleArrayArchiver)]
79 pub da: DoubleArray<Data>,
80 pub vals_data: Data,
81 pub words_idx_data: Data,
82 pub words_data: Data,
83 pub is_system: bool,
84}
85
86impl PrefixDictionary {
87 pub fn load(
88 da_data: impl Into<Data>,
89 vals_data: impl Into<Data>,
90 words_idx_data: impl Into<Data>,
91 words_data: impl Into<Data>,
92 is_system: bool,
93 ) -> PrefixDictionary {
94 let da = DoubleArray::new(da_data.into());
95
96 PrefixDictionary {
97 da,
98 vals_data: vals_data.into(),
99 words_idx_data: words_idx_data.into(),
100 words_data: words_data.into(),
101 is_system,
102 }
103 }
104
105 pub fn prefix<'a>(&'a self, s: &'a str) -> impl Iterator<Item = (usize, WordEntry)> + 'a {
106 self.da
107 .common_prefix_search(s)
108 .flat_map(move |(offset_len, prefix_len)| {
109 let len = offset_len & ((1u32 << 5) - 1u32);
110 let offset = offset_len >> 5u32;
111 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
112 let data: &[u8] = &self.vals_data[offset_bytes..];
113 (0..len as usize).map(move |i| {
114 (
115 prefix_len,
116 WordEntry::deserialize(
117 &data[WordEntry::SERIALIZED_LEN * i..],
118 self.is_system,
119 ),
120 )
121 })
122 })
123 }
124
125 pub fn find_surface(&self, surface: &str) -> Vec<WordEntry> {
127 match self.da.exact_match_search(surface) {
128 Some(offset_len) => {
129 let offset = offset_len >> 5u32;
130 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
131 let data: &[u8] = &self.vals_data[offset_bytes..];
132 let len = offset_len & ((1u32 << 5) - 1u32);
133 (0..len as usize)
134 .map(|i| {
135 WordEntry::deserialize(
136 &data[WordEntry::SERIALIZED_LEN * i..],
137 self.is_system,
138 )
139 })
140 .collect::<Vec<WordEntry>>()
141 }
142 None => vec![],
143 }
144 }
145
146 pub fn find_surface_iter(&self, surface: &str) -> impl Iterator<Item = WordEntry> + '_ {
149 self.da
150 .exact_match_search(surface)
151 .map(|offset_len| {
152 let offset = offset_len >> 5u32;
153 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
154 let data = &self.vals_data[offset_bytes..];
155 let len = offset_len & ((1u32 << 5) - 1u32);
156 (0..len as usize).map(move |i| {
157 WordEntry::deserialize(&data[WordEntry::SERIALIZED_LEN * i..], self.is_system)
158 })
159 })
160 .into_iter()
161 .flatten()
162 }
163
164 pub fn common_prefix_iterator(&self, suffix: &[char]) -> Vec<Match> {
166 if self.vals_data.is_empty() {
168 return Vec::new();
169 }
170
171 let suffix_str: String = suffix.iter().collect();
172 self.da
173 .common_prefix_search(&suffix_str)
174 .flat_map(|(offset_len, prefix_len)| {
175 let len = offset_len & ((1u32 << 5) - 1u32);
176 let offset = offset_len >> 5u32;
177 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
178
179 if offset_bytes >= self.vals_data.len() {
181 return vec![].into_iter();
182 }
183
184 let data: &[u8] = &self.vals_data[offset_bytes..];
185 (0..len as usize)
186 .filter_map(move |i| {
187 let required_bytes = WordEntry::SERIALIZED_LEN * (i + 1);
188 if required_bytes <= data.len() {
189 let word_entry = WordEntry::deserialize(
190 &data[WordEntry::SERIALIZED_LEN * i..],
191 self.is_system,
192 );
193 Some(Match {
194 word_idx: WordIdx::new(word_entry.word_id.id),
195 end_char: prefix_len,
196 })
197 } else {
198 None
199 }
200 })
201 .collect::<Vec<_>>()
202 .into_iter()
203 })
204 .collect()
205 }
206}
207
208impl ArchivedPrefixDictionary {
209 pub fn prefix<'a>(&'a self, s: &'a str) -> impl Iterator<Item = (usize, WordEntry)> + 'a {
210 let da = DoubleArray::new(self.da.as_slice());
211 let matches: Vec<_> = da.common_prefix_search(s).collect();
212
213 matches
214 .into_iter()
215 .flat_map(move |(offset_len, prefix_len)| {
216 let len = offset_len & ((1u32 << 5) - 1u32);
217 let offset = offset_len >> 5u32;
218 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
219 let data: &[u8] = &self.vals_data.as_slice()[offset_bytes..];
220 (0..len as usize).map(move |i| {
221 (
222 prefix_len,
223 WordEntry::deserialize(
224 &data[WordEntry::SERIALIZED_LEN * i..],
225 self.is_system,
226 ),
227 )
228 })
229 })
230 }
231
232 pub fn find_surface(&self, surface: &str) -> Vec<WordEntry> {
233 let da = DoubleArray::new(self.da.as_slice());
234 match da.exact_match_search(surface) {
235 Some(offset_len) => {
236 let offset = offset_len >> 5u32;
237 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
238 let data: &[u8] = &self.vals_data.as_slice()[offset_bytes..];
239 let len = offset_len & ((1u32 << 5) - 1u32);
240 (0..len as usize)
241 .map(|i| {
242 WordEntry::deserialize(
243 &data[WordEntry::SERIALIZED_LEN * i..],
244 self.is_system,
245 )
246 })
247 .collect::<Vec<WordEntry>>()
248 }
249 None => vec![],
250 }
251 }
252
253 pub fn find_surface_iter(&self, surface: &str) -> impl Iterator<Item = WordEntry> + '_ {
254 let da = DoubleArray::new(self.da.as_slice());
255 da.exact_match_search(surface)
256 .map(|offset_len| {
257 let offset = offset_len >> 5u32;
258 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
259 let data = &self.vals_data.as_slice()[offset_bytes..];
260 let len = offset_len & ((1u32 << 5) - 1u32);
261 (0..len as usize).map(move |i| {
262 WordEntry::deserialize(&data[WordEntry::SERIALIZED_LEN * i..], self.is_system)
263 })
264 })
265 .into_iter()
266 .flatten()
267 }
268
269 pub fn common_prefix_iterator(&self, suffix: &[char]) -> Vec<Match> {
270 if self.vals_data.as_slice().is_empty() {
271 return Vec::new();
272 }
273
274 let suffix_str: String = suffix.iter().collect();
275 let da = DoubleArray::new(self.da.as_slice());
276
277 da.common_prefix_search(&suffix_str)
278 .flat_map(|(offset_len, prefix_len)| {
279 let len = offset_len & ((1u32 << 5) - 1u32);
280 let offset = offset_len >> 5u32;
281 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
282
283 if offset_bytes >= self.vals_data.as_slice().len() {
284 return vec![].into_iter();
285 }
286
287 let data: &[u8] = &self.vals_data.as_slice()[offset_bytes..];
288 (0..len as usize)
289 .filter_map(move |i| {
290 let required_bytes = WordEntry::SERIALIZED_LEN * (i + 1);
291 if required_bytes <= data.len() {
292 let word_entry = WordEntry::deserialize(
293 &data[WordEntry::SERIALIZED_LEN * i..],
294 self.is_system,
295 );
296 Some(Match {
297 word_idx: WordIdx::new(word_entry.word_id.id),
298 end_char: prefix_len,
299 })
300 } else {
301 None
302 }
303 })
304 .collect::<Vec<_>>()
305 .into_iter()
306 })
307 .collect()
308 }
309}