Skip to main content

lindera_nodejs/
trainer.rs

1//! Training functionality for custom morphological models.
2//!
3//! This module provides functions for training custom morphological analysis models
4//! from annotated corpora. Requires the `train` feature.
5
6use std::fs::File;
7use std::path::Path;
8
9use lindera::dictionary::trainer::{Corpus, Model, SerializableModel, Trainer, TrainerConfig};
10
11use crate::error::to_napi_error;
12
13/// Options for training a CRF model.
14#[napi(object)]
15pub struct TrainOptions {
16    /// Path to the seed lexicon file (CSV format).
17    pub seed: String,
18    /// Path to the annotated training corpus.
19    pub corpus: String,
20    /// Path to the character definition file (char.def).
21    pub char_def: String,
22    /// Path to the unknown word definition file (unk.def).
23    pub unk_def: String,
24    /// Path to the feature definition file (feature.def).
25    pub feature_def: String,
26    /// Path to the rewrite rule definition file (rewrite.def).
27    pub rewrite_def: String,
28    /// Output path for the trained model file.
29    pub output: String,
30    /// L1 regularization cost (0.0-1.0, default: 0.01).
31    pub lambda: Option<f64>,
32    /// Maximum number of training iterations (default: 100).
33    pub max_iter: Option<u32>,
34    /// Number of threads (undefined = auto-detect CPU cores).
35    pub max_threads: Option<u32>,
36}
37
38/// Trains a morphological analysis model from an annotated corpus.
39///
40/// # Arguments
41///
42/// * `options` - Training options containing file paths and parameters.
43#[napi]
44pub fn train(options: TrainOptions) -> napi::Result<()> {
45    let seed_path = Path::new(&options.seed);
46    let corpus_path = Path::new(&options.corpus);
47    let char_def_path = Path::new(&options.char_def);
48    let unk_def_path = Path::new(&options.unk_def);
49    let feature_def_path = Path::new(&options.feature_def);
50    let rewrite_def_path = Path::new(&options.rewrite_def);
51    let output_path = Path::new(&options.output);
52
53    // Validate input files
54    for (path, name) in [
55        (seed_path, "seed"),
56        (corpus_path, "corpus"),
57        (char_def_path, "charDef"),
58        (unk_def_path, "unkDef"),
59        (feature_def_path, "featureDef"),
60        (rewrite_def_path, "rewriteDef"),
61    ] {
62        if !path.exists() {
63            return Err(napi::Error::new(
64                napi::Status::InvalidArg,
65                format!("{} file does not exist: {}", name, path.display()),
66            ));
67        }
68    }
69
70    // Load configuration
71    let config = TrainerConfig::from_paths(
72        seed_path,
73        char_def_path,
74        unk_def_path,
75        feature_def_path,
76        rewrite_def_path,
77    )
78    .map_err(|e| to_napi_error(format!("Failed to load trainer configuration: {e}")))?;
79
80    // Initialize trainer
81    let lambda = options.lambda.unwrap_or(0.01);
82    let max_iter = options.max_iter.unwrap_or(100) as u64;
83    let num_threads = options
84        .max_threads
85        .map(|t| t as usize)
86        .unwrap_or_else(num_cpus::get);
87
88    let trainer = Trainer::new(config)
89        .map_err(|e| to_napi_error(format!("Failed to initialize trainer: {e}")))?
90        .regularization_cost(lambda)
91        .max_iter(max_iter)
92        .num_threads(num_threads);
93
94    // Load corpus
95    let corpus_file = File::open(corpus_path)
96        .map_err(|e| to_napi_error(format!("Failed to open corpus file: {e}")))?;
97    let corpus = Corpus::from_reader(corpus_file)
98        .map_err(|e| to_napi_error(format!("Failed to load corpus: {e}")))?;
99
100    println!("Training with {} examples...", corpus.len());
101
102    // Train model
103    let model = trainer
104        .train(corpus)
105        .map_err(|e| to_napi_error(format!("Training failed: {e}")))?;
106
107    // Save model
108    if let Some(parent) = output_path.parent() {
109        std::fs::create_dir_all(parent)
110            .map_err(|e| to_napi_error(format!("Failed to create output directory: {e}")))?;
111    }
112
113    let mut output_file = File::create(output_path)
114        .map_err(|e| to_napi_error(format!("Failed to create output file: {e}")))?;
115
116    model
117        .write_model(&mut output_file)
118        .map_err(|e| to_napi_error(format!("Failed to write model: {e}")))?;
119
120    println!("Model saved to {}", output_path.display());
121    Ok(())
122}
123
124/// Options for exporting a trained model.
125#[napi(object)]
126pub struct ExportOptions {
127    /// Path to the trained model file (.dat).
128    pub model: String,
129    /// Output directory for dictionary source files.
130    pub output: String,
131    /// Optional path to a base metadata.json file.
132    pub metadata: Option<String>,
133}
134
135/// Exports dictionary files from a trained model.
136///
137/// # Arguments
138///
139/// * `options` - Export options containing file paths.
140#[napi]
141pub fn export_model(options: ExportOptions) -> napi::Result<()> {
142    let model_path = Path::new(&options.model);
143    let output_path = Path::new(&options.output);
144
145    if !model_path.exists() {
146        return Err(napi::Error::new(
147            napi::Status::InvalidArg,
148            format!("Model file does not exist: {}", model_path.display()),
149        ));
150    }
151
152    // Load trained model
153    let model_file = File::open(model_path)
154        .map_err(|e| to_napi_error(format!("Failed to open model file: {e}")))?;
155
156    let serializable_model: SerializableModel = Model::read_model(model_file)
157        .map_err(|e| to_napi_error(format!("Failed to load model: {e}")))?;
158
159    // Create output directory
160    std::fs::create_dir_all(output_path)
161        .map_err(|e| to_napi_error(format!("Failed to create output directory: {e}")))?;
162
163    // Export dictionary files
164    let lexicon_path = output_path.join("lex.csv");
165    let connector_path = output_path.join("matrix.def");
166    let unk_path = output_path.join("unk.def");
167    let char_def_path = output_path.join("char.def");
168
169    // Write lexicon file
170    let mut lexicon_file = File::create(&lexicon_path)
171        .map_err(|e| to_napi_error(format!("Failed to create lexicon file: {e}")))?;
172    serializable_model
173        .write_lexicon(&mut lexicon_file)
174        .map_err(|e| to_napi_error(format!("Failed to write lexicon: {e}")))?;
175
176    // Write connection matrix
177    let mut connector_file = File::create(&connector_path)
178        .map_err(|e| to_napi_error(format!("Failed to create connection matrix file: {e}")))?;
179    serializable_model
180        .write_connection_costs(&mut connector_file)
181        .map_err(|e| to_napi_error(format!("Failed to write connection costs: {e}")))?;
182
183    // Write unknown word definitions
184    let mut unk_file = File::create(&unk_path)
185        .map_err(|e| to_napi_error(format!("Failed to create unknown word file: {e}")))?;
186    serializable_model
187        .write_unknown_dictionary(&mut unk_file)
188        .map_err(|e| to_napi_error(format!("Failed to write unknown dictionary: {e}")))?;
189
190    // Write character definition file
191    let mut char_def_file = File::create(&char_def_path)
192        .map_err(|e| to_napi_error(format!("Failed to create character definition file: {e}")))?;
193
194    use std::io::Write;
195    writeln!(
196        char_def_file,
197        "# Character definition file generated from trained model"
198    )
199    .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
200    writeln!(char_def_file, "# Format: CATEGORY_NAME invoke group length")
201        .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
202    writeln!(char_def_file, "DEFAULT 0 1 0")
203        .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
204    writeln!(char_def_file, "HIRAGANA 1 1 0")
205        .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
206    writeln!(char_def_file, "KATAKANA 1 1 0")
207        .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
208    writeln!(char_def_file, "KANJI 0 0 2")
209        .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
210    writeln!(char_def_file, "ALPHA 1 1 0")
211        .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
212    writeln!(char_def_file, "NUMERIC 1 1 0")
213        .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
214    writeln!(char_def_file).map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
215
216    writeln!(char_def_file, "# Character mappings")
217        .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
218    writeln!(char_def_file, "0x3041..0x3096 HIRAGANA  # Hiragana")
219        .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
220    writeln!(char_def_file, "0x30A1..0x30F6 KATAKANA  # Katakana")
221        .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
222    writeln!(
223        char_def_file,
224        "0x4E00..0x9FAF KANJI     # CJK Unified Ideographs"
225    )
226    .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
227    writeln!(char_def_file, "0x0030..0x0039 NUMERIC   # ASCII Digits")
228        .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
229    writeln!(char_def_file, "0x0041..0x005A ALPHA     # ASCII Uppercase")
230        .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
231    writeln!(char_def_file, "0x0061..0x007A ALPHA     # ASCII Lowercase")
232        .map_err(|e| to_napi_error(format!("Failed to write char.def: {e}")))?;
233
234    let mut files_created = vec![
235        lexicon_path.clone(),
236        connector_path.clone(),
237        unk_path.clone(),
238        char_def_path.clone(),
239    ];
240
241    // Handle metadata.json update if provided
242    if let Some(metadata_str) = &options.metadata {
243        let metadata_path = Path::new(metadata_str);
244        if !metadata_path.exists() {
245            return Err(napi::Error::new(
246                napi::Status::InvalidArg,
247                format!("Metadata file does not exist: {}", metadata_path.display()),
248            ));
249        }
250
251        let output_metadata_path = output_path.join("metadata.json");
252        let mut metadata_file = File::create(&output_metadata_path)
253            .map_err(|e| to_napi_error(format!("Failed to create metadata file: {e}")))?;
254
255        serializable_model
256            .update_metadata_json(metadata_path, &mut metadata_file)
257            .map_err(|e| to_napi_error(format!("Failed to update metadata: {e}")))?;
258
259        files_created.push(output_metadata_path);
260        println!("Updated metadata.json with trained model values");
261    }
262
263    println!("Dictionary files exported to: {}", output_path.display());
264    println!("Files created:");
265    for file in &files_created {
266        println!("  - {}", file.display());
267    }
268
269    Ok(())
270}