use crate::cognition::knowledge::KnowledgeIndex;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InformalLearner {
pub reward_threshold: f64,
pub min_count: usize,
pub pmi_threshold: f64,
token_counts: HashMap<String, usize>,
pair_counts: HashMap<(String, String), usize>,
pub observation_count: usize,
pub synthesis_count: usize,
}
impl InformalLearner {
pub fn new(reward_threshold: f64, min_count: usize, pmi_threshold: f64) -> Self {
Self {
reward_threshold,
min_count: min_count.max(1),
pmi_threshold,
token_counts: HashMap::new(),
pair_counts: HashMap::new(),
observation_count: 0,
synthesis_count: 0,
}
}
pub fn observe(&mut self, text: &str, reward: f64) {
if reward < self.reward_threshold {
return;
}
let tokens: Vec<String> = tokenise(text);
if tokens.len() < 2 {
return;
}
for t in &tokens {
*self.token_counts.entry(t.clone()).or_insert(0) += 1;
}
for i in 0..tokens.len() {
for j in (i + 1)..tokens.len() {
let pair = canonical_pair(&tokens[i], &tokens[j]);
*self.pair_counts.entry(pair).or_insert(0) += 1;
}
}
self.observation_count += 1;
}
pub fn high_pmi_pairs(&self, threshold_override: f64) -> Vec<(String, String, f64)> {
let n = self.observation_count as f64;
if n < 1.0 {
return Vec::new();
}
let eff_threshold = threshold_override.max(self.pmi_threshold);
let mut results: Vec<(String, String, f64)> = self
.pair_counts
.iter()
.filter(|&(_, count)| *count >= self.min_count)
.filter_map(|((a, b), &c_ab)| {
let c_a = *self.token_counts.get(a)? as f64;
let c_b = *self.token_counts.get(b)? as f64;
let pmi = (c_ab as f64 * n / (c_a * c_b)).ln();
if pmi >= eff_threshold {
Some((a.clone(), b.clone(), pmi))
} else {
None
}
})
.collect();
results.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
results
}
pub fn synthesise_into(
&mut self,
index: &mut KnowledgeIndex,
top_k: usize,
threshold_override: f64,
) -> usize {
let pairs = self.high_pmi_pairs(threshold_override);
let mut added = 0;
for (a, b, _) in pairs.into_iter().take(top_k) {
let fact = format!("{a} and {b} are strongly related concepts in this domain");
let n = index.ingest_text("informal", &fact);
added += n;
self.synthesis_count += n;
}
added
}
pub fn vocabulary_size(&self) -> usize {
self.token_counts.len()
}
pub fn pair_count(&self) -> usize {
self.pair_counts.len()
}
}
fn tokenise(text: &str) -> Vec<String> {
text.split_whitespace()
.map(|w| {
w.chars()
.filter(|c| c.is_alphabetic())
.collect::<String>()
.to_ascii_lowercase()
})
.filter(|s| s.len() >= 3)
.collect()
}
fn canonical_pair(a: &str, b: &str) -> (String, String) {
if a <= b {
(a.to_string(), b.to_string())
} else {
(b.to_string(), a.to_string())
}
}