lindera_dictionary/dictionary/
prefix_dictionary.rs1use daachorse::DoubleArrayAhoCorasick;
2use rkyv::rancor::Fallible;
3use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Place, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::{util::Data, viterbi::WordEntry};
8
9#[derive(Debug, Clone)]
11pub struct Match {
12 pub word_idx: WordIdx,
13 pub end_char: usize,
14}
15
16#[derive(Debug, Clone, Copy)]
17pub struct WordIdx {
18 pub word_id: u32,
19}
20
21impl WordIdx {
22 pub fn new(word_id: u32) -> Self {
23 Self { word_id }
24 }
25}
26
27pub struct DoubleArrayArchiver;
28
29impl ArchiveWith<DoubleArrayAhoCorasick<u32>> for DoubleArrayArchiver {
30 type Archived = rkyv::vec::ArchivedVec<u8>;
31 type Resolver = rkyv::vec::VecResolver;
32
33 fn resolve_with(
34 field: &DoubleArrayAhoCorasick<u32>,
35 resolver: Self::Resolver,
36 out: Place<Self::Archived>,
37 ) {
38 let bytes = field.serialize();
39 rkyv::vec::ArchivedVec::resolve_from_slice(&bytes, resolver, out);
40 }
41}
42
43impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized>
44 SerializeWith<DoubleArrayAhoCorasick<u32>, S> for DoubleArrayArchiver
45{
46 fn serialize_with(
47 field: &DoubleArrayAhoCorasick<u32>,
48 serializer: &mut S,
49 ) -> Result<Self::Resolver, S::Error> {
50 let bytes = field.serialize();
51 rkyv::vec::ArchivedVec::serialize_from_slice(&bytes, serializer)
52 }
53}
54
55impl<D: Fallible + ?Sized>
56 DeserializeWith<rkyv::vec::ArchivedVec<u8>, DoubleArrayAhoCorasick<u32>, D>
57 for DoubleArrayArchiver
58{
59 fn deserialize_with(
60 archived: &rkyv::vec::ArchivedVec<u8>,
61 _deserializer: &mut D,
62 ) -> Result<DoubleArrayAhoCorasick<u32>, D::Error> {
63 unsafe {
64 let (da, _) = DoubleArrayAhoCorasick::deserialize_unchecked(archived.as_slice());
65 Ok(da)
66 }
67 }
68}
69
70mod double_array_serde {
71 use daachorse::DoubleArrayAhoCorasick;
72 use serde::{Deserialize, Deserializer, Serializer};
73
74 pub fn serialize<S>(da: &DoubleArrayAhoCorasick<u32>, serializer: S) -> Result<S::Ok, S::Error>
75 where
76 S: Serializer,
77 {
78 let bytes = da.serialize();
79 serializer.serialize_bytes(&bytes)
80 }
81
82 pub fn deserialize<'de, D>(deserializer: D) -> Result<DoubleArrayAhoCorasick<u32>, D::Error>
83 where
84 D: Deserializer<'de>,
85 {
86 let bytes: Vec<u8> = Deserialize::deserialize(deserializer)?;
87 unsafe {
88 let (da, _) = DoubleArrayAhoCorasick::deserialize_unchecked(&bytes);
89 Ok(da)
90 }
91 }
92}
93
94#[derive(Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
95pub struct PrefixDictionary {
96 #[serde(with = "self::double_array_serde")]
97 #[rkyv(with = DoubleArrayArchiver)]
98 pub da: DoubleArrayAhoCorasick<u32>,
99 pub vals_data: Data,
100 pub words_idx_data: Data,
101 pub words_data: Data,
102 pub is_system: bool,
103}
104
105impl PrefixDictionary {
106 pub fn load(
107 da_data: impl Into<Data>,
108 vals_data: impl Into<Data>,
109 words_idx_data: impl Into<Data>,
110 words_data: impl Into<Data>,
111 is_system: bool,
112 ) -> PrefixDictionary {
113 let da_bytes = da_data.into();
114 let (da, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(&da_bytes[..]) };
115
116 PrefixDictionary {
117 da,
118 vals_data: vals_data.into(),
119 words_idx_data: words_idx_data.into(),
120 words_data: words_data.into(),
121 is_system,
122 }
123 }
124
125 pub fn prefix<'a>(&'a self, s: &'a str) -> impl Iterator<Item = (usize, WordEntry)> + 'a {
126 self.da
127 .find_overlapping_iter(s)
128 .filter(|m| m.start() == 0)
129 .flat_map(move |m| {
130 let id = m.value();
131 let len = id & ((1u32 << 5) - 1u32);
132 let offset = id >> 5u32;
133 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
134 let data: &[u8] = &self.vals_data[offset_bytes..];
135 (0..len as usize).map(move |i| {
136 (
137 m.end(),
138 WordEntry::deserialize(
139 &data[WordEntry::SERIALIZED_LEN * i..],
140 self.is_system,
141 ),
142 )
143 })
144 })
145 }
146
147 pub fn find_surface(&self, surface: &str) -> Vec<WordEntry> {
149 self.find_surface_iter(surface).collect()
150 }
151
152 pub fn find_surface_iter<'a>(
155 &'a self,
156 surface: &'a str,
157 ) -> impl Iterator<Item = WordEntry> + 'a {
158 self.da
159 .find_overlapping_iter(surface)
160 .filter(|m| m.start() == 0 && m.end() == surface.len())
161 .flat_map(move |m| {
162 let offset_len = m.value();
163 let offset = offset_len >> 5u32;
164 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
165 let data = &self.vals_data[offset_bytes..];
166 let len = offset_len & ((1u32 << 5) - 1u32);
167 (0..len as usize).map(move |i| {
168 WordEntry::deserialize(&data[WordEntry::SERIALIZED_LEN * i..], self.is_system)
169 })
170 })
171 }
172
173 pub fn common_prefix_iterator(&self, suffix: &[char]) -> Vec<Match> {
175 if self.vals_data.is_empty() {
179 return Vec::new();
180 }
181
182 let suffix_str: String = suffix.iter().collect();
183
184 self.da
185 .find_overlapping_iter(&suffix_str)
186 .filter(|m| m.start() == 0)
187 .flat_map(|m| {
188 let offset_len = m.value();
189 let len = offset_len & ((1u32 << 5) - 1u32);
190 let offset = offset_len >> 5u32;
191 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
192
193 if offset_bytes >= self.vals_data.len() {
195 return vec![].into_iter();
196 }
197
198 let data: &[u8] = &self.vals_data[offset_bytes..];
199 (0..len as usize)
200 .filter_map(move |i| {
201 let required_bytes = WordEntry::SERIALIZED_LEN * (i + 1);
202 if required_bytes <= data.len() {
203 let word_entry = WordEntry::deserialize(
204 &data[WordEntry::SERIALIZED_LEN * i..],
205 self.is_system,
206 );
207 Some(Match {
208 word_idx: WordIdx::new(word_entry.word_id.id),
209 end_char: m.end(), })
231 } else {
232 None
233 }
234 })
235 .collect::<Vec<_>>()
236 .into_iter()
237 })
238 .collect()
239 }
240}
241
242impl ArchivedPrefixDictionary {
243 pub fn prefix<'a>(&'a self, s: &'a str) -> impl Iterator<Item = (usize, WordEntry)> + 'a {
244 let (da, _) =
246 unsafe { DoubleArrayAhoCorasick::<u32>::deserialize_unchecked(self.da.as_slice()) };
247
248 let matches: Vec<_> = da
249 .find_overlapping_iter(s)
250 .filter(|m| m.start() == 0)
251 .map(|m| (m.end(), m.value()))
252 .collect();
253
254 matches.into_iter().flat_map(move |(end, offset_len)| {
255 let len = offset_len & ((1u32 << 5) - 1u32);
256 let offset = offset_len >> 5u32;
257 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
258
259 let vals = self.vals_data.as_slice();
260 if offset_bytes >= vals.len() {
262 return vec![].into_iter(); }
264
265 let data = &vals[offset_bytes..];
266 (0..len as usize)
267 .map(move |i| {
268 (
269 end,
270 WordEntry::deserialize(
271 &data[WordEntry::SERIALIZED_LEN * i..],
272 self.is_system,
273 ),
274 )
275 })
276 .collect::<Vec<_>>() .into_iter()
278 })
279 }
280
281 pub fn find_surface(&self, surface: &str) -> Vec<WordEntry> {
282 let (da, _) =
283 unsafe { DoubleArrayAhoCorasick::<u32>::deserialize_unchecked(self.da.as_slice()) };
284
285 let matches: Vec<_> = da
287 .find_overlapping_iter(surface)
288 .filter(|m| m.start() == 0 && m.end() == surface.len())
289 .map(|m| m.value())
290 .collect();
291
292 matches
293 .into_iter()
294 .flat_map(|offset_len| {
295 let len = offset_len & ((1u32 << 5) - 1u32);
296 let offset = offset_len >> 5u32;
297 let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
298 let vals = self.vals_data.as_slice();
299 if offset_bytes >= vals.len() {
300 return Vec::new().into_iter();
301 }
302 let data = &vals[offset_bytes..];
303 (0..len as usize)
304 .map(|i| {
305 WordEntry::deserialize(
306 &data[WordEntry::SERIALIZED_LEN * i..],
307 self.is_system,
308 )
309 })
310 .collect::<Vec<_>>()
311 .into_iter()
312 })
313 .collect()
314 }
315}