Skip to main content

lindera_ruby/
trainer.rs

1//! Training functionality for custom morphological models.
2//!
3//! This module provides functions for training custom morphological analysis models
4//! from annotated corpora.
5
6use std::fs::File;
7use std::path::Path;
8
9use magnus::{Error, Ruby, function};
10
11use lindera::dictionary::trainer::{Corpus, Model, SerializableModel, Trainer, TrainerConfig};
12
13use crate::error::to_magnus_error;
14
15/// Trains a morphological analysis model from an annotated corpus.
16///
17/// # Arguments
18///
19/// * `seed` - Seed lexicon file path (CSV format).
20/// * `corpus` - Training corpus file path (annotated text).
21/// * `char_def` - Character definition file path (char.def).
22/// * `unk_def` - Unknown word definition file path (unk.def).
23/// * `feature_def` - Feature definition file path (feature.def).
24/// * `rewrite_def` - Rewrite rule definition file path (rewrite.def).
25/// * `output` - Output model file path.
26/// * `lambda` - L1 regularization (0.0-1.0), default: 0.01.
27/// * `max_iter` - Maximum number of iterations, default: 100.
28/// * `max_threads` - Maximum number of threads (nil = auto-detect CPU cores).
29#[allow(clippy::too_many_arguments)]
30fn train(
31    seed: String,
32    corpus: String,
33    char_def: String,
34    unk_def: String,
35    feature_def: String,
36    rewrite_def: String,
37    output: String,
38    lambda: Option<f64>,
39    max_iter: Option<u64>,
40    max_threads: Option<usize>,
41) -> Result<(), Error> {
42    let ruby = Ruby::get().expect("Ruby runtime not initialized");
43
44    let seed_path = Path::new(&seed);
45    let corpus_path = Path::new(&corpus);
46    let char_def_path = Path::new(&char_def);
47    let unk_def_path = Path::new(&unk_def);
48    let feature_def_path = Path::new(&feature_def);
49    let rewrite_def_path = Path::new(&rewrite_def);
50    let output_path = Path::new(&output);
51
52    // Validate input files
53    for (path, name) in [
54        (seed_path, "seed"),
55        (corpus_path, "corpus"),
56        (char_def_path, "char_def"),
57        (unk_def_path, "unk_def"),
58        (feature_def_path, "feature_def"),
59        (rewrite_def_path, "rewrite_def"),
60    ] {
61        if !path.exists() {
62            return Err(Error::new(
63                ruby.exception_arg_error(),
64                format!("{} file does not exist: {}", name, path.display()),
65            ));
66        }
67    }
68
69    // Load configuration
70    let config = TrainerConfig::from_paths(
71        seed_path,
72        char_def_path,
73        unk_def_path,
74        feature_def_path,
75        rewrite_def_path,
76    )
77    .map_err(|e| to_magnus_error(&ruby, format!("Failed to load trainer configuration: {e}")))?;
78
79    // Initialize trainer
80    let num_threads = max_threads.unwrap_or_else(num_cpus::get);
81    let lambda_val = lambda.unwrap_or(0.01);
82    let max_iter_val = max_iter.unwrap_or(100);
83
84    let trainer = Trainer::new(config)
85        .map_err(|e| to_magnus_error(&ruby, format!("Failed to initialize trainer: {e}")))?
86        .regularization_cost(lambda_val)
87        .max_iter(max_iter_val)
88        .num_threads(num_threads);
89
90    // Load corpus
91    let corpus_file = File::open(corpus_path)
92        .map_err(|e| to_magnus_error(&ruby, format!("Failed to open corpus file: {e}")))?;
93    let corpus_data = Corpus::from_reader(corpus_file)
94        .map_err(|e| to_magnus_error(&ruby, format!("Failed to load corpus: {e}")))?;
95
96    println!("Training with {} examples...", corpus_data.len());
97
98    // Train model
99    let model = trainer
100        .train(corpus_data)
101        .map_err(|e| to_magnus_error(&ruby, format!("Training failed: {e}")))?;
102
103    // Save model
104    if let Some(parent) = output_path.parent() {
105        std::fs::create_dir_all(parent).map_err(|e| {
106            to_magnus_error(&ruby, format!("Failed to create output directory: {e}"))
107        })?;
108    }
109
110    let mut output_file = File::create(output_path)
111        .map_err(|e| to_magnus_error(&ruby, format!("Failed to create output file: {e}")))?;
112
113    model
114        .write_model(&mut output_file)
115        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write model: {e}")))?;
116
117    println!("Model saved to {}", output_path.display());
118    Ok(())
119}
120
121/// Export dictionary files from trained model.
122///
123/// # Arguments
124///
125/// * `model` - Trained model file path (.dat format).
126/// * `output` - Output directory path (creates lex.csv, matrix.def, unk.def, char.def).
127/// * `metadata` - Optional base metadata.json file to update with trained model values.
128fn export(model: String, output: String, metadata: Option<String>) -> Result<(), Error> {
129    let ruby = Ruby::get().expect("Ruby runtime not initialized");
130    let model_path = Path::new(&model);
131    let output_path = Path::new(&output);
132
133    if !model_path.exists() {
134        return Err(Error::new(
135            ruby.exception_arg_error(),
136            format!("Model file does not exist: {}", model_path.display()),
137        ));
138    }
139
140    // Load trained model
141    let model_file = File::open(model_path)
142        .map_err(|e| to_magnus_error(&ruby, format!("Failed to open model file: {e}")))?;
143
144    let serializable_model: SerializableModel = Model::read_model(model_file)
145        .map_err(|e| to_magnus_error(&ruby, format!("Failed to load model: {e}")))?;
146
147    // Create output directory
148    std::fs::create_dir_all(output_path)
149        .map_err(|e| to_magnus_error(&ruby, format!("Failed to create output directory: {e}")))?;
150
151    // Export dictionary files
152    let lexicon_path = output_path.join("lex.csv");
153    let connector_path = output_path.join("matrix.def");
154    let unk_path = output_path.join("unk.def");
155    let char_def_path = output_path.join("char.def");
156
157    // Write lexicon file
158    let mut lexicon_file = File::create(&lexicon_path)
159        .map_err(|e| to_magnus_error(&ruby, format!("Failed to create lexicon file: {e}")))?;
160    serializable_model
161        .write_lexicon(&mut lexicon_file)
162        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write lexicon: {e}")))?;
163
164    // Write connection matrix
165    let mut connector_file = File::create(&connector_path).map_err(|e| {
166        to_magnus_error(
167            &ruby,
168            format!("Failed to create connection matrix file: {e}"),
169        )
170    })?;
171    serializable_model
172        .write_connection_costs(&mut connector_file)
173        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write connection costs: {e}")))?;
174
175    // Write unknown word definitions
176    let mut unk_file = File::create(&unk_path)
177        .map_err(|e| to_magnus_error(&ruby, format!("Failed to create unknown word file: {e}")))?;
178    serializable_model
179        .write_unknown_dictionary(&mut unk_file)
180        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write unknown dictionary: {e}")))?;
181
182    // Write character definition file
183    let mut char_def_file = File::create(&char_def_path).map_err(|e| {
184        to_magnus_error(
185            &ruby,
186            format!("Failed to create character definition file: {e}"),
187        )
188    })?;
189
190    use std::io::Write;
191    writeln!(
192        char_def_file,
193        "# Character definition file generated from trained model"
194    )
195    .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
196    writeln!(char_def_file, "# Format: CATEGORY_NAME invoke group length")
197        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
198    writeln!(char_def_file, "DEFAULT 0 1 0")
199        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
200    writeln!(char_def_file, "HIRAGANA 1 1 0")
201        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
202    writeln!(char_def_file, "KATAKANA 1 1 0")
203        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
204    writeln!(char_def_file, "KANJI 0 0 2")
205        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
206    writeln!(char_def_file, "ALPHA 1 1 0")
207        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
208    writeln!(char_def_file, "NUMERIC 1 1 0")
209        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
210    writeln!(char_def_file)
211        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
212
213    writeln!(char_def_file, "# Character mappings")
214        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
215    writeln!(char_def_file, "0x3041..0x3096 HIRAGANA  # Hiragana")
216        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
217    writeln!(char_def_file, "0x30A1..0x30F6 KATAKANA  # Katakana")
218        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
219    writeln!(
220        char_def_file,
221        "0x4E00..0x9FAF KANJI     # CJK Unified Ideographs"
222    )
223    .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
224    writeln!(char_def_file, "0x0030..0x0039 NUMERIC   # ASCII Digits")
225        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
226    writeln!(char_def_file, "0x0041..0x005A ALPHA     # ASCII Uppercase")
227        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
228    writeln!(char_def_file, "0x0061..0x007A ALPHA     # ASCII Lowercase")
229        .map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
230
231    let mut files_created = vec![
232        lexicon_path.clone(),
233        connector_path.clone(),
234        unk_path.clone(),
235        char_def_path.clone(),
236    ];
237
238    // Handle metadata.json update if provided
239    if let Some(metadata_str) = metadata {
240        let metadata_path = Path::new(&metadata_str);
241        if !metadata_path.exists() {
242            return Err(Error::new(
243                ruby.exception_arg_error(),
244                format!("Metadata file does not exist: {}", metadata_path.display()),
245            ));
246        }
247
248        let output_metadata_path = output_path.join("metadata.json");
249        let mut metadata_file = File::create(&output_metadata_path)
250            .map_err(|e| to_magnus_error(&ruby, format!("Failed to create metadata file: {e}")))?;
251
252        serializable_model
253            .update_metadata_json(metadata_path, &mut metadata_file)
254            .map_err(|e| to_magnus_error(&ruby, format!("Failed to update metadata: {e}")))?;
255
256        files_created.push(output_metadata_path);
257        println!("Updated metadata.json with trained model values");
258    }
259
260    println!("Dictionary files exported to: {}", output_path.display());
261    println!("Files created:");
262    for file in &files_created {
263        println!("  - {}", file.display());
264    }
265
266    Ok(())
267}
268
269/// Defines train and export module functions in the given Ruby module.
270///
271/// # Arguments
272///
273/// * `_ruby` - Ruby runtime handle.
274/// * `module` - Parent Ruby module.
275///
276/// # Returns
277///
278/// `Ok(())` on success, or a Magnus `Error` on failure.
279pub fn define(_ruby: &Ruby, module: &magnus::RModule) -> Result<(), Error> {
280    module.define_module_function("train", function!(train, 10))?;
281    module.define_module_function("export", function!(export, 3))?;
282    Ok(())
283}