1use std::io;
2
3use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::dictionary::character_definition::{CategoryId, CharacterDefinition};
8use crate::dictionary::connection_cost_matrix::ConnectionCostMatrix;
9use crate::dictionary::prefix_dictionary::PrefixDictionary;
10use crate::dictionary::unknown_dictionary::UnknownDictionary;
11use crate::mode::Mode;
12
13const EOS_NODE: EdgeId = EdgeId(1u32);
14
15#[derive(
17 Clone,
18 Copy,
19 Debug,
20 Eq,
21 PartialEq,
22 Serialize,
23 Deserialize,
24 Default,
25 Archive,
26 RkyvSerialize,
27 RkyvDeserialize,
28)]
29
30pub enum LexType {
31 #[default]
33 System,
34 User,
36 Unknown,
38}
39
40#[derive(
41 Clone,
42 Copy,
43 Debug,
44 Eq,
45 PartialEq,
46 Serialize,
47 Deserialize,
48 Archive,
49 RkyvDeserialize,
50 RkyvSerialize,
51)]
52
53pub struct WordId {
54 pub id: u32,
55 pub is_system: bool,
56 pub lex_type: LexType,
57}
58
59impl WordId {
60 pub fn new(lex_type: LexType, id: u32) -> Self {
62 WordId {
63 id,
64 is_system: matches!(lex_type, LexType::System),
65 lex_type,
66 }
67 }
68
69 pub fn is_unknown(&self) -> bool {
70 self.id == u32::MAX || matches!(self.lex_type, LexType::Unknown)
71 }
72
73 pub fn is_system(&self) -> bool {
74 self.is_system
75 }
76
77 pub fn lex_type(&self) -> LexType {
78 self.lex_type
79 }
80}
81
82impl Default for WordId {
83 fn default() -> Self {
84 WordId {
85 id: u32::MAX,
86 is_system: true,
87 lex_type: LexType::System,
88 }
89 }
90}
91
92#[derive(
93 Default,
94 Clone,
95 Copy,
96 Debug,
97 Eq,
98 PartialEq,
99 Serialize,
100 Deserialize,
101 Archive,
102 RkyvSerialize,
103 RkyvDeserialize,
104)]
105
106pub struct WordEntry {
107 pub word_id: WordId,
108 pub word_cost: i16,
109 pub left_id: u16,
110 pub right_id: u16,
111}
112
113impl WordEntry {
114 pub const SERIALIZED_LEN: usize = 10;
115
116 pub fn left_id(&self) -> u32 {
117 self.left_id as u32
118 }
119
120 pub fn right_id(&self) -> u32 {
121 self.right_id as u32
122 }
123
124 pub fn serialize<W: io::Write>(&self, wtr: &mut W) -> io::Result<()> {
125 wtr.write_u32::<LittleEndian>(self.word_id.id)?;
126 wtr.write_i16::<LittleEndian>(self.word_cost)?;
127 wtr.write_u16::<LittleEndian>(self.left_id)?;
128 wtr.write_u16::<LittleEndian>(self.right_id)?;
129 Ok(())
130 }
131
132 pub fn deserialize(data: &[u8], is_system_entry: bool) -> WordEntry {
133 let word_id = WordId::new(
134 if is_system_entry {
135 LexType::System
136 } else {
137 LexType::User
138 },
139 LittleEndian::read_u32(&data[0..4]),
140 );
141 let word_cost = LittleEndian::read_i16(&data[4..6]);
142 let left_id = LittleEndian::read_u16(&data[6..8]);
143 let right_id = LittleEndian::read_u16(&data[8..10]);
144 WordEntry {
145 word_id,
146 word_cost,
147 left_id,
148 right_id,
149 }
150 }
151}
152
153#[derive(Clone, Copy, Debug, Default)]
154pub enum EdgeType {
155 #[default]
156 KNOWN,
157 UNKNOWN,
158 USER,
159 INSERTED,
160}
161
162#[derive(Eq, PartialEq, Clone, Copy, Debug)]
163pub struct EdgeId(pub u32);
164
165#[derive(Default, Clone, Debug)]
166pub struct Edge {
167 pub edge_type: EdgeType,
168 pub word_entry: WordEntry,
169
170 pub path_cost: i32,
171 pub left_edge: Option<EdgeId>,
172
173 pub start_index: u32,
174 pub stop_index: u32,
175
176 pub kanji_only: bool,
177}
178
179impl Edge {
180 pub fn num_chars(&self) -> usize {
182 (self.stop_index - self.start_index) as usize / 3
183 }
184}
185
186#[derive(Clone, Default)]
187pub struct Lattice {
188 capacity: usize,
189 edges: Vec<Edge>,
190 starts_at: Vec<Vec<EdgeId>>,
191 ends_at: Vec<Vec<EdgeId>>,
192 edge_buffer: Vec<Edge>,
194 edge_id_buffer: Vec<EdgeId>,
195}
196
197#[inline]
198fn is_kanji(c: char) -> bool {
199 let c = c as u32;
200 (19968..=40879).contains(&c)
202}
203
204#[inline]
205fn is_kanji_only(s: &str) -> bool {
206 !s.is_empty() && s.chars().all(is_kanji)
208}
209
210impl Lattice {
211 #[inline]
213 fn create_edge(
214 edge_type: EdgeType,
215 word_entry: WordEntry,
216 start: usize,
217 stop: usize,
218 kanji_only: bool,
219 ) -> Edge {
220 Edge {
221 edge_type,
222 word_entry,
223 left_edge: None,
224 start_index: start as u32,
225 stop_index: stop as u32,
226 path_cost: i32::MAX,
227 kanji_only,
228 }
229 }
230
231 pub fn clear(&mut self) {
232 for edge_vec in &mut self.starts_at {
233 edge_vec.clear();
234 }
235 for edge_vec in &mut self.ends_at {
236 edge_vec.clear();
237 }
238 self.edges.clear();
239 self.edge_buffer.clear();
241 self.edge_id_buffer.clear();
242 }
243
244 pub fn get_edge_buffer(&mut self) -> &mut Vec<Edge> {
246 self.edge_buffer.clear();
247 &mut self.edge_buffer
248 }
249
250 pub fn get_edge_id_buffer(&mut self) -> &mut Vec<EdgeId> {
252 self.edge_id_buffer.clear();
253 &mut self.edge_id_buffer
254 }
255
256 fn set_capacity(&mut self, text_len: usize) {
257 self.clear();
258 if self.capacity < text_len {
259 self.capacity = text_len;
260 self.edges.clear();
261 self.starts_at.resize(text_len + 1, Vec::new());
262 self.ends_at.resize(text_len + 1, Vec::new());
263 }
264 }
265
266 #[inline(never)]
267 pub fn set_text(
268 &mut self,
269 dict: &PrefixDictionary,
270 user_dict: &Option<&PrefixDictionary>,
271 char_definitions: &CharacterDefinition,
272 unknown_dictionary: &UnknownDictionary,
273 text: &str,
274 search_mode: &Mode,
275 ) {
276 let len = text.len();
277 self.set_capacity(len);
278
279 let start_edge_id = self.add_edge(Edge::default());
280 let end_edge_id = self.add_edge(Edge::default());
281
282 assert_eq!(EOS_NODE, end_edge_id);
283 self.ends_at[0].push(start_edge_id);
284 self.starts_at[len].push(end_edge_id);
285
286 let mut unknown_word_end: Option<usize> = None;
288
289 for start in 0..len {
290 if self.ends_at[start].is_empty() {
293 continue;
294 }
295
296 let suffix = &text[start..];
297
298 let mut found: bool = false;
299
300 if user_dict.is_some() {
302 let dict = user_dict.as_ref().unwrap();
303 for (prefix_len, word_entry) in dict.prefix(suffix) {
304 let edge = Self::create_edge(
305 EdgeType::KNOWN,
306 word_entry,
307 start,
308 start + prefix_len,
309 is_kanji_only(&suffix[..prefix_len]),
310 );
311 self.add_edge_in_lattice(edge);
312 found = true;
313 }
314 }
315
316 for (prefix_len, word_entry) in dict.prefix(suffix) {
319 let edge = Self::create_edge(
320 EdgeType::KNOWN,
321 word_entry,
322 start,
323 start + prefix_len,
324 is_kanji_only(&suffix[..prefix_len]),
325 );
326 self.add_edge_in_lattice(edge);
327 found = true;
328 }
329
330 if (search_mode.is_search()
332 || unknown_word_end.map(|index| index <= start).unwrap_or(true))
333 && let Some(first_char) = suffix.chars().next()
334 {
335 let categories = char_definitions.lookup_categories(first_char);
336 for (category_ord, &category) in categories.iter().enumerate() {
337 unknown_word_end = self.process_unknown_word(
338 char_definitions,
339 unknown_dictionary,
340 category,
341 category_ord,
342 unknown_word_end,
343 start,
344 suffix,
345 found,
346 );
347 }
348 }
349 }
350 }
351
352 #[allow(clippy::too_many_arguments)]
353 fn process_unknown_word(
354 &mut self,
355 char_definitions: &CharacterDefinition,
356 unknown_dictionary: &UnknownDictionary,
357 category: CategoryId,
358 category_ord: usize,
359 unknown_word_index: Option<usize>,
360 start: usize,
361 suffix: &str,
362 found: bool,
363 ) -> Option<usize> {
364 let mut unknown_word_num_chars: usize = 0;
365 let category_data = char_definitions.lookup_definition(category);
366 if category_data.invoke || !found {
367 unknown_word_num_chars = 1;
368 if category_data.group {
369 for c in suffix.chars().skip(1) {
370 let categories = char_definitions.lookup_categories(c);
371 if categories.len() > category_ord && categories[category_ord] == category {
372 unknown_word_num_chars += 1;
373 } else {
374 break;
375 }
376 }
377 }
378 }
379 if unknown_word_num_chars > 0 {
380 let byte_end = suffix
382 .char_indices()
383 .nth(unknown_word_num_chars)
384 .map_or(suffix.len(), |(pos, _)| pos);
385 let unknown_word = &suffix[..byte_end];
386 for &word_id in unknown_dictionary.lookup_word_ids(category) {
387 let word_entry = unknown_dictionary.word_entry(word_id);
388 let edge = Self::create_edge(
389 EdgeType::UNKNOWN,
390 word_entry,
391 start,
392 start + unknown_word.len(),
393 is_kanji_only(unknown_word),
394 );
395 self.add_edge_in_lattice(edge);
396 }
397 return Some(start + unknown_word.len());
398 }
399 unknown_word_index
400 }
401
402 fn add_edge_in_lattice(&mut self, edge: Edge) {
403 let start_index = edge.start_index as usize;
404 let stop_index = edge.stop_index as usize;
405 let edge_id = self.add_edge(edge);
406 self.starts_at[start_index].push(edge_id);
407 self.ends_at[stop_index].push(edge_id);
408 }
409
410 fn add_edge(&mut self, edge: Edge) -> EdgeId {
411 let edge_id = EdgeId(self.edges.len() as u32);
412 self.edges.push(edge);
413 edge_id
414 }
415
416 pub fn edge(&self, edge_id: EdgeId) -> &Edge {
417 &self.edges[edge_id.0 as usize]
418 }
419
420 #[inline(never)]
421 pub fn calculate_path_costs(&mut self, cost_matrix: &ConnectionCostMatrix, mode: &Mode) {
422 let text_len = self.starts_at.len();
423 for i in 0..text_len {
424 let left_edge_ids = &self.ends_at[i];
425 let right_edge_ids = &self.starts_at[i];
426
427 for &right_edge_id in right_edge_ids {
428 let right_edge = &self.edges[right_edge_id.0 as usize];
430 let right_word_entry = right_edge.word_entry;
431 let right_left_id = right_word_entry.left_id();
432
433 let mut best_cost = i32::MAX;
435 let mut best_left = None;
436
437 for &left_edge_id in left_edge_ids {
438 let left_edge = &self.edges[left_edge_id.0 as usize];
439 let left_right_id = left_edge.word_entry.right_id();
440
441 let mut path_cost =
443 left_edge.path_cost + cost_matrix.cost(left_right_id, right_left_id);
444 path_cost += mode.penalty_cost(left_edge);
445
446 if path_cost < best_cost {
448 best_cost = path_cost;
449 best_left = Some(left_edge_id);
450 }
451 }
452
453 if let Some(best_left_id) = best_left {
455 let edge = &mut self.edges[right_edge_id.0 as usize];
456 edge.left_edge = Some(best_left_id);
457 edge.path_cost = right_word_entry.word_cost as i32 + best_cost;
458 }
459 }
460 }
461 }
462
463 pub fn tokens_offset(&self) -> Vec<(usize, WordId)> {
464 let mut offsets = Vec::new();
465 let mut edge_id = EOS_NODE;
466 let _edge = self.edge(EOS_NODE);
467 loop {
468 let edge = self.edge(edge_id);
469 if let Some(left_edge_id) = edge.left_edge {
470 offsets.push((edge.start_index as usize, edge.word_entry.word_id));
471 edge_id = left_edge_id;
472 } else {
473 break;
474 }
475 }
476 offsets.reverse();
477 offsets.pop();
478 offsets
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use crate::viterbi::{LexType, WordEntry, WordId};
485
486 #[test]
487 fn test_word_entry() {
488 let mut buffer = Vec::new();
489 let word_entry = WordEntry {
490 word_id: WordId {
491 id: 1u32,
492 is_system: true,
493 lex_type: LexType::System,
494 },
495 word_cost: -17i16,
496 left_id: 1411u16,
497 right_id: 1412u16,
498 };
499 word_entry.serialize(&mut buffer).unwrap();
500 assert_eq!(WordEntry::SERIALIZED_LEN, buffer.len());
501 let word_entry2 = WordEntry::deserialize(&buffer[..], true);
502 assert_eq!(word_entry, word_entry2);
503 }
504}