Skip to main content

lindera_dictionary/
viterbi.rs

1use std::io;
2
3use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::dictionary::character_definition::{CategoryId, CharacterDefinition};
8use crate::dictionary::connection_cost_matrix::ConnectionCostMatrix;
9use crate::dictionary::prefix_dictionary::PrefixDictionary;
10use crate::dictionary::unknown_dictionary::UnknownDictionary;
11use crate::mode::Mode;
12
13/// Type of lexicon containing the word
14#[derive(
15    Clone,
16    Copy,
17    Debug,
18    Eq,
19    PartialEq,
20    Serialize,
21    Deserialize,
22    Default,
23    Archive,
24    RkyvSerialize,
25    RkyvDeserialize,
26)]
27
28pub enum LexType {
29    /// System dictionary (base dictionary)
30    #[default]
31    System,
32    /// User dictionary (additional vocabulary)
33    User,
34    /// Unknown words (OOV handling)
35    Unknown,
36}
37
38#[derive(
39    Clone,
40    Copy,
41    Debug,
42    Eq,
43    PartialEq,
44    Serialize,
45    Deserialize,
46    Archive,
47    RkyvDeserialize,
48    RkyvSerialize,
49)]
50
51pub struct WordId {
52    pub id: u32,
53    pub is_system: bool,
54    pub lex_type: LexType,
55}
56
57impl WordId {
58    /// Creates a new WordId with specified lexicon type
59    pub fn new(lex_type: LexType, id: u32) -> Self {
60        WordId {
61            id,
62            is_system: matches!(lex_type, LexType::System),
63            lex_type,
64        }
65    }
66
67    pub fn is_unknown(&self) -> bool {
68        matches!(self.lex_type, LexType::Unknown)
69    }
70
71    pub fn is_system(&self) -> bool {
72        self.is_system
73    }
74
75    pub fn lex_type(&self) -> LexType {
76        self.lex_type
77    }
78}
79
80impl Default for WordId {
81    fn default() -> Self {
82        WordId {
83            id: u32::MAX,
84            is_system: true,
85            lex_type: LexType::System,
86        }
87    }
88}
89
90#[derive(
91    Default,
92    Clone,
93    Copy,
94    Debug,
95    Eq,
96    PartialEq,
97    Serialize,
98    Deserialize,
99    Archive,
100    RkyvSerialize,
101    RkyvDeserialize,
102)]
103
104pub struct WordEntry {
105    pub word_id: WordId,
106    pub word_cost: i16,
107    pub left_id: u16,
108    pub right_id: u16,
109}
110
111impl WordEntry {
112    pub const SERIALIZED_LEN: usize = 10;
113
114    pub fn left_id(&self) -> u32 {
115        self.left_id as u32
116    }
117
118    pub fn right_id(&self) -> u32 {
119        self.right_id as u32
120    }
121
122    pub fn serialize<W: io::Write>(&self, wtr: &mut W) -> io::Result<()> {
123        wtr.write_u32::<LittleEndian>(self.word_id.id)?;
124        wtr.write_i16::<LittleEndian>(self.word_cost)?;
125        wtr.write_u16::<LittleEndian>(self.left_id)?;
126        wtr.write_u16::<LittleEndian>(self.right_id)?;
127        Ok(())
128    }
129
130    pub fn deserialize(data: &[u8], is_system_entry: bool) -> WordEntry {
131        let word_id = WordId::new(
132            if is_system_entry {
133                LexType::System
134            } else {
135                LexType::User
136            },
137            LittleEndian::read_u32(&data[0..4]),
138        );
139        let word_cost = LittleEndian::read_i16(&data[4..6]);
140        let left_id = LittleEndian::read_u16(&data[6..8]);
141        let right_id = LittleEndian::read_u16(&data[8..10]);
142        WordEntry {
143            word_id,
144            word_cost,
145            left_id,
146            right_id,
147        }
148    }
149}
150
151#[derive(Clone, Copy, Debug, Default)]
152pub enum EdgeType {
153    #[default]
154    KNOWN,
155    UNKNOWN,
156    USER,
157    INSERTED,
158}
159
160#[derive(Default, Clone, Debug)]
161pub struct Edge {
162    pub edge_type: EdgeType,
163    pub word_entry: WordEntry,
164
165    pub path_cost: i32,
166    pub left_index: u16, // Index in the previous position's vector
167
168    pub start_index: u32,
169    pub stop_index: u32,
170
171    pub kanji_only: bool,
172}
173
174impl Edge {
175    pub fn num_chars(&self) -> usize {
176        (self.stop_index - self.start_index) as usize / 3
177    }
178}
179
180/// Records a transition from a left edge to the current edge.
181/// Used in N-Best mode to store all predecessor transitions
182/// (not just the best one as in 1-best).
183#[derive(Clone, Debug)]
184pub struct PathEntry {
185    /// Index of this edge in ends_at[stop_index]
186    pub edge_index: u16,
187    /// Byte position where the left edge ends (= this edge's start_index)
188    pub left_pos: u32,
189    /// Index of the left edge in ends_at[left_pos]
190    pub left_index: u16,
191    /// Total forward cost: left_edge.path_cost + conn_cost + penalty_cost
192    pub cost: i32,
193}
194
195#[derive(Clone, Default)]
196pub struct Lattice {
197    capacity: usize,
198    ends_at: Vec<Vec<Edge>>, // Now stores edges directly
199    char_info_buffer: Vec<CharData>,
200    categories_buffer: Vec<CategoryId>,
201    char_category_cache: Vec<Vec<CategoryId>>,
202
203    // N-Best fields (only populated when set_text_nbest is called)
204    all_paths: Vec<Vec<PathEntry>>,
205    nbest_capacity: usize,
206    /// The text length (in bytes) of the last set_text/set_text_nbest call
207    last_text_len: usize,
208}
209
210#[derive(Clone, Copy, Debug, Default)]
211struct CharData {
212    byte_offset: u32,
213    is_kanji: bool,
214    categories_start: u32,
215    categories_len: u16,
216    kanji_run_byte_len: u32,
217}
218
219#[inline]
220pub fn is_kanji(c: char) -> bool {
221    let c = c as u32;
222    // CJK Unified Ideographs (4E00-9FAF) and Extension A (3400-4DBF)
223    (0x4E00..=0x9FAF).contains(&c) || (0x3400..=0x4DBF).contains(&c)
224}
225
226impl Lattice {
227    /// Helper method to create an edge efficiently
228    #[inline]
229    fn create_edge(
230        edge_type: EdgeType,
231        word_entry: WordEntry,
232        start: usize,
233        stop: usize,
234        kanji_only: bool,
235    ) -> Edge {
236        Edge {
237            edge_type,
238            word_entry,
239            left_index: u16::MAX,
240            start_index: start as u32,
241            stop_index: stop as u32,
242            path_cost: i32::MAX,
243            kanji_only,
244        }
245    }
246
247    pub fn clear(&mut self) {
248        for edge_vec in &mut self.ends_at {
249            edge_vec.clear();
250        }
251        for path_vec in &mut self.all_paths {
252            path_vec.clear();
253        }
254        self.char_info_buffer.clear();
255        self.categories_buffer.clear();
256    }
257
258    #[inline]
259    fn is_kanji_all(&self, char_idx: usize, byte_len: usize) -> bool {
260        self.char_info_buffer[char_idx].kanji_run_byte_len >= byte_len as u32
261    }
262
263    #[inline]
264    fn get_cached_category(&self, char_idx: usize, category_ord: usize) -> CategoryId {
265        let char_data = &self.char_info_buffer[char_idx];
266        self.categories_buffer[char_data.categories_start as usize + category_ord]
267    }
268
269    fn set_capacity(&mut self, text_len: usize) {
270        self.clear();
271        self.last_text_len = text_len;
272        if self.capacity <= text_len {
273            self.capacity = text_len;
274            self.ends_at.resize(text_len + 1, Vec::new());
275        }
276        for vec in &mut self.ends_at {
277            vec.clear();
278        }
279    }
280
281    fn set_capacity_nbest(&mut self, text_len: usize) {
282        self.set_capacity(text_len);
283        if self.nbest_capacity <= text_len {
284            self.nbest_capacity = text_len;
285            self.all_paths.resize(text_len + 1, Vec::new());
286        }
287        for vec in &mut self.all_paths {
288            vec.clear();
289        }
290    }
291
292    #[inline(never)]
293    // Forward Viterbi implementation:
294    // Constructs the lattice and calculates the path costs simultaneously.
295    // This improves performance by avoiding a separate lattice traversal pass.
296    #[allow(clippy::too_many_arguments)]
297    pub fn set_text(
298        &mut self,
299        dict: &PrefixDictionary,
300        user_dict: &Option<&PrefixDictionary>,
301        char_definitions: &CharacterDefinition,
302        unknown_dictionary: &UnknownDictionary,
303        cost_matrix: &ConnectionCostMatrix,
304        text: &str,
305        search_mode: &Mode,
306    ) {
307        let len = text.len();
308        self.set_capacity(len);
309
310        // Pre-calculate character information for the text
311        self.char_info_buffer.clear();
312        self.categories_buffer.clear();
313
314        if self.char_category_cache.is_empty() {
315            self.char_category_cache.resize(256, Vec::new());
316        }
317
318        for (byte_offset, c) in text.char_indices() {
319            let categories_start = self.categories_buffer.len() as u32;
320
321            if (c as u32) < 256 {
322                let cached = &mut self.char_category_cache[c as usize];
323                if cached.is_empty() {
324                    let cats = char_definitions.lookup_categories(c);
325                    for &category in cats {
326                        cached.push(category);
327                    }
328                }
329                for &category in cached.iter() {
330                    self.categories_buffer.push(category);
331                }
332            } else {
333                let categories = char_definitions.lookup_categories(c);
334                for &category in categories {
335                    self.categories_buffer.push(category);
336                }
337            }
338
339            let categories_len = (self.categories_buffer.len() as u32 - categories_start) as u16;
340
341            self.char_info_buffer.push(CharData {
342                byte_offset: byte_offset as u32,
343                is_kanji: is_kanji(c),
344                categories_start,
345                categories_len,
346                kanji_run_byte_len: 0,
347            });
348        }
349        // Sentinel for end of text
350        self.char_info_buffer.push(CharData {
351            byte_offset: len as u32,
352            is_kanji: false,
353            categories_start: 0,
354            categories_len: 0,
355            kanji_run_byte_len: 0,
356        });
357
358        // Pre-calculate Kanji run lengths (backwards)
359        for i in (0..self.char_info_buffer.len() - 1).rev() {
360            if self.char_info_buffer[i].is_kanji {
361                let next_byte_offset = self.char_info_buffer[i + 1].byte_offset;
362                let char_byte_len = next_byte_offset - self.char_info_buffer[i].byte_offset;
363                self.char_info_buffer[i].kanji_run_byte_len =
364                    char_byte_len + self.char_info_buffer[i + 1].kanji_run_byte_len;
365            } else {
366                self.char_info_buffer[i].kanji_run_byte_len = 0;
367            }
368        }
369
370        let start_edge = Edge {
371            path_cost: 0,
372            left_index: u16::MAX,
373            ..Default::default()
374        };
375        self.ends_at[0].push(start_edge);
376
377        // Index of the last character of unknown word
378        let mut unknown_word_end: Option<usize> = None;
379
380        // Pre-scan text with Aho-Corasick to report all matches
381        // Optimization: Use flat vectors instead of Vec<Vec<_>> to avoid many small allocations.
382        // Linked list structure: matches_head[start_idx] -> index in matches_store
383        let mut matches_head = vec![usize::MAX; len + 1];
384        let mut matches_store: Vec<(usize, WordEntry, usize)> = Vec::with_capacity(len * 10);
385
386        // System dictionary scan (8-bit variant-count encoding)
387        for m in dict.da.find_overlapping_iter(text) {
388            let start = m.start();
389            let (offset, count) = dict.decode_val(m.value());
390            let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
391
392            // Bounds check for safety, though daachorse should guarantee valid ids if built correctly
393            if offset_bytes < dict.vals_data.len() {
394                let data_slice = &dict.vals_data[offset_bytes..];
395                for i in 0..count {
396                    let entry_offset = WordEntry::SERIALIZED_LEN * (i as usize);
397                    if entry_offset + WordEntry::SERIALIZED_LEN <= data_slice.len() {
398                        let entry = WordEntry::deserialize(&data_slice[entry_offset..], true);
399                        if start < matches_head.len() {
400                            let next = matches_head[start];
401                            matches_head[start] = matches_store.len();
402                            matches_store.push((m.end(), entry, next));
403                        }
404                    }
405                }
406            }
407        }
408
409        // User dictionary scan (5-bit variant-count encoding for bwd compat)
410        if let Some(ud) = user_dict {
411            for m in ud.da.find_overlapping_iter(text) {
412                let start = m.start();
413                let (offset, count) = ud.decode_val(m.value());
414                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
415
416                if offset_bytes < ud.vals_data.len() {
417                    let data_slice = &ud.vals_data[offset_bytes..];
418                    for i in 0..count {
419                        let entry_offset = WordEntry::SERIALIZED_LEN * (i as usize);
420                        if entry_offset + WordEntry::SERIALIZED_LEN <= data_slice.len() {
421                            let entry = WordEntry::deserialize(&data_slice[entry_offset..], false);
422                            if start < matches_head.len() {
423                                let next = matches_head[start];
424                                matches_head[start] = matches_store.len();
425                                matches_store.push((m.end(), entry, next));
426                            }
427                        }
428                    }
429                }
430            }
431        }
432
433        for char_idx in 0..self.char_info_buffer.len() - 1 {
434            let start = self.char_info_buffer[char_idx].byte_offset as usize;
435
436            // No arc is ending here.
437            // No need to check if a valid word starts here.
438            if self.ends_at[start].is_empty() {
439                continue;
440            }
441
442            let mut found: bool = false;
443
444            // Use cached matches
445            if start < matches_head.len() {
446                let mut match_idx = matches_head[start];
447                while match_idx != usize::MAX {
448                    let (end, word_entry, next) = matches_store[match_idx];
449
450                    let prefix_len = end - start;
451                    let kanji_only = self.is_kanji_all(char_idx, prefix_len);
452                    let edge = Self::create_edge(
453                        EdgeType::KNOWN,
454                        word_entry, // WordEntry is Copy
455                        start,
456                        end,
457                        kanji_only,
458                    );
459                    self.add_edge_in_lattice(edge, cost_matrix, search_mode);
460                    found = true;
461
462                    match_idx = next;
463                }
464            }
465
466            // In the case of normal mode, it doesn't process unknown word greedily.
467            if (search_mode.is_search()
468                || unknown_word_end.map(|index| index <= start).unwrap_or(true))
469                && char_idx < self.char_info_buffer.len() - 1
470            {
471                let num_categories = self.char_info_buffer[char_idx].categories_len as usize;
472                for category_ord in 0..num_categories {
473                    let category = self.get_cached_category(char_idx, category_ord);
474                    unknown_word_end = self.process_unknown_word(
475                        char_definitions,
476                        unknown_dictionary,
477                        cost_matrix,
478                        search_mode,
479                        category,
480                        category_ord,
481                        unknown_word_end,
482                        start,
483                        char_idx,
484                        found,
485                    );
486                }
487            }
488        }
489
490        // Connect EOS
491        if !self.ends_at[len].is_empty() {
492            let mut eos_edge = Edge {
493                start_index: len as u32,
494                stop_index: len as u32,
495                ..Default::default()
496            };
497            // Calculate cost for EOS
498            let left_edges = &self.ends_at[len];
499            let mut best_cost = i32::MAX;
500            let mut best_left = None;
501            let right_left_id = 0; // EOS default left_id
502
503            for (i, left_edge) in left_edges.iter().enumerate() {
504                let left_right_id = left_edge.word_entry.right_id();
505                let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
506                let path_cost = left_edge.path_cost.saturating_add(conn_cost);
507                if path_cost < best_cost {
508                    best_cost = path_cost;
509                    best_left = Some(i as u16);
510                }
511            }
512            if let Some(left_idx) = best_left {
513                eos_edge.left_index = left_idx;
514                eos_edge.path_cost = best_cost;
515                self.ends_at[len].push(eos_edge);
516            }
517        }
518    }
519
520    #[allow(clippy::too_many_arguments)]
521    fn process_unknown_word(
522        &mut self,
523        char_definitions: &CharacterDefinition,
524        unknown_dictionary: &UnknownDictionary,
525        cost_matrix: &ConnectionCostMatrix,
526        search_mode: &Mode,
527        category: CategoryId,
528        category_ord: usize,
529        unknown_word_index: Option<usize>,
530        start: usize,
531        char_idx: usize,
532        found: bool,
533    ) -> Option<usize> {
534        let mut unknown_word_num_chars: usize = 0;
535        let category_data = char_definitions.lookup_definition(category);
536        if category_data.invoke || !found {
537            unknown_word_num_chars = 1;
538            if category_data.group {
539                for i in 1.. {
540                    let next_idx = char_idx + i;
541                    if next_idx >= self.char_info_buffer.len() - 1 {
542                        break;
543                    }
544                    let num_categories = self.char_info_buffer[next_idx].categories_len as usize;
545                    let mut found_cat = false;
546                    if category_ord < num_categories {
547                        let cat = self.get_cached_category(next_idx, category_ord);
548                        if cat == category {
549                            unknown_word_num_chars += 1;
550                            found_cat = true;
551                        }
552                    }
553                    if !found_cat {
554                        break;
555                    }
556                }
557            }
558        }
559        if unknown_word_num_chars > 0 {
560            let byte_end_offset =
561                self.char_info_buffer[char_idx + unknown_word_num_chars].byte_offset;
562            let byte_len = byte_end_offset as usize - start;
563
564            // Check Kanji status using pre-calculated buffer
565            let kanji_only = self.is_kanji_all(char_idx, byte_len);
566
567            for &word_id in unknown_dictionary.lookup_word_ids(category) {
568                let word_entry = unknown_dictionary.word_entry(word_id);
569                let edge = Self::create_edge(
570                    EdgeType::UNKNOWN,
571                    word_entry,
572                    start,
573                    start + byte_len,
574                    kanji_only,
575                );
576                self.add_edge_in_lattice(edge, cost_matrix, search_mode);
577            }
578            return Some(start + byte_len);
579        }
580        unknown_word_index
581    }
582
583    // Adds an edge to the lattice and calculates the minimum cost to reach it.
584    fn add_edge_in_lattice(
585        &mut self,
586        mut edge: Edge,
587        cost_matrix: &ConnectionCostMatrix,
588        mode: &Mode,
589    ) {
590        let start_index = edge.start_index as usize;
591        let stop_index = edge.stop_index as usize;
592        let right_left_id = edge.word_entry.left_id();
593
594        let left_edges = &self.ends_at[start_index];
595        if left_edges.is_empty() {
596            return;
597        }
598
599        let mut best_cost = i32::MAX;
600        let mut best_left = None;
601
602        match mode {
603            Mode::Normal => {
604                for (i, left_edge) in left_edges.iter().enumerate() {
605                    let left_right_id = left_edge.word_entry.right_id();
606                    let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
607                    let total_cost = left_edge.path_cost.saturating_add(conn_cost);
608
609                    if total_cost < best_cost {
610                        best_cost = total_cost;
611                        best_left = Some(i as u16);
612                    }
613                }
614            }
615            Mode::Decompose(penalty) => {
616                for (i, left_edge) in left_edges.iter().enumerate() {
617                    let left_right_id = left_edge.word_entry.right_id();
618                    let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
619                    let penalty_cost = penalty.penalty(left_edge);
620                    let total_cost = left_edge
621                        .path_cost
622                        .saturating_add(conn_cost)
623                        .saturating_add(penalty_cost);
624
625                    if total_cost < best_cost {
626                        best_cost = total_cost;
627                        best_left = Some(i as u16);
628                    }
629                }
630            }
631        }
632
633        if let Some(best_left_idx) = best_left {
634            edge.path_cost = best_cost.saturating_add(edge.word_entry.word_cost as i32);
635            edge.left_index = best_left_idx;
636            self.ends_at[stop_index].push(edge);
637        }
638    }
639
640    pub fn tokens_offset(&self) -> Vec<(usize, WordId)> {
641        let mut offsets = Vec::new();
642
643        if self.ends_at.is_empty() {
644            return offsets;
645        }
646
647        let mut last_idx = self.ends_at.len() - 1;
648        while last_idx > 0 && self.ends_at[last_idx].is_empty() {
649            last_idx -= 1;
650        }
651
652        if self.ends_at[last_idx].is_empty() {
653            return offsets;
654        }
655
656        let idx = self.ends_at[last_idx].len() - 1;
657        let mut edge = &self.ends_at[last_idx][idx];
658
659        if edge.left_index == u16::MAX {
660            return offsets;
661        }
662
663        loop {
664            if edge.left_index == u16::MAX {
665                break;
666            }
667
668            offsets.push((edge.start_index as usize, edge.word_entry.word_id));
669
670            let left_idx = edge.left_index as usize;
671            let start_idx = edge.start_index as usize;
672
673            edge = &self.ends_at[start_idx][left_idx];
674        }
675
676        offsets.reverse();
677        offsets.pop(); // Remove EOS
678
679        offsets
680    }
681
682    // --- N-Best support ---
683
684    /// Returns the text length (in bytes) from the last set_text/set_text_nbest call.
685    pub fn text_len(&self) -> usize {
686        self.last_text_len
687    }
688
689    /// Returns the edges at a given byte position.
690    pub fn edges_at(&self, byte_pos: usize) -> &[Edge] {
691        &self.ends_at[byte_pos]
692    }
693
694    /// Returns the N-Best path entries at a given byte position.
695    pub fn paths_at(&self, byte_pos: usize) -> &[PathEntry] {
696        if byte_pos < self.all_paths.len() {
697            &self.all_paths[byte_pos]
698        } else {
699            &[]
700        }
701    }
702
703    /// Adds an edge to the lattice, recording ALL predecessor transitions for N-Best.
704    fn add_edge_in_lattice_nbest(
705        &mut self,
706        mut edge: Edge,
707        cost_matrix: &ConnectionCostMatrix,
708        mode: &Mode,
709    ) {
710        let start_index = edge.start_index as usize;
711        let stop_index = edge.stop_index as usize;
712        let right_left_id = edge.word_entry.left_id();
713
714        let left_edges = &self.ends_at[start_index];
715        if left_edges.is_empty() {
716            return;
717        }
718
719        let mut best_cost = i32::MAX;
720        let mut best_left = None;
721
722        // The edge_index of the new edge being added
723        let new_edge_index = self.ends_at[stop_index].len() as u16;
724
725        match mode {
726            Mode::Normal => {
727                for (i, left_edge) in left_edges.iter().enumerate() {
728                    let left_right_id = left_edge.word_entry.right_id();
729                    let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
730                    let total_cost = left_edge.path_cost.saturating_add(conn_cost);
731
732                    // Record ALL transitions for N-Best
733                    self.all_paths[stop_index].push(PathEntry {
734                        edge_index: new_edge_index,
735                        left_pos: start_index as u32,
736                        left_index: i as u16,
737                        cost: total_cost,
738                    });
739
740                    if total_cost < best_cost {
741                        best_cost = total_cost;
742                        best_left = Some(i as u16);
743                    }
744                }
745            }
746            Mode::Decompose(penalty) => {
747                for (i, left_edge) in left_edges.iter().enumerate() {
748                    let left_right_id = left_edge.word_entry.right_id();
749                    let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
750                    let penalty_cost = penalty.penalty(left_edge);
751                    let total_cost = left_edge
752                        .path_cost
753                        .saturating_add(conn_cost)
754                        .saturating_add(penalty_cost);
755
756                    // Record ALL transitions for N-Best
757                    self.all_paths[stop_index].push(PathEntry {
758                        edge_index: new_edge_index,
759                        left_pos: start_index as u32,
760                        left_index: i as u16,
761                        cost: total_cost,
762                    });
763
764                    if total_cost < best_cost {
765                        best_cost = total_cost;
766                        best_left = Some(i as u16);
767                    }
768                }
769            }
770        }
771
772        if let Some(best_left_idx) = best_left {
773            edge.path_cost = best_cost.saturating_add(edge.word_entry.word_cost as i32);
774            edge.left_index = best_left_idx;
775            self.ends_at[stop_index].push(edge);
776        }
777    }
778
779    #[allow(clippy::too_many_arguments)]
780    fn process_unknown_word_nbest(
781        &mut self,
782        char_definitions: &CharacterDefinition,
783        unknown_dictionary: &UnknownDictionary,
784        cost_matrix: &ConnectionCostMatrix,
785        search_mode: &Mode,
786        category: CategoryId,
787        category_ord: usize,
788        unknown_word_index: Option<usize>,
789        start: usize,
790        char_idx: usize,
791        found: bool,
792    ) -> Option<usize> {
793        let mut unknown_word_num_chars: usize = 0;
794        let category_data = char_definitions.lookup_definition(category);
795        if category_data.invoke || !found {
796            unknown_word_num_chars = 1;
797            if category_data.group {
798                for i in 1.. {
799                    let next_idx = char_idx + i;
800                    if next_idx >= self.char_info_buffer.len() - 1 {
801                        break;
802                    }
803                    let num_categories = self.char_info_buffer[next_idx].categories_len as usize;
804                    let mut found_cat = false;
805                    if category_ord < num_categories {
806                        let cat = self.get_cached_category(next_idx, category_ord);
807                        if cat == category {
808                            unknown_word_num_chars += 1;
809                            found_cat = true;
810                        }
811                    }
812                    if !found_cat {
813                        break;
814                    }
815                }
816            }
817        }
818        if unknown_word_num_chars > 0 {
819            let byte_end_offset =
820                self.char_info_buffer[char_idx + unknown_word_num_chars].byte_offset;
821            let byte_len = byte_end_offset as usize - start;
822
823            let kanji_only = self.is_kanji_all(char_idx, byte_len);
824
825            for &word_id in unknown_dictionary.lookup_word_ids(category) {
826                let word_entry = unknown_dictionary.word_entry(word_id);
827                let edge = Self::create_edge(
828                    EdgeType::UNKNOWN,
829                    word_entry,
830                    start,
831                    start + byte_len,
832                    kanji_only,
833                );
834                self.add_edge_in_lattice_nbest(edge, cost_matrix, search_mode);
835            }
836            return Some(start + byte_len);
837        }
838        unknown_word_index
839    }
840
841    /// Forward Viterbi implementation for N-Best mode.
842    /// Same as set_text() but records ALL predecessor transitions in all_paths.
843    #[inline(never)]
844    #[allow(clippy::too_many_arguments)]
845    pub fn set_text_nbest(
846        &mut self,
847        dict: &PrefixDictionary,
848        user_dict: &Option<&PrefixDictionary>,
849        char_definitions: &CharacterDefinition,
850        unknown_dictionary: &UnknownDictionary,
851        cost_matrix: &ConnectionCostMatrix,
852        text: &str,
853        search_mode: &Mode,
854    ) {
855        let len = text.len();
856        self.set_capacity_nbest(len);
857
858        // Pre-calculate character information for the text
859        self.char_info_buffer.clear();
860        self.categories_buffer.clear();
861
862        if self.char_category_cache.is_empty() {
863            self.char_category_cache.resize(256, Vec::new());
864        }
865
866        for (byte_offset, c) in text.char_indices() {
867            let categories_start = self.categories_buffer.len() as u32;
868
869            if (c as u32) < 256 {
870                let cached = &mut self.char_category_cache[c as usize];
871                if cached.is_empty() {
872                    let cats = char_definitions.lookup_categories(c);
873                    for &category in cats {
874                        cached.push(category);
875                    }
876                }
877                for &category in cached.iter() {
878                    self.categories_buffer.push(category);
879                }
880            } else {
881                let categories = char_definitions.lookup_categories(c);
882                for &category in categories {
883                    self.categories_buffer.push(category);
884                }
885            }
886
887            let categories_len = (self.categories_buffer.len() as u32 - categories_start) as u16;
888
889            self.char_info_buffer.push(CharData {
890                byte_offset: byte_offset as u32,
891                is_kanji: is_kanji(c),
892                categories_start,
893                categories_len,
894                kanji_run_byte_len: 0,
895            });
896        }
897        // Sentinel for end of text
898        self.char_info_buffer.push(CharData {
899            byte_offset: len as u32,
900            is_kanji: false,
901            categories_start: 0,
902            categories_len: 0,
903            kanji_run_byte_len: 0,
904        });
905
906        // Pre-calculate Kanji run lengths (backwards)
907        for i in (0..self.char_info_buffer.len() - 1).rev() {
908            if self.char_info_buffer[i].is_kanji {
909                let next_byte_offset = self.char_info_buffer[i + 1].byte_offset;
910                let char_byte_len = next_byte_offset - self.char_info_buffer[i].byte_offset;
911                self.char_info_buffer[i].kanji_run_byte_len =
912                    char_byte_len + self.char_info_buffer[i + 1].kanji_run_byte_len;
913            } else {
914                self.char_info_buffer[i].kanji_run_byte_len = 0;
915            }
916        }
917
918        let start_edge = Edge {
919            path_cost: 0,
920            left_index: u16::MAX,
921            ..Default::default()
922        };
923        self.ends_at[0].push(start_edge);
924
925        let mut unknown_word_end: Option<usize> = None;
926
927        // Pre-scan text with Aho-Corasick
928        let mut matches_head = vec![usize::MAX; len + 1];
929        let mut matches_store: Vec<(usize, WordEntry, usize)> = Vec::with_capacity(len * 10);
930
931        // System dictionary scan (8-bit variant-count encoding)
932        for m in dict.da.find_overlapping_iter(text) {
933            let start = m.start();
934            let (offset, count) = dict.decode_val(m.value());
935            let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
936
937            if offset_bytes < dict.vals_data.len() {
938                let data_slice = &dict.vals_data[offset_bytes..];
939                for i in 0..count {
940                    let entry_offset = WordEntry::SERIALIZED_LEN * (i as usize);
941                    if entry_offset + WordEntry::SERIALIZED_LEN <= data_slice.len() {
942                        let entry = WordEntry::deserialize(&data_slice[entry_offset..], true);
943                        if start < matches_head.len() {
944                            let next = matches_head[start];
945                            matches_head[start] = matches_store.len();
946                            matches_store.push((m.end(), entry, next));
947                        }
948                    }
949                }
950            }
951        }
952
953        // User dictionary scan (5-bit variant-count encoding for bwd compat)
954        if let Some(ud) = user_dict {
955            for m in ud.da.find_overlapping_iter(text) {
956                let start = m.start();
957                let (offset, count) = ud.decode_val(m.value());
958                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
959
960                if offset_bytes < ud.vals_data.len() {
961                    let data_slice = &ud.vals_data[offset_bytes..];
962                    for i in 0..count {
963                        let entry_offset = WordEntry::SERIALIZED_LEN * (i as usize);
964                        if entry_offset + WordEntry::SERIALIZED_LEN <= data_slice.len() {
965                            let entry = WordEntry::deserialize(&data_slice[entry_offset..], false);
966                            if start < matches_head.len() {
967                                let next = matches_head[start];
968                                matches_head[start] = matches_store.len();
969                                matches_store.push((m.end(), entry, next));
970                            }
971                        }
972                    }
973                }
974            }
975        }
976
977        for char_idx in 0..self.char_info_buffer.len() - 1 {
978            let start = self.char_info_buffer[char_idx].byte_offset as usize;
979
980            if self.ends_at[start].is_empty() {
981                continue;
982            }
983
984            let mut found: bool = false;
985
986            if start < matches_head.len() {
987                let mut match_idx = matches_head[start];
988                while match_idx != usize::MAX {
989                    let (end, word_entry, next) = matches_store[match_idx];
990
991                    let prefix_len = end - start;
992                    let kanji_only = self.is_kanji_all(char_idx, prefix_len);
993                    let edge =
994                        Self::create_edge(EdgeType::KNOWN, word_entry, start, end, kanji_only);
995                    self.add_edge_in_lattice_nbest(edge, cost_matrix, search_mode);
996                    found = true;
997
998                    match_idx = next;
999                }
1000            }
1001
1002            if (search_mode.is_search()
1003                || unknown_word_end.map(|index| index <= start).unwrap_or(true))
1004                && char_idx < self.char_info_buffer.len() - 1
1005            {
1006                let num_categories = self.char_info_buffer[char_idx].categories_len as usize;
1007                for category_ord in 0..num_categories {
1008                    let category = self.get_cached_category(char_idx, category_ord);
1009                    unknown_word_end = self.process_unknown_word_nbest(
1010                        char_definitions,
1011                        unknown_dictionary,
1012                        cost_matrix,
1013                        search_mode,
1014                        category,
1015                        category_ord,
1016                        unknown_word_end,
1017                        start,
1018                        char_idx,
1019                        found,
1020                    );
1021                }
1022            }
1023        }
1024
1025        // Connect EOS with all-path recording
1026        if !self.ends_at[len].is_empty() {
1027            let eos_edge_index = self.ends_at[len].len() as u16;
1028            let mut eos_edge = Edge {
1029                start_index: len as u32,
1030                stop_index: len as u32,
1031                ..Default::default()
1032            };
1033            let left_edges = &self.ends_at[len];
1034            let mut best_cost = i32::MAX;
1035            let mut best_left = None;
1036            let right_left_id = 0; // EOS default left_id
1037
1038            for (i, left_edge) in left_edges.iter().enumerate() {
1039                let left_right_id = left_edge.word_entry.right_id();
1040                let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
1041                let path_cost = left_edge.path_cost.saturating_add(conn_cost);
1042
1043                // Record all transitions to EOS
1044                self.all_paths[len].push(PathEntry {
1045                    edge_index: eos_edge_index,
1046                    left_pos: len as u32,
1047                    left_index: i as u16,
1048                    cost: path_cost,
1049                });
1050
1051                if path_cost < best_cost {
1052                    best_cost = path_cost;
1053                    best_left = Some(i as u16);
1054                }
1055            }
1056            if let Some(left_idx) = best_left {
1057                eos_edge.left_index = left_idx;
1058                eos_edge.path_cost = best_cost;
1059                self.ends_at[len].push(eos_edge);
1060            }
1061        }
1062    }
1063
1064    /// Returns the top-N paths through the lattice.
1065    /// Each result is a (path, cost) pair where path is a Vec of (byte_start, WordId) pairs.
1066    /// The first result (index 0) is the 1-best path.
1067    /// If `unique` is true, paths with the same segmentation (same byte_start sequence)
1068    /// are deduplicated, keeping only the first (lowest cost) variant.
1069    /// If `cost_threshold` is Some(t), paths whose cost exceeds best_cost + t are discarded.
1070    /// Requires set_text_nbest() to have been called first.
1071    pub fn nbest_tokens_offset(
1072        &self,
1073        n: usize,
1074        unique: bool,
1075        cost_threshold: Option<i64>,
1076    ) -> Vec<(Vec<(usize, WordId)>, i64)> {
1077        use std::collections::HashSet;
1078
1079        use crate::nbest::NBestGenerator;
1080        let mut generator = NBestGenerator::new(self);
1081        let mut results = Vec::with_capacity(n);
1082        let mut best_cost: Option<i64> = None;
1083
1084        if unique {
1085            let mut seen: HashSet<Vec<usize>> = HashSet::new();
1086            while results.len() < n {
1087                match generator.next() {
1088                    Some((path, cost)) => {
1089                        // Record best cost from first result
1090                        let bc = *best_cost.get_or_insert(cost);
1091                        // Skip if cost exceeds threshold
1092                        if let Some(threshold) = cost_threshold
1093                            && cost > bc + threshold
1094                        {
1095                            break;
1096                        }
1097                        let key: Vec<usize> = path.iter().map(|(start, _)| *start).collect();
1098                        if seen.insert(key) {
1099                            results.push((path, cost));
1100                        }
1101                    }
1102                    None => break,
1103                }
1104            }
1105        } else {
1106            while results.len() < n {
1107                match generator.next() {
1108                    Some((path, cost)) => {
1109                        let bc = *best_cost.get_or_insert(cost);
1110                        if let Some(threshold) = cost_threshold
1111                            && cost > bc + threshold
1112                        {
1113                            break;
1114                        }
1115                        results.push((path, cost));
1116                    }
1117                    None => break,
1118                }
1119            }
1120        }
1121        results
1122    }
1123}
1124
1125#[cfg(test)]
1126mod tests {
1127    use crate::viterbi::{LexType, WordEntry, WordId};
1128
1129    #[test]
1130    fn test_word_entry() {
1131        let mut buffer = Vec::new();
1132        let word_entry = WordEntry {
1133            word_id: WordId {
1134                id: 1u32,
1135                is_system: true,
1136                lex_type: LexType::System,
1137            },
1138            word_cost: -17i16,
1139            left_id: 1411u16,
1140            right_id: 1412u16,
1141        };
1142        word_entry.serialize(&mut buffer).unwrap();
1143        assert_eq!(WordEntry::SERIALIZED_LEN, buffer.len());
1144        let word_entry2 = WordEntry::deserialize(&buffer[..], true);
1145        assert_eq!(word_entry, word_entry2);
1146    }
1147}