use std::fs::File;
use std::path::Path;
use magnus::{Error, Ruby, function};
use lindera::dictionary::trainer::{Corpus, Model, SerializableModel, Trainer, TrainerConfig};
use crate::error::to_magnus_error;
#[allow(clippy::too_many_arguments)]
fn train(
seed: String,
corpus: String,
char_def: String,
unk_def: String,
feature_def: String,
rewrite_def: String,
output: String,
lambda: Option<f64>,
max_iter: Option<u64>,
max_threads: Option<usize>,
) -> Result<(), Error> {
let ruby = Ruby::get().expect("Ruby runtime not initialized");
let seed_path = Path::new(&seed);
let corpus_path = Path::new(&corpus);
let char_def_path = Path::new(&char_def);
let unk_def_path = Path::new(&unk_def);
let feature_def_path = Path::new(&feature_def);
let rewrite_def_path = Path::new(&rewrite_def);
let output_path = Path::new(&output);
for (path, name) in [
(seed_path, "seed"),
(corpus_path, "corpus"),
(char_def_path, "char_def"),
(unk_def_path, "unk_def"),
(feature_def_path, "feature_def"),
(rewrite_def_path, "rewrite_def"),
] {
if !path.exists() {
return Err(Error::new(
ruby.exception_arg_error(),
format!("{} file does not exist: {}", name, path.display()),
));
}
}
let config = TrainerConfig::from_paths(
seed_path,
char_def_path,
unk_def_path,
feature_def_path,
rewrite_def_path,
)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to load trainer configuration: {e}")))?;
let num_threads = max_threads.unwrap_or_else(num_cpus::get);
let lambda_val = lambda.unwrap_or(0.01);
let max_iter_val = max_iter.unwrap_or(100);
let trainer = Trainer::new(config)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to initialize trainer: {e}")))?
.regularization_cost(lambda_val)
.max_iter(max_iter_val)
.num_threads(num_threads);
let corpus_file = File::open(corpus_path)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to open corpus file: {e}")))?;
let corpus_data = Corpus::from_reader(corpus_file)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to load corpus: {e}")))?;
println!("Training with {} examples...", corpus_data.len());
let model = trainer
.train(corpus_data)
.map_err(|e| to_magnus_error(&ruby, format!("Training failed: {e}")))?;
if let Some(parent) = output_path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
to_magnus_error(&ruby, format!("Failed to create output directory: {e}"))
})?;
}
let mut output_file = File::create(output_path)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to create output file: {e}")))?;
model
.write_model(&mut output_file)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write model: {e}")))?;
println!("Model saved to {}", output_path.display());
Ok(())
}
fn export(model: String, output: String, metadata: Option<String>) -> Result<(), Error> {
let ruby = Ruby::get().expect("Ruby runtime not initialized");
let model_path = Path::new(&model);
let output_path = Path::new(&output);
if !model_path.exists() {
return Err(Error::new(
ruby.exception_arg_error(),
format!("Model file does not exist: {}", model_path.display()),
));
}
let model_file = File::open(model_path)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to open model file: {e}")))?;
let serializable_model: SerializableModel = Model::read_model(model_file)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to load model: {e}")))?;
std::fs::create_dir_all(output_path)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to create output directory: {e}")))?;
let lexicon_path = output_path.join("lex.csv");
let connector_path = output_path.join("matrix.def");
let unk_path = output_path.join("unk.def");
let char_def_path = output_path.join("char.def");
let mut lexicon_file = File::create(&lexicon_path)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to create lexicon file: {e}")))?;
serializable_model
.write_lexicon(&mut lexicon_file)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write lexicon: {e}")))?;
let mut connector_file = File::create(&connector_path).map_err(|e| {
to_magnus_error(
&ruby,
format!("Failed to create connection matrix file: {e}"),
)
})?;
serializable_model
.write_connection_costs(&mut connector_file)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write connection costs: {e}")))?;
let mut unk_file = File::create(&unk_path)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to create unknown word file: {e}")))?;
serializable_model
.write_unknown_dictionary(&mut unk_file)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write unknown dictionary: {e}")))?;
let mut char_def_file = File::create(&char_def_path).map_err(|e| {
to_magnus_error(
&ruby,
format!("Failed to create character definition file: {e}"),
)
})?;
use std::io::Write;
writeln!(
char_def_file,
"# Character definition file generated from trained model"
)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file, "# Format: CATEGORY_NAME invoke group length")
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file, "DEFAULT 0 1 0")
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file, "HIRAGANA 1 1 0")
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file, "KATAKANA 1 1 0")
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file, "KANJI 0 0 2")
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file, "ALPHA 1 1 0")
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file, "NUMERIC 1 1 0")
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file, "# Character mappings")
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file, "0x3041..0x3096 HIRAGANA # Hiragana")
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file, "0x30A1..0x30F6 KATAKANA # Katakana")
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(
char_def_file,
"0x4E00..0x9FAF KANJI # CJK Unified Ideographs"
)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file, "0x0030..0x0039 NUMERIC # ASCII Digits")
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file, "0x0041..0x005A ALPHA # ASCII Uppercase")
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
writeln!(char_def_file, "0x0061..0x007A ALPHA # ASCII Lowercase")
.map_err(|e| to_magnus_error(&ruby, format!("Failed to write char.def: {e}")))?;
let mut files_created = vec![
lexicon_path.clone(),
connector_path.clone(),
unk_path.clone(),
char_def_path.clone(),
];
if let Some(metadata_str) = metadata {
let metadata_path = Path::new(&metadata_str);
if !metadata_path.exists() {
return Err(Error::new(
ruby.exception_arg_error(),
format!("Metadata file does not exist: {}", metadata_path.display()),
));
}
let output_metadata_path = output_path.join("metadata.json");
let mut metadata_file = File::create(&output_metadata_path)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to create metadata file: {e}")))?;
serializable_model
.update_metadata_json(metadata_path, &mut metadata_file)
.map_err(|e| to_magnus_error(&ruby, format!("Failed to update metadata: {e}")))?;
files_created.push(output_metadata_path);
println!("Updated metadata.json with trained model values");
}
println!("Dictionary files exported to: {}", output_path.display());
println!("Files created:");
for file in &files_created {
println!(" - {}", file.display());
}
Ok(())
}
pub fn define(_ruby: &Ruby, module: &magnus::RModule) -> Result<(), Error> {
module.define_module_function("train", function!(train, 10))?;
module.define_module_function("export", function!(export, 3))?;
Ok(())
}