use std::collections::HashMap;
use crate::embed::{Embeddings, Vector, cosine};
type Chain = HashMap<Vec<String>, HashMap<String, f32>>;
pub struct Markov {
chains: Vec<Chain>,
max_order: usize,
unigrams: HashMap<String, f32>,
}
fn tokenize(text: &str) -> Vec<String> {
text.split_whitespace()
.map(|w| w.to_string())
.filter(|w| !w.is_empty())
.collect()
}
fn detokenize(tokens: &[String]) -> String {
tokens.join(" ")
}
impl Markov {
pub fn train(texts: &[&str], max_order: usize) -> Self {
let mut chains: Vec<Chain> = (0..max_order).map(|_| HashMap::new()).collect();
let mut unigrams: HashMap<String, f32> = HashMap::new();
for text in texts {
let tokens = tokenize(text);
for tok in &tokens {
*unigrams.entry(tok.clone()).or_default() += 1.0;
}
for order in 1..=max_order {
for window in tokens.windows(order + 1) {
let ctx = window[..order].to_vec();
let next = window[order].clone();
*chains[order - 1]
.entry(ctx)
.or_default()
.entry(next)
.or_default() += 1.0;
}
}
}
Self { chains, max_order, unigrams }
}
fn candidates(&self, ctx: &[String]) -> Option<&HashMap<String, f32>> {
for order in (0..self.max_order).rev() {
if ctx.len() < order + 1 { continue; }
let key = ctx[ctx.len() - order - 1..].to_vec();
if let Some(nexts) = self.chains[order].get(&key) {
return Some(nexts);
}
}
None
}
pub fn generate(
&self,
seed: &[String],
max_tokens: usize,
style_vec: Option<&Vector>,
embeddings: Option<&Embeddings>,
temperature: f32,
) -> Vec<String> {
let mut output = seed.to_vec();
for _ in 0..max_tokens {
let ctx = &output;
let raw = match self.candidates(ctx) {
Some(c) => c,
None => &self.unigrams,
};
let mut scored: Vec<(String, f32)> = raw
.iter()
.map(|(tok, &freq)| {
let mut score = freq;
if let (Some(sv), Some(emb)) = (style_vec, embeddings) {
if let Some(&idx) = emb.vocab.get(tok) {
let sim = cosine(&emb.vectors[idx], sv);
score *= (1.0 + sim).max(0.01);
}
}
(tok.clone(), score)
})
.collect();
let total: f32 = scored.iter().map(|(_, s)| s / temperature).sum();
if total <= 0.0 { break; }
let mut r = rand_f32() * total;
let mut chosen = scored[0].0.clone();
for (tok, score) in &scored {
r -= score / temperature;
if r <= 0.0 {
chosen = tok.clone();
break;
}
}
let stop = [".", "!", "?"].contains(&chosen.as_str());
output.push(chosen);
if stop { break; }
}
output[seed.len()..].to_vec()
}
pub fn generate_from(
&self,
prompt: &str,
context_msgs: &[&str],
style_vec: Option<&Vector>,
embeddings: Option<&Embeddings>,
max_tokens: usize,
temperature: f32,
) -> String {
let mut seed: Vec<String> = context_msgs
.iter()
.flat_map(|m| tokenize(m))
.collect();
seed.extend(tokenize(prompt));
let generated = self.generate(&seed, max_tokens, style_vec, embeddings, temperature);
detokenize(&generated)
}
}
fn rand_f32() -> f32 {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.subsec_nanos();
(nanos % 100_000) as f32 / 100_000.0
}