use std::io::{BufWriter, Read, Write};
use std::num::NonZeroU32;
use bincode::{Decode, Encode};
use hashbrown::HashMap;
use crate::common;
use crate::dictionary::lexicon::Lexicon;
use crate::dictionary::word_idx::WordIdx;
use crate::dictionary::{LexType, WordParam};
use crate::errors::Result;
pub use crate::trainer::config::TrainerConfig;
use crate::trainer::corpus::Word;
pub use crate::trainer::Trainer;
use crate::utils::{self, FromU32};
#[derive(Decode, Encode)]
pub struct ModelData {
pub config: TrainerConfig,
pub raw_model: rucrf::RawModel,
}
pub struct Model {
pub(crate) data: ModelData,
pub(crate) merged_model: Option<rucrf::MergedModel>,
pub(crate) user_entries: Vec<(Word, WordParam, NonZeroU32)>,
}
impl Model {
pub fn read_user_lexicon<R>(&mut self, mut rdr: R) -> Result<()>
where
R: Read,
{
let mut bytes = vec![];
rdr.read_to_end(&mut bytes)?;
self.merged_model = None;
let entries = Lexicon::parse_csv(&bytes, "user.csv")?;
for entry in entries {
let first_char = entry.surface.chars().next().unwrap();
let cate_id = self
.data
.config
.dict
.char_prop()
.char_info(first_char)
.base_id();
let feature_set = Trainer::extract_feature_set(
&mut self.data.config.feature_extractor,
&self.data.config.unigram_rewriter,
&self.data.config.left_rewriter,
&self.data.config.right_rewriter,
entry.feature,
cate_id,
);
let label_id = self
.data
.raw_model
.feature_provider()
.add_feature_set(feature_set)?;
self.user_entries.push((
Word::new(&entry.surface, entry.feature),
entry.param,
label_id,
));
}
Ok(())
}
pub fn write_bigram_details<L, R, C>(
&mut self,
left_wtr: L,
right_wtr: R,
cost_wtr: C,
) -> Result<()>
where
L: Write,
R: Write,
C: Write,
{
if self.merged_model.is_none() {
self.merged_model = Some(self.data.raw_model.merge()?);
}
let merged_model = self.merged_model.as_ref().unwrap();
let mut weight_abs_max = 0f64;
for feature_set in &merged_model.feature_sets {
weight_abs_max = weight_abs_max.max(feature_set.weight.abs());
}
for hm in &merged_model.matrix {
for &w in hm.values() {
weight_abs_max = weight_abs_max.max(w.abs());
}
}
let weight_scale_factor = f64::from(i16::MAX) / weight_abs_max;
let feature_extractor = &self.data.config.feature_extractor;
let mut right_features = HashMap::new();
for (feature, idx) in feature_extractor.right_feature_ids().iter() {
right_features.insert(idx.get(), feature);
}
let feature_list = &merged_model.left_conn_to_right_feats;
let mut left_wtr = BufWriter::new(left_wtr);
for (conn_id, feat_ids) in feature_list[..feature_list.len()].iter().enumerate() {
write!(&mut left_wtr, "{}\t", conn_id + 1)?;
for (i, feat_id) in feat_ids.iter().enumerate() {
if i != 0 {
write!(&mut left_wtr, ",")?;
}
if let Some(feat_id) = feat_id {
let feat_str = right_features.get(&feat_id.get()).unwrap();
utils::quote_csv_cell(&mut left_wtr, feat_str.as_bytes())?;
} else {
write!(&mut left_wtr, "*")?;
}
}
writeln!(&mut left_wtr)?;
}
let mut left_features = HashMap::new();
for (feature, idx) in feature_extractor.left_feature_ids().iter() {
left_features.insert(idx.get(), feature);
}
let feature_list = &merged_model.right_conn_to_left_feats;
let mut right_wtr = BufWriter::new(right_wtr);
for (conn_id, feat_ids) in feature_list[..feature_list.len()].iter().enumerate() {
write!(&mut right_wtr, "{}\t", conn_id + 1)?;
for (i, feat_id) in feat_ids.iter().enumerate() {
if i != 0 {
write!(&mut right_wtr, ",")?;
}
if let Some(feat_id) = feat_id {
let feat_str = left_features.get(&feat_id.get()).unwrap();
utils::quote_csv_cell(&mut right_wtr, feat_str.as_bytes())?;
} else {
write!(&mut right_wtr, "*")?;
}
}
writeln!(&mut right_wtr)?;
}
let mut cost_wtr = BufWriter::new(cost_wtr);
for (left_feat_id, hm) in self
.data
.raw_model
.bigram_weight_indices()
.iter()
.enumerate()
{
let left_feat_str = left_features
.get(&u32::try_from(left_feat_id).unwrap())
.map_or("", |x| x.as_str());
for (right_feat_id, widx) in hm {
let right_feat_str = right_features.get(right_feat_id).map_or("", |x| x.as_str());
let w = self.data.raw_model.weights()[usize::from_u32(*widx)];
let cost = (-w * weight_scale_factor) as i32;
writeln!(&mut cost_wtr, "{left_feat_str}/{right_feat_str}\t{cost}")?;
}
}
Ok(())
}
pub fn write_dictionary<L, C, U, S>(
&mut self,
lexicon_wtr: L,
connector_wtr: C,
unk_handler_wtr: U,
user_lexicon_wtr: S,
) -> Result<()>
where
L: Write,
C: Write,
U: Write,
S: Write,
{
if self.merged_model.is_none() {
self.merged_model = Some(self.data.raw_model.merge()?);
}
let merged_model = self.merged_model.as_ref().unwrap();
let mut lexicon_wtr = BufWriter::new(lexicon_wtr);
let mut unk_handler_wtr = BufWriter::new(unk_handler_wtr);
let mut connector_wtr = BufWriter::new(connector_wtr);
let mut user_lexicon_wtr = BufWriter::new(user_lexicon_wtr);
let mut weight_abs_max = 0f64;
for feature_set in &merged_model.feature_sets {
weight_abs_max = weight_abs_max.max(feature_set.weight.abs());
}
for hm in &merged_model.matrix {
for &w in hm.values() {
weight_abs_max = weight_abs_max.max(w.abs());
}
}
let weight_scale_factor = f64::from(i16::MAX) / weight_abs_max;
let config = &self.data.config;
for i in 0..config.surfaces.len() {
let feature_set = merged_model.feature_sets[i];
let word_idx = WordIdx::new(LexType::System, u32::try_from(i).unwrap());
let feature = config.dict.system_lexicon().word_feature(word_idx);
utils::quote_csv_cell(&mut lexicon_wtr, config.surfaces[i].as_bytes())?;
writeln!(
&mut lexicon_wtr,
",{},{},{},{}",
feature_set.left_id,
feature_set.right_id,
(-feature_set.weight * weight_scale_factor) as i16,
feature,
)?;
}
for i in 0..config.dict.unk_handler().len() {
let word_idx = WordIdx::new(LexType::Unknown, u32::try_from(i).unwrap());
let cate_id = config.dict.unk_handler().word_cate_id(word_idx);
let feature = config.dict.unk_handler().word_feature(word_idx);
let cate_string = config
.dict
.char_prop()
.cate_str(u32::from(cate_id))
.unwrap();
let feature_set = merged_model.feature_sets[config.surfaces.len() + i];
writeln!(
&mut unk_handler_wtr,
"{},{},{},{},{}",
cate_string,
feature_set.left_id,
feature_set.right_id,
(-feature_set.weight * weight_scale_factor) as i16,
feature,
)?;
}
writeln!(
&mut connector_wtr,
"{} {}",
merged_model.right_conn_to_left_feats.len() + 1,
merged_model.left_conn_to_right_feats.len() + 1,
)?;
for (right_conn_id, hm) in merged_model.matrix.iter().enumerate() {
let mut pairs: Vec<_> = hm.iter().map(|(&j, &w)| (j, w)).collect();
pairs.sort_unstable_by_key(|&(k, _)| k);
for (left_conn_id, w) in pairs {
writeln!(
&mut connector_wtr,
"{} {} {}",
right_conn_id,
left_conn_id,
(-w * weight_scale_factor) as i16
)?;
}
}
for (word, param, label_id) in &self.user_entries {
let feature_set = merged_model.feature_sets[usize::from_u32(label_id.get() - 1)];
utils::quote_csv_cell(&mut user_lexicon_wtr, word.surface().as_bytes())?;
if *param == WordParam::default() {
writeln!(
&mut user_lexicon_wtr,
",{},{},{},{}",
feature_set.left_id,
feature_set.right_id,
(-feature_set.weight * weight_scale_factor) as i16,
word.feature(),
)?;
} else {
writeln!(
&mut user_lexicon_wtr,
",{},{},{},{}",
param.left_id,
param.right_id,
param.word_cost,
word.feature(),
)?;
}
}
Ok(())
}
pub fn write_model<W>(&self, mut wtr: W) -> Result<usize>
where
W: Write,
{
let num_bytes =
bincode::encode_into_std_write(&self.data, &mut wtr, common::bincode_config())?;
Ok(num_bytes)
}
pub fn read_model<R>(mut rdr: R) -> Result<Self>
where
R: Read,
{
let data = bincode::decode_from_std_read(&mut rdr, common::bincode_config())?;
Ok(Self {
data,
merged_model: None,
user_entries: vec![],
})
}
}