use crate::discovery::SymbolicRegression;
use crate::equation::Expression;
use crate::error::{LmmError, Result};
use std::collections::HashMap;
pub struct TextPredictor {
pub window_size: usize,
pub iterations: usize,
pub depth: usize,
}
struct WordVocab {
words: Vec<String>,
word_to_id: HashMap<String, usize>,
}
impl WordVocab {
fn build(text: &str) -> Self {
let mut words = Vec::new();
let mut word_to_id = HashMap::new();
for token in text.split_whitespace() {
if !word_to_id.contains_key(token) {
let id = words.len();
word_to_id.insert(token.to_string(), id);
words.push(token.to_string());
}
}
Self { words, word_to_id }
}
fn id_of(&self, word: &str) -> Option<usize> {
self.word_to_id.get(word).copied()
}
}
struct MarkovChain {
bigram: HashMap<(usize, usize), HashMap<usize, f64>>,
unigram: HashMap<usize, HashMap<usize, f64>>,
}
impl MarkovChain {
fn build(tokens: &[String], vocab: &WordVocab) -> Self {
let ids: Vec<usize> = tokens.iter().filter_map(|t| vocab.id_of(t)).collect();
let mut uni_counts: HashMap<usize, HashMap<usize, usize>> = HashMap::new();
for pair in ids.windows(2) {
*uni_counts
.entry(pair[0])
.or_default()
.entry(pair[1])
.or_insert(0) += 1;
}
let mut unigram: HashMap<usize, HashMap<usize, f64>> = HashMap::new();
for (from, nexts) in &uni_counts {
let total: usize = nexts.values().sum();
if total > 0 {
unigram.insert(
*from,
nexts
.iter()
.map(|(&to, &c)| (to, c as f64 / total as f64))
.collect(),
);
}
}
let mut bi_counts: HashMap<(usize, usize), HashMap<usize, usize>> = HashMap::new();
for triple in ids.windows(3) {
*bi_counts
.entry((triple[0], triple[1]))
.or_default()
.entry(triple[2])
.or_insert(0) += 1;
}
let mut bigram: HashMap<(usize, usize), HashMap<usize, f64>> = HashMap::new();
for (key, nexts) in &bi_counts {
let total: usize = nexts.values().sum();
if total > 0 {
bigram.insert(
*key,
nexts
.iter()
.map(|(&to, &c)| (to, c as f64 / total as f64))
.collect(),
);
}
}
Self { bigram, unigram }
}
fn prob(&self, prev2: Option<usize>, prev1: usize, next: usize) -> f64 {
let bigram_prob = prev2.and_then(|p2| {
self.bigram
.get(&(p2, prev1))
.and_then(|row| row.get(&next).copied())
});
if let Some(p) = bigram_prob {
return p;
}
self.unigram
.get(&prev1)
.and_then(|row| row.get(&next))
.copied()
.unwrap_or(0.0)
}
fn score(&self, prev2: Option<usize>, prev1: usize, next: usize) -> f64 {
let p = self.prob(prev2, prev1, next);
if p > 0.0 { 1.0 - p } else { 1.0 }
}
}
impl TextPredictor {
pub fn new(window_size: usize, iterations: usize, depth: usize) -> Self {
Self {
window_size,
iterations,
depth,
}
}
fn fit_trajectory(&self, positions: &[f64], word_ids: &[f64]) -> Result<Expression> {
if positions.len() < 2 {
return Err(LmmError::Discovery("Need at least 2 tokens".into()));
}
let inputs: Vec<Vec<f64>> = positions.iter().map(|&p| vec![p]).collect();
SymbolicRegression::new(self.depth, self.iterations)
.with_variables(vec!["x".into()])
.with_population(60)
.fit(&inputs, word_ids)
}
fn fit_rhythm(&self, positions: &[f64], lengths: &[f64]) -> Result<Expression> {
if positions.len() < 2 {
return Err(LmmError::Discovery("Need at least 2 tokens".into()));
}
let inputs: Vec<Vec<f64>> = positions.iter().map(|&p| vec![p]).collect();
SymbolicRegression::new(self.depth.min(3), self.iterations / 2)
.with_variables(vec!["x".into()])
.with_population(40)
.fit(&inputs, lengths)
}
fn suffix_match(
window_tokens: &[String],
recent: &[String],
max_suffix: usize,
) -> Option<String> {
let n = window_tokens.len();
for len in (1..=max_suffix.min(recent.len())).rev() {
let suffix = &recent[recent.len() - len..];
for wpos in 0..n.saturating_sub(len) {
if window_tokens[wpos..wpos + len] == *suffix {
let next_pos = wpos + len;
if next_pos < n {
return Some(window_tokens[next_pos].clone());
}
}
}
}
None
}
fn eval_score(eq: &Expression, pos: f64, target: f64, scale: f64) -> f64 {
let mut vars = HashMap::new();
vars.insert("x".to_string(), pos);
let pred = eq.evaluate(&vars).unwrap_or(0.0);
(pred - target).abs() / scale.max(1.0)
}
pub fn predict_continuation(
&self,
text: &str,
predict_length: usize,
) -> Result<PredictedContinuation> {
if text.is_empty() {
return Err(LmmError::Perception("Input text is empty".into()));
}
let all_tokens: Vec<String> = text.split_whitespace().map(String::from).collect();
if all_tokens.len() < 2 {
return Err(LmmError::Perception("Need at least 2 words".into()));
}
let vocab = WordVocab::build(text);
let window_start = all_tokens.len().saturating_sub(self.window_size);
let window_tokens = &all_tokens[window_start..];
let markov = MarkovChain::build(window_tokens, &vocab);
let positions: Vec<f64> = (0..window_tokens.len()).map(|i| i as f64).collect();
let word_ids: Vec<f64> = window_tokens
.iter()
.map(|t| vocab.id_of(t).unwrap_or(0) as f64)
.collect();
let lengths: Vec<f64> = window_tokens.iter().map(|t| t.len() as f64).collect();
let trajectory_eq = self.fit_trajectory(&positions, &word_ids)?;
let rhythm_eq = self.fit_rhythm(&positions, &lengths)?;
let vocab_size = vocab.words.len();
let traj_weight = if vocab_size <= 8 { 0.05 } else { 0.20 };
let markov_weight = 1.0 - traj_weight - 0.10 - 0.15;
let mut continuation = String::new();
let mut generated: Vec<String> =
window_tokens[window_tokens.len().saturating_sub(3)..].to_vec();
let mut prev2_id: Option<usize> = if window_tokens.len() >= 2 {
vocab.id_of(&window_tokens[window_tokens.len() - 2])
} else {
None
};
let mut prev1_id = vocab.id_of(window_tokens.last().unwrap()).unwrap_or(0);
let mut pos = window_tokens.len() as f64;
let mut recency_counts: HashMap<usize, usize> = HashMap::new();
while continuation.len() < predict_length {
let max_suffix = 2.min(generated.len());
let chosen_word =
if let Some(w) = Self::suffix_match(window_tokens, &generated, max_suffix) {
w
} else {
let mut best_id = 0;
let mut best_score = f64::MAX;
for (id, word) in vocab.words.iter().enumerate() {
if continuation.len() + word.len() + 1 > predict_length + 6 {
continue;
}
let m_score = markov.score(prev2_id, prev1_id, id);
let t_score =
Self::eval_score(&trajectory_eq, pos, id as f64, vocab_size as f64);
let r_score = Self::eval_score(&rhythm_eq, pos, word.len() as f64, 20.0);
let recency = *recency_counts.get(&id).unwrap_or(&0) as f64;
let consec = if id == prev1_id { 2.0 } else { 0.0 };
let penalty = recency * 0.4 + consec;
let composite = markov_weight * m_score
+ traj_weight * t_score
+ 0.10 * r_score
+ 0.15 * penalty;
if composite < best_score {
best_score = composite;
best_id = id;
}
}
vocab.words[best_id].clone()
};
let chosen_id = vocab.id_of(&chosen_word).unwrap_or(0);
continuation.push(' ');
continuation.push_str(&chosen_word);
*recency_counts.entry(chosen_id).or_insert(0) += 1;
generated.push(chosen_word);
if generated.len() > 4 {
generated.remove(0);
}
prev2_id = Some(prev1_id);
prev1_id = chosen_id;
pos += 1.0;
if recency_counts.len() == vocab_size && recency_counts.values().all(|&c| c >= 2) {
break;
}
}
Ok(PredictedContinuation {
trajectory_equation: trajectory_eq,
rhythm_equation: rhythm_eq,
window_used: window_tokens.len(),
continuation,
})
}
}
pub struct PredictedContinuation {
pub trajectory_equation: Expression,
pub rhythm_equation: Expression,
pub window_used: usize,
pub continuation: String,
}