use std::collections::HashMap;
use std::io::{BufReader, Read};
use std::path::Path;
use anyhow::Result;
use super::feature_extractor::FeatureExtractor;
use super::feature_rewriter::DictionaryRewriter;
use crate::dictionary::Dictionary;
use crate::dictionary::character_definition::CharacterDefinition;
use crate::dictionary::connection_cost_matrix::ConnectionCostMatrix;
use crate::dictionary::metadata::Metadata;
use crate::dictionary::prefix_dictionary::PrefixDictionary;
use crate::dictionary::unknown_dictionary::UnknownDictionary;
pub struct TrainerConfig {
pub(crate) dict: Dictionary,
pub(crate) surfaces: Vec<String>,
pub(crate) features: Vec<String>,
pub(crate) surface_features: HashMap<String, String>,
pub(crate) user_lexicon: HashMap<String, String>,
pub(crate) feature_extractor: FeatureExtractor,
pub(crate) dictionary_rewriter: DictionaryRewriter,
pub(crate) cost_factor: i32,
pub(crate) metadata: Metadata,
pub(crate) unk_categories: HashMap<String, String>,
pub(crate) unk_costs: HashMap<String, i32>,
pub(crate) char_def_content: String,
pub(crate) feature_def_content: String,
pub(crate) rewrite_def_content: String,
}
impl TrainerConfig {
pub fn system_lexicon(&self) -> &PrefixDictionary {
&self.dict.prefix_dictionary
}
pub fn dict(&self) -> &Dictionary {
&self.dict
}
pub fn unk_handler(&self) -> &crate::dictionary::unknown_dictionary::UnknownDictionary {
&self.dict.unknown_dictionary
}
}
impl TrainerConfig {
pub fn from_readers<R1, R2, R3, R4, R5>(
lexicon_rdr: R1,
char_prop_rdr: R2,
unk_handler_rdr: R3,
feature_templates_rdr: R4,
rewrite_rules_rdr: R5,
) -> Result<Self>
where
R1: Read,
R2: Read,
R3: Read,
R4: Read,
R5: Read,
{
let mut surfaces = Vec::new();
let mut features = Vec::new();
let mut surface_features = HashMap::new();
let mut lexicon_content = String::new();
{
let mut lexicon_reader = BufReader::new(lexicon_rdr);
std::io::Read::read_to_string(&mut lexicon_reader, &mut lexicon_content)?;
}
for line in lexicon_content.lines() {
if line.trim().is_empty() || line.starts_with('#') {
continue;
}
let parts: Vec<&str> = line.split(',').collect();
if parts.len() >= 5 {
let surface = parts[0].to_string();
let feature_str = parts[4..].join(",");
surfaces.push(surface.clone());
features.push(feature_str.clone());
surface_features.insert(surface, feature_str);
}
}
let mut feature_content = String::new();
{
let mut template_reader = BufReader::new(feature_templates_rdr);
std::io::Read::read_to_string(&mut template_reader, &mut feature_content)?;
}
let mut unigram_templates = Vec::new();
let mut bigram_templates = Vec::new();
for line in feature_content.lines() {
if line.trim().is_empty() || line.starts_with('#') {
continue;
}
if let Some(rest) = line.strip_prefix("UNIGRAM") {
let rest = rest.trim_start().trim_start_matches(':').trim_start();
let template = if let Some(idx) = rest.find('%') {
&rest[idx..]
} else {
rest
};
unigram_templates.push(template.to_string());
} else if let Some(rest) = line.strip_prefix("BIGRAM") {
let rest = rest.trim_start().trim_start_matches(':').trim_start();
let template = if let Some(idx) = rest.find('%') {
&rest[idx..]
} else {
rest
};
if let Some((left, right)) = template.split_once('/') {
bigram_templates.push((left.to_string(), right.to_string()));
}
} else {
unigram_templates.push(line.to_string());
}
}
let feature_extractor =
FeatureExtractor::from_templates(&unigram_templates, &bigram_templates);
let mut rewrite_def_content = String::new();
{
let mut rewrite_reader = BufReader::new(rewrite_rules_rdr);
std::io::Read::read_to_string(&mut rewrite_reader, &mut rewrite_def_content)?;
}
let dictionary_rewriter =
DictionaryRewriter::from_reader(std::io::Cursor::new(rewrite_def_content.as_bytes()))?;
let mut unk_content = String::new();
{
let mut unk_reader = BufReader::new(unk_handler_rdr);
std::io::Read::read_to_string(&mut unk_reader, &mut unk_content)?;
}
let mut unk_categories = HashMap::new();
let mut unk_costs = HashMap::new();
for line in unk_content.lines() {
if line.trim().is_empty() || line.starts_with('#') {
continue;
}
let parts: Vec<&str> = line.split(',').collect();
if parts.len() >= 5 {
let category = parts[0].to_string();
let features = parts[4..].join(",");
unk_categories.insert(category.clone(), features);
if let Ok(cost) = parts[3].parse::<i32>() {
unk_costs.insert(category, cost);
}
}
}
let mut char_def_content = String::new();
{
let mut char_prop_reader = BufReader::new(char_prop_rdr);
std::io::Read::read_to_string(&mut char_prop_reader, &mut char_def_content)?;
}
use std::io::Cursor;
let dict = Self::build_dictionary_from_readers(
&lexicon_content,
Cursor::new(char_def_content.as_bytes()),
Cursor::new(unk_content.as_bytes()),
)?;
Ok(Self {
dict,
surfaces,
features,
surface_features,
user_lexicon: HashMap::new(), feature_extractor,
dictionary_rewriter,
cost_factor: 700, metadata: Metadata::default(), unk_categories,
unk_costs,
char_def_content,
feature_def_content: feature_content,
rewrite_def_content,
})
}
pub fn surfaces(&self) -> &[String] {
&self.surfaces
}
pub fn surface_features(&self) -> &HashMap<String, String> {
&self.surface_features
}
pub fn user_lexicon(&self) -> &HashMap<String, String> {
&self.user_lexicon
}
pub fn add_user_lexicon_entry(&mut self, surface: String, features: String) {
self.user_lexicon.insert(surface, features);
}
pub fn get_features(&self, surface: &str) -> Option<String> {
self.user_lexicon
.get(surface)
.or_else(|| self.surface_features.get(surface))
.cloned()
}
pub fn load_user_lexicon_from_content(&mut self, content: &str) -> Result<()> {
for line in content.lines() {
if line.trim().is_empty() || line.starts_with('#') {
continue;
}
let parts: Vec<&str> = line.split(',').collect();
if parts.len() >= 5 {
let surface = parts[0].to_string();
let features = parts[4..].join(",");
self.user_lexicon.insert(surface, features);
}
}
Ok(())
}
pub fn from_paths(
lexicon_path: &Path,
char_prop_path: &Path,
unk_handler_path: &Path,
feature_templates_path: &Path,
rewrite_rules_path: &Path,
) -> Result<Self> {
use std::fs::File;
Self::from_readers(
File::open(lexicon_path)?,
File::open(char_prop_path)?,
File::open(unk_handler_path)?,
File::open(feature_templates_path)?,
File::open(rewrite_rules_path)?,
)
}
pub fn metadata(&self) -> &Metadata {
&self.metadata
}
fn build_dictionary_from_readers<R2, R3>(
lexicon_content: &str,
char_prop_rdr: R2,
unk_handler_rdr: R3,
) -> Result<Dictionary>
where
R2: Read,
R3: Read,
{
let mut char_prop_content = String::new();
let mut char_prop_reader = BufReader::new(char_prop_rdr);
std::io::Read::read_to_string(&mut char_prop_reader, &mut char_prop_content)?;
let mut unk_content = String::new();
let mut unk_reader = BufReader::new(unk_handler_rdr);
std::io::Read::read_to_string(&mut unk_reader, &mut unk_content)?;
let char_def = Self::build_char_def_from_content(&char_prop_content)?;
let unknown_dict = Self::build_unknown_dict_from_content(&unk_content, &char_def)?;
let prefix_dict = Self::build_prefix_dict_from_content(lexicon_content)?;
let conn_matrix = Self::create_minimal_connection_matrix()?;
Ok(Dictionary {
prefix_dictionary: prefix_dict,
connection_cost_matrix: conn_matrix,
character_definition: char_def,
unknown_dictionary: unknown_dict,
metadata: Metadata::default(),
})
}
fn build_char_def_from_content(content: &str) -> Result<CharacterDefinition> {
use crate::dictionary::character_definition::{CategoryData, CategoryId, LookupTable};
use std::collections::HashMap;
let mut category_definitions = Vec::new();
let mut category_names = Vec::new();
let mut category_map = HashMap::new(); let mut char_ranges = Vec::new();
category_names.push("DEFAULT".to_string());
category_map.insert("DEFAULT".to_string(), 0);
category_definitions.push(CategoryData {
invoke: false,
group: true,
length: 0,
});
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if line.starts_with("0x") {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 {
let range_str = parts[0];
let category = parts[1];
if let Some(range_parts) = range_str.split_once("..") {
let start = u32::from_str_radix(&range_parts.0[2..], 16)?;
let end = u32::from_str_radix(&range_parts.1[2..], 16)?;
let cat_idx =
*category_map.entry(category.to_string()).or_insert_with(|| {
let idx = category_names.len();
category_names.push(category.to_string());
category_definitions.push(CategoryData {
invoke: true,
group: true,
length: 0,
});
idx
});
char_ranges.push((start, end, cat_idx));
}
}
} else {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 4 {
let name = parts[0];
let invoke = parts[1] != "0";
let group = parts[2] != "0";
let length = parts[3].parse::<u8>().unwrap_or(0);
let cat_idx = *category_map.entry(name.to_string()).or_insert_with(|| {
let idx = category_names.len();
category_names.push(name.to_string());
category_definitions.push(CategoryData {
invoke,
group,
length: length.into(),
});
idx
});
if cat_idx < category_definitions.len() {
category_definitions[cat_idx] = CategoryData {
invoke,
group,
length: length.into(),
};
}
}
}
}
char_ranges.sort_by_key(|&(start, _, _)| start);
let mut boundaries = vec![0u32];
for &(start, end, _) in &char_ranges {
if start > boundaries[boundaries.len() - 1] {
boundaries.push(start);
}
boundaries.push(end + 1);
}
if boundaries[boundaries.len() - 1] < 0x10FFFF {
boundaries.push(0x10FFFF);
}
let ranges_clone = char_ranges.clone();
let mapping = LookupTable::from_fn(boundaries, &|c, buff| {
let code = c;
for &(start, end, cat_idx) in &ranges_clone {
if code >= start && code <= end {
buff.push(CategoryId(cat_idx));
return;
}
}
buff.push(CategoryId(0));
});
Ok(CharacterDefinition {
category_definitions,
category_names,
mapping,
})
}
fn build_unknown_dict_from_content(
_content: &str,
_char_def: &CharacterDefinition,
) -> Result<UnknownDictionary> {
Ok(UnknownDictionary {
category_references: vec![vec![0]; 6], costs: vec![], words_idx_data: vec![],
words_data: vec![],
})
}
fn build_prefix_dict_from_content(_content: &str) -> Result<PrefixDictionary> {
use crate::util::Data;
use daachorse::DoubleArrayAhoCorasickBuilder;
let keys: &[&str] = &["\0"];
let da = DoubleArrayAhoCorasickBuilder::new().build(keys).unwrap();
Ok(PrefixDictionary {
da,
vals_data: Data::from(vec![]),
words_idx_data: Data::from(vec![]),
words_data: Data::from(vec![]),
is_system: true,
})
}
fn create_minimal_connection_matrix() -> Result<ConnectionCostMatrix> {
let matrix_size = 6u16;
let mut matrix_data = vec![0u8; 4];
matrix_data[0..2].copy_from_slice(&matrix_size.to_le_bytes());
matrix_data[2..4].copy_from_slice(&matrix_size.to_le_bytes());
let cost_data_size = (matrix_size as usize) * (matrix_size as usize) * 2; matrix_data.extend(vec![0u8; cost_data_size]);
Ok(ConnectionCostMatrix::load(matrix_data)?)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_ipadic_format_13_columns() {
let seed_csv = "東京,0,0,5000,名詞,固有名詞,地域,一般,*,*,東京,トウキョウ,トーキョー\n\
行く,1,1,4000,動詞,自立,*,*,五段・カ行促音便,基本形,行く,イク,イク\n";
let char_def = "DEFAULT 0 1 0\nHIRAGANA 1 1 0\n0x3042..0x3096 HIRAGANA\n";
let unk_def = "DEFAULT,0,0,1500,名詞,一般,*,*,*,*,*,*,*\n";
let feature_def = "UNIGRAM:%F[0]\nUNIGRAM:%F[1]\n";
let rewrite_def = "*\tUNK\n";
let config = TrainerConfig::from_readers(
Cursor::new(seed_csv),
Cursor::new(char_def),
Cursor::new(unk_def),
Cursor::new(feature_def),
Cursor::new(rewrite_def),
)
.unwrap();
assert_eq!(config.surfaces().len(), 2);
assert!(config.surfaces().contains(&"東京".to_string()));
assert!(config.surfaces().contains(&"行く".to_string()));
let tokyo_features = config.surface_features().get("東京").unwrap();
assert_eq!(
tokyo_features,
"名詞,固有名詞,地域,一般,*,*,東京,トウキョウ,トーキョー"
);
}
#[test]
fn test_ko_dic_format_8_columns() {
let seed_csv = "한국,0,0,5000,NNG,Korea,F,han-guk\n\
안녕,1,1,4000,NNG,hello,F,an-nyeong\n";
let char_def = "DEFAULT 0 1 0\nHANGUL 1 1 0\n0xAC00..0xD7A3 HANGUL\n";
let unk_def = "DEFAULT,0,0,1500,NNG,unknown,F,*\n";
let feature_def = "UNIGRAM:%F[0]\n";
let rewrite_def = "*\tUNK\n";
let config = TrainerConfig::from_readers(
Cursor::new(seed_csv),
Cursor::new(char_def),
Cursor::new(unk_def),
Cursor::new(feature_def),
Cursor::new(rewrite_def),
)
.unwrap();
assert_eq!(config.surfaces().len(), 2);
assert!(config.surfaces().contains(&"한국".to_string()));
assert!(config.surfaces().contains(&"안녕".to_string()));
let korea_features = config.surface_features().get("한국").unwrap();
assert_eq!(korea_features, "NNG,Korea,F,han-guk");
}
#[test]
fn test_cc_cedict_format_8_columns() {
let seed_csv = "中国,0,0,5000,n,China,*,zhong1guo2\n\
你好,1,1,4000,x,hello,*,ni3hao3\n";
let char_def = "DEFAULT 0 1 0\nHANZI 1 1 0\n0x4E00..0x9FFF HANZI\n";
let unk_def = "DEFAULT,0,0,1500,n,unknown,*,*\n";
let feature_def = "UNIGRAM:%F[0]\n";
let rewrite_def = "*\tUNK\n";
let config = TrainerConfig::from_readers(
Cursor::new(seed_csv),
Cursor::new(char_def),
Cursor::new(unk_def),
Cursor::new(feature_def),
Cursor::new(rewrite_def),
)
.unwrap();
assert_eq!(config.surfaces().len(), 2);
assert!(config.surfaces().contains(&"中国".to_string()));
assert!(config.surfaces().contains(&"你好".to_string()));
let china_features = config.surface_features().get("中国").unwrap();
assert_eq!(china_features, "n,China,*,zhong1guo2");
}
#[test]
fn test_unidic_format_21_columns() {
let seed_csv = "東京,0,0,5000,名詞,固有名詞,地名,一般,*,*,トウキョウ,東京,東京,東京,東京,東京,トウキョウ,トーキョー,東京,東京,1\n";
let char_def = "DEFAULT 0 1 0\nKANJI 0 0 2\n0x4E00..0x9FFF KANJI\n";
let unk_def = "DEFAULT,0,0,1500,名詞,普通名詞,一般,*,*,*,*,*,*,*,*,*,*,*,*,*,*\n";
let feature_def = "UNIGRAM:%F[0]\nUNIGRAM:%F[1]\n";
let rewrite_def = "*\tUNK\n";
let config = TrainerConfig::from_readers(
Cursor::new(seed_csv),
Cursor::new(char_def),
Cursor::new(unk_def),
Cursor::new(feature_def),
Cursor::new(rewrite_def),
)
.unwrap();
assert_eq!(config.surfaces().len(), 1);
assert!(config.surfaces().contains(&"東京".to_string()));
let tokyo_features = config.surface_features().get("東京").unwrap();
assert_eq!(
tokyo_features,
"名詞,固有名詞,地名,一般,*,*,トウキョウ,東京,東京,東京,東京,東京,トウキョウ,トーキョー,東京,東京,1"
);
}
#[test]
fn test_mixed_column_counts() {
let seed_csv = "東京,0,0,5000,名詞,固有名詞,地域,一般,*,*,東京,トウキョウ,トーキョー\n\
한국,1,1,4000,NNG,Korea,F,han-guk\n\
中国,2,2,3000,n,China,*,zhong1guo2\n";
let char_def = "DEFAULT 0 1 0\n";
let unk_def = "DEFAULT,0,0,1500,*,*,*,*\n";
let feature_def = "UNIGRAM:%F[0]\n";
let rewrite_def = "*\tUNK\n";
let config = TrainerConfig::from_readers(
Cursor::new(seed_csv),
Cursor::new(char_def),
Cursor::new(unk_def),
Cursor::new(feature_def),
Cursor::new(rewrite_def),
)
.unwrap();
assert_eq!(config.surfaces().len(), 3);
assert_eq!(
config.surface_features().get("東京").unwrap(),
"名詞,固有名詞,地域,一般,*,*,東京,トウキョウ,トーキョー"
);
assert_eq!(
config.surface_features().get("한국").unwrap(),
"NNG,Korea,F,han-guk"
);
assert_eq!(
config.surface_features().get("中国").unwrap(),
"n,China,*,zhong1guo2"
);
}
#[test]
fn test_trainer_config_creation() {
let lexicon_data = "外国,0,0,5000,名詞,一般,*,*,*,*,外国,ガイコク,ガイコク\n人,1,1,5000,名詞,接尾,一般,*,*,*,人,ジン,ジン\n";
let char_data = "# char.def placeholder\n";
let unk_data = "# unk.def placeholder\n";
let feature_data = "UNIGRAM:%F[0]\nLEFT:%L[0]\nRIGHT:%R[0]\n";
let rewrite_data = "# rewrite.def placeholder\n";
let result = TrainerConfig::from_readers(
Cursor::new(lexicon_data.as_bytes()),
Cursor::new(char_data.as_bytes()),
Cursor::new(unk_data.as_bytes()),
Cursor::new(feature_data.as_bytes()),
Cursor::new(rewrite_data.as_bytes()),
);
assert!(result.is_ok());
let config = result.unwrap();
assert_eq!(config.surfaces().len(), 2);
assert!(config.surfaces().contains(&"外国".to_string()));
assert!(config.surfaces().contains(&"人".to_string()));
}
#[test]
fn test_unk_categories_ipadic() {
let lexicon_data = "東京,0,0,5000,名詞,固有名詞,地域,一般,*,*,東京,トウキョウ,トーキョー\n";
let char_data = "DEFAULT 0 1 0\nHIRAGANA 1 1 0\n";
let unk_data = "DEFAULT,0,0,1500,名詞,一般,*,*,*,*,*,*,*\nHIRAGANA,1,1,2000,名詞,代名詞,一般,*,*,*,*,*,*\n";
let feature_data = "UNIGRAM:%F[0]\n";
let rewrite_data = "*\tUNK\n";
let config = TrainerConfig::from_readers(
Cursor::new(lexicon_data),
Cursor::new(char_data),
Cursor::new(unk_data),
Cursor::new(feature_data),
Cursor::new(rewrite_data),
)
.unwrap();
assert_eq!(config.unk_categories.len(), 2);
assert_eq!(
config.unk_categories.get("DEFAULT").unwrap(),
"名詞,一般,*,*,*,*,*,*,*"
);
assert_eq!(
config.unk_categories.get("HIRAGANA").unwrap(),
"名詞,代名詞,一般,*,*,*,*,*,*"
);
}
#[test]
fn test_unk_categories_ko_dic() {
let lexicon_data = "한국,0,0,5000,NNG,Korea,F,han-guk\n";
let char_data = "DEFAULT 0 1 0\n";
let unk_data = "DEFAULT,0,0,1500,NNG,unknown,F,*\n";
let feature_data = "UNIGRAM:%F[0]\n";
let rewrite_data = "*\tUNK\n";
let config = TrainerConfig::from_readers(
Cursor::new(lexicon_data),
Cursor::new(char_data),
Cursor::new(unk_data),
Cursor::new(feature_data),
Cursor::new(rewrite_data),
)
.unwrap();
assert_eq!(config.unk_categories.len(), 1);
assert_eq!(
config.unk_categories.get("DEFAULT").unwrap(),
"NNG,unknown,F,*"
);
}
#[test]
fn test_unk_categories_cc_cedict() {
let lexicon_data = "中国,0,0,5000,n,China,*,zhong1guo2\n";
let char_data = "DEFAULT 0 1 0\n";
let unk_data = "DEFAULT,0,0,1500,n,unknown,*,*\n";
let feature_data = "UNIGRAM:%F[0]\n";
let rewrite_data = "*\tUNK\n";
let config = TrainerConfig::from_readers(
Cursor::new(lexicon_data),
Cursor::new(char_data),
Cursor::new(unk_data),
Cursor::new(feature_data),
Cursor::new(rewrite_data),
)
.unwrap();
assert_eq!(config.unk_categories.len(), 1);
assert_eq!(
config.unk_categories.get("DEFAULT").unwrap(),
"n,unknown,*,*"
);
}
}