Skip to main content

lindera_dictionary/dictionary/
prefix_dictionary.rs

1use daachorse::DoubleArrayAhoCorasick;
2use rkyv::rancor::{Fallible, Source};
3use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Place, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::{LinderaResult, error::LinderaErrorKind, util::Data, viterbi::WordEntry};
8
9/// Match structure for common prefix iterator compatibility
10#[derive(Debug, Clone)]
11pub struct Match {
12    pub word_idx: WordIdx,
13    pub end_char: usize,
14}
15
16#[derive(Debug, Clone, Copy)]
17pub struct WordIdx {
18    pub word_id: u32,
19}
20
21impl WordIdx {
22    pub fn new(word_id: u32) -> Self {
23        Self { word_id }
24    }
25}
26
27pub struct DoubleArrayArchiver;
28
29impl ArchiveWith<DoubleArrayAhoCorasick<u32>> for DoubleArrayArchiver {
30    type Archived = rkyv::vec::ArchivedVec<u8>;
31    type Resolver = rkyv::vec::VecResolver;
32
33    fn resolve_with(
34        field: &DoubleArrayAhoCorasick<u32>,
35        resolver: Self::Resolver,
36        out: Place<Self::Archived>,
37    ) {
38        let bytes = field.serialize();
39        rkyv::vec::ArchivedVec::resolve_from_slice(&bytes, resolver, out);
40    }
41}
42
43impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized>
44    SerializeWith<DoubleArrayAhoCorasick<u32>, S> for DoubleArrayArchiver
45{
46    fn serialize_with(
47        field: &DoubleArrayAhoCorasick<u32>,
48        serializer: &mut S,
49    ) -> Result<Self::Resolver, S::Error> {
50        let bytes = field.serialize();
51        rkyv::vec::ArchivedVec::serialize_from_slice(&bytes, serializer)
52    }
53}
54
55impl<D: Fallible<Error: Source> + ?Sized>
56    DeserializeWith<rkyv::vec::ArchivedVec<u8>, DoubleArrayAhoCorasick<u32>, D>
57    for DoubleArrayArchiver
58{
59    /// Deserialize the archived byte vector into a `DoubleArrayAhoCorasick`.
60    ///
61    /// # Returns
62    ///
63    /// The deserialized `DoubleArrayAhoCorasick`, or an error if deserialization fails.
64    fn deserialize_with(
65        archived: &rkyv::vec::ArchivedVec<u8>,
66        _deserializer: &mut D,
67    ) -> Result<DoubleArrayAhoCorasick<u32>, D::Error> {
68        let (da, _) = DoubleArrayAhoCorasick::deserialize(archived.as_slice()).map_err(|err| {
69            D::Error::new(std::io::Error::new(
70                std::io::ErrorKind::InvalidData,
71                err.to_string(),
72            ))
73        })?;
74        Ok(da)
75    }
76}
77
78mod double_array_serde {
79    use daachorse::DoubleArrayAhoCorasick;
80    use serde::{Deserialize, Deserializer, Serializer};
81
82    pub fn serialize<S>(da: &DoubleArrayAhoCorasick<u32>, serializer: S) -> Result<S::Ok, S::Error>
83    where
84        S: Serializer,
85    {
86        let bytes = da.serialize();
87        serializer.serialize_bytes(&bytes)
88    }
89
90    pub fn deserialize<'de, D>(deserializer: D) -> Result<DoubleArrayAhoCorasick<u32>, D::Error>
91    where
92        D: Deserializer<'de>,
93    {
94        let bytes: Vec<u8> = Deserialize::deserialize(deserializer)?;
95        let (da, _) = DoubleArrayAhoCorasick::deserialize(&bytes)
96            .map_err(|err| serde::de::Error::custom(err.to_string()))?;
97        Ok(da)
98    }
99}
100
101#[derive(Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
102pub struct PrefixDictionary {
103    #[serde(with = "self::double_array_serde")]
104    #[rkyv(with = DoubleArrayArchiver)]
105    pub da: DoubleArrayAhoCorasick<u32>,
106    pub vals_data: Data,
107    pub words_idx_data: Data,
108    pub words_data: Data,
109    pub is_system: bool,
110}
111
112impl PrefixDictionary {
113    /// Decode the `(offset, count)` pair stored in the double-array value.
114    ///
115    /// System dictionaries use 8-bit count (supports up to 255 variants per
116    /// surface), while user dictionaries retain the legacy 5-bit count
117    /// encoding (max 31) for binary backward compatibility with pre-built
118    /// `.bin` user dictionary files.
119    #[inline]
120    pub(crate) fn decode_val(&self, val: u32) -> (u32, u32) {
121        if self.is_system {
122            (val >> 8u32, val & ((1u32 << 8) - 1u32))
123        } else {
124            (val >> 5u32, val & ((1u32 << 5) - 1u32))
125        }
126    }
127
128    /// Load a `PrefixDictionary` from raw binary data.
129    ///
130    /// # Arguments
131    ///
132    /// * `da_data` - Double-array data bytes.
133    /// * `vals_data` - Values data bytes.
134    /// * `words_idx_data` - Word index data bytes.
135    /// * `words_data` - Words data bytes.
136    /// * `is_system` - Whether this is a system dictionary.
137    ///
138    /// # Returns
139    ///
140    /// A `PrefixDictionary`, or an error if deserialization fails.
141    pub fn load(
142        da_data: impl Into<Data>,
143        vals_data: impl Into<Data>,
144        words_idx_data: impl Into<Data>,
145        words_data: impl Into<Data>,
146        is_system: bool,
147    ) -> LinderaResult<PrefixDictionary> {
148        let da_bytes = da_data.into();
149        let (da, _) = DoubleArrayAhoCorasick::deserialize(&da_bytes[..]).map_err(|err| {
150            LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(err.to_string()))
151        })?;
152
153        Ok(PrefixDictionary {
154            da,
155            vals_data: vals_data.into(),
156            words_idx_data: words_idx_data.into(),
157            words_data: words_data.into(),
158            is_system,
159        })
160    }
161
162    pub fn prefix<'a>(&'a self, s: &'a str) -> impl Iterator<Item = (usize, WordEntry)> + 'a {
163        self.da
164            .find_overlapping_iter(s)
165            .filter(|m| m.start() == 0)
166            .flat_map(move |m| {
167                let (offset, len) = self.decode_val(m.value());
168                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
169                let data: &[u8] = &self.vals_data[offset_bytes..];
170                (0..len as usize).map(move |i| {
171                    (
172                        m.end(),
173                        WordEntry::deserialize(
174                            &data[WordEntry::SERIALIZED_LEN * i..],
175                            self.is_system,
176                        ),
177                    )
178                })
179            })
180    }
181
182    /// Find `WordEntry`s with surface
183    pub fn find_surface(&self, surface: &str) -> Vec<WordEntry> {
184        self.find_surface_iter(surface).collect()
185    }
186
187    /// Find `WordEntry`s with surface using lazy evaluation
188    /// This iterator-based approach reduces memory allocations
189    pub fn find_surface_iter<'a>(
190        &'a self,
191        surface: &'a str,
192    ) -> impl Iterator<Item = WordEntry> + 'a {
193        self.da
194            .find_overlapping_iter(surface)
195            .filter(|m| m.start() == 0 && m.end() == surface.len())
196            .flat_map(move |m| {
197                let (offset, len) = self.decode_val(m.value());
198                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
199                let data = &self.vals_data[offset_bytes..];
200                (0..len as usize).map(move |i| {
201                    WordEntry::deserialize(&data[WordEntry::SERIALIZED_LEN * i..], self.is_system)
202                })
203            })
204    }
205
206    /// Common prefix iterator using character array input
207    pub fn common_prefix_iterator(&self, suffix: &[char]) -> Vec<Match> {
208        // Warning: This method takes &[char], but daachorse works on bytes (str).
209        // Converting char slice to string is costly but necessary if we use daachorse standard API.
210
211        if self.vals_data.is_empty() {
212            return Vec::new();
213        }
214
215        let suffix_str: String = suffix.iter().collect();
216
217        self.da
218            .find_overlapping_iter(&suffix_str)
219            .filter(|m| m.start() == 0)
220            .flat_map(|m| {
221                let (offset, len) = self.decode_val(m.value());
222                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
223
224                // 範囲チェックを追加
225                if offset_bytes >= self.vals_data.len() {
226                    return vec![].into_iter();
227                }
228
229                let data: &[u8] = &self.vals_data[offset_bytes..];
230                (0..len as usize)
231                    .filter_map(move |i| {
232                        let required_bytes = WordEntry::SERIALIZED_LEN * (i + 1);
233                        if required_bytes <= data.len() {
234                            let word_entry = WordEntry::deserialize(
235                                &data[WordEntry::SERIALIZED_LEN * i..],
236                                self.is_system,
237                            );
238                            Some(Match {
239                                word_idx: WordIdx::new(word_entry.word_id.id),
240                                end_char: m.end(), // prefix_len in bytes? No, m.end() is byte index.
241                                                   // Match expects char length?
242                                                   // Original code: end_char: prefix_len
243                                                   // prefix_len was number of bytes or chars?
244                                                   // yada::common_prefix_search returns (val, len) where len is length in bytes?
245                                                   // yada common_prefix_search(str) returns length in bytes.
246                                                   // But common_prefix_iterator takes &[char].
247                                                   // Match.end_char usually implies character index if used for Viterbi on chars.
248                                                   // But Viterbi usually works on bytes in Lindera?
249                                                   // Let's check typical usage.
250                                                   // NOTE: daachorse returns byte indices.
251                                                   // If input was chars converted to String, byte index != char index.
252                                                   // We need to map back to char index?
253                                                   // This function common_prefix_iterator might be inefficient or deprecated given we move to byte-based Viterbi.
254                                                   // For now, let's assume we return byte length.
255                                                   // But wait, suffix is &[char].
256                                                   // The caller likely expects char length?
257                                                   // Yes. if suffix is &[char], end_char 3 means 3 chars.
258                                                   // We have byte length from daachorse.
259                                                   // We need to count chars in suffix_str[..m.end()].
260                                                   // This is inefficient.
261                            })
262                        } else {
263                            None
264                        }
265                    })
266                    .collect::<Vec<_>>()
267                    .into_iter()
268            })
269            .collect()
270    }
271}
272
273impl ArchivedPrefixDictionary {
274    /// Decode the `(offset, count)` pair. See [`PrefixDictionary::decode_val`].
275    #[inline]
276    fn decode_val(&self, val: u32) -> (u32, u32) {
277        if self.is_system {
278            (val >> 8u32, val & ((1u32 << 8) - 1u32))
279        } else {
280            (val >> 5u32, val & ((1u32 << 5) - 1u32))
281        }
282    }
283
284    /// Find all prefix matches for the given string using the archived dictionary.
285    ///
286    /// # Arguments
287    ///
288    /// * `s` - The input string to search for prefix matches.
289    ///
290    /// # Returns
291    ///
292    /// An iterator of `(end_position, WordEntry)` pairs, or an error if deserialization fails.
293    pub fn prefix<'a>(
294        &'a self,
295        s: &'a str,
296    ) -> LinderaResult<impl Iterator<Item = (usize, WordEntry)> + 'a> {
297        // Deserialize on the fly. Performance warning: this is slow.
298        let (da, _) =
299            DoubleArrayAhoCorasick::<u32>::deserialize(self.da.as_slice()).map_err(|err| {
300                LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(err.to_string()))
301            })?;
302
303        let matches: Vec<_> = da
304            .find_overlapping_iter(s)
305            .filter(|m| m.start() == 0)
306            .map(|m| (m.end(), m.value()))
307            .collect();
308
309        Ok(matches.into_iter().flat_map(move |(end, offset_len)| {
310            let (offset, len) = self.decode_val(offset_len);
311            let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
312
313            let vals = self.vals_data.as_slice();
314            if offset_bytes >= vals.len() {
315                return vec![].into_iter();
316            }
317
318            let data = &vals[offset_bytes..];
319            (0..len as usize)
320                .map(move |i| {
321                    (
322                        end,
323                        WordEntry::deserialize(
324                            &data[WordEntry::SERIALIZED_LEN * i..],
325                            self.is_system,
326                        ),
327                    )
328                })
329                .collect::<Vec<_>>()
330                .into_iter()
331        }))
332    }
333
334    /// Find `WordEntry`s matching the exact surface in the archived dictionary.
335    ///
336    /// # Arguments
337    ///
338    /// * `surface` - The surface string to search for.
339    ///
340    /// # Returns
341    ///
342    /// A vector of matching `WordEntry`s, or an error if deserialization fails.
343    pub fn find_surface(&self, surface: &str) -> LinderaResult<Vec<WordEntry>> {
344        let (da, _) =
345            DoubleArrayAhoCorasick::<u32>::deserialize(self.da.as_slice()).map_err(|err| {
346                LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(err.to_string()))
347            })?;
348
349        let matches: Vec<_> = da
350            .find_overlapping_iter(surface)
351            .filter(|m| m.start() == 0 && m.end() == surface.len())
352            .map(|m| m.value())
353            .collect();
354
355        Ok(matches
356            .into_iter()
357            .flat_map(|offset_len| {
358                let (offset, len) = self.decode_val(offset_len);
359                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
360                let vals = self.vals_data.as_slice();
361                if offset_bytes >= vals.len() {
362                    return Vec::new().into_iter();
363                }
364                let data = &vals[offset_bytes..];
365                (0..len as usize)
366                    .map(|i| {
367                        WordEntry::deserialize(
368                            &data[WordEntry::SERIALIZED_LEN * i..],
369                            self.is_system,
370                        )
371                    })
372                    .collect::<Vec<_>>()
373                    .into_iter()
374            })
375            .collect())
376    }
377}