Skip to main content

lindera_dictionary/
viterbi.rs

1use std::io;
2
3use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::dictionary::character_definition::{CategoryId, CharacterDefinition};
8use crate::dictionary::connection_cost_matrix::ConnectionCostMatrix;
9use crate::dictionary::prefix_dictionary::PrefixDictionary;
10use crate::dictionary::unknown_dictionary::UnknownDictionary;
11use crate::mode::Mode;
12
13/// Type of lexicon containing the word
14#[derive(
15    Clone,
16    Copy,
17    Debug,
18    Eq,
19    PartialEq,
20    Serialize,
21    Deserialize,
22    Default,
23    Archive,
24    RkyvSerialize,
25    RkyvDeserialize,
26)]
27
28pub enum LexType {
29    /// System dictionary (base dictionary)
30    #[default]
31    System,
32    /// User dictionary (additional vocabulary)
33    User,
34    /// Unknown words (OOV handling)
35    Unknown,
36}
37
38#[derive(
39    Clone,
40    Copy,
41    Debug,
42    Eq,
43    PartialEq,
44    Serialize,
45    Deserialize,
46    Archive,
47    RkyvDeserialize,
48    RkyvSerialize,
49)]
50
51pub struct WordId {
52    pub id: u32,
53    pub is_system: bool,
54    pub lex_type: LexType,
55}
56
57impl WordId {
58    /// Creates a new WordId with specified lexicon type
59    pub fn new(lex_type: LexType, id: u32) -> Self {
60        WordId {
61            id,
62            is_system: matches!(lex_type, LexType::System),
63            lex_type,
64        }
65    }
66
67    pub fn is_unknown(&self) -> bool {
68        matches!(self.lex_type, LexType::Unknown)
69    }
70
71    pub fn is_system(&self) -> bool {
72        self.is_system
73    }
74
75    pub fn lex_type(&self) -> LexType {
76        self.lex_type
77    }
78}
79
80impl Default for WordId {
81    fn default() -> Self {
82        WordId {
83            id: u32::MAX,
84            is_system: true,
85            lex_type: LexType::System,
86        }
87    }
88}
89
90#[derive(
91    Default,
92    Clone,
93    Copy,
94    Debug,
95    Eq,
96    PartialEq,
97    Serialize,
98    Deserialize,
99    Archive,
100    RkyvSerialize,
101    RkyvDeserialize,
102)]
103
104pub struct WordEntry {
105    pub word_id: WordId,
106    pub word_cost: i16,
107    pub left_id: u16,
108    pub right_id: u16,
109}
110
111impl WordEntry {
112    pub const SERIALIZED_LEN: usize = 10;
113
114    pub fn left_id(&self) -> u32 {
115        self.left_id as u32
116    }
117
118    pub fn right_id(&self) -> u32 {
119        self.right_id as u32
120    }
121
122    pub fn serialize<W: io::Write>(&self, wtr: &mut W) -> io::Result<()> {
123        wtr.write_u32::<LittleEndian>(self.word_id.id)?;
124        wtr.write_i16::<LittleEndian>(self.word_cost)?;
125        wtr.write_u16::<LittleEndian>(self.left_id)?;
126        wtr.write_u16::<LittleEndian>(self.right_id)?;
127        Ok(())
128    }
129
130    pub fn deserialize(data: &[u8], is_system_entry: bool) -> WordEntry {
131        let word_id = WordId::new(
132            if is_system_entry {
133                LexType::System
134            } else {
135                LexType::User
136            },
137            LittleEndian::read_u32(&data[0..4]),
138        );
139        let word_cost = LittleEndian::read_i16(&data[4..6]);
140        let left_id = LittleEndian::read_u16(&data[6..8]);
141        let right_id = LittleEndian::read_u16(&data[8..10]);
142        WordEntry {
143            word_id,
144            word_cost,
145            left_id,
146            right_id,
147        }
148    }
149}
150
151#[derive(Clone, Copy, Debug, Default)]
152pub enum EdgeType {
153    #[default]
154    KNOWN,
155    UNKNOWN,
156    USER,
157    INSERTED,
158}
159
160#[derive(Default, Clone, Debug)]
161pub struct Edge {
162    pub edge_type: EdgeType,
163    pub word_entry: WordEntry,
164
165    pub path_cost: i32,
166    pub left_index: u16, // Index in the previous position's vector
167
168    pub start_index: u32,
169    pub stop_index: u32,
170
171    pub kanji_only: bool,
172}
173
174impl Edge {
175    pub fn num_chars(&self) -> usize {
176        (self.stop_index - self.start_index) as usize / 3
177    }
178}
179
180/// Records a transition from a left edge to the current edge.
181/// Used in N-Best mode to store all predecessor transitions
182/// (not just the best one as in 1-best).
183#[derive(Clone, Debug)]
184pub struct PathEntry {
185    /// Index of this edge in ends_at[stop_index]
186    pub edge_index: u16,
187    /// Byte position where the left edge ends (= this edge's start_index)
188    pub left_pos: u32,
189    /// Index of the left edge in ends_at[left_pos]
190    pub left_index: u16,
191    /// Total forward cost: left_edge.path_cost + conn_cost + penalty_cost
192    pub cost: i32,
193}
194
195#[derive(Clone, Default)]
196pub struct Lattice {
197    capacity: usize,
198    ends_at: Vec<Vec<Edge>>, // Now stores edges directly
199    char_info_buffer: Vec<CharData>,
200    categories_buffer: Vec<CategoryId>,
201    char_category_cache: Vec<Vec<CategoryId>>,
202
203    // N-Best fields (only populated when set_text_nbest is called)
204    all_paths: Vec<Vec<PathEntry>>,
205    nbest_capacity: usize,
206    /// The text length (in bytes) of the last set_text/set_text_nbest call
207    last_text_len: usize,
208}
209
210#[derive(Clone, Copy, Debug, Default)]
211struct CharData {
212    byte_offset: u32,
213    is_kanji: bool,
214    categories_start: u32,
215    categories_len: u16,
216    kanji_run_byte_len: u32,
217}
218
219#[inline]
220pub fn is_kanji(c: char) -> bool {
221    let c = c as u32;
222    // CJK Unified Ideographs (4E00-9FAF) and Extension A (3400-4DBF)
223    (0x4E00..=0x9FAF).contains(&c) || (0x3400..=0x4DBF).contains(&c)
224}
225
226impl Lattice {
227    /// Helper method to create an edge efficiently
228    #[inline]
229    fn create_edge(
230        edge_type: EdgeType,
231        word_entry: WordEntry,
232        start: usize,
233        stop: usize,
234        kanji_only: bool,
235    ) -> Edge {
236        Edge {
237            edge_type,
238            word_entry,
239            left_index: u16::MAX,
240            start_index: start as u32,
241            stop_index: stop as u32,
242            path_cost: i32::MAX,
243            kanji_only,
244        }
245    }
246
247    pub fn clear(&mut self) {
248        for edge_vec in &mut self.ends_at {
249            edge_vec.clear();
250        }
251        for path_vec in &mut self.all_paths {
252            path_vec.clear();
253        }
254        self.char_info_buffer.clear();
255        self.categories_buffer.clear();
256    }
257
258    #[inline]
259    fn is_kanji_all(&self, char_idx: usize, byte_len: usize) -> bool {
260        self.char_info_buffer[char_idx].kanji_run_byte_len >= byte_len as u32
261    }
262
263    #[inline]
264    fn get_cached_category(&self, char_idx: usize, category_ord: usize) -> CategoryId {
265        let char_data = &self.char_info_buffer[char_idx];
266        self.categories_buffer[char_data.categories_start as usize + category_ord]
267    }
268
269    fn set_capacity(&mut self, text_len: usize) {
270        self.clear();
271        self.last_text_len = text_len;
272        if self.capacity <= text_len {
273            self.capacity = text_len;
274            self.ends_at.resize(text_len + 1, Vec::new());
275        }
276        for vec in &mut self.ends_at {
277            vec.clear();
278        }
279    }
280
281    fn set_capacity_nbest(&mut self, text_len: usize) {
282        self.set_capacity(text_len);
283        if self.nbest_capacity <= text_len {
284            self.nbest_capacity = text_len;
285            self.all_paths.resize(text_len + 1, Vec::new());
286        }
287        for vec in &mut self.all_paths {
288            vec.clear();
289        }
290    }
291
292    #[inline(never)]
293    // Forward Viterbi implementation:
294    // Constructs the lattice and calculates the path costs simultaneously.
295    // This improves performance by avoiding a separate lattice traversal pass.
296    #[allow(clippy::too_many_arguments)]
297    pub fn set_text(
298        &mut self,
299        dict: &PrefixDictionary,
300        user_dict: &Option<&PrefixDictionary>,
301        char_definitions: &CharacterDefinition,
302        unknown_dictionary: &UnknownDictionary,
303        cost_matrix: &ConnectionCostMatrix,
304        text: &str,
305        search_mode: &Mode,
306    ) {
307        let len = text.len();
308        self.set_capacity(len);
309
310        // Pre-calculate character information for the text
311        self.char_info_buffer.clear();
312        self.categories_buffer.clear();
313
314        if self.char_category_cache.is_empty() {
315            self.char_category_cache.resize(256, Vec::new());
316        }
317
318        for (byte_offset, c) in text.char_indices() {
319            let categories_start = self.categories_buffer.len() as u32;
320
321            if (c as u32) < 256 {
322                let cached = &mut self.char_category_cache[c as usize];
323                if cached.is_empty() {
324                    let cats = char_definitions.lookup_categories(c);
325                    for &category in cats {
326                        cached.push(category);
327                    }
328                }
329                for &category in cached.iter() {
330                    self.categories_buffer.push(category);
331                }
332            } else {
333                let categories = char_definitions.lookup_categories(c);
334                for &category in categories {
335                    self.categories_buffer.push(category);
336                }
337            }
338
339            let categories_len = (self.categories_buffer.len() as u32 - categories_start) as u16;
340
341            self.char_info_buffer.push(CharData {
342                byte_offset: byte_offset as u32,
343                is_kanji: is_kanji(c),
344                categories_start,
345                categories_len,
346                kanji_run_byte_len: 0,
347            });
348        }
349        // Sentinel for end of text
350        self.char_info_buffer.push(CharData {
351            byte_offset: len as u32,
352            is_kanji: false,
353            categories_start: 0,
354            categories_len: 0,
355            kanji_run_byte_len: 0,
356        });
357
358        // Pre-calculate Kanji run lengths (backwards)
359        for i in (0..self.char_info_buffer.len() - 1).rev() {
360            if self.char_info_buffer[i].is_kanji {
361                let next_byte_offset = self.char_info_buffer[i + 1].byte_offset;
362                let char_byte_len = next_byte_offset - self.char_info_buffer[i].byte_offset;
363                self.char_info_buffer[i].kanji_run_byte_len =
364                    char_byte_len + self.char_info_buffer[i + 1].kanji_run_byte_len;
365            } else {
366                self.char_info_buffer[i].kanji_run_byte_len = 0;
367            }
368        }
369
370        let start_edge = Edge {
371            path_cost: 0,
372            left_index: u16::MAX,
373            ..Default::default()
374        };
375        self.ends_at[0].push(start_edge);
376
377        // Index of the last character of unknown word
378        let mut unknown_word_end: Option<usize> = None;
379
380        // Pre-scan text with Aho-Corasick to report all matches
381        // Optimization: Use flat vectors instead of Vec<Vec<_>> to avoid many small allocations.
382        // Linked list structure: matches_head[start_idx] -> index in matches_store
383        let mut matches_head = vec![usize::MAX; len + 1];
384        let mut matches_store: Vec<(usize, WordEntry, usize)> = Vec::with_capacity(len * 10);
385
386        // System dictionary scan
387        for m in dict.da.find_overlapping_iter(text) {
388            let start = m.start();
389            let id = m.value();
390            let count = id & ((1u32 << 5) - 1u32);
391            let offset = id >> 5u32;
392            let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
393
394            // Bounds check for safety, though daachorse should guarantee valid ids if built correctly
395            if offset_bytes < dict.vals_data.len() {
396                let data_slice = &dict.vals_data[offset_bytes..];
397                for i in 0..count {
398                    let entry_offset = WordEntry::SERIALIZED_LEN * (i as usize);
399                    if entry_offset + WordEntry::SERIALIZED_LEN <= data_slice.len() {
400                        let entry = WordEntry::deserialize(&data_slice[entry_offset..], true);
401                        if start < matches_head.len() {
402                            let next = matches_head[start];
403                            matches_head[start] = matches_store.len();
404                            matches_store.push((m.end(), entry, next));
405                        }
406                    }
407                }
408            }
409        }
410
411        // User dictionary scan
412        if let Some(ud) = user_dict {
413            for m in ud.da.find_overlapping_iter(text) {
414                let start = m.start();
415                let id = m.value();
416                let count = id & ((1u32 << 5) - 1u32);
417                let offset = id >> 5u32;
418                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
419
420                if offset_bytes < ud.vals_data.len() {
421                    let data_slice = &ud.vals_data[offset_bytes..];
422                    for i in 0..count {
423                        let entry_offset = WordEntry::SERIALIZED_LEN * (i as usize);
424                        if entry_offset + WordEntry::SERIALIZED_LEN <= data_slice.len() {
425                            let entry = WordEntry::deserialize(&data_slice[entry_offset..], false);
426                            if start < matches_head.len() {
427                                let next = matches_head[start];
428                                matches_head[start] = matches_store.len();
429                                matches_store.push((m.end(), entry, next));
430                            }
431                        }
432                    }
433                }
434            }
435        }
436
437        for char_idx in 0..self.char_info_buffer.len() - 1 {
438            let start = self.char_info_buffer[char_idx].byte_offset as usize;
439
440            // No arc is ending here.
441            // No need to check if a valid word starts here.
442            if self.ends_at[start].is_empty() {
443                continue;
444            }
445
446            let mut found: bool = false;
447
448            // Use cached matches
449            if start < matches_head.len() {
450                let mut match_idx = matches_head[start];
451                while match_idx != usize::MAX {
452                    let (end, word_entry, next) = matches_store[match_idx];
453
454                    let prefix_len = end - start;
455                    let kanji_only = self.is_kanji_all(char_idx, prefix_len);
456                    let edge = Self::create_edge(
457                        EdgeType::KNOWN,
458                        word_entry, // WordEntry is Copy
459                        start,
460                        end,
461                        kanji_only,
462                    );
463                    self.add_edge_in_lattice(edge, cost_matrix, search_mode);
464                    found = true;
465
466                    match_idx = next;
467                }
468            }
469
470            // In the case of normal mode, it doesn't process unknown word greedily.
471            if (search_mode.is_search()
472                || unknown_word_end.map(|index| index <= start).unwrap_or(true))
473                && char_idx < self.char_info_buffer.len() - 1
474            {
475                let num_categories = self.char_info_buffer[char_idx].categories_len as usize;
476                for category_ord in 0..num_categories {
477                    let category = self.get_cached_category(char_idx, category_ord);
478                    unknown_word_end = self.process_unknown_word(
479                        char_definitions,
480                        unknown_dictionary,
481                        cost_matrix,
482                        search_mode,
483                        category,
484                        category_ord,
485                        unknown_word_end,
486                        start,
487                        char_idx,
488                        found,
489                    );
490                }
491            }
492        }
493
494        // Connect EOS
495        if !self.ends_at[len].is_empty() {
496            let mut eos_edge = Edge {
497                start_index: len as u32,
498                stop_index: len as u32,
499                ..Default::default()
500            };
501            // Calculate cost for EOS
502            let left_edges = &self.ends_at[len];
503            let mut best_cost = i32::MAX;
504            let mut best_left = None;
505            let right_left_id = 0; // EOS default left_id
506
507            for (i, left_edge) in left_edges.iter().enumerate() {
508                let left_right_id = left_edge.word_entry.right_id();
509                let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
510                let path_cost = left_edge.path_cost.saturating_add(conn_cost);
511                if path_cost < best_cost {
512                    best_cost = path_cost;
513                    best_left = Some(i as u16);
514                }
515            }
516            if let Some(left_idx) = best_left {
517                eos_edge.left_index = left_idx;
518                eos_edge.path_cost = best_cost;
519                self.ends_at[len].push(eos_edge);
520            }
521        }
522    }
523
524    #[allow(clippy::too_many_arguments)]
525    fn process_unknown_word(
526        &mut self,
527        char_definitions: &CharacterDefinition,
528        unknown_dictionary: &UnknownDictionary,
529        cost_matrix: &ConnectionCostMatrix,
530        search_mode: &Mode,
531        category: CategoryId,
532        category_ord: usize,
533        unknown_word_index: Option<usize>,
534        start: usize,
535        char_idx: usize,
536        found: bool,
537    ) -> Option<usize> {
538        let mut unknown_word_num_chars: usize = 0;
539        let category_data = char_definitions.lookup_definition(category);
540        if category_data.invoke || !found {
541            unknown_word_num_chars = 1;
542            if category_data.group {
543                for i in 1.. {
544                    let next_idx = char_idx + i;
545                    if next_idx >= self.char_info_buffer.len() - 1 {
546                        break;
547                    }
548                    let num_categories = self.char_info_buffer[next_idx].categories_len as usize;
549                    let mut found_cat = false;
550                    if category_ord < num_categories {
551                        let cat = self.get_cached_category(next_idx, category_ord);
552                        if cat == category {
553                            unknown_word_num_chars += 1;
554                            found_cat = true;
555                        }
556                    }
557                    if !found_cat {
558                        break;
559                    }
560                }
561            }
562        }
563        if unknown_word_num_chars > 0 {
564            let byte_end_offset =
565                self.char_info_buffer[char_idx + unknown_word_num_chars].byte_offset;
566            let byte_len = byte_end_offset as usize - start;
567
568            // Check Kanji status using pre-calculated buffer
569            let kanji_only = self.is_kanji_all(char_idx, byte_len);
570
571            for &word_id in unknown_dictionary.lookup_word_ids(category) {
572                let word_entry = unknown_dictionary.word_entry(word_id);
573                let edge = Self::create_edge(
574                    EdgeType::UNKNOWN,
575                    word_entry,
576                    start,
577                    start + byte_len,
578                    kanji_only,
579                );
580                self.add_edge_in_lattice(edge, cost_matrix, search_mode);
581            }
582            return Some(start + byte_len);
583        }
584        unknown_word_index
585    }
586
587    // Adds an edge to the lattice and calculates the minimum cost to reach it.
588    fn add_edge_in_lattice(
589        &mut self,
590        mut edge: Edge,
591        cost_matrix: &ConnectionCostMatrix,
592        mode: &Mode,
593    ) {
594        let start_index = edge.start_index as usize;
595        let stop_index = edge.stop_index as usize;
596        let right_left_id = edge.word_entry.left_id();
597
598        let left_edges = &self.ends_at[start_index];
599        if left_edges.is_empty() {
600            return;
601        }
602
603        let mut best_cost = i32::MAX;
604        let mut best_left = None;
605
606        match mode {
607            Mode::Normal => {
608                for (i, left_edge) in left_edges.iter().enumerate() {
609                    let left_right_id = left_edge.word_entry.right_id();
610                    let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
611                    let total_cost = left_edge.path_cost.saturating_add(conn_cost);
612
613                    if total_cost < best_cost {
614                        best_cost = total_cost;
615                        best_left = Some(i as u16);
616                    }
617                }
618            }
619            Mode::Decompose(penalty) => {
620                for (i, left_edge) in left_edges.iter().enumerate() {
621                    let left_right_id = left_edge.word_entry.right_id();
622                    let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
623                    let penalty_cost = penalty.penalty(left_edge);
624                    let total_cost = left_edge
625                        .path_cost
626                        .saturating_add(conn_cost)
627                        .saturating_add(penalty_cost);
628
629                    if total_cost < best_cost {
630                        best_cost = total_cost;
631                        best_left = Some(i as u16);
632                    }
633                }
634            }
635        }
636
637        if let Some(best_left_idx) = best_left {
638            edge.path_cost = best_cost.saturating_add(edge.word_entry.word_cost as i32);
639            edge.left_index = best_left_idx;
640            self.ends_at[stop_index].push(edge);
641        }
642    }
643
644    pub fn tokens_offset(&self) -> Vec<(usize, WordId)> {
645        let mut offsets = Vec::new();
646
647        if self.ends_at.is_empty() {
648            return offsets;
649        }
650
651        let mut last_idx = self.ends_at.len() - 1;
652        while last_idx > 0 && self.ends_at[last_idx].is_empty() {
653            last_idx -= 1;
654        }
655
656        if self.ends_at[last_idx].is_empty() {
657            return offsets;
658        }
659
660        let idx = self.ends_at[last_idx].len() - 1;
661        let mut edge = &self.ends_at[last_idx][idx];
662
663        if edge.left_index == u16::MAX {
664            return offsets;
665        }
666
667        loop {
668            if edge.left_index == u16::MAX {
669                break;
670            }
671
672            offsets.push((edge.start_index as usize, edge.word_entry.word_id));
673
674            let left_idx = edge.left_index as usize;
675            let start_idx = edge.start_index as usize;
676
677            edge = &self.ends_at[start_idx][left_idx];
678        }
679
680        offsets.reverse();
681        offsets.pop(); // Remove EOS
682
683        offsets
684    }
685
686    // --- N-Best support ---
687
688    /// Returns the text length (in bytes) from the last set_text/set_text_nbest call.
689    pub fn text_len(&self) -> usize {
690        self.last_text_len
691    }
692
693    /// Returns the edges at a given byte position.
694    pub fn edges_at(&self, byte_pos: usize) -> &[Edge] {
695        &self.ends_at[byte_pos]
696    }
697
698    /// Returns the N-Best path entries at a given byte position.
699    pub fn paths_at(&self, byte_pos: usize) -> &[PathEntry] {
700        if byte_pos < self.all_paths.len() {
701            &self.all_paths[byte_pos]
702        } else {
703            &[]
704        }
705    }
706
707    /// Adds an edge to the lattice, recording ALL predecessor transitions for N-Best.
708    fn add_edge_in_lattice_nbest(
709        &mut self,
710        mut edge: Edge,
711        cost_matrix: &ConnectionCostMatrix,
712        mode: &Mode,
713    ) {
714        let start_index = edge.start_index as usize;
715        let stop_index = edge.stop_index as usize;
716        let right_left_id = edge.word_entry.left_id();
717
718        let left_edges = &self.ends_at[start_index];
719        if left_edges.is_empty() {
720            return;
721        }
722
723        let mut best_cost = i32::MAX;
724        let mut best_left = None;
725
726        // The edge_index of the new edge being added
727        let new_edge_index = self.ends_at[stop_index].len() as u16;
728
729        match mode {
730            Mode::Normal => {
731                for (i, left_edge) in left_edges.iter().enumerate() {
732                    let left_right_id = left_edge.word_entry.right_id();
733                    let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
734                    let total_cost = left_edge.path_cost.saturating_add(conn_cost);
735
736                    // Record ALL transitions for N-Best
737                    self.all_paths[stop_index].push(PathEntry {
738                        edge_index: new_edge_index,
739                        left_pos: start_index as u32,
740                        left_index: i as u16,
741                        cost: total_cost,
742                    });
743
744                    if total_cost < best_cost {
745                        best_cost = total_cost;
746                        best_left = Some(i as u16);
747                    }
748                }
749            }
750            Mode::Decompose(penalty) => {
751                for (i, left_edge) in left_edges.iter().enumerate() {
752                    let left_right_id = left_edge.word_entry.right_id();
753                    let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
754                    let penalty_cost = penalty.penalty(left_edge);
755                    let total_cost = left_edge
756                        .path_cost
757                        .saturating_add(conn_cost)
758                        .saturating_add(penalty_cost);
759
760                    // Record ALL transitions for N-Best
761                    self.all_paths[stop_index].push(PathEntry {
762                        edge_index: new_edge_index,
763                        left_pos: start_index as u32,
764                        left_index: i as u16,
765                        cost: total_cost,
766                    });
767
768                    if total_cost < best_cost {
769                        best_cost = total_cost;
770                        best_left = Some(i as u16);
771                    }
772                }
773            }
774        }
775
776        if let Some(best_left_idx) = best_left {
777            edge.path_cost = best_cost.saturating_add(edge.word_entry.word_cost as i32);
778            edge.left_index = best_left_idx;
779            self.ends_at[stop_index].push(edge);
780        }
781    }
782
783    #[allow(clippy::too_many_arguments)]
784    fn process_unknown_word_nbest(
785        &mut self,
786        char_definitions: &CharacterDefinition,
787        unknown_dictionary: &UnknownDictionary,
788        cost_matrix: &ConnectionCostMatrix,
789        search_mode: &Mode,
790        category: CategoryId,
791        category_ord: usize,
792        unknown_word_index: Option<usize>,
793        start: usize,
794        char_idx: usize,
795        found: bool,
796    ) -> Option<usize> {
797        let mut unknown_word_num_chars: usize = 0;
798        let category_data = char_definitions.lookup_definition(category);
799        if category_data.invoke || !found {
800            unknown_word_num_chars = 1;
801            if category_data.group {
802                for i in 1.. {
803                    let next_idx = char_idx + i;
804                    if next_idx >= self.char_info_buffer.len() - 1 {
805                        break;
806                    }
807                    let num_categories = self.char_info_buffer[next_idx].categories_len as usize;
808                    let mut found_cat = false;
809                    if category_ord < num_categories {
810                        let cat = self.get_cached_category(next_idx, category_ord);
811                        if cat == category {
812                            unknown_word_num_chars += 1;
813                            found_cat = true;
814                        }
815                    }
816                    if !found_cat {
817                        break;
818                    }
819                }
820            }
821        }
822        if unknown_word_num_chars > 0 {
823            let byte_end_offset =
824                self.char_info_buffer[char_idx + unknown_word_num_chars].byte_offset;
825            let byte_len = byte_end_offset as usize - start;
826
827            let kanji_only = self.is_kanji_all(char_idx, byte_len);
828
829            for &word_id in unknown_dictionary.lookup_word_ids(category) {
830                let word_entry = unknown_dictionary.word_entry(word_id);
831                let edge = Self::create_edge(
832                    EdgeType::UNKNOWN,
833                    word_entry,
834                    start,
835                    start + byte_len,
836                    kanji_only,
837                );
838                self.add_edge_in_lattice_nbest(edge, cost_matrix, search_mode);
839            }
840            return Some(start + byte_len);
841        }
842        unknown_word_index
843    }
844
845    /// Forward Viterbi implementation for N-Best mode.
846    /// Same as set_text() but records ALL predecessor transitions in all_paths.
847    #[inline(never)]
848    #[allow(clippy::too_many_arguments)]
849    pub fn set_text_nbest(
850        &mut self,
851        dict: &PrefixDictionary,
852        user_dict: &Option<&PrefixDictionary>,
853        char_definitions: &CharacterDefinition,
854        unknown_dictionary: &UnknownDictionary,
855        cost_matrix: &ConnectionCostMatrix,
856        text: &str,
857        search_mode: &Mode,
858    ) {
859        let len = text.len();
860        self.set_capacity_nbest(len);
861
862        // Pre-calculate character information for the text
863        self.char_info_buffer.clear();
864        self.categories_buffer.clear();
865
866        if self.char_category_cache.is_empty() {
867            self.char_category_cache.resize(256, Vec::new());
868        }
869
870        for (byte_offset, c) in text.char_indices() {
871            let categories_start = self.categories_buffer.len() as u32;
872
873            if (c as u32) < 256 {
874                let cached = &mut self.char_category_cache[c as usize];
875                if cached.is_empty() {
876                    let cats = char_definitions.lookup_categories(c);
877                    for &category in cats {
878                        cached.push(category);
879                    }
880                }
881                for &category in cached.iter() {
882                    self.categories_buffer.push(category);
883                }
884            } else {
885                let categories = char_definitions.lookup_categories(c);
886                for &category in categories {
887                    self.categories_buffer.push(category);
888                }
889            }
890
891            let categories_len = (self.categories_buffer.len() as u32 - categories_start) as u16;
892
893            self.char_info_buffer.push(CharData {
894                byte_offset: byte_offset as u32,
895                is_kanji: is_kanji(c),
896                categories_start,
897                categories_len,
898                kanji_run_byte_len: 0,
899            });
900        }
901        // Sentinel for end of text
902        self.char_info_buffer.push(CharData {
903            byte_offset: len as u32,
904            is_kanji: false,
905            categories_start: 0,
906            categories_len: 0,
907            kanji_run_byte_len: 0,
908        });
909
910        // Pre-calculate Kanji run lengths (backwards)
911        for i in (0..self.char_info_buffer.len() - 1).rev() {
912            if self.char_info_buffer[i].is_kanji {
913                let next_byte_offset = self.char_info_buffer[i + 1].byte_offset;
914                let char_byte_len = next_byte_offset - self.char_info_buffer[i].byte_offset;
915                self.char_info_buffer[i].kanji_run_byte_len =
916                    char_byte_len + self.char_info_buffer[i + 1].kanji_run_byte_len;
917            } else {
918                self.char_info_buffer[i].kanji_run_byte_len = 0;
919            }
920        }
921
922        let start_edge = Edge {
923            path_cost: 0,
924            left_index: u16::MAX,
925            ..Default::default()
926        };
927        self.ends_at[0].push(start_edge);
928
929        let mut unknown_word_end: Option<usize> = None;
930
931        // Pre-scan text with Aho-Corasick
932        let mut matches_head = vec![usize::MAX; len + 1];
933        let mut matches_store: Vec<(usize, WordEntry, usize)> = Vec::with_capacity(len * 10);
934
935        // System dictionary scan
936        for m in dict.da.find_overlapping_iter(text) {
937            let start = m.start();
938            let id = m.value();
939            let count = id & ((1u32 << 5) - 1u32);
940            let offset = id >> 5u32;
941            let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
942
943            if offset_bytes < dict.vals_data.len() {
944                let data_slice = &dict.vals_data[offset_bytes..];
945                for i in 0..count {
946                    let entry_offset = WordEntry::SERIALIZED_LEN * (i as usize);
947                    if entry_offset + WordEntry::SERIALIZED_LEN <= data_slice.len() {
948                        let entry = WordEntry::deserialize(&data_slice[entry_offset..], true);
949                        if start < matches_head.len() {
950                            let next = matches_head[start];
951                            matches_head[start] = matches_store.len();
952                            matches_store.push((m.end(), entry, next));
953                        }
954                    }
955                }
956            }
957        }
958
959        // User dictionary scan
960        if let Some(ud) = user_dict {
961            for m in ud.da.find_overlapping_iter(text) {
962                let start = m.start();
963                let id = m.value();
964                let count = id & ((1u32 << 5) - 1u32);
965                let offset = id >> 5u32;
966                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
967
968                if offset_bytes < ud.vals_data.len() {
969                    let data_slice = &ud.vals_data[offset_bytes..];
970                    for i in 0..count {
971                        let entry_offset = WordEntry::SERIALIZED_LEN * (i as usize);
972                        if entry_offset + WordEntry::SERIALIZED_LEN <= data_slice.len() {
973                            let entry = WordEntry::deserialize(&data_slice[entry_offset..], false);
974                            if start < matches_head.len() {
975                                let next = matches_head[start];
976                                matches_head[start] = matches_store.len();
977                                matches_store.push((m.end(), entry, next));
978                            }
979                        }
980                    }
981                }
982            }
983        }
984
985        for char_idx in 0..self.char_info_buffer.len() - 1 {
986            let start = self.char_info_buffer[char_idx].byte_offset as usize;
987
988            if self.ends_at[start].is_empty() {
989                continue;
990            }
991
992            let mut found: bool = false;
993
994            if start < matches_head.len() {
995                let mut match_idx = matches_head[start];
996                while match_idx != usize::MAX {
997                    let (end, word_entry, next) = matches_store[match_idx];
998
999                    let prefix_len = end - start;
1000                    let kanji_only = self.is_kanji_all(char_idx, prefix_len);
1001                    let edge =
1002                        Self::create_edge(EdgeType::KNOWN, word_entry, start, end, kanji_only);
1003                    self.add_edge_in_lattice_nbest(edge, cost_matrix, search_mode);
1004                    found = true;
1005
1006                    match_idx = next;
1007                }
1008            }
1009
1010            if (search_mode.is_search()
1011                || unknown_word_end.map(|index| index <= start).unwrap_or(true))
1012                && char_idx < self.char_info_buffer.len() - 1
1013            {
1014                let num_categories = self.char_info_buffer[char_idx].categories_len as usize;
1015                for category_ord in 0..num_categories {
1016                    let category = self.get_cached_category(char_idx, category_ord);
1017                    unknown_word_end = self.process_unknown_word_nbest(
1018                        char_definitions,
1019                        unknown_dictionary,
1020                        cost_matrix,
1021                        search_mode,
1022                        category,
1023                        category_ord,
1024                        unknown_word_end,
1025                        start,
1026                        char_idx,
1027                        found,
1028                    );
1029                }
1030            }
1031        }
1032
1033        // Connect EOS with all-path recording
1034        if !self.ends_at[len].is_empty() {
1035            let eos_edge_index = self.ends_at[len].len() as u16;
1036            let mut eos_edge = Edge {
1037                start_index: len as u32,
1038                stop_index: len as u32,
1039                ..Default::default()
1040            };
1041            let left_edges = &self.ends_at[len];
1042            let mut best_cost = i32::MAX;
1043            let mut best_left = None;
1044            let right_left_id = 0; // EOS default left_id
1045
1046            for (i, left_edge) in left_edges.iter().enumerate() {
1047                let left_right_id = left_edge.word_entry.right_id();
1048                let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
1049                let path_cost = left_edge.path_cost.saturating_add(conn_cost);
1050
1051                // Record all transitions to EOS
1052                self.all_paths[len].push(PathEntry {
1053                    edge_index: eos_edge_index,
1054                    left_pos: len as u32,
1055                    left_index: i as u16,
1056                    cost: path_cost,
1057                });
1058
1059                if path_cost < best_cost {
1060                    best_cost = path_cost;
1061                    best_left = Some(i as u16);
1062                }
1063            }
1064            if let Some(left_idx) = best_left {
1065                eos_edge.left_index = left_idx;
1066                eos_edge.path_cost = best_cost;
1067                self.ends_at[len].push(eos_edge);
1068            }
1069        }
1070    }
1071
1072    /// Returns the top-N paths through the lattice.
1073    /// Each result is a (path, cost) pair where path is a Vec of (byte_start, WordId) pairs.
1074    /// The first result (index 0) is the 1-best path.
1075    /// If `unique` is true, paths with the same segmentation (same byte_start sequence)
1076    /// are deduplicated, keeping only the first (lowest cost) variant.
1077    /// If `cost_threshold` is Some(t), paths whose cost exceeds best_cost + t are discarded.
1078    /// Requires set_text_nbest() to have been called first.
1079    pub fn nbest_tokens_offset(
1080        &self,
1081        n: usize,
1082        unique: bool,
1083        cost_threshold: Option<i64>,
1084    ) -> Vec<(Vec<(usize, WordId)>, i64)> {
1085        use std::collections::HashSet;
1086
1087        use crate::nbest::NBestGenerator;
1088        let mut generator = NBestGenerator::new(self);
1089        let mut results = Vec::with_capacity(n);
1090        let mut best_cost: Option<i64> = None;
1091
1092        if unique {
1093            let mut seen: HashSet<Vec<usize>> = HashSet::new();
1094            while results.len() < n {
1095                match generator.next() {
1096                    Some((path, cost)) => {
1097                        // Record best cost from first result
1098                        let bc = *best_cost.get_or_insert(cost);
1099                        // Skip if cost exceeds threshold
1100                        if let Some(threshold) = cost_threshold {
1101                            if cost > bc + threshold {
1102                                break;
1103                            }
1104                        }
1105                        let key: Vec<usize> = path.iter().map(|(start, _)| *start).collect();
1106                        if seen.insert(key) {
1107                            results.push((path, cost));
1108                        }
1109                    }
1110                    None => break,
1111                }
1112            }
1113        } else {
1114            while results.len() < n {
1115                match generator.next() {
1116                    Some((path, cost)) => {
1117                        let bc = *best_cost.get_or_insert(cost);
1118                        if let Some(threshold) = cost_threshold {
1119                            if cost > bc + threshold {
1120                                break;
1121                            }
1122                        }
1123                        results.push((path, cost));
1124                    }
1125                    None => break,
1126                }
1127            }
1128        }
1129        results
1130    }
1131}
1132
1133#[cfg(test)]
1134mod tests {
1135    use crate::viterbi::{LexType, WordEntry, WordId};
1136
1137    #[test]
1138    fn test_word_entry() {
1139        let mut buffer = Vec::new();
1140        let word_entry = WordEntry {
1141            word_id: WordId {
1142                id: 1u32,
1143                is_system: true,
1144                lex_type: LexType::System,
1145            },
1146            word_cost: -17i16,
1147            left_id: 1411u16,
1148            right_id: 1412u16,
1149        };
1150        word_entry.serialize(&mut buffer).unwrap();
1151        assert_eq!(WordEntry::SERIALIZED_LEN, buffer.len());
1152        let word_entry2 = WordEntry::deserialize(&buffer[..], true);
1153        assert_eq!(word_entry, word_entry2);
1154    }
1155}