lindera/
trainer.rs

1//! Training functionality for custom morphological models.
2//!
3//! This module provides functions for training custom morphological analysis models
4//! from annotated corpora.
5//!
6//! # Examples
7//!
8//! ```python
9//! # Train a model
10//! lindera.train(
11//!     seed="seed.csv",
12//!     corpus="corpus.txt",
13//!     char_def="char.def",
14//!     unk_def="unk.def",
15//!     feature_def="feature.def",
16//!     rewrite_def="rewrite.def",
17//!     output="model.bin",
18//!     lambda_=0.01,
19//!     max_iter=100
20//! )
21//!
22//! # Export trained model
23//! lindera.export(
24//!     model_file="model.bin",
25//!     output_dir="/path/to/output"
26//! )
27//! ```
28
29use std::fs::File;
30use std::path::Path;
31
32use pyo3::{exceptions::PyValueError, prelude::*};
33
34use lindera::dictionary::trainer::{Corpus, Model, SerializableModel, Trainer, TrainerConfig};
35
36/// Trains a morphological analysis model from an annotated corpus.
37///
38/// # Arguments
39///
40/// * `seed` - Seed lexicon file path (CSV format)
41/// * `corpus` - Training corpus file path (annotated text)
42/// * `char_def` - Character definition file path (char.def)
43/// * `unk_def` - Unknown word definition file path (unk.def)
44/// * `feature_def` - Feature definition file path (feature.def)
45/// * `rewrite_def` - Rewrite rule definition file path (rewrite.def)
46/// * `output` - Output model file path
47/// * `lambda_` - L1 regularization (0.0-1.0), default: 0.01
48/// * `max_iter` - Maximum number of iterations, default: 100
49/// * `max_threads` - Maximum number of threads (None = auto-detect CPU cores)
50///
51/// # Returns
52///
53/// * `PyResult<()>` - Returns Ok(()) on success
54#[pyfunction]
55#[pyo3(signature = (seed, corpus, char_def, unk_def, feature_def, rewrite_def, output, lambda_=0.01, max_iter=100, max_threads=None))]
56#[allow(clippy::too_many_arguments)]
57pub fn train(
58    seed: &str,
59    corpus: &str,
60    char_def: &str,
61    unk_def: &str,
62    feature_def: &str,
63    rewrite_def: &str,
64    output: &str,
65    lambda_: f64,
66    max_iter: u64,
67    max_threads: Option<usize>,
68) -> PyResult<()> {
69    let seed_path = Path::new(seed);
70    let corpus_path = Path::new(corpus);
71    let char_def_path = Path::new(char_def);
72    let unk_def_path = Path::new(unk_def);
73    let feature_def_path = Path::new(feature_def);
74    let rewrite_def_path = Path::new(rewrite_def);
75    let output_path = Path::new(output);
76
77    // Validate input files
78    for (path, name) in [
79        (seed_path, "seed"),
80        (corpus_path, "corpus"),
81        (char_def_path, "char_def"),
82        (unk_def_path, "unk_def"),
83        (feature_def_path, "feature_def"),
84        (rewrite_def_path, "rewrite_def"),
85    ] {
86        if !path.exists() {
87            return Err(PyValueError::new_err(format!(
88                "{} file does not exist: {}",
89                name,
90                path.display()
91            )));
92        }
93    }
94
95    // Load configuration
96    let config = TrainerConfig::from_paths(
97        seed_path,
98        char_def_path,
99        unk_def_path,
100        feature_def_path,
101        rewrite_def_path,
102    )
103    .map_err(|e| PyValueError::new_err(format!("Failed to load trainer configuration: {e}")))?;
104
105    // Initialize trainer
106    let num_threads = max_threads.unwrap_or_else(num_cpus::get);
107    let trainer = Trainer::new(config)
108        .map_err(|e| PyValueError::new_err(format!("Failed to initialize trainer: {e}")))?
109        .regularization_cost(lambda_)
110        .max_iter(max_iter)
111        .num_threads(num_threads);
112
113    // Load corpus
114    let corpus_file = File::open(corpus_path)
115        .map_err(|e| PyValueError::new_err(format!("Failed to open corpus file: {e}")))?;
116    let corpus = Corpus::from_reader(corpus_file)
117        .map_err(|e| PyValueError::new_err(format!("Failed to load corpus: {e}")))?;
118
119    println!("Training with {} examples...", corpus.len());
120
121    // Train model
122    let model = trainer
123        .train(corpus)
124        .map_err(|e| PyValueError::new_err(format!("Training failed: {e}")))?;
125
126    // Save model
127    if let Some(parent) = output_path.parent() {
128        std::fs::create_dir_all(parent).map_err(|e| {
129            PyValueError::new_err(format!("Failed to create output directory: {e}"))
130        })?;
131    }
132
133    let mut output_file = File::create(output_path)
134        .map_err(|e| PyValueError::new_err(format!("Failed to create output file: {e}")))?;
135
136    model
137        .write_model(&mut output_file)
138        .map_err(|e| PyValueError::new_err(format!("Failed to write model: {e}")))?;
139
140    println!("Model saved to {}", output_path.display());
141    Ok(())
142}
143
144/// Export dictionary files from trained model
145///
146/// # Arguments
147///
148/// * `model` - Trained model file path (.dat format)
149/// * `output` - Output directory path (creates lex.csv, matrix.def, unk.def, char.def)
150/// * `metadata` - Optional base metadata.json file to update with trained model values
151///
152/// # Returns
153///
154/// * `PyResult<()>` - Returns Ok(()) on success
155#[pyfunction]
156#[pyo3(signature = (model, output, metadata=None))]
157pub fn export(model: &str, output: &str, metadata: Option<&str>) -> PyResult<()> {
158    let model_path = Path::new(model);
159    let output_path = Path::new(output);
160
161    // Validate input files
162    if !model_path.exists() {
163        return Err(PyValueError::new_err(format!(
164            "Model file does not exist: {}",
165            model_path.display()
166        )));
167    }
168
169    // Load trained model
170    let model_file = File::open(model_path)
171        .map_err(|e| PyValueError::new_err(format!("Failed to open model file: {e}")))?;
172
173    let serializable_model: SerializableModel = Model::read_model(model_file)
174        .map_err(|e| PyValueError::new_err(format!("Failed to load model: {e}")))?;
175
176    // Create output directory
177    std::fs::create_dir_all(output_path)
178        .map_err(|e| PyValueError::new_err(format!("Failed to create output directory: {e}")))?;
179
180    // Export dictionary files
181    let lexicon_path = output_path.join("lex.csv");
182    let connector_path = output_path.join("matrix.def");
183    let unk_path = output_path.join("unk.def");
184    let char_def_path = output_path.join("char.def");
185
186    // Write lexicon file
187    let mut lexicon_file = File::create(&lexicon_path)
188        .map_err(|e| PyValueError::new_err(format!("Failed to create lexicon file: {e}")))?;
189    serializable_model
190        .write_lexicon(&mut lexicon_file)
191        .map_err(|e| PyValueError::new_err(format!("Failed to write lexicon: {e}")))?;
192
193    // Write connection matrix
194    let mut connector_file = File::create(&connector_path).map_err(|e| {
195        PyValueError::new_err(format!("Failed to create connection matrix file: {e}"))
196    })?;
197    serializable_model
198        .write_connection_costs(&mut connector_file)
199        .map_err(|e| PyValueError::new_err(format!("Failed to write connection costs: {e}")))?;
200
201    // Write unknown word definitions
202    let mut unk_file = File::create(&unk_path)
203        .map_err(|e| PyValueError::new_err(format!("Failed to create unknown word file: {e}")))?;
204    serializable_model
205        .write_unknown_dictionary(&mut unk_file)
206        .map_err(|e| PyValueError::new_err(format!("Failed to write unknown dictionary: {e}")))?;
207
208    // Write character definition file
209    let mut char_def_file = File::create(&char_def_path).map_err(|e| {
210        PyValueError::new_err(format!("Failed to create character definition file: {e}"))
211    })?;
212
213    use std::io::Write;
214    writeln!(
215        char_def_file,
216        "# Character definition file generated from trained model"
217    )
218    .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
219    writeln!(char_def_file, "# Format: CATEGORY_NAME invoke group length")
220        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
221    writeln!(char_def_file, "DEFAULT 0 1 0")
222        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
223    writeln!(char_def_file, "HIRAGANA 1 1 0")
224        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
225    writeln!(char_def_file, "KATAKANA 1 1 0")
226        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
227    writeln!(char_def_file, "KANJI 0 0 2")
228        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
229    writeln!(char_def_file, "ALPHA 1 1 0")
230        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
231    writeln!(char_def_file, "NUMERIC 1 1 0")
232        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
233    writeln!(char_def_file)
234        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
235
236    writeln!(char_def_file, "# Character mappings")
237        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
238    writeln!(char_def_file, "0x3041..0x3096 HIRAGANA  # Hiragana")
239        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
240    writeln!(char_def_file, "0x30A1..0x30F6 KATAKANA  # Katakana")
241        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
242    writeln!(
243        char_def_file,
244        "0x4E00..0x9FAF KANJI     # CJK Unified Ideographs"
245    )
246    .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
247    writeln!(char_def_file, "0x0030..0x0039 NUMERIC   # ASCII Digits")
248        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
249    writeln!(char_def_file, "0x0041..0x005A ALPHA     # ASCII Uppercase")
250        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
251    writeln!(char_def_file, "0x0061..0x007A ALPHA     # ASCII Lowercase")
252        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
253
254    let mut files_created = vec![
255        lexicon_path.clone(),
256        connector_path.clone(),
257        unk_path.clone(),
258        char_def_path.clone(),
259    ];
260
261    // Handle metadata.json update if provided
262    if let Some(metadata_str) = metadata {
263        let metadata_path = Path::new(metadata_str);
264        if !metadata_path.exists() {
265            return Err(PyValueError::new_err(format!(
266                "Metadata file does not exist: {}",
267                metadata_path.display()
268            )));
269        }
270
271        let output_metadata_path = output_path.join("metadata.json");
272        let mut metadata_file = File::create(&output_metadata_path)
273            .map_err(|e| PyValueError::new_err(format!("Failed to create metadata file: {e}")))?;
274
275        serializable_model
276            .update_metadata_json(metadata_path, &mut metadata_file)
277            .map_err(|e| PyValueError::new_err(format!("Failed to update metadata: {e}")))?;
278
279        files_created.push(output_metadata_path);
280        println!("Updated metadata.json with trained model values");
281    }
282
283    println!("Dictionary files exported to: {}", output_path.display());
284    println!("Files created:");
285    for file in &files_created {
286        println!("  - {}", file.display());
287    }
288
289    Ok(())
290}