1use std::fs::File;
7use std::path::Path;
8
9use magnus::{Error, Ruby, function};
10
11use lindera::dictionary::trainer::{Corpus, Model, SerializableModel, Trainer, TrainerConfig};
12
13use crate::error::to_magnus_error;
14
15#[allow(clippy::too_many_arguments)]
30fn train(
31 seed: String,
32 corpus: String,
33 char_def: String,
34 unk_def: String,
35 feature_def: String,
36 rewrite_def: String,
37 output: String,
38 lambda: Option<f64>,
39 max_iter: Option<u64>,
40 max_threads: Option<usize>,
41) -> Result<(), Error> {
42 let ruby = Ruby::get().expect("Ruby runtime not initialized");
43
44 let seed_path = Path::new(&seed);
45 let corpus_path = Path::new(&corpus);
46 let char_def_path = Path::new(&char_def);
47 let unk_def_path = Path::new(&unk_def);
48 let feature_def_path = Path::new(&feature_def);
49 let rewrite_def_path = Path::new(&rewrite_def);
50 let output_path = Path::new(&output);
51
52 for (path, name) in [
54 (seed_path, "seed"),
55 (corpus_path, "corpus"),
56 (char_def_path, "char_def"),
57 (unk_def_path, "unk_def"),
58 (feature_def_path, "feature_def"),
59 (rewrite_def_path, "rewrite_def"),
60 ] {
61 if !path.exists() {
62 return Err(Error::new(
63 ruby.exception_arg_error(),
64 format!("{} file does not exist: {}", name, path.display()),
65 ));
66 }
67 }
68
69 let config = TrainerConfig::from_paths(
71 seed_path,
72 char_def_path,
73 unk_def_path,
74 feature_def_path,
75 rewrite_def_path,
76 )
77 .map_err(|e| to_magnus_error(&ruby, format!("Failed to load trainer configuration: {e}")))?;
78
79 let num_threads = max_threads.unwrap_or_else(num_cpus::get);
81 let lambda_val = lambda.unwrap_or(0.01);
82 let max_iter_val = max_iter.unwrap_or(100);
83
84 let trainer = Trainer::new(config)
85 .map_err(|e| to_magnus_error(&ruby, format!("Failed to initialize trainer: {e}")))?
86 .regularization_cost(lambda_val)
87 .max_iter(max_iter_val)
88 .num_threads(num_threads);
89
90 let corpus_file = File::open(corpus_path)
92 .map_err(|e| to_magnus_error(&ruby, format!("Failed to open corpus file: {e}")))?;
93 let corpus_data = Corpus::from_reader(corpus_file)
94 .map_err(|e| to_magnus_error(&ruby, format!("Failed to load corpus: {e}")))?;
95
96 println!("Training with {} examples...", corpus_data.len());
97
98 let model = trainer
100 .train(corpus_data)
101 .map_err(|e| to_magnus_error(&ruby, format!("Training failed: {e}")))?;
102
103 if let Some(parent) = output_path.parent() {
105 std::fs::create_dir_all(parent).map_err(|e| {
106 to_magnus_error(&ruby, format!("Failed to create output directory: {e}"))
107 })?;
108 }
109
110 let mut output_file = File::create(output_path)
111 .map_err(|e| to_magnus_error(&ruby, format!("Failed to create output file: {e}")))?;
112
113 model
114 .write_model(&mut output_file)
115 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write model: {e}")))?;
116
117 println!("Model saved to {}", output_path.display());
118 Ok(())
119}
120
121fn export(model: String, output: String, metadata: Option<String>) -> Result<(), Error> {
129 let ruby = Ruby::get().expect("Ruby runtime not initialized");
130 let model_path = Path::new(&model);
131 let output_path = Path::new(&output);
132
133 if !model_path.exists() {
134 return Err(Error::new(
135 ruby.exception_arg_error(),
136 format!("Model file does not exist: {}", model_path.display()),
137 ));
138 }
139
140 let model_file = File::open(model_path)
142 .map_err(|e| to_magnus_error(&ruby, format!("Failed to open model file: {e}")))?;
143
144 let serializable_model: SerializableModel = Model::read_model(model_file)
145 .map_err(|e| to_magnus_error(&ruby, format!("Failed to load model: {e}")))?;
146
147 std::fs::create_dir_all(output_path)
149 .map_err(|e| to_magnus_error(&ruby, format!("Failed to create output directory: {e}")))?;
150
151 let lexicon_path = output_path.join("lex.csv");
153 let connector_path = output_path.join("matrix.def");
154 let unk_path = output_path.join("unk.def");
155 let char_def_path = output_path.join("char.def");
156
157 let mut lexicon_file = File::create(&lexicon_path)
159 .map_err(|e| to_magnus_error(&ruby, format!("Failed to create lexicon file: {e}")))?;
160 serializable_model
161 .write_lexicon(&mut lexicon_file)
162 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write lexicon: {e}")))?;
163
164 let mut connector_file = File::create(&connector_path).map_err(|e| {
166 to_magnus_error(
167 &ruby,
168 format!("Failed to create connection matrix file: {e}"),
169 )
170 })?;
171 serializable_model
172 .write_connection_costs(&mut connector_file)
173 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write connection costs: {e}")))?;
174
175 let mut unk_file = File::create(&unk_path)
177 .map_err(|e| to_magnus_error(&ruby, format!("Failed to create unknown word file: {e}")))?;
178 serializable_model
179 .write_unknown_dictionary(&mut unk_file)
180 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write unknown dictionary: {e}")))?;
181
182 let mut char_def_file = File::create(&char_def_path).map_err(|e| {
184 to_magnus_error(
185 &ruby,
186 format!("Failed to create character definition file: {e}"),
187 )
188 })?;
189
190 use std::io::Write;
191 writeln!(
192 char_def_file,
193 "# Character definition file generated from trained model"
194 )
195 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
196 writeln!(char_def_file, "# Format: CATEGORY_NAME invoke group length")
197 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
198 writeln!(char_def_file, "DEFAULT 0 1 0")
199 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
200 writeln!(char_def_file, "HIRAGANA 1 1 0")
201 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
202 writeln!(char_def_file, "KATAKANA 1 1 0")
203 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
204 writeln!(char_def_file, "KANJI 0 0 2")
205 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
206 writeln!(char_def_file, "ALPHA 1 1 0")
207 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
208 writeln!(char_def_file, "NUMERIC 1 1 0")
209 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
210 writeln!(char_def_file)
211 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
212
213 writeln!(char_def_file, "# Character mappings")
214 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
215 writeln!(char_def_file, "0x3041..0x3096 HIRAGANA # Hiragana")
216 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
217 writeln!(char_def_file, "0x30A1..0x30F6 KATAKANA # Katakana")
218 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
219 writeln!(
220 char_def_file,
221 "0x4E00..0x9FAF KANJI # CJK Unified Ideographs"
222 )
223 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
224 writeln!(char_def_file, "0x0030..0x0039 NUMERIC # ASCII Digits")
225 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
226 writeln!(char_def_file, "0x0041..0x005A ALPHA # ASCII Uppercase")
227 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
228 writeln!(char_def_file, "0x0061..0x007A ALPHA # ASCII Lowercase")
229 .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
230
231 let mut files_created = vec![
232 lexicon_path.clone(),
233 connector_path.clone(),
234 unk_path.clone(),
235 char_def_path.clone(),
236 ];
237
238 if let Some(metadata_str) = metadata {
240 let metadata_path = Path::new(&metadata_str);
241 if !metadata_path.exists() {
242 return Err(Error::new(
243 ruby.exception_arg_error(),
244 format!("Metadata file does not exist: {}", metadata_path.display()),
245 ));
246 }
247
248 let output_metadata_path = output_path.join("metadata.json");
249 let mut metadata_file = File::create(&output_metadata_path)
250 .map_err(|e| to_magnus_error(&ruby, format!("Failed to create metadata file: {e}")))?;
251
252 serializable_model
253 .update_metadata_json(metadata_path, &mut metadata_file)
254 .map_err(|e| to_magnus_error(&ruby, format!("Failed to update metadata: {e}")))?;
255
256 files_created.push(output_metadata_path);
257 println!("Updated metadata.json with trained model values");
258 }
259
260 println!("Dictionary files exported to: {}", output_path.display());
261 println!("Files created:");
262 for file in &files_created {
263 println!(" - {}", file.display());
264 }
265
266 Ok(())
267}
268
269pub fn define(_ruby: &Ruby, module: &magnus::RModule) -> Result<(), Error> {
280 module.define_module_function("train", function!(train, 10))?;
281 module.define_module_function("export", function!(export, 3))?;
282 Ok(())
283}