1use crate::dictionary::lexicon_set::LexiconSet;
7use crate::lattice::node::LatticeNode;
8use crate::lattice::Lattice;
9use crate::oov::mecab::{provide_mecab_oov, MecabOovConfig};
10use crate::oov::simple::{provide_oov, OovProviderConfig};
11use crate::types::ConnectionCosts;
12
13#[derive(Debug, Clone)]
15pub struct PathNode {
16 pub begin: usize,
17 pub end: usize,
18 pub word_id: i32,
19 pub left_id: i16,
20 pub right_id: i16,
21 pub cost: i16,
22 pub total_cost: i32,
23 pub is_oov: bool,
24 pub oov_pos_id: Option<i16>,
25}
26
27impl From<&LatticeNode> for PathNode {
28 fn from(n: &LatticeNode) -> Self {
29 Self {
30 begin: n.begin,
31 end: n.end,
32 word_id: n.word_id,
33 left_id: n.left_id,
34 right_id: n.right_id,
35 cost: n.cost,
36 total_cost: n.total_cost,
37 is_oov: n.is_oov,
38 oov_pos_id: n.oov_pos_id,
39 }
40 }
41}
42
43pub struct LatticeInput<'a> {
45 pub bytes: &'a [u8],
46 pub can_bow: &'a [u8],
47 pub char_categories: &'a [u32],
48 pub word_candidate_lengths: &'a [u32],
49 pub continuous_lengths: &'a [u32],
50 pub code_point_byte_lengths_flat: &'a [u8],
51 pub code_point_offsets: &'a [u32],
52}
53
54pub struct DictionaryCtx<'a> {
56 pub lexicon: &'a LexiconSet,
57 pub connection: ConnectionCosts<'a>,
58}
59
60pub struct OovCtx<'a> {
62 pub simple: &'a OovProviderConfig,
63 pub mecab: Option<&'a MecabOovConfig>,
64}
65
66pub fn build_lattice_and_solve(
68 dict: &DictionaryCtx<'_>,
69 input: &LatticeInput<'_>,
70 oov: &OovCtx<'_>,
71) -> Result<Vec<PathNode>, String> {
72 let byte_length = input.bytes.len();
73 let mut lattice = Lattice::new();
74 lattice.resize(byte_length);
75
76 for i in 0..byte_length {
77 if input.can_bow[i] == 0 {
79 continue;
80 }
81
82 if !lattice.has_previous_node(i) {
84 continue;
85 }
86
87 let mut has_words = false;
88
89 let matches = dict.lexicon.lookup(input.bytes, i, byte_length - i);
91 for m in &matches {
92 for &word_id in &m.word_ids {
93 let node = LatticeNode {
94 word_id,
95 left_id: dict.lexicon.get_left_id(word_id),
96 right_id: dict.lexicon.get_right_id(word_id),
97 cost: dict.lexicon.get_cost(word_id),
98 ..Default::default()
99 };
100 lattice.insert(i, i + m.length, node, &dict.connection);
101 has_words = true;
102 }
103 }
104
105 if let Some(mecab_cfg) = oov.mecab {
107 let cat = input.char_categories.get(i).copied().unwrap_or(1);
108 let cont_len = input.continuous_lengths.get(i).copied().unwrap_or(0) as usize;
109
110 let cp_start = input.code_point_offsets.get(i).copied().unwrap_or(0) as usize;
112 let cp_end = if i + cont_len < input.code_point_offsets.len() {
113 input.code_point_offsets[i + cont_len] as usize
114 } else {
115 input.code_point_byte_lengths_flat.len()
116 };
117 let cp_bytes: Vec<usize> = input
119 .code_point_byte_lengths_flat
120 .get(cp_start..cp_end)
121 .unwrap_or(&[])
122 .iter()
123 .map(|&b| b as usize)
124 .collect();
125
126 let mecab_results = provide_mecab_oov(cat, cont_len, &cp_bytes, has_words, mecab_cfg);
127 for result in mecab_results {
128 lattice.insert(i, i + result.byte_length, result.node, &dict.connection);
129 has_words = true;
130 }
131 }
132
133 if !has_words {
135 let wc_len = input.word_candidate_lengths.get(i).copied().unwrap_or(0) as usize;
136 if let Some(node) = provide_oov(wc_len, false, oov.simple) {
137 lattice.insert(i, i + wc_len, node, &dict.connection);
138 }
139 }
140 }
141
142 lattice.connect_eos(&dict.connection);
143 let best_path = lattice.get_best_path()?;
144 Ok(best_path.iter().map(PathNode::from).collect())
145}