lindera_dictionary/dictionary/
prefix_dictionary.rs

1use std::ops::Deref;
2
3use rkyv::rancor::Fallible;
4use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
5use rkyv::{Archive, Deserialize as RkyvDeserialize, Place, Serialize as RkyvSerialize};
6use serde::{Deserialize, Serialize};
7use yada::DoubleArray;
8
9use crate::{util::Data, viterbi::WordEntry};
10
11/// Match structure for common prefix iterator compatibility
12#[derive(Debug, Clone)]
13pub struct Match {
14    pub word_idx: WordIdx,
15    pub end_char: usize,
16}
17
18#[derive(Debug, Clone, Copy)]
19pub struct WordIdx {
20    pub word_id: u32,
21}
22
23impl WordIdx {
24    pub fn new(word_id: u32) -> Self {
25        Self { word_id }
26    }
27}
28
29#[derive(Serialize, Deserialize)]
30#[serde(remote = "DoubleArray")]
31struct DoubleArrayDef<T>(pub T)
32where
33    T: Deref<Target = [u8]>;
34
35pub struct DoubleArrayArchiver;
36
37impl ArchiveWith<DoubleArray<Data>> for DoubleArrayArchiver {
38    type Archived = rkyv::vec::ArchivedVec<u8>;
39    type Resolver = rkyv::vec::VecResolver;
40
41    fn resolve_with(
42        field: &DoubleArray<Data>,
43        resolver: Self::Resolver,
44        out: Place<Self::Archived>,
45    ) {
46        // DoubleArray<Data> derefs to [u8] via Data
47        rkyv::vec::ArchivedVec::resolve_from_slice(&field.0[..], resolver, out);
48    }
49}
50
51impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized>
52    SerializeWith<DoubleArray<Data>, S> for DoubleArrayArchiver
53{
54    fn serialize_with(
55        field: &DoubleArray<Data>,
56        serializer: &mut S,
57    ) -> Result<Self::Resolver, S::Error> {
58        rkyv::vec::ArchivedVec::serialize_from_slice(&field.0[..], serializer)
59    }
60}
61
62impl<D: Fallible + ?Sized> DeserializeWith<rkyv::vec::ArchivedVec<u8>, DoubleArray<Data>, D>
63    for DoubleArrayArchiver
64{
65    fn deserialize_with(
66        archived: &rkyv::vec::ArchivedVec<u8>,
67        _deserializer: &mut D,
68    ) -> Result<DoubleArray<Data>, D::Error> {
69        let mut vec = Vec::with_capacity(archived.len());
70        vec.extend_from_slice(archived.as_slice());
71        Ok(DoubleArray::new(Data::Vec(vec)))
72    }
73}
74
75#[derive(Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
76pub struct PrefixDictionary {
77    #[serde(with = "DoubleArrayDef")]
78    #[rkyv(with = DoubleArrayArchiver)]
79    pub da: DoubleArray<Data>,
80    pub vals_data: Data,
81    pub words_idx_data: Data,
82    pub words_data: Data,
83    pub is_system: bool,
84}
85
86impl PrefixDictionary {
87    pub fn load(
88        da_data: impl Into<Data>,
89        vals_data: impl Into<Data>,
90        words_idx_data: impl Into<Data>,
91        words_data: impl Into<Data>,
92        is_system: bool,
93    ) -> PrefixDictionary {
94        let da = DoubleArray::new(da_data.into());
95
96        PrefixDictionary {
97            da,
98            vals_data: vals_data.into(),
99            words_idx_data: words_idx_data.into(),
100            words_data: words_data.into(),
101            is_system,
102        }
103    }
104
105    pub fn prefix<'a>(&'a self, s: &'a str) -> impl Iterator<Item = (usize, WordEntry)> + 'a {
106        self.da
107            .common_prefix_search(s)
108            .flat_map(move |(offset_len, prefix_len)| {
109                let len = offset_len & ((1u32 << 5) - 1u32);
110                let offset = offset_len >> 5u32;
111                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
112                let data: &[u8] = &self.vals_data[offset_bytes..];
113                (0..len as usize).map(move |i| {
114                    (
115                        prefix_len,
116                        WordEntry::deserialize(
117                            &data[WordEntry::SERIALIZED_LEN * i..],
118                            self.is_system,
119                        ),
120                    )
121                })
122            })
123    }
124
125    /// Find `WordEntry`s with surface
126    pub fn find_surface(&self, surface: &str) -> Vec<WordEntry> {
127        match self.da.exact_match_search(surface) {
128            Some(offset_len) => {
129                let offset = offset_len >> 5u32;
130                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
131                let data: &[u8] = &self.vals_data[offset_bytes..];
132                let len = offset_len & ((1u32 << 5) - 1u32);
133                (0..len as usize)
134                    .map(|i| {
135                        WordEntry::deserialize(
136                            &data[WordEntry::SERIALIZED_LEN * i..],
137                            self.is_system,
138                        )
139                    })
140                    .collect::<Vec<WordEntry>>()
141            }
142            None => vec![],
143        }
144    }
145
146    /// Find `WordEntry`s with surface using lazy evaluation
147    /// This iterator-based approach reduces memory allocations
148    pub fn find_surface_iter(&self, surface: &str) -> impl Iterator<Item = WordEntry> + '_ {
149        self.da
150            .exact_match_search(surface)
151            .map(|offset_len| {
152                let offset = offset_len >> 5u32;
153                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
154                let data = &self.vals_data[offset_bytes..];
155                let len = offset_len & ((1u32 << 5) - 1u32);
156                (0..len as usize).map(move |i| {
157                    WordEntry::deserialize(&data[WordEntry::SERIALIZED_LEN * i..], self.is_system)
158                })
159            })
160            .into_iter()
161            .flatten()
162    }
163
164    /// Common prefix iterator using character array input
165    pub fn common_prefix_iterator(&self, suffix: &[char]) -> Vec<Match> {
166        // 空の辞書の場合は空のマッチを返す
167        if self.vals_data.is_empty() {
168            return Vec::new();
169        }
170
171        let suffix_str: String = suffix.iter().collect();
172        self.da
173            .common_prefix_search(&suffix_str)
174            .flat_map(|(offset_len, prefix_len)| {
175                let len = offset_len & ((1u32 << 5) - 1u32);
176                let offset = offset_len >> 5u32;
177                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
178
179                // 範囲チェックを追加
180                if offset_bytes >= self.vals_data.len() {
181                    return vec![].into_iter();
182                }
183
184                let data: &[u8] = &self.vals_data[offset_bytes..];
185                (0..len as usize)
186                    .filter_map(move |i| {
187                        let required_bytes = WordEntry::SERIALIZED_LEN * (i + 1);
188                        if required_bytes <= data.len() {
189                            let word_entry = WordEntry::deserialize(
190                                &data[WordEntry::SERIALIZED_LEN * i..],
191                                self.is_system,
192                            );
193                            Some(Match {
194                                word_idx: WordIdx::new(word_entry.word_id.id),
195                                end_char: prefix_len,
196                            })
197                        } else {
198                            None
199                        }
200                    })
201                    .collect::<Vec<_>>()
202                    .into_iter()
203            })
204            .collect()
205    }
206}
207
208impl ArchivedPrefixDictionary {
209    pub fn prefix<'a>(&'a self, s: &'a str) -> impl Iterator<Item = (usize, WordEntry)> + 'a {
210        let da = DoubleArray::new(self.da.as_slice());
211        let matches: Vec<_> = da.common_prefix_search(s).collect();
212
213        matches
214            .into_iter()
215            .flat_map(move |(offset_len, prefix_len)| {
216                let len = offset_len & ((1u32 << 5) - 1u32);
217                let offset = offset_len >> 5u32;
218                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
219                let data: &[u8] = &self.vals_data.as_slice()[offset_bytes..];
220                (0..len as usize).map(move |i| {
221                    (
222                        prefix_len,
223                        WordEntry::deserialize(
224                            &data[WordEntry::SERIALIZED_LEN * i..],
225                            self.is_system,
226                        ),
227                    )
228                })
229            })
230    }
231
232    pub fn find_surface(&self, surface: &str) -> Vec<WordEntry> {
233        let da = DoubleArray::new(self.da.as_slice());
234        match da.exact_match_search(surface) {
235            Some(offset_len) => {
236                let offset = offset_len >> 5u32;
237                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
238                let data: &[u8] = &self.vals_data.as_slice()[offset_bytes..];
239                let len = offset_len & ((1u32 << 5) - 1u32);
240                (0..len as usize)
241                    .map(|i| {
242                        WordEntry::deserialize(
243                            &data[WordEntry::SERIALIZED_LEN * i..],
244                            self.is_system,
245                        )
246                    })
247                    .collect::<Vec<WordEntry>>()
248            }
249            None => vec![],
250        }
251    }
252
253    pub fn find_surface_iter(&self, surface: &str) -> impl Iterator<Item = WordEntry> + '_ {
254        let da = DoubleArray::new(self.da.as_slice());
255        da.exact_match_search(surface)
256            .map(|offset_len| {
257                let offset = offset_len >> 5u32;
258                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
259                let data = &self.vals_data.as_slice()[offset_bytes..];
260                let len = offset_len & ((1u32 << 5) - 1u32);
261                (0..len as usize).map(move |i| {
262                    WordEntry::deserialize(&data[WordEntry::SERIALIZED_LEN * i..], self.is_system)
263                })
264            })
265            .into_iter()
266            .flatten()
267    }
268
269    pub fn common_prefix_iterator(&self, suffix: &[char]) -> Vec<Match> {
270        if self.vals_data.as_slice().is_empty() {
271            return Vec::new();
272        }
273
274        let suffix_str: String = suffix.iter().collect();
275        let da = DoubleArray::new(self.da.as_slice());
276
277        da.common_prefix_search(&suffix_str)
278            .flat_map(|(offset_len, prefix_len)| {
279                let len = offset_len & ((1u32 << 5) - 1u32);
280                let offset = offset_len >> 5u32;
281                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
282
283                if offset_bytes >= self.vals_data.as_slice().len() {
284                    return vec![].into_iter();
285                }
286
287                let data: &[u8] = &self.vals_data.as_slice()[offset_bytes..];
288                (0..len as usize)
289                    .filter_map(move |i| {
290                        let required_bytes = WordEntry::SERIALIZED_LEN * (i + 1);
291                        if required_bytes <= data.len() {
292                            let word_entry = WordEntry::deserialize(
293                                &data[WordEntry::SERIALIZED_LEN * i..],
294                                self.is_system,
295                            );
296                            Some(Match {
297                                word_idx: WordIdx::new(word_entry.word_id.id),
298                                end_char: prefix_len,
299                            })
300                        } else {
301                            None
302                        }
303                    })
304                    .collect::<Vec<_>>()
305                    .into_iter()
306            })
307            .collect()
308    }
309}