Skip to main content

lindera_dictionary/dictionary/
prefix_dictionary.rs

1use daachorse::DoubleArrayAhoCorasick;
2use rkyv::rancor::Fallible;
3use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Place, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::{util::Data, viterbi::WordEntry};
8
9/// Match structure for common prefix iterator compatibility
10#[derive(Debug, Clone)]
11pub struct Match {
12    pub word_idx: WordIdx,
13    pub end_char: usize,
14}
15
16#[derive(Debug, Clone, Copy)]
17pub struct WordIdx {
18    pub word_id: u32,
19}
20
21impl WordIdx {
22    pub fn new(word_id: u32) -> Self {
23        Self { word_id }
24    }
25}
26
27pub struct DoubleArrayArchiver;
28
29impl ArchiveWith<DoubleArrayAhoCorasick<u32>> for DoubleArrayArchiver {
30    type Archived = rkyv::vec::ArchivedVec<u8>;
31    type Resolver = rkyv::vec::VecResolver;
32
33    fn resolve_with(
34        field: &DoubleArrayAhoCorasick<u32>,
35        resolver: Self::Resolver,
36        out: Place<Self::Archived>,
37    ) {
38        let bytes = field.serialize();
39        rkyv::vec::ArchivedVec::resolve_from_slice(&bytes, resolver, out);
40    }
41}
42
43impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized>
44    SerializeWith<DoubleArrayAhoCorasick<u32>, S> for DoubleArrayArchiver
45{
46    fn serialize_with(
47        field: &DoubleArrayAhoCorasick<u32>,
48        serializer: &mut S,
49    ) -> Result<Self::Resolver, S::Error> {
50        let bytes = field.serialize();
51        rkyv::vec::ArchivedVec::serialize_from_slice(&bytes, serializer)
52    }
53}
54
55impl<D: Fallible + ?Sized>
56    DeserializeWith<rkyv::vec::ArchivedVec<u8>, DoubleArrayAhoCorasick<u32>, D>
57    for DoubleArrayArchiver
58{
59    fn deserialize_with(
60        archived: &rkyv::vec::ArchivedVec<u8>,
61        _deserializer: &mut D,
62    ) -> Result<DoubleArrayAhoCorasick<u32>, D::Error> {
63        unsafe {
64            let (da, _) = DoubleArrayAhoCorasick::deserialize_unchecked(archived.as_slice());
65            Ok(da)
66        }
67    }
68}
69
70mod double_array_serde {
71    use daachorse::DoubleArrayAhoCorasick;
72    use serde::{Deserialize, Deserializer, Serializer};
73
74    pub fn serialize<S>(da: &DoubleArrayAhoCorasick<u32>, serializer: S) -> Result<S::Ok, S::Error>
75    where
76        S: Serializer,
77    {
78        let bytes = da.serialize();
79        serializer.serialize_bytes(&bytes)
80    }
81
82    pub fn deserialize<'de, D>(deserializer: D) -> Result<DoubleArrayAhoCorasick<u32>, D::Error>
83    where
84        D: Deserializer<'de>,
85    {
86        let bytes: Vec<u8> = Deserialize::deserialize(deserializer)?;
87        unsafe {
88            let (da, _) = DoubleArrayAhoCorasick::deserialize_unchecked(&bytes);
89            Ok(da)
90        }
91    }
92}
93
94#[derive(Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
95pub struct PrefixDictionary {
96    #[serde(with = "self::double_array_serde")]
97    #[rkyv(with = DoubleArrayArchiver)]
98    pub da: DoubleArrayAhoCorasick<u32>,
99    pub vals_data: Data,
100    pub words_idx_data: Data,
101    pub words_data: Data,
102    pub is_system: bool,
103}
104
105impl PrefixDictionary {
106    pub fn load(
107        da_data: impl Into<Data>,
108        vals_data: impl Into<Data>,
109        words_idx_data: impl Into<Data>,
110        words_data: impl Into<Data>,
111        is_system: bool,
112    ) -> PrefixDictionary {
113        let da_bytes = da_data.into();
114        let (da, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(&da_bytes[..]) };
115
116        PrefixDictionary {
117            da,
118            vals_data: vals_data.into(),
119            words_idx_data: words_idx_data.into(),
120            words_data: words_data.into(),
121            is_system,
122        }
123    }
124
125    pub fn prefix<'a>(&'a self, s: &'a str) -> impl Iterator<Item = (usize, WordEntry)> + 'a {
126        self.da
127            .find_overlapping_iter(s)
128            .filter(|m| m.start() == 0)
129            .flat_map(move |m| {
130                let id = m.value();
131                let len = id & ((1u32 << 5) - 1u32);
132                let offset = id >> 5u32;
133                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
134                let data: &[u8] = &self.vals_data[offset_bytes..];
135                (0..len as usize).map(move |i| {
136                    (
137                        m.end(),
138                        WordEntry::deserialize(
139                            &data[WordEntry::SERIALIZED_LEN * i..],
140                            self.is_system,
141                        ),
142                    )
143                })
144            })
145    }
146
147    /// Find `WordEntry`s with surface
148    pub fn find_surface(&self, surface: &str) -> Vec<WordEntry> {
149        self.find_surface_iter(surface).collect()
150    }
151
152    /// Find `WordEntry`s with surface using lazy evaluation
153    /// This iterator-based approach reduces memory allocations
154    pub fn find_surface_iter<'a>(
155        &'a self,
156        surface: &'a str,
157    ) -> impl Iterator<Item = WordEntry> + 'a {
158        self.da
159            .find_overlapping_iter(surface)
160            .filter(|m| m.start() == 0 && m.end() == surface.len())
161            .flat_map(move |m| {
162                let offset_len = m.value();
163                let offset = offset_len >> 5u32;
164                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
165                let data = &self.vals_data[offset_bytes..];
166                let len = offset_len & ((1u32 << 5) - 1u32);
167                (0..len as usize).map(move |i| {
168                    WordEntry::deserialize(&data[WordEntry::SERIALIZED_LEN * i..], self.is_system)
169                })
170            })
171    }
172
173    /// Common prefix iterator using character array input
174    pub fn common_prefix_iterator(&self, suffix: &[char]) -> Vec<Match> {
175        // Warning: This method takes &[char], but daachorse works on bytes (str).
176        // Converting char slice to string is costly but necessary if we use daachorse standard API.
177
178        if self.vals_data.is_empty() {
179            return Vec::new();
180        }
181
182        let suffix_str: String = suffix.iter().collect();
183
184        self.da
185            .find_overlapping_iter(&suffix_str)
186            .filter(|m| m.start() == 0)
187            .flat_map(|m| {
188                let offset_len = m.value();
189                let len = offset_len & ((1u32 << 5) - 1u32);
190                let offset = offset_len >> 5u32;
191                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
192
193                // 範囲チェックを追加
194                if offset_bytes >= self.vals_data.len() {
195                    return vec![].into_iter();
196                }
197
198                let data: &[u8] = &self.vals_data[offset_bytes..];
199                (0..len as usize)
200                    .filter_map(move |i| {
201                        let required_bytes = WordEntry::SERIALIZED_LEN * (i + 1);
202                        if required_bytes <= data.len() {
203                            let word_entry = WordEntry::deserialize(
204                                &data[WordEntry::SERIALIZED_LEN * i..],
205                                self.is_system,
206                            );
207                            Some(Match {
208                                word_idx: WordIdx::new(word_entry.word_id.id),
209                                end_char: m.end(), // prefix_len in bytes? No, m.end() is byte index.
210                                                   // Match expects char length?
211                                                   // Original code: end_char: prefix_len
212                                                   // prefix_len was number of bytes or chars?
213                                                   // yada::common_prefix_search returns (val, len) where len is length in bytes?
214                                                   // yada common_prefix_search(str) returns length in bytes.
215                                                   // But common_prefix_iterator takes &[char].
216                                                   // Match.end_char usually implies character index if used for Viterbi on chars.
217                                                   // But Viterbi usually works on bytes in Lindera?
218                                                   // Let's check typical usage.
219                                                   // NOTE: daachorse returns byte indices.
220                                                   // If input was chars converted to String, byte index != char index.
221                                                   // We need to map back to char index?
222                                                   // This function common_prefix_iterator might be inefficient or deprecated given we move to byte-based Viterbi.
223                                                   // For now, let's assume we return byte length.
224                                                   // But wait, suffix is &[char].
225                                                   // The caller likely expects char length?
226                                                   // Yes. if suffix is &[char], end_char 3 means 3 chars.
227                                                   // We have byte length from daachorse.
228                                                   // We need to count chars in suffix_str[..m.end()].
229                                                   // This is inefficient.
230                            })
231                        } else {
232                            None
233                        }
234                    })
235                    .collect::<Vec<_>>()
236                    .into_iter()
237            })
238            .collect()
239    }
240}
241
242impl ArchivedPrefixDictionary {
243    pub fn prefix<'a>(&'a self, s: &'a str) -> impl Iterator<Item = (usize, WordEntry)> + 'a {
244        // Deserialize on the fly. Performance warning: this is slow.
245        let (da, _) =
246            unsafe { DoubleArrayAhoCorasick::<u32>::deserialize_unchecked(self.da.as_slice()) };
247
248        let matches: Vec<_> = da
249            .find_overlapping_iter(s)
250            .filter(|m| m.start() == 0)
251            .map(|m| (m.end(), m.value()))
252            .collect();
253
254        matches.into_iter().flat_map(move |(end, offset_len)| {
255            let len = offset_len & ((1u32 << 5) - 1u32);
256            let offset = offset_len >> 5u32;
257            let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
258
259            let vals = self.vals_data.as_slice();
260            // Check bounds?
261            if offset_bytes >= vals.len() {
262                return vec![].into_iter(); // Handle gracefully
263            }
264
265            let data = &vals[offset_bytes..];
266            (0..len as usize)
267                .map(move |i| {
268                    (
269                        end,
270                        WordEntry::deserialize(
271                            &data[WordEntry::SERIALIZED_LEN * i..],
272                            self.is_system,
273                        ),
274                    )
275                })
276                .collect::<Vec<_>>() // Collect to avoid lifetime issues with 'a and move?
277                .into_iter()
278        })
279    }
280
281    pub fn find_surface(&self, surface: &str) -> Vec<WordEntry> {
282        let (da, _) =
283            unsafe { DoubleArrayAhoCorasick::<u32>::deserialize_unchecked(self.da.as_slice()) };
284
285        // Check if there is a match with start=0 and end=surface.len()
286        let matches: Vec<_> = da
287            .find_overlapping_iter(surface)
288            .filter(|m| m.start() == 0 && m.end() == surface.len())
289            .map(|m| m.value())
290            .collect();
291
292        matches
293            .into_iter()
294            .flat_map(|offset_len| {
295                let len = offset_len & ((1u32 << 5) - 1u32);
296                let offset = offset_len >> 5u32;
297                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
298                let vals = self.vals_data.as_slice();
299                if offset_bytes >= vals.len() {
300                    return Vec::new().into_iter();
301                }
302                let data = &vals[offset_bytes..];
303                (0..len as usize)
304                    .map(|i| {
305                        WordEntry::deserialize(
306                            &data[WordEntry::SERIALIZED_LEN * i..],
307                            self.is_system,
308                        )
309                    })
310                    .collect::<Vec<_>>()
311                    .into_iter()
312            })
313            .collect()
314    }
315}