lindera_nodejs/
trainer.rs1use std::fs::File;
7use std::path::Path;
8
9use lindera::dictionary::trainer::{Corpus, Model, SerializableModel, Trainer, TrainerConfig};
10
11use crate::error::to_napi_error;
12
13#[napi(object)]
15pub struct TrainOptions {
16 pub seed: String,
18 pub corpus: String,
20 pub char_def: String,
22 pub unk_def: String,
24 pub feature_def: String,
26 pub rewrite_def: String,
28 pub output: String,
30 pub lambda: Option<f64>,
32 pub max_iter: Option<u32>,
34 pub max_threads: Option<u32>,
36}
37
38#[napi]
44pub fn train(options: TrainOptions) -> napi::Result<()> {
45 let seed_path = Path::new(&options.seed);
46 let corpus_path = Path::new(&options.corpus);
47 let char_def_path = Path::new(&options.char_def);
48 let unk_def_path = Path::new(&options.unk_def);
49 let feature_def_path = Path::new(&options.feature_def);
50 let rewrite_def_path = Path::new(&options.rewrite_def);
51 let output_path = Path::new(&options.output);
52
53 for (path, name) in [
55 (seed_path, "seed"),
56 (corpus_path, "corpus"),
57 (char_def_path, "charDef"),
58 (unk_def_path, "unkDef"),
59 (feature_def_path, "featureDef"),
60 (rewrite_def_path, "rewriteDef"),
61 ] {
62 if !path.exists() {
63 return Err(napi::Error::new(
64 napi::Status::InvalidArg,
65 format!("{} file does not exist: {}", name, path.display()),
66 ));
67 }
68 }
69
70 let config = TrainerConfig::from_paths(
72 seed_path,
73 char_def_path,
74 unk_def_path,
75 feature_def_path,
76 rewrite_def_path,
77 )
78 .map_err(|e| to_napi_error(format!("Failed to load trainer configuration: {e}")))?;
79
80 let lambda = options.lambda.unwrap_or(0.01);
82 let max_iter = options.max_iter.unwrap_or(100) as u64;
83 let num_threads = options
84 .max_threads
85 .map(|t| t as usize)
86 .unwrap_or_else(num_cpus::get);
87
88 let trainer = Trainer::new(config)
89 .map_err(|e| to_napi_error(format!("Failed to initialize trainer: {e}")))?
90 .regularization_cost(lambda)
91 .max_iter(max_iter)
92 .num_threads(num_threads);
93
94 let corpus_file = File::open(corpus_path)
96 .map_err(|e| to_napi_error(format!("Failed to open corpus file: {e}")))?;
97 let corpus = Corpus::from_reader(corpus_file)
98 .map_err(|e| to_napi_error(format!("Failed to load corpus: {e}")))?;
99
100 println!("Training with {} examples...", corpus.len());
101
102 let model = trainer
104 .train(corpus)
105 .map_err(|e| to_napi_error(format!("Training failed: {e}")))?;
106
107 if let Some(parent) = output_path.parent() {
109 std::fs::create_dir_all(parent)
110 .map_err(|e| to_napi_error(format!("Failed to create output directory: {e}")))?;
111 }
112
113 let mut output_file = File::create(output_path)
114 .map_err(|e| to_napi_error(format!("Failed to create output file: {e}")))?;
115
116 model
117 .write_model(&mut output_file)
118 .map_err(|e| to_napi_error(format!("Failed to write model: {e}")))?;
119
120 println!("Model saved to {}", output_path.display());
121 Ok(())
122}
123
124#[napi(object)]
126pub struct ExportOptions {
127 pub model: String,
129 pub output: String,
131 pub metadata: Option<String>,
133}
134
135#[napi]
141pub fn export_model(options: ExportOptions) -> napi::Result<()> {
142 let model_path = Path::new(&options.model);
143 let output_path = Path::new(&options.output);
144
145 if !model_path.exists() {
146 return Err(napi::Error::new(
147 napi::Status::InvalidArg,
148 format!("Model file does not exist: {}", model_path.display()),
149 ));
150 }
151
152 let model_file = File::open(model_path)
154 .map_err(|e| to_napi_error(format!("Failed to open model file: {e}")))?;
155
156 let serializable_model: SerializableModel = Model::read_model(model_file)
157 .map_err(|e| to_napi_error(format!("Failed to load model: {e}")))?;
158
159 std::fs::create_dir_all(output_path)
161 .map_err(|e| to_napi_error(format!("Failed to create output directory: {e}")))?;
162
163 let lexicon_path = output_path.join("lex.csv");
165 let connector_path = output_path.join("matrix.def");
166 let unk_path = output_path.join("unk.def");
167 let char_def_path = output_path.join("char.def");
168
169 let mut lexicon_file = File::create(&lexicon_path)
171 .map_err(|e| to_napi_error(format!("Failed to create lexicon file: {e}")))?;
172 serializable_model
173 .write_lexicon(&mut lexicon_file)
174 .map_err(|e| to_napi_error(format!("Failed to write lexicon: {e}")))?;
175
176 let mut connector_file = File::create(&connector_path)
178 .map_err(|e| to_napi_error(format!("Failed to create connection matrix file: {e}")))?;
179 serializable_model
180 .write_connection_costs(&mut connector_file)
181 .map_err(|e| to_napi_error(format!("Failed to write connection costs: {e}")))?;
182
183 let mut unk_file = File::create(&unk_path)
185 .map_err(|e| to_napi_error(format!("Failed to create unknown word file: {e}")))?;
186 serializable_model
187 .write_unknown_dictionary(&mut unk_file)
188 .map_err(|e| to_napi_error(format!("Failed to write unknown dictionary: {e}")))?;
189
190 let mut char_def_file = File::create(&char_def_path)
192 .map_err(|e| to_napi_error(format!("Failed to create character definition file: {e}")))?;
193
194 use std::io::Write;
195 writeln!(
196 char_def_file,
197 "# Character definition file generated from trained model"
198 )
199 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
200 writeln!(char_def_file, "# Format: CATEGORY_NAME invoke group length")
201 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
202 writeln!(char_def_file, "DEFAULT 0 1 0")
203 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
204 writeln!(char_def_file, "HIRAGANA 1 1 0")
205 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
206 writeln!(char_def_file, "KATAKANA 1 1 0")
207 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
208 writeln!(char_def_file, "KANJI 0 0 2")
209 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
210 writeln!(char_def_file, "ALPHA 1 1 0")
211 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
212 writeln!(char_def_file, "NUMERIC 1 1 0")
213 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
214 writeln!(char_def_file).map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
215
216 writeln!(char_def_file, "# Character mappings")
217 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
218 writeln!(char_def_file, "0x3041..0x3096 HIRAGANA # Hiragana")
219 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
220 writeln!(char_def_file, "0x30A1..0x30F6 KATAKANA # Katakana")
221 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
222 writeln!(
223 char_def_file,
224 "0x4E00..0x9FAF KANJI # CJK Unified Ideographs"
225 )
226 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
227 writeln!(char_def_file, "0x0030..0x0039 NUMERIC # ASCII Digits")
228 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
229 writeln!(char_def_file, "0x0041..0x005A ALPHA # ASCII Uppercase")
230 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
231 writeln!(char_def_file, "0x0061..0x007A ALPHA # ASCII Lowercase")
232 .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
233
234 let mut files_created = vec![
235 lexicon_path.clone(),
236 connector_path.clone(),
237 unk_path.clone(),
238 char_def_path.clone(),
239 ];
240
241 if let Some(metadata_str) = &options.metadata {
243 let metadata_path = Path::new(metadata_str);
244 if !metadata_path.exists() {
245 return Err(napi::Error::new(
246 napi::Status::InvalidArg,
247 format!("Metadata file does not exist: {}", metadata_path.display()),
248 ));
249 }
250
251 let output_metadata_path = output_path.join("metadata.json");
252 let mut metadata_file = File::create(&output_metadata_path)
253 .map_err(|e| to_napi_error(format!("Failed to create metadata file: {e}")))?;
254
255 serializable_model
256 .update_metadata_json(metadata_path, &mut metadata_file)
257 .map_err(|e| to_napi_error(format!("Failed to update metadata: {e}")))?;
258
259 files_created.push(output_metadata_path);
260 println!("Updated metadata.json with trained model values");
261 }
262
263 println!("Dictionary files exported to: {}", output_path.display());
264 println!("Files created:");
265 for file in &files_created {
266 println!(" - {}", file.display());
267 }
268
269 Ok(())
270}