use rand::distr::Distribution;
use rand::distr::weighted::WeightedIndex;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, fs, io::Write};
#[derive(Serialize, Deserialize)]
pub struct Markov {
pub transitions: HashMap<(u8, u8), HashMap<u8, f64>>,
}
impl Markov {
pub fn train(names: &[String]) -> Self {
let mut transitions: HashMap<(u8, u8), HashMap<u8, f64>> = HashMap::new();
for name in names {
let bytes = name.as_bytes();
let mut p1 = b'^';
let mut p2 = b'^';
for ¤t in bytes.iter().chain(std::iter::once(&b'$')) {
transitions
.entry((p1, p2))
.or_default()
.entry(current)
.and_modify(|count| *count += 1.0)
.or_insert(1.0);
p1 = p2;
p2 = current;
if current == b'$' {
break;
}
}
}
for counts in transitions.values_mut() {
let sum: f64 = counts.values().sum();
for val in counts.values_mut() {
*val /= sum;
}
}
Self { transitions }
}
pub fn precompute_distributions(
&self,
smoothing: f64,
temperature: f64,
) -> HashMap<(u8, u8), (Vec<u8>, WeightedIndex<f64>)> {
let mut distributions = HashMap::new();
for (&state, inner_counts) in &self.transitions {
let mut choices = Vec::new();
let mut weights = Vec::new();
for c in (b'a'..=b'z').chain(std::iter::once(b'$')) {
let count = inner_counts.get(&c).copied().unwrap_or(0.0);
choices.push(c);
weights.push((count + smoothing).powf(1.0 / temperature));
}
if let Ok(dist) = WeightedIndex::new(weights) {
distributions.insert(state, (choices, dist));
}
}
distributions
}
pub fn write_transitions_to_file(&self, file_name: &str) -> bincode::Result<()> {
let bytes = bincode::serialize(&self.transitions)?;
let compressed = zstd::encode_all(&bytes[..], 3)?;
let mut file = fs::File::create(file_name)?;
file.write_all(&compressed)?;
Ok(())
}
pub fn read_transitions_from(file_name: &str) -> bincode::Result<Self> {
let compressed = fs::read(file_name)?;
let decompressed = zstd::decode_all(&compressed[..])?;
let data: Markov = bincode::deserialize(&decompressed)?;
Ok(data)
}
pub fn generate(
&self,
rng: &mut impl rand::Rng,
distributions: &HashMap<(u8, u8), (Vec<u8>, WeightedIndex<f64>)>,
) -> String {
let mut result = String::new();
let mut p1 = b'^';
let mut p2 = b'^';
loop {
let (choices, dist) = match distributions.get(&(p1, p2)) {
Some(data) => data,
None => break,
};
let next = choices[dist.sample(rng)];
if next == b'$' {
break;
}
result.push(next as char);
p1 = p2;
p2 = next;
}
result
}
}