lindera/
trainer.rs

1use std::fs::File;
2use std::path::Path;
3
4use pyo3::{exceptions::PyValueError, prelude::*};
5
6use lindera::dictionary::trainer::{Corpus, Model, SerializableModel, Trainer, TrainerConfig};
7
8/// Train a morphological analysis model from corpus
9///
10/// # Arguments
11///
12/// * `seed` - Seed lexicon file path (CSV format)
13/// * `corpus` - Training corpus file path (annotated text)
14/// * `char_def` - Character definition file path (char.def)
15/// * `unk_def` - Unknown word definition file path (unk.def)
16/// * `feature_def` - Feature definition file path (feature.def)
17/// * `rewrite_def` - Rewrite rule definition file path (rewrite.def)
18/// * `output` - Output model file path
19/// * `lambda_` - L1 regularization (0.0-1.0), default: 0.01
20/// * `max_iter` - Maximum number of iterations, default: 100
21/// * `max_threads` - Maximum number of threads (None = auto-detect CPU cores)
22///
23/// # Returns
24///
25/// * `PyResult<()>` - Returns Ok(()) on success
26#[pyfunction]
27#[pyo3(signature = (seed, corpus, char_def, unk_def, feature_def, rewrite_def, output, lambda_=0.01, max_iter=100, max_threads=None))]
28#[allow(clippy::too_many_arguments)]
29pub fn train(
30    seed: &str,
31    corpus: &str,
32    char_def: &str,
33    unk_def: &str,
34    feature_def: &str,
35    rewrite_def: &str,
36    output: &str,
37    lambda_: f64,
38    max_iter: u64,
39    max_threads: Option<usize>,
40) -> PyResult<()> {
41    let seed_path = Path::new(seed);
42    let corpus_path = Path::new(corpus);
43    let char_def_path = Path::new(char_def);
44    let unk_def_path = Path::new(unk_def);
45    let feature_def_path = Path::new(feature_def);
46    let rewrite_def_path = Path::new(rewrite_def);
47    let output_path = Path::new(output);
48
49    // Validate input files
50    for (path, name) in [
51        (seed_path, "seed"),
52        (corpus_path, "corpus"),
53        (char_def_path, "char_def"),
54        (unk_def_path, "unk_def"),
55        (feature_def_path, "feature_def"),
56        (rewrite_def_path, "rewrite_def"),
57    ] {
58        if !path.exists() {
59            return Err(PyValueError::new_err(format!(
60                "{} file does not exist: {}",
61                name,
62                path.display()
63            )));
64        }
65    }
66
67    // Load configuration
68    let config = TrainerConfig::from_paths(
69        seed_path,
70        char_def_path,
71        unk_def_path,
72        feature_def_path,
73        rewrite_def_path,
74    )
75    .map_err(|e| PyValueError::new_err(format!("Failed to load trainer configuration: {e}")))?;
76
77    // Initialize trainer
78    let num_threads = max_threads.unwrap_or_else(num_cpus::get);
79    let trainer = Trainer::new(config)
80        .map_err(|e| PyValueError::new_err(format!("Failed to initialize trainer: {e}")))?
81        .regularization_cost(lambda_)
82        .max_iter(max_iter)
83        .num_threads(num_threads);
84
85    // Load corpus
86    let corpus_file = File::open(corpus_path)
87        .map_err(|e| PyValueError::new_err(format!("Failed to open corpus file: {e}")))?;
88    let corpus = Corpus::from_reader(corpus_file)
89        .map_err(|e| PyValueError::new_err(format!("Failed to load corpus: {e}")))?;
90
91    println!("Training with {} examples...", corpus.len());
92
93    // Train model
94    let model = trainer
95        .train(corpus)
96        .map_err(|e| PyValueError::new_err(format!("Training failed: {e}")))?;
97
98    // Save model
99    if let Some(parent) = output_path.parent() {
100        std::fs::create_dir_all(parent).map_err(|e| {
101            PyValueError::new_err(format!("Failed to create output directory: {e}"))
102        })?;
103    }
104
105    let mut output_file = File::create(output_path)
106        .map_err(|e| PyValueError::new_err(format!("Failed to create output file: {e}")))?;
107
108    model
109        .write_model(&mut output_file)
110        .map_err(|e| PyValueError::new_err(format!("Failed to write model: {e}")))?;
111
112    println!("Model saved to {}", output_path.display());
113    Ok(())
114}
115
116/// Export dictionary files from trained model
117///
118/// # Arguments
119///
120/// * `model` - Trained model file path (.dat format)
121/// * `output` - Output directory path (creates lex.csv, matrix.def, unk.def, char.def)
122/// * `metadata` - Optional base metadata.json file to update with trained model values
123///
124/// # Returns
125///
126/// * `PyResult<()>` - Returns Ok(()) on success
127#[pyfunction]
128#[pyo3(signature = (model, output, metadata=None))]
129pub fn export(model: &str, output: &str, metadata: Option<&str>) -> PyResult<()> {
130    let model_path = Path::new(model);
131    let output_path = Path::new(output);
132
133    // Validate input files
134    if !model_path.exists() {
135        return Err(PyValueError::new_err(format!(
136            "Model file does not exist: {}",
137            model_path.display()
138        )));
139    }
140
141    // Load trained model
142    let model_file = File::open(model_path)
143        .map_err(|e| PyValueError::new_err(format!("Failed to open model file: {e}")))?;
144
145    let serializable_model: SerializableModel = Model::read_model(model_file)
146        .map_err(|e| PyValueError::new_err(format!("Failed to load model: {e}")))?;
147
148    // Create output directory
149    std::fs::create_dir_all(output_path)
150        .map_err(|e| PyValueError::new_err(format!("Failed to create output directory: {e}")))?;
151
152    // Export dictionary files
153    let lexicon_path = output_path.join("lex.csv");
154    let connector_path = output_path.join("matrix.def");
155    let unk_path = output_path.join("unk.def");
156    let char_def_path = output_path.join("char.def");
157
158    // Write lexicon file
159    let mut lexicon_file = File::create(&lexicon_path)
160        .map_err(|e| PyValueError::new_err(format!("Failed to create lexicon file: {e}")))?;
161    serializable_model
162        .write_lexicon(&mut lexicon_file)
163        .map_err(|e| PyValueError::new_err(format!("Failed to write lexicon: {e}")))?;
164
165    // Write connection matrix
166    let mut connector_file = File::create(&connector_path).map_err(|e| {
167        PyValueError::new_err(format!("Failed to create connection matrix file: {e}"))
168    })?;
169    serializable_model
170        .write_connection_costs(&mut connector_file)
171        .map_err(|e| PyValueError::new_err(format!("Failed to write connection costs: {e}")))?;
172
173    // Write unknown word definitions
174    let mut unk_file = File::create(&unk_path)
175        .map_err(|e| PyValueError::new_err(format!("Failed to create unknown word file: {e}")))?;
176    serializable_model
177        .write_unknown_dictionary(&mut unk_file)
178        .map_err(|e| PyValueError::new_err(format!("Failed to write unknown dictionary: {e}")))?;
179
180    // Write character definition file
181    let mut char_def_file = File::create(&char_def_path).map_err(|e| {
182        PyValueError::new_err(format!("Failed to create character definition file: {e}"))
183    })?;
184
185    use std::io::Write;
186    writeln!(
187        char_def_file,
188        "# Character definition file generated from trained model"
189    )
190    .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
191    writeln!(char_def_file, "# Format: CATEGORY_NAME invoke group length")
192        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
193    writeln!(char_def_file, "DEFAULT 0 1 0")
194        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
195    writeln!(char_def_file, "HIRAGANA 1 1 0")
196        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
197    writeln!(char_def_file, "KATAKANA 1 1 0")
198        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
199    writeln!(char_def_file, "KANJI 0 0 2")
200        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
201    writeln!(char_def_file, "ALPHA 1 1 0")
202        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
203    writeln!(char_def_file, "NUMERIC 1 1 0")
204        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
205    writeln!(char_def_file)
206        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
207
208    writeln!(char_def_file, "# Character mappings")
209        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
210    writeln!(char_def_file, "0x3041..0x3096 HIRAGANA  # Hiragana")
211        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
212    writeln!(char_def_file, "0x30A1..0x30F6 KATAKANA  # Katakana")
213        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
214    writeln!(
215        char_def_file,
216        "0x4E00..0x9FAF KANJI     # CJK Unified Ideographs"
217    )
218    .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
219    writeln!(char_def_file, "0x0030..0x0039 NUMERIC   # ASCII Digits")
220        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
221    writeln!(char_def_file, "0x0041..0x005A ALPHA     # ASCII Uppercase")
222        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
223    writeln!(char_def_file, "0x0061..0x007A ALPHA     # ASCII Lowercase")
224        .map_err(|e| PyValueError::new_err(format!("Failed to write char.def: {e}")))?;
225
226    let mut files_created = vec![
227        lexicon_path.clone(),
228        connector_path.clone(),
229        unk_path.clone(),
230        char_def_path.clone(),
231    ];
232
233    // Handle metadata.json update if provided
234    if let Some(metadata_str) = metadata {
235        let metadata_path = Path::new(metadata_str);
236        if !metadata_path.exists() {
237            return Err(PyValueError::new_err(format!(
238                "Metadata file does not exist: {}",
239                metadata_path.display()
240            )));
241        }
242
243        let output_metadata_path = output_path.join("metadata.json");
244        let mut metadata_file = File::create(&output_metadata_path)
245            .map_err(|e| PyValueError::new_err(format!("Failed to create metadata file: {e}")))?;
246
247        serializable_model
248            .update_metadata_json(metadata_path, &mut metadata_file)
249            .map_err(|e| PyValueError::new_err(format!("Failed to update metadata: {e}")))?;
250
251        files_created.push(output_metadata_path);
252        println!("Updated metadata.json with trained model values");
253    }
254
255    println!("Dictionary files exported to: {}", output_path.display());
256    println!("Files created:");
257    for file in &files_created {
258        println!("  - {}", file.display());
259    }
260
261    Ok(())
262}