use std::collections::{BTreeMap, HashMap};
use std::io::{Read, Write};
use crate::persistence::ModelError;
use crate::seq_feature::SeqFeatureTemplate;
use crate::trie::CountTrie;
use crate::wordseg::{BaseHiddenMarkovModelSegmenter, HiddenMarkovModelSegmenter};
fn write_section<W: Write>(writer: &mut W, data: &[u8]) -> Result<(), ModelError> {
let len = data.len() as u64;
writer
.write_all(&len.to_le_bytes())
.map_err(|e| ModelError::Io(format!("Failed to write section length: {e}")))?;
writer
.write_all(data)
.map_err(|e| ModelError::Io(format!("Failed to write section data: {e}")))
}
fn read_section<R: Read>(reader: &mut R) -> Result<Vec<u8>, ModelError> {
let mut len_buf = [0u8; 8];
reader
.read_exact(&mut len_buf)
.map_err(|e| ModelError::ParseError(format!("Failed to read section length: {e}")))?;
let len = u64::from_le_bytes(len_buf) as usize;
let mut data = vec![0u8; len];
reader
.read_exact(&mut data)
.map_err(|e| ModelError::ParseError(format!("Failed to read section data: {e}")))?;
Ok(data)
}
#[derive(Clone)]
pub struct DagHmmSegmenter {
dict: CountTrie<char>,
total: f64,
hmm: HiddenMarkovModelSegmenter,
}
impl DagHmmSegmenter {
pub fn new(
n_iter: Option<usize>,
tolerance: Option<f64>,
gamma: Option<f64>,
random_seed: Option<u64>,
features: Option<Vec<SeqFeatureTemplate>>,
) -> Self {
Self {
dict: CountTrie::new(),
total: 0.0,
hmm: HiddenMarkovModelSegmenter::new(n_iter, tolerance, gamma, random_seed, features),
}
}
#[allow(clippy::needless_range_loop)]
fn get_dag(&self, chars: &[char]) -> Vec<Vec<usize>> {
let n = chars.len();
let mut dag: Vec<Vec<usize>> = vec![Vec::new(); n];
for i in 0..n {
let mut node = self.dict.get_node(std::iter::empty::<char>());
for j in i..n {
node = node.and_then(|nd| nd.get_child(&chars[j]));
match node {
None => break,
Some(nd) => {
if nd.terminal().copied().unwrap_or(0) > 0 {
dag[i].push(j);
}
}
}
}
if dag[i].is_empty() {
dag[i].push(i);
}
}
dag
}
fn calc_route(&self, chars: &[char], dag: &[Vec<usize>]) -> Vec<(f64, usize)> {
let n = chars.len();
let log_total = self.total.ln();
let mut route: Vec<(f64, usize)> = vec![(0.0, 0); n + 1];
for i in (0..n).rev() {
let mut best = (f64::NEG_INFINITY, 0usize);
for &j in &dag[i] {
let word_freq = self.dict.get_count(chars[i..=j].iter().copied());
let log_prob = if word_freq > 0 {
(word_freq as f64).ln() - log_total
} else {
-log_total
};
let candidate = log_prob + route[j + 1].0;
if candidate > best.0 {
best = (candidate, j + 1);
}
}
route[i] = best;
}
route
}
fn segment_one(&self, sent_str: &str) -> Vec<String> {
let chars: Vec<char> = sent_str.chars().collect();
if chars.is_empty() {
return Vec::new();
}
if self.dict.is_empty() {
let result =
BaseHiddenMarkovModelSegmenter::predict(&self.hmm, vec![sent_str.to_string()]);
return result.into_iter().next().unwrap_or_default();
}
let dag = self.get_dag(&chars);
let route = self.calc_route(&chars, &dag);
let mut words: Vec<String> = Vec::new();
let mut oov_buf: Vec<char> = Vec::new();
let mut i = 0;
while i < chars.len() {
let next = route[i].1;
let word_len = next - i;
if word_len == 1 {
oov_buf.push(chars[i]);
} else {
if !oov_buf.is_empty() {
Self::flush_oov_buf(&mut oov_buf, &self.dict, &self.hmm, &mut words);
}
let word: String = chars[i..next].iter().collect();
words.push(word);
}
i = next;
}
if !oov_buf.is_empty() {
Self::flush_oov_buf(&mut oov_buf, &self.dict, &self.hmm, &mut words);
}
words
}
fn flush_oov_buf(
oov_buf: &mut Vec<char>,
dict: &CountTrie<char>,
hmm: &HiddenMarkovModelSegmenter,
words: &mut Vec<String>,
) {
if oov_buf.is_empty() {
return;
}
let span: String = oov_buf.iter().collect();
oov_buf.clear();
if span.chars().count() == 1 {
words.push(span);
} else if dict.get_count(span.chars()) > 0 {
words.extend(span.chars().map(|c| c.to_string()));
} else {
let hmm_result = BaseHiddenMarkovModelSegmenter::predict(hmm, vec![span]);
if let Some(hmm_words) = hmm_result.into_iter().next() {
words.extend(hmm_words);
}
}
}
pub fn score(&self, sents: Vec<Vec<String>>) -> Result<Vec<f64>, ModelError> {
BaseHiddenMarkovModelSegmenter::score(&self.hmm, sents)
}
pub fn predict(&self, sent_strs: Vec<String>) -> Vec<Vec<String>> {
sent_strs.iter().map(|s| self.segment_one(s)).collect()
}
pub fn predict_with_offsets(
&self,
sent_strs: Vec<String>,
) -> Vec<Vec<(String, (usize, usize))>> {
let words = self.predict(sent_strs);
crate::wordseg::attach_offsets(words)
}
pub fn fit_segmented(&mut self, sents: Vec<Vec<String>>) {
let mut dict = CountTrie::new();
let mut total: u64 = 0;
for sent in &sents {
for word in sent {
dict.increment(word.chars());
total += 1;
}
}
self.dict = dict;
self.total = total as f64;
BaseHiddenMarkovModelSegmenter::fit_segmented(&mut self.hmm, sents);
}
pub fn fit_unsegmented(&mut self, sent_strs: Vec<String>) {
BaseHiddenMarkovModelSegmenter::fit_unsegmented(&mut self.hmm, sent_strs);
}
pub fn save(&self, path: &str, metadata: &HashMap<String, String>) -> Result<(), ModelError> {
let mut file = std::fs::File::create(path)
.map_err(|e| ModelError::Io(format!("Failed to create file: {e}")))?;
file.write_all(b"PSEG")
.map_err(|e| ModelError::Io(format!("Failed to write magic: {e}")))?;
let freq_map: BTreeMap<String, u64> = self
.dict
.all_counts()
.into_iter()
.map(|(chars, count)| (chars.into_iter().collect::<String>(), count))
.collect();
let dict_bytes = serde_json::to_vec(&freq_map)
.map_err(|e| ModelError::Io(format!("Failed to serialize dict: {e}")))?;
write_section(&mut file, &dict_bytes)?;
let mut hmm_bytes = Vec::new();
BaseHiddenMarkovModelSegmenter::save_to_writer(&self.hmm, &mut hmm_bytes)
.map_err(|e| ModelError::Io(format!("Failed to serialize HMM: {e:?}")))?;
let hmm_compressed = zstd::bulk::compress(&hmm_bytes, 19)
.map_err(|e| ModelError::Io(format!("Failed to compress HMM: {e}")))?;
write_section(&mut file, &hmm_compressed)?;
let meta_bytes = serde_json::to_vec(metadata)
.map_err(|e| ModelError::Io(format!("Failed to serialize metadata: {e}")))?;
write_section(&mut file, &meta_bytes)?;
Ok(())
}
pub fn load(&mut self, path: &str) -> Result<HashMap<String, String>, ModelError> {
let mut file = std::fs::File::open(path)
.map_err(|_| ModelError::FileNotFound(format!("Can't locate model {path}")))?;
let mut magic = [0u8; 4];
file.read_exact(&mut magic)
.map_err(|e| ModelError::ParseError(format!("Failed to read magic: {e}")))?;
if &magic != b"PSEG" {
return Err(ModelError::ParseError(format!(
"Not a valid PSEG model file (bad magic: {:?})",
magic
)));
}
let dict_bytes = read_section(&mut file)?;
let word_counts: HashMap<String, u64> = serde_json::from_slice(&dict_bytes)
.map_err(|e| ModelError::ParseError(format!("Failed to parse dict: {e}")))?;
let mut dict = CountTrie::new();
let mut total: u64 = 0;
for (word, count) in word_counts {
dict.insert_count(word.chars(), count);
total += count;
}
self.dict = dict;
self.total = total as f64;
let hmm_compressed = read_section(&mut file)?;
let hmm_bytes = zstd::bulk::decompress(&hmm_compressed, 256 * 1024 * 1024)
.map_err(|e| ModelError::ParseError(format!("Failed to decompress HMM: {e}")))?;
BaseHiddenMarkovModelSegmenter::load_from_reader(&mut self.hmm, hmm_bytes.as_slice())
.map_err(|e| ModelError::ParseError(format!("Failed to load HMM: {e:?}")))?;
let meta_bytes = read_section(&mut file)?;
let metadata: HashMap<String, String> = serde_json::from_slice(&meta_bytes)
.map_err(|e| ModelError::ParseError(format!("Failed to parse metadata: {e}")))?;
Ok(metadata)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn training_data() -> Vec<Vec<String>> {
vec![
vec!["你好".into(), "世界".into()],
vec!["我".into(), "喜歡".into(), "你".into()],
vec!["你好".into(), "我".into(), "喜歡".into(), "世界".into()],
]
}
#[test]
fn test_get_dag() {
let mut seg = DagHmmSegmenter::new(None, None, None, None, None);
seg.dict.insert_count("你好".chars(), 2);
seg.dict.insert_count("你".chars(), 5);
seg.dict.insert_count("好".chars(), 3);
seg.total = 10.0;
let chars: Vec<char> = "你好".chars().collect();
let dag = seg.get_dag(&chars);
assert!(dag[0].contains(&0)); assert!(dag[0].contains(&1)); assert!(dag[1].contains(&1)); }
#[test]
fn test_calc_route_prefers_longer_frequent_word() {
let mut seg = DagHmmSegmenter::new(None, None, None, None, None);
seg.dict.insert_count("你好".chars(), 10);
seg.dict.insert_count("你".chars(), 1);
seg.dict.insert_count("好".chars(), 1);
seg.total = 12.0;
let chars: Vec<char> = "你好".chars().collect();
let dag = seg.get_dag(&chars);
let route = seg.calc_route(&chars, &dag);
assert_eq!(route[0].1, 2); }
#[test]
fn test_fit_and_predict_known_words() {
let mut seg = DagHmmSegmenter::new(None, None, None, None, None);
seg.fit_segmented(training_data());
let result = seg.predict(vec!["你好世界".into()]);
assert_eq!(result[0], vec!["你好", "世界"]);
}
#[test]
fn test_predict_empty_string() {
let mut seg = DagHmmSegmenter::new(None, None, None, None, None);
seg.fit_segmented(training_data());
let result = seg.predict(vec!["".into()]);
assert_eq!(result[0], Vec::<String>::new());
}
#[test]
fn test_predict_single_char_word() {
let mut seg = DagHmmSegmenter::new(None, None, None, None, None);
seg.fit_segmented(training_data());
let result = seg.predict(vec!["我喜歡你".into()]);
assert_eq!(result[0], vec!["我", "喜歡", "你"]);
}
#[test]
fn test_predict_in_dict_span_split_char_by_char() {
let mut seg = DagHmmSegmenter::new(None, None, None, None, None);
seg.dict.insert_count("你好".chars(), 1);
seg.dict.insert_count("你".chars(), 100);
seg.dict.insert_count("好".chars(), 100);
seg.total = 201.0;
let result = seg.predict(vec!["你好".into()]);
assert_eq!(result[0], vec!["你", "好"]);
}
#[test]
fn test_predict_oov_delegated_to_hmm() {
let mut seg = DagHmmSegmenter::new(None, None, None, None, None);
seg.fit_segmented(training_data());
let result = seg.predict(vec!["你好地球".into()]);
assert_eq!(result[0][0], "你好");
let rest: String = result[0][1..].join("");
assert_eq!(rest, "地球"); }
#[test]
fn test_score_not_fitted() {
let seg = DagHmmSegmenter::new(None, None, None, None, None);
let result = seg.score(vec![vec!["你好".into(), "世界".into()]]);
assert!(result.is_err());
}
#[test]
fn test_score_empty_input() {
let mut seg = DagHmmSegmenter::new(None, None, None, None, None);
seg.fit_segmented(training_data());
let result = seg.score(vec![]).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_score_returns_finite_values() {
let mut seg = DagHmmSegmenter::new(None, None, None, None, None);
seg.fit_segmented(training_data());
let scores = seg
.score(vec![
vec!["你好".into(), "世界".into()],
vec!["我".into(), "喜歡".into(), "你".into()],
])
.unwrap();
assert_eq!(scores.len(), 2);
for s in &scores {
assert!(s.is_finite(), "score should be finite, got {}", s);
}
}
#[test]
fn test_save_and_load_round_trip() {
let mut seg = DagHmmSegmenter::new(None, None, None, None, None);
seg.fit_segmented(training_data());
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("model.bin");
let path_str = path.to_str().unwrap();
let metadata: HashMap<String, String> =
[("foo".into(), "bar".into()), ("baz".into(), "qux".into())]
.into_iter()
.collect();
seg.save(path_str, &metadata).unwrap();
let mut loaded = DagHmmSegmenter::new(None, None, None, None, None);
let loaded_metadata = loaded.load(path_str).unwrap();
assert_eq!(metadata, loaded_metadata);
let input = vec!["你好世界".into(), "我喜歡你".into()];
assert_eq!(seg.predict(input.clone()), loaded.predict(input));
}
}