use crate::storage::VectorStore;
use crate::types::Episode;
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct PropagationResult {
pub episodes_updated: usize,
pub total_change: f64,
pub max_change: f64,
}
pub fn bellman_propagate(
episodes: &[Episode],
vectors: &VectorStore,
learning_rate: f64,
discount: f64,
similarity_threshold: f64,
) -> (HashMap<Uuid, f64>, PropagationResult) {
let mut new_utilities: HashMap<Uuid, f64> = HashMap::new();
let mut total_change: f64 = 0.0;
let mut max_change: f64 = 0.0;
let mut updates = 0;
let utility_map: HashMap<Uuid, f64> = episodes.iter().map(|e| (e.id, e.utility)).collect();
for episode in episodes {
let Some(embedding) = vectors.get(episode.id) else {
new_utilities.insert(episode.id, episode.utility);
continue;
};
let similar = vectors.search(embedding, 10);
let mut weighted_utility = 0.0;
let mut weight_sum = 0.0;
for (similar_id, similarity) in similar {
if similar_id == episode.id {
continue;
}
if similarity < similarity_threshold {
continue;
}
if let Some(&other_utility) = utility_map.get(&similar_id) {
weighted_utility += similarity * other_utility;
weight_sum += similarity;
}
}
let new_utility = if weight_sum > 0.0 {
let neighbor_contribution = weighted_utility / weight_sum;
let target = discount * neighbor_contribution;
let update = learning_rate * (target - episode.utility);
(episode.utility + update).clamp(0.0, 1.0)
} else {
episode.utility
};
let change = (new_utility - episode.utility).abs();
if change > 0.001 {
updates += 1;
total_change += change;
max_change = max_change.max(change);
}
new_utilities.insert(episode.id, new_utility);
}
let result = PropagationResult {
episodes_updated: updates,
total_change,
max_change,
};
(new_utilities, result)
}
pub fn temporal_credit_assignment(
episodes: &[Episode],
temporal_discount: f64,
) -> HashMap<Uuid, f64> {
let mut credits: HashMap<Uuid, f64> = HashMap::new();
for episode in episodes {
credits.insert(episode.id, episode.utility);
}
let mut prev_utility = 0.0;
for episode in episodes.iter().rev() {
let current_utility = episode.utility;
let temporal_bonus = temporal_discount * prev_utility;
let new_utility = (current_utility + temporal_bonus).min(1.0);
credits.insert(episode.id, new_utility);
prev_utility = new_utility;
}
credits
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::EpisodeOutcome;
fn make_episode(id: Uuid, utility: f64) -> Episode {
let mut ep = Episode::new(
format!("Episode {}", id),
"test".to_string(),
EpisodeOutcome::Success,
);
ep.id = id;
ep.utility = utility;
ep
}
#[test]
fn test_propagation_basic() {
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
let ep1 = make_episode(id1, 1.0); let ep2 = make_episode(id2, 0.2);
let episodes = vec![ep1, ep2];
let mut vectors = VectorStore::new(3);
vectors.store(id1, vec![1.0, 0.0, 0.0]).unwrap();
vectors.store(id2, vec![0.9, 0.1, 0.0]).unwrap();
let (utilities, result) = bellman_propagate(&episodes, &vectors, 0.5, 0.9, 0.5);
assert!(utilities[&id2] > 0.2);
assert!(result.episodes_updated > 0);
}
#[test]
fn test_no_propagation_for_dissimilar() {
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
let ep1 = make_episode(id1, 1.0);
let ep2 = make_episode(id2, 0.2);
let episodes = vec![ep1, ep2];
let mut vectors = VectorStore::new(3);
vectors.store(id1, vec![1.0, 0.0, 0.0]).unwrap();
vectors.store(id2, vec![0.0, 1.0, 0.0]).unwrap();
let (utilities, _) = bellman_propagate(&episodes, &vectors, 0.5, 0.9, 0.5);
assert!((utilities[&id2] - 0.2).abs() < 0.1);
}
#[test]
fn test_temporal_credit() {
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
let id3 = Uuid::new_v4();
let ep1 = make_episode(id1, 0.3); let ep2 = make_episode(id2, 0.5); let ep3 = make_episode(id3, 1.0);
let episodes = vec![ep1, ep2, ep3];
let credits = temporal_credit_assignment(&episodes, 0.5);
assert!(credits[&id1] > 0.3);
assert!(credits[&id2] > 0.5);
}
}