use alloc::string::String;
use alloc::vec::Vec;
#[cfg(feature = "std")]
use std::io::{Read, Write};
use bincode::{Decode, Encode};
use crate::dict_model::{DictModel, WordWeightRecord};
use crate::errors::{Result, VaporettoError};
use crate::ngram_model::{NgramModel, TagNgramModel};
use crate::utils::VecWriter;
const MODEL_MAGIC: &[u8] = b"VaporettoTokenizer 0.5.0\n";
#[derive(Debug, Decode, Encode)]
pub struct TagModel {
pub(crate) token: String,
pub(crate) tags: Vec<Vec<String>>,
pub(crate) char_ngram_model: TagNgramModel<String>,
pub(crate) type_ngram_model: TagNgramModel<Vec<u8>>,
pub(crate) bias: Vec<i32>,
}
impl TagModel {
pub fn token(&self) -> &str {
&self.token
}
}
#[derive(Debug)]
pub struct Model(pub(crate) ModelData);
#[derive(Debug, Decode, Encode)]
pub struct ModelData {
pub(crate) char_ngram_model: NgramModel<String>,
pub(crate) type_ngram_model: NgramModel<Vec<u8>>,
pub(crate) dict_model: DictModel,
pub(crate) bias: i32,
pub(crate) char_window_size: u8,
pub(crate) type_window_size: u8,
pub(crate) tag_models: Vec<TagModel>,
}
impl Model {
#[cfg(any(feature = "train", feature = "kytea", test))]
pub(crate) const fn new(
char_ngram_model: NgramModel<String>,
type_ngram_model: NgramModel<Vec<u8>>,
dict_model: DictModel,
bias: i32,
char_window_size: u8,
type_window_size: u8,
tag_models: Vec<TagModel>,
) -> Self {
Self(ModelData {
char_ngram_model,
type_ngram_model,
dict_model,
bias,
char_window_size,
type_window_size,
tag_models,
})
}
pub fn to_vec(&self) -> Result<Vec<u8>> {
let mut wtr = VecWriter(MODEL_MAGIC.to_vec());
let config = bincode::config::standard();
bincode::encode_into_writer(&self.0, &mut wtr, config)?;
Ok(wtr.0)
}
#[cfg(feature = "std")]
pub fn write<W>(&self, mut wtr: W) -> Result<()>
where
W: Write,
{
wtr.write_all(MODEL_MAGIC)?;
let config = bincode::config::standard();
bincode::encode_into_std_write(&self.0, &mut wtr, config)?;
Ok(())
}
pub fn read_slice(slice: &[u8]) -> Result<(Self, &[u8])> {
if &slice[..MODEL_MAGIC.len()] != MODEL_MAGIC {
return Err(VaporettoError::invalid_model("model version mismatch"));
}
let config = bincode::config::standard();
let (data, size) = bincode::decode_from_slice(&slice[MODEL_MAGIC.len()..], config)?;
Ok((Self(data), &slice[MODEL_MAGIC.len() + size..]))
}
#[cfg(feature = "std")]
pub fn read<R>(mut rdr: R) -> Result<Self>
where
R: Read,
{
let mut magic = [0; MODEL_MAGIC.len()];
rdr.read_exact(&mut magic)?;
if magic != MODEL_MAGIC {
return Err(VaporettoError::invalid_model("model version mismatch"));
}
let config = bincode::config::standard();
Ok(Self(bincode::decode_from_std_read(&mut rdr, config)?))
}
pub fn dictionary(&self) -> &[WordWeightRecord] {
self.0.dict_model.dictionary()
}
pub fn replace_dictionary(&mut self, dict: Vec<WordWeightRecord>) {
self.0.dict_model = DictModel::new(dict);
}
pub fn tag_models(&self) -> &[TagModel] {
&self.0.tag_models
}
}