use crate::errors::{MarkovError, Result};
use fxhash::FxHashMap;
use rand::Rng;
use serde::{Deserialize, Serialize};
pub const BEGIN: &str = "___BEGIN__";
pub const END: &str = "___END__";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompiledNext {
pub words: Vec<String>,
pub cumulative_weights: Vec<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Chain {
state_size: usize,
model: FxHashMap<Vec<String>, FxHashMap<String, usize>>,
compiled: bool,
#[serde(skip)]
compiled_model: FxHashMap<Vec<String>, CompiledNext>,
#[serde(skip)]
begin_choices: Option<Vec<String>>,
#[serde(skip)]
begin_cumdist: Option<Vec<usize>>,
}
impl Chain {
pub fn new(corpus: &[Vec<String>], state_size: usize) -> Self {
let model = Self::build(corpus, state_size);
let mut chain = Chain {
state_size,
model,
compiled: false,
compiled_model: FxHashMap::default(),
begin_choices: None,
begin_cumdist: None,
};
chain.precompute_begin_state();
chain
}
fn build(
corpus: &[Vec<String>],
state_size: usize,
) -> FxHashMap<Vec<String>, FxHashMap<String, usize>> {
let mut model: FxHashMap<Vec<String>, FxHashMap<String, usize>> = FxHashMap::default();
for run in corpus {
let mut items: Vec<String> = vec![BEGIN.to_string(); state_size];
items.extend(run.iter().cloned());
items.push(END.to_string());
for i in 0..=run.len() {
let state: Vec<String> = items[i..i + state_size].to_vec();
let follow = items[i + state_size].clone();
let next_dict = model.entry(state).or_default();
*next_dict.entry(follow).or_insert(0) += 1;
}
}
model
}
fn precompute_begin_state(&mut self) {
let begin_state: Vec<String> = vec![BEGIN.to_string(); self.state_size];
if let Some(next_dict) = self.model.get(&begin_state) {
let (choices, cumdist) = Self::compile_next_dict(next_dict);
self.begin_choices = Some(choices);
self.begin_cumdist = Some(cumdist);
}
}
fn compile_next_dict(next_dict: &FxHashMap<String, usize>) -> (Vec<String>, Vec<usize>) {
let mut words = Vec::with_capacity(next_dict.len());
let mut cumulative_weights = Vec::with_capacity(next_dict.len());
let mut cumsum = 0;
for (word, &weight) in next_dict.iter() {
words.push(word.clone());
cumsum += weight;
cumulative_weights.push(cumsum);
}
(words, cumulative_weights)
}
pub fn compile(&self) -> Self {
let mut compiled_model: FxHashMap<Vec<String>, CompiledNext> = FxHashMap::default();
for (state, next_dict) in &self.model {
let (words, cumulative_weights) = Self::compile_next_dict(next_dict);
compiled_model.insert(
state.clone(),
CompiledNext {
words,
cumulative_weights,
},
);
}
Chain {
state_size: self.state_size,
model: self.model.clone(),
compiled: true,
compiled_model,
begin_choices: self.begin_choices.clone(),
begin_cumdist: self.begin_cumdist.clone(),
}
}
fn move_state(&self, state: &[String]) -> Option<String> {
let (choices, cumdist) = if self.compiled {
if let Some(compiled) = self.compiled_model.get(state) {
(&compiled.words, &compiled.cumulative_weights)
} else {
return None;
}
} else if state.iter().all(|s| s == BEGIN) {
if let (Some(choices), Some(cumdist)) = (&self.begin_choices, &self.begin_cumdist) {
(choices, cumdist)
} else {
return None;
}
} else {
if let Some(next_dict) = self.model.get(state) {
let (choices, cumdist) = Self::compile_next_dict(next_dict);
return Self::select_random(&choices, &cumdist);
} else {
return None;
}
};
if cumdist.is_empty() {
return None;
}
Self::select_random(choices, cumdist)
}
fn select_random(choices: &[String], cumdist: &[usize]) -> Option<String> {
if cumdist.is_empty() {
return None;
}
let mut rng = rand::thread_rng();
let r = rng.gen_range(0..cumdist[cumdist.len() - 1]);
let idx = cumdist.partition_point(|&x| x <= r);
if idx < choices.len() {
Some(choices[idx].clone())
} else {
Some(choices[choices.len() - 1].clone())
}
}
pub fn gen(&self, init_state: Option<&[String]>) -> ChainGenerator<'_> {
let state = init_state
.map(|s| s.to_vec())
.unwrap_or_else(|| vec![BEGIN.to_string(); self.state_size]);
ChainGenerator {
chain: self,
state,
done: false,
}
}
pub fn walk(&self, init_state: Option<&[String]>) -> Vec<String> {
self.gen(init_state).collect()
}
pub fn state_size(&self) -> usize {
self.state_size
}
pub fn model(&self) -> &FxHashMap<Vec<String>, FxHashMap<String, usize>> {
&self.model
}
pub fn is_compiled(&self) -> bool {
self.compiled
}
pub fn to_json(&self) -> Result<String> {
let items: Vec<(Vec<String>, FxHashMap<String, usize>)> = self
.model
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
Ok(serde_json::to_string(&items)?)
}
pub fn from_json(json_str: &str) -> Result<Self> {
let items: Vec<(Vec<String>, FxHashMap<String, usize>)> = serde_json::from_str(json_str)?;
if items.is_empty() {
return Err(MarkovError::ModelFormatError("Empty model".to_string()));
}
let state_size = items[0].0.len();
let model: FxHashMap<Vec<String>, FxHashMap<String, usize>> = items.into_iter().collect();
let mut chain = Chain {
state_size,
model,
compiled: false,
compiled_model: FxHashMap::default(),
begin_choices: None,
begin_cumdist: None,
};
chain.precompute_begin_state();
Ok(chain)
}
pub fn from_combined_model(
model: FxHashMap<Vec<String>, FxHashMap<String, usize>>,
state_size: usize,
) -> Self {
let mut chain = Chain {
state_size,
model,
compiled: false,
compiled_model: FxHashMap::default(),
begin_choices: None,
begin_cumdist: None,
};
chain.precompute_begin_state();
chain
}
}
pub struct ChainGenerator<'a> {
chain: &'a Chain,
state: Vec<String>,
done: bool,
}
impl<'a> Iterator for ChainGenerator<'a> {
type Item = String;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
if let Some(next_word) = self.chain.move_state(&self.state) {
if next_word == END {
self.done = true;
return None;
}
self.state.remove(0);
self.state.push(next_word.clone());
Some(next_word)
} else {
self.done = true;
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chain_creation() {
let corpus = vec![
vec!["hello".to_string(), "world".to_string()],
vec!["hello".to_string(), "rust".to_string()],
];
let chain = Chain::new(&corpus, 1);
assert_eq!(chain.state_size(), 1);
}
#[test]
fn test_chain_walk() {
let corpus = vec![vec![
"the".to_string(),
"cat".to_string(),
"sat".to_string(),
]];
let chain = Chain::new(&corpus, 1);
let result = chain.walk(None);
assert!(!result.is_empty());
}
#[test]
fn test_chain_json_serialization() {
let corpus = vec![vec!["hello".to_string(), "world".to_string()]];
let chain = Chain::new(&corpus, 1);
let json = chain.to_json().unwrap();
let restored = Chain::from_json(&json).unwrap();
assert_eq!(chain.state_size(), restored.state_size());
}
}