1use std::fs::File;
30use std::path::Path;
31
32use pyo3::{exceptions::PyValueError, prelude::*};
33
34use lindera::dictionary::trainer::{Corpus, Model, SerializableModel, Trainer, TrainerConfig};
35
36#[pyfunction]
55#[pyo3(signature = (seed, corpus, char_def, unk_def, feature_def, rewrite_def, output, lambda_=0.01, max_iter=100, max_threads=None))]
56#[allow(clippy::too_many_arguments)]
57pub fn train(
58 seed: &str,
59 corpus: &str,
60 char_def: &str,
61 unk_def: &str,
62 feature_def: &str,
63 rewrite_def: &str,
64 output: &str,
65 lambda_: f64,
66 max_iter: u64,
67 max_threads: Option<usize>,
68) -> PyResult<()> {
69 let seed_path = Path::new(seed);
70 let corpus_path = Path::new(corpus);
71 let char_def_path = Path::new(char_def);
72 let unk_def_path = Path::new(unk_def);
73 let feature_def_path = Path::new(feature_def);
74 let rewrite_def_path = Path::new(rewrite_def);
75 let output_path = Path::new(output);
76
77 for (path, name) in [
79 (seed_path, "seed"),
80 (corpus_path, "corpus"),
81 (char_def_path, "char_def"),
82 (unk_def_path, "unk_def"),
83 (feature_def_path, "feature_def"),
84 (rewrite_def_path, "rewrite_def"),
85 ] {
86 if !path.exists() {
87 return Err(PyValueError::new_err(format!(
88 "{} file does not exist: {}",
89 name,
90 path.display()
91 )));
92 }
93 }
94
95 let config = TrainerConfig::from_paths(
97 seed_path,
98 char_def_path,
99 unk_def_path,
100 feature_def_path,
101 rewrite_def_path,
102 )
103 .map_err(|e| PyValueError::new_err(format!("Failed to load trainer configuration: {e}")))?;
104
105 let num_threads = max_threads.unwrap_or_else(num_cpus::get);
107 let trainer = Trainer::new(config)
108 .map_err(|e| PyValueError::new_err(format!("Failed to initialize trainer: {e}")))?
109 .regularization_cost(lambda_)
110 .max_iter(max_iter)
111 .num_threads(num_threads);
112
113 let corpus_file = File::open(corpus_path)
115 .map_err(|e| PyValueError::new_err(format!("Failed to open corpus file: {e}")))?;
116 let corpus = Corpus::from_reader(corpus_file)
117 .map_err(|e| PyValueError::new_err(format!("Failed to load corpus: {e}")))?;
118
119 println!("Training with {} examples...", corpus.len());
120
121 let model = trainer
123 .train(corpus)
124 .map_err(|e| PyValueError::new_err(format!("Training failed: {e}")))?;
125
126 if let Some(parent) = output_path.parent() {
128 std::fs::create_dir_all(parent).map_err(|e| {
129 PyValueError::new_err(format!("Failed to create output directory: {e}"))
130 })?;
131 }
132
133 let mut output_file = File::create(output_path)
134 .map_err(|e| PyValueError::new_err(format!("Failed to create output file: {e}")))?;
135
136 model
137 .write_model(&mut output_file)
138 .map_err(|e| PyValueError::new_err(format!("Failed to write model: {e}")))?;
139
140 println!("Model saved to {}", output_path.display());
141 Ok(())
142}
143
144#[pyfunction]
156#[pyo3(signature = (model, output, metadata=None))]
157pub fn export(model: &str, output: &str, metadata: Option<&str>) -> PyResult<()> {
158 let model_path = Path::new(model);
159 let output_path = Path::new(output);
160
161 if !model_path.exists() {
163 return Err(PyValueError::new_err(format!(
164 "Model file does not exist: {}",
165 model_path.display()
166 )));
167 }
168
169 let model_file = File::open(model_path)
171 .map_err(|e| PyValueError::new_err(format!("Failed to open model file: {e}")))?;
172
173 let serializable_model: SerializableModel = Model::read_model(model_file)
174 .map_err(|e| PyValueError::new_err(format!("Failed to load model: {e}")))?;
175
176 std::fs::create_dir_all(output_path)
178 .map_err(|e| PyValueError::new_err(format!("Failed to create output directory: {e}")))?;
179
180 let lexicon_path = output_path.join("lex.csv");
182 let connector_path = output_path.join("matrix.def");
183 let unk_path = output_path.join("unk.def");
184 let char_def_path = output_path.join("char.def");
185
186 let mut lexicon_file = File::create(&lexicon_path)
188 .map_err(|e| PyValueError::new_err(format!("Failed to create lexicon file: {e}")))?;
189 serializable_model
190 .write_lexicon(&mut lexicon_file)
191 .map_err(|e| PyValueError::new_err(format!("Failed to write lexicon: {e}")))?;
192
193 let mut connector_file = File::create(&connector_path).map_err(|e| {
195 PyValueError::new_err(format!("Failed to create connection matrix file: {e}"))
196 })?;
197 serializable_model
198 .write_connection_costs(&mut connector_file)
199 .map_err(|e| PyValueError::new_err(format!("Failed to write connection costs: {e}")))?;
200
201 let mut unk_file = File::create(&unk_path)
203 .map_err(|e| PyValueError::new_err(format!("Failed to create unknown word file: {e}")))?;
204 serializable_model
205 .write_unknown_dictionary(&mut unk_file)
206 .map_err(|e| PyValueError::new_err(format!("Failed to write unknown dictionary: {e}")))?;
207
208 let mut char_def_file = File::create(&char_def_path).map_err(|e| {
210 PyValueError::new_err(format!("Failed to create character definition file: {e}"))
211 })?;
212
213 use std::io::Write;
214 writeln!(
215 char_def_file,
216 "# Character definition file generated from trained model"
217 )
218 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
219 writeln!(char_def_file, "# Format: CATEGORY_NAME invoke group length")
220 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
221 writeln!(char_def_file, "DEFAULT 0 1 0")
222 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
223 writeln!(char_def_file, "HIRAGANA 1 1 0")
224 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
225 writeln!(char_def_file, "KATAKANA 1 1 0")
226 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
227 writeln!(char_def_file, "KANJI 0 0 2")
228 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
229 writeln!(char_def_file, "ALPHA 1 1 0")
230 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
231 writeln!(char_def_file, "NUMERIC 1 1 0")
232 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
233 writeln!(char_def_file)
234 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
235
236 writeln!(char_def_file, "# Character mappings")
237 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
238 writeln!(char_def_file, "0x3041..0x3096 HIRAGANA # Hiragana")
239 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
240 writeln!(char_def_file, "0x30A1..0x30F6 KATAKANA # Katakana")
241 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
242 writeln!(
243 char_def_file,
244 "0x4E00..0x9FAF KANJI # CJK Unified Ideographs"
245 )
246 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
247 writeln!(char_def_file, "0x0030..0x0039 NUMERIC # ASCII Digits")
248 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
249 writeln!(char_def_file, "0x0041..0x005A ALPHA # ASCII Uppercase")
250 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
251 writeln!(char_def_file, "0x0061..0x007A ALPHA # ASCII Lowercase")
252 .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
253
254 let mut files_created = vec![
255 lexicon_path.clone(),
256 connector_path.clone(),
257 unk_path.clone(),
258 char_def_path.clone(),
259 ];
260
261 if let Some(metadata_str) = metadata {
263 let metadata_path = Path::new(metadata_str);
264 if !metadata_path.exists() {
265 return Err(PyValueError::new_err(format!(
266 "Metadata file does not exist: {}",
267 metadata_path.display()
268 )));
269 }
270
271 let output_metadata_path = output_path.join("metadata.json");
272 let mut metadata_file = File::create(&output_metadata_path)
273 .map_err(|e| PyValueError::new_err(format!("Failed to create metadata file: {e}")))?;
274
275 serializable_model
276 .update_metadata_json(metadata_path, &mut metadata_file)
277 .map_err(|e| PyValueError::new_err(format!("Failed to update metadata: {e}")))?;
278
279 files_created.push(output_metadata_path);
280 println!("Updated metadata.json with trained model values");
281 }
282
283 println!("Dictionary files exported to: {}", output_path.display());
284 println!("Files created:");
285 for file in &files_created {
286 println!(" - {}", file.display());
287 }
288
289 Ok(())
290}