1use std::fs::File;
2use std::path::Path;
3
4use pyo3::{exceptions::PyValueError, prelude::*};
5
6use lindera::dictionary::trainer::{Corpus, Model, SerializableModel, Trainer, TrainerConfig};
7
8#[pyfunction]
27#[pyo3(signature = (seed, corpus, char_def, unk_def, feature_def, rewrite_def, output, lambda_=0.01, max_iter=100, max_threads=None))]
28#[allow(clippy::too_many_arguments)]
29pub fn train(
30 seed: &str,
31 corpus: &str,
32 char_def: &str,
33 unk_def: &str,
34 feature_def: &str,
35 rewrite_def: &str,
36 output: &str,
37 lambda_: f64,
38 max_iter: u64,
39 max_threads: Option<usize>,
40) -> PyResult<()> {
41 let seed_path = Path::new(seed);
42 let corpus_path = Path::new(corpus);
43 let char_def_path = Path::new(char_def);
44 let unk_def_path = Path::new(unk_def);
45 let feature_def_path = Path::new(feature_def);
46 let rewrite_def_path = Path::new(rewrite_def);
47 let output_path = Path::new(output);
48
49 for (path, name) in [
51 (seed_path, "seed"),
52 (corpus_path, "corpus"),
53 (char_def_path, "char_def"),
54 (unk_def_path, "unk_def"),
55 (feature_def_path, "feature_def"),
56 (rewrite_def_path, "rewrite_def"),
57 ] {
58 if !path.exists() {
59 return Err(PyValueError::new_err(format!(
60 "{} file does not exist: {}",
61 name,
62 path.display()
63 )));
64 }
65 }
66
67 let config = TrainerConfig::from_paths(
69 seed_path,
70 char_def_path,
71 unk_def_path,
72 feature_def_path,
73 rewrite_def_path,
74 )
75 .map_err(|e| PyValueError::new_err(format!("Failed to load trainer configuration: {e}")))?;
76
77 let num_threads = max_threads.unwrap_or_else(num_cpus::get);
79 let trainer = Trainer::new(config)
80 .map_err(|e| PyValueError::new_err(format!("Failed to initialize trainer: {e}")))?
81 .regularization_cost(lambda_)
82 .max_iter(max_iter)
83 .num_threads(num_threads);
84
85 let corpus_file = File::open(corpus_path)
87 .map_err(|e| PyValueError::new_err(format!("Failed to open corpus file: {e}")))?;
88 let corpus = Corpus::from_reader(corpus_file)
89 .map_err(|e| PyValueError::new_err(format!("Failed to load corpus: {e}")))?;
90
91 println!("Training with {} examples...", corpus.len());
92
93 let model = trainer
95 .train(corpus)
96 .map_err(|e| PyValueError::new_err(format!("Training failed: {e}")))?;
97
98 if let Some(parent) = output_path.parent() {
100 std::fs::create_dir_all(parent).map_err(|e| {
101 PyValueError::new_err(format!("Failed to create output directory: {e}"))
102 })?;
103 }
104
105 let mut output_file = File::create(output_path)
106 .map_err(|e| PyValueError::new_err(format!("Failed to create output file: {e}")))?;
107
108 model
109 .write_model(&mut output_file)
110 .map_err(|e| PyValueError::new_err(format!("Failed to write model: {e}")))?;
111
112 println!("Model saved to {}", output_path.display());
113 Ok(())
114}
115
116#[pyfunction]
128#[pyo3(signature = (model, output, metadata=None))]
129pub fn export(model: &str, output: &str, metadata: Option<&str>) -> PyResult<()> {
130 let model_path = Path::new(model);
131 let output_path = Path::new(output);
132
133 if !model_path.exists() {
135 return Err(PyValueError::new_err(format!(
136 "Model file does not exist: {}",
137 model_path.display()
138 )));
139 }
140
141 let model_file = File::open(model_path)
143 .map_err(|e| PyValueError::new_err(format!("Failed to open model file: {e}")))?;
144
145 let serializable_model: SerializableModel = Model::read_model(model_file)
146 .map_err(|e| PyValueError::new_err(format!("Failed to load model: {e}")))?;
147
148 std::fs::create_dir_all(output_path)
150 .map_err(|e| PyValueError::new_err(format!("Failed to create output directory: {e}")))?;
151
152 let lexicon_path = output_path.join("lex.csv");
154 let connector_path = output_path.join("matrix.def");
155 let unk_path = output_path.join("unk.def");
156 let char_def_path = output_path.join("char.def");
157
158 let mut lexicon_file = File::create(&lexicon_path)
160 .map_err(|e| PyValueError::new_err(format!("Failed to create lexicon file: {e}")))?;
161 serializable_model
162 .write_lexicon(&mut lexicon_file)
163 .map_err(|e| PyValueError::new_err(format!("Failed to write lexicon: {e}")))?;
164
165 let mut connector_file = File::create(&connector_path).map_err(|e| {
167 PyValueError::new_err(format!("Failed to create connection matrix file: {e}"))
168 })?;
169 serializable_model
170 .write_connection_costs(&mut connector_file)
171 .map_err(|e| PyValueError::new_err(format!("Failed to write connection costs: {e}")))?;
172
173 let mut unk_file = File::create(&unk_path)
175 .map_err(|e| PyValueError::new_err(format!("Failed to create unknown word file: {e}")))?;
176 serializable_model
177 .write_unknown_dictionary(&mut unk_file)
178 .map_err(|e| PyValueError::new_err(format!("Failed to write unknown dictionary: {e}")))?;
179
180 let mut char_def_file = File::create(&char_def_path).map_err(|e| {
182 PyValueError::new_err(format!("Failed to create character definition file: {e}"))
183 })?;
184
185 use std::io::Write;
186 writeln!(
187 char_def_file,
188 "# Character definition file generated from trained model"
189 )
190 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
191 writeln!(char_def_file, "# Format: CATEGORY_NAME invoke group length")
192 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
193 writeln!(char_def_file, "DEFAULT 0 1 0")
194 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
195 writeln!(char_def_file, "HIRAGANA 1 1 0")
196 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
197 writeln!(char_def_file, "KATAKANA 1 1 0")
198 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
199 writeln!(char_def_file, "KANJI 0 0 2")
200 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
201 writeln!(char_def_file, "ALPHA 1 1 0")
202 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
203 writeln!(char_def_file, "NUMERIC 1 1 0")
204 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
205 writeln!(char_def_file)
206 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
207
208 writeln!(char_def_file, "# Character mappings")
209 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
210 writeln!(char_def_file, "0x3041..0x3096 HIRAGANA # Hiragana")
211 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
212 writeln!(char_def_file, "0x30A1..0x30F6 KATAKANA # Katakana")
213 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
214 writeln!(
215 char_def_file,
216 "0x4E00..0x9FAF KANJI # CJK Unified Ideographs"
217 )
218 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
219 writeln!(char_def_file, "0x0030..0x0039 NUMERIC # ASCII Digits")
220 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
221 writeln!(char_def_file, "0x0041..0x005A ALPHA # ASCII Uppercase")
222 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
223 writeln!(char_def_file, "0x0061..0x007A ALPHA # ASCII Lowercase")
224 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
225
226 let mut files_created = vec![
227 lexicon_path.clone(),
228 connector_path.clone(),
229 unk_path.clone(),
230 char_def_path.clone(),
231 ];
232
233 if let Some(metadata_str) = metadata {
235 let metadata_path = Path::new(metadata_str);
236 if !metadata_path.exists() {
237 return Err(PyValueError::new_err(format!(
238 "Metadata file does not exist: {}",
239 metadata_path.display()
240 )));
241 }
242
243 let output_metadata_path = output_path.join("metadata.json");
244 let mut metadata_file = File::create(&output_metadata_path)
245 .map_err(|e| PyValueError::new_err(format!("Failed to create metadata file: {e}")))?;
246
247 serializable_model
248 .update_metadata_json(metadata_path, &mut metadata_file)
249 .map_err(|e| PyValueError::new_err(format!("Failed to update metadata: {e}")))?;
250
251 files_created.push(output_metadata_path);
252 println!("Updated metadata.json with trained model values");
253 }
254
255 println!("Dictionary files exported to: {}", output_path.display());
256 println!("Files created:");
257 for file in &files_created {
258 println!(" - {}", file.display());
259 }
260
261 Ok(())
262}