lindera_dictionary/dictionary/
prefix_dictionary.rs1use daachorse::DoubleArrayAhoCorasick;
2use rkyv::rancor::{Fallible, Source};
3use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Place, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::{LinderaResult, error::LinderaErrorKind, util::Data, viterbi::WordEntry};
8
9#[derive(Debug, Clone)]
11pub struct Match {
12 pub word_idx: WordIdx,
13 pub end_char: usize,
14}
15
16#[derive(Debug, Clone, Copy)]
17pub struct WordIdx {
18 pub word_id: u32,
19}
20
21impl WordIdx {
22 pub fn new(word_id: u32) -> Self {
23 Self { word_id }
24 }
25}
26
27pub struct DoubleArrayArchiver;
28
29impl ArchiveWith<DoubleArrayAhoCorasick<u32>> for DoubleArrayArchiver {
30 type Archived = rkyv::vec::ArchivedVec<u8>;
31 type Resolver = rkyv::vec::VecResolver;
32
33 fn resolve_with(
34 field: &DoubleArrayAhoCorasick<u32>,
35 resolver: Self::Resolver,
36 out: Place<Self::Archived>,
37 ) {
38 let bytes = field.serialize();
39 rkyv::vec::ArchivedVec::resolve_from_slice(&bytes, resolver, out);
40 }
41}
42
43impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized>
44 SerializeWith<DoubleArrayAhoCorasick<u32>, S> for DoubleArrayArchiver
45{
46 fn serialize_with(
47 field: &DoubleArrayAhoCorasick<u32>,
48 serializer: &mut S,
49 ) -> Result<Self::Resolver, S::Error> {
50 let bytes = field.serialize();
51 rkyv::vec::ArchivedVec::serialize_from_slice(&bytes, serializer)
52 }
53}
54
55impl<D: Fallible<Error: Source> + ?Sized>
56 DeserializeWith<rkyv::vec::ArchivedVec<u8>, DoubleArrayAhoCorasick<u32>, D>
57 for DoubleArrayArchiver
58{
59 fn deserialize_with(
65 archived: &rkyv::vec::ArchivedVec<u8>,
66 _deserializer: &mut D,
67 ) -> Result<DoubleArrayAhoCorasick<u32>, D::Error> {
68 let (da, _) = DoubleArrayAhoCorasick::deserialize(archived.as_slice()).map_err(|err| {
69 D::Error::new(std::io::Error::new(
70 std::io::ErrorKind::InvalidData,
71 err.to_string(),
72 ))
73 })?;
74 Ok(da)
75 }
76}
77
78mod double_array_serde {
79 use daachorse::DoubleArrayAhoCorasick;
80 use serde::{Deserialize, Deserializer, Serializer};
81
82 pub fn serialize<S>(da: &DoubleArrayAhoCorasick<u32>, serializer: S) -> Result<S::Ok, S::Error>
83 where
84 S: Serializer,
85 {
86 let bytes = da.serialize();
87 serializer.serialize_bytes(&bytes)
88 }
89
90 pub fn deserialize<'de, D>(deserializer: D) -> Result<DoubleArrayAhoCorasick<u32>, D::Error>
91 where
92 D: Deserializer<'de>,
93 {
94 let bytes: Vec<u8> = Deserialize::deserialize(deserializer)?;
95 let (da, _) = DoubleArrayAhoCorasick::deserialize(&bytes)
96 .map_err(|err| serde::de::Error::custom(err.to_string()))?;
97 Ok(da)
98 }
99}
100
101#[derive(Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
102pub struct PrefixDictionary {
103 #[serde(with = "self::double_array_serde")]
104 #[rkyv(with = DoubleArrayArchiver)]
105 pub da: DoubleArrayAhoCorasick<u32>,
106 pub vals_data: Data,
107 pub words_idx_data: Data,
108 pub words_data: Data,
109 pub is_system: bool,
110}
111
112impl PrefixDictionary {
113 #[inline]
120 pub(crate) fn decode_val(&self, val: u32) -> (u32, u32) {
121 if self.is_system {
122 (val >> 8u32, val & ((1u32 << 8) - 1u32))
123 } else {
124 (val >> 5u32, val & ((1u32 << 5) - 1u32))
125 }
126 }
127
128 pub fn load(
142 da_data: impl Into<Data>,
143 vals_data: impl Into<Data>,
144 words_idx_data: impl Into<Data>,
145 words_data: impl Into<Data>,
146 is_system: bool,
147 ) -> LinderaResult<PrefixDictionary> {
148 let da_bytes = da_data.into();
149 let (da, _) = DoubleArrayAhoCorasick::deserialize(&da_bytes[..]).map_err(|err| {
150 LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(err.to_string()))
151 })?;
152
153 Ok(PrefixDictionary {
154 da,
155 vals_data: vals_data.into(),
156 words_idx_data: words_idx_data.into(),
157 words_data: words_data.into(),
158 is_system,
159 })
160 }
161
162 pub fn prefix<'a>(&'a self, s: &'a str) -> impl Iterator<Item = (usize, WordEntry)> + 'a {
163 self.da
164 .find_overlapping_iter(s)
165 .filter(|m| m.start() == 0)
166 .flat_map(move |m| {
167 let (offset, len) = self.decode_val(m.value());
168 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
169 let data: &[u8] = &self.vals_data[offset_bytes..];
170 (0..len as usize).map(move |i| {
171 (
172 m.end(),
173 WordEntry::deserialize(
174 &data[WordEntry::SERIALIZED_LEN * i..],
175 self.is_system,
176 ),
177 )
178 })
179 })
180 }
181
182 pub fn find_surface(&self, surface: &str) -> Vec<WordEntry> {
184 self.find_surface_iter(surface).collect()
185 }
186
187 pub fn find_surface_iter<'a>(
190 &'a self,
191 surface: &'a str,
192 ) -> impl Iterator<Item = WordEntry> + 'a {
193 self.da
194 .find_overlapping_iter(surface)
195 .filter(|m| m.start() == 0 && m.end() == surface.len())
196 .flat_map(move |m| {
197 let (offset, len) = self.decode_val(m.value());
198 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
199 let data = &self.vals_data[offset_bytes..];
200 (0..len as usize).map(move |i| {
201 WordEntry::deserialize(&data[WordEntry::SERIALIZED_LEN * i..], self.is_system)
202 })
203 })
204 }
205
206 pub fn common_prefix_iterator(&self, suffix: &[char]) -> Vec<Match> {
208 if self.vals_data.is_empty() {
212 return Vec::new();
213 }
214
215 let suffix_str: String = suffix.iter().collect();
216
217 self.da
218 .find_overlapping_iter(&suffix_str)
219 .filter(|m| m.start() == 0)
220 .flat_map(|m| {
221 let (offset, len) = self.decode_val(m.value());
222 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
223
224 if offset_bytes >= self.vals_data.len() {
226 return vec![].into_iter();
227 }
228
229 let data: &[u8] = &self.vals_data[offset_bytes..];
230 (0..len as usize)
231 .filter_map(move |i| {
232 let required_bytes = WordEntry::SERIALIZED_LEN * (i + 1);
233 if required_bytes <= data.len() {
234 let word_entry = WordEntry::deserialize(
235 &data[WordEntry::SERIALIZED_LEN * i..],
236 self.is_system,
237 );
238 Some(Match {
239 word_idx: WordIdx::new(word_entry.word_id.id),
240 end_char: m.end(), })
262 } else {
263 None
264 }
265 })
266 .collect::<Vec<_>>()
267 .into_iter()
268 })
269 .collect()
270 }
271}
272
273impl ArchivedPrefixDictionary {
274 #[inline]
276 fn decode_val(&self, val: u32) -> (u32, u32) {
277 if self.is_system {
278 (val >> 8u32, val & ((1u32 << 8) - 1u32))
279 } else {
280 (val >> 5u32, val & ((1u32 << 5) - 1u32))
281 }
282 }
283
284 pub fn prefix<'a>(
294 &'a self,
295 s: &'a str,
296 ) -> LinderaResult<impl Iterator<Item = (usize, WordEntry)> + 'a> {
297 let (da, _) =
299 DoubleArrayAhoCorasick::<u32>::deserialize(self.da.as_slice()).map_err(|err| {
300 LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(err.to_string()))
301 })?;
302
303 let matches: Vec<_> = da
304 .find_overlapping_iter(s)
305 .filter(|m| m.start() == 0)
306 .map(|m| (m.end(), m.value()))
307 .collect();
308
309 Ok(matches.into_iter().flat_map(move |(end, offset_len)| {
310 let (offset, len) = self.decode_val(offset_len);
311 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
312
313 let vals = self.vals_data.as_slice();
314 if offset_bytes >= vals.len() {
315 return vec![].into_iter();
316 }
317
318 let data = &vals[offset_bytes..];
319 (0..len as usize)
320 .map(move |i| {
321 (
322 end,
323 WordEntry::deserialize(
324 &data[WordEntry::SERIALIZED_LEN * i..],
325 self.is_system,
326 ),
327 )
328 })
329 .collect::<Vec<_>>()
330 .into_iter()
331 }))
332 }
333
334 pub fn find_surface(&self, surface: &str) -> LinderaResult<Vec<WordEntry>> {
344 let (da, _) =
345 DoubleArrayAhoCorasick::<u32>::deserialize(self.da.as_slice()).map_err(|err| {
346 LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(err.to_string()))
347 })?;
348
349 let matches: Vec<_> = da
350 .find_overlapping_iter(surface)
351 .filter(|m| m.start() == 0 && m.end() == surface.len())
352 .map(|m| m.value())
353 .collect();
354
355 Ok(matches
356 .into_iter()
357 .flat_map(|offset_len| {
358 let (offset, len) = self.decode_val(offset_len);
359 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
360 let vals = self.vals_data.as_slice();
361 if offset_bytes >= vals.len() {
362 return Vec::new().into_iter();
363 }
364 let data = &vals[offset_bytes..];
365 (0..len as usize)
366 .map(|i| {
367 WordEntry::deserialize(
368 &data[WordEntry::SERIALIZED_LEN * i..],
369 self.is_system,
370 )
371 })
372 .collect::<Vec<_>>()
373 .into_iter()
374 })
375 .collect())
376 }
377}