use std::collections::HashMap;
use super::episode::Episode;
use super::learn_model::LearnError;
use super::learned_component::{
ComponentLearner, LearnedDepGraph, LearnedExploration, LearnedStrategy,
};
use super::record::ActionRecord;
use super::RecommendedPath;
use crate::exploration::DependencyGraph;
#[derive(Debug, Clone, Default)]
pub struct DepGraphLearner {
pub min_episodes: usize,
pub min_order_count: usize,
}
impl DepGraphLearner {
pub fn new() -> Self {
Self {
min_episodes: 3,
min_order_count: 2,
}
}
pub fn with_min_episodes(mut self, n: usize) -> Self {
self.min_episodes = n;
self
}
fn extract_order_relations(
&self,
action_sequences: &[Vec<String>],
) -> HashMap<(String, String), usize> {
let mut relations: HashMap<(String, String), usize> = HashMap::new();
for sequence in action_sequences {
for i in 0..sequence.len() {
for j in (i + 1)..sequence.len() {
let key = (sequence[i].clone(), sequence[j].clone());
*relations.entry(key).or_insert(0) += 1;
}
}
}
relations
}
fn compute_action_order(&self, relations: &HashMap<(String, String), usize>) -> Vec<String> {
let mut scores: HashMap<String, i64> = HashMap::new();
for ((from, to), &count) in relations {
*scores.entry(from.clone()).or_insert(0) += count as i64;
*scores.entry(to.clone()).or_insert(0) -= count as i64;
}
let mut actions: Vec<_> = scores.into_iter().collect();
actions.sort_by(|a, b| b.1.cmp(&a.1));
actions.into_iter().map(|(action, _)| action).collect()
}
fn compute_recommended_paths(
&self,
success_count: &HashMap<Vec<String>, usize>,
total_success: usize,
) -> Vec<RecommendedPath> {
let mut paths: Vec<_> = success_count
.iter()
.map(|(actions, &count)| {
let success_rate = count as f64 / total_success.max(1) as f64;
RecommendedPath {
actions: actions.clone(),
success_rate,
observations: count as u32,
}
})
.collect();
paths.sort_by(|a, b| {
b.success_rate
.partial_cmp(&a.success_rate)
.unwrap_or(std::cmp::Ordering::Equal)
});
paths.truncate(10);
paths
}
}
impl ComponentLearner for DepGraphLearner {
type Output = LearnedDepGraph;
fn name(&self) -> &str {
"dep_graph_learner"
}
fn objective(&self) -> &str {
"Learn action dependency graph from successful execution traces"
}
fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError> {
let success_episodes: Vec<_> = episodes.iter().filter(|e| e.outcome.is_success()).collect();
if success_episodes.is_empty() {
return Err(LearnError::InsufficientData(
"No successful episodes to learn from".into(),
));
}
let mut action_sequences: Vec<Vec<String>> = Vec::new();
let mut success_count: HashMap<Vec<String>, usize> = HashMap::new();
let mut session_ids: Vec<String> = Vec::new();
for episode in &success_episodes {
let actions: Vec<String> = episode
.context
.iter::<ActionRecord>()
.map(|r| r.action.clone())
.collect();
if !actions.is_empty() {
*success_count.entry(actions.clone()).or_insert(0) += 1;
action_sequences.push(actions);
}
let episode_id = episode.id.to_string();
if !session_ids.contains(&episode_id) {
session_ids.push(episode_id);
}
}
let relations = self.extract_order_relations(&action_sequences);
let action_order = self.compute_action_order(&relations);
let recommended_paths =
self.compute_recommended_paths(&success_count, success_episodes.len());
let confidence = if success_episodes.len() >= self.min_episodes {
(success_episodes.len() as f64 / (self.min_episodes as f64 * 2.0)).min(1.0)
} else {
success_episodes.len() as f64 / self.min_episodes as f64
};
let graph = DependencyGraph::new();
Ok(LearnedDepGraph::new(graph, action_order)
.with_confidence(confidence)
.with_sessions(session_ids)
.with_recommended_paths(recommended_paths))
}
}
#[derive(Debug, Clone, Default)]
pub struct ExplorationLearner {
pub initial_ucb1_c: f64,
}
impl ExplorationLearner {
pub fn new() -> Self {
Self {
initial_ucb1_c: 1.414,
}
}
}
impl ComponentLearner for ExplorationLearner {
type Output = LearnedExploration;
fn name(&self) -> &str {
"exploration_learner"
}
fn objective(&self) -> &str {
"Optimize exploration parameters from session statistics"
}
fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError> {
if episodes.is_empty() {
return Err(LearnError::InsufficientData(
"No episodes to learn from".into(),
));
}
let total = episodes.len();
let success = episodes.iter().filter(|e| e.outcome.is_success()).count();
let success_rate = success as f64 / total as f64;
let ucb1_c = if success_rate < 0.3 {
2.0 } else if success_rate < 0.7 {
1.414 } else {
1.0 };
let confidence = (total as f64 / 10.0).min(1.0);
Ok(LearnedExploration {
ucb1_c,
learning_weight: 0.3,
ngram_weight: 1.0,
confidence,
session_count: total,
updated_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
})
}
}
#[derive(Debug, Clone, Default)]
pub struct StrategyLearner;
impl StrategyLearner {
pub fn new() -> Self {
Self
}
}
impl ComponentLearner for StrategyLearner {
type Output = LearnedStrategy;
fn name(&self) -> &str {
"strategy_learner"
}
fn objective(&self) -> &str {
"Determine optimal strategy selection settings"
}
fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError> {
if episodes.is_empty() {
return Err(LearnError::InsufficientData(
"No episodes to learn from".into(),
));
}
let total = episodes.len();
let success = episodes.iter().filter(|e| e.outcome.is_success()).count();
let success_rate = success as f64 / total as f64;
let initial_strategy = if success_rate < 0.5 {
"ucb1".to_string() } else {
"greedy".to_string() };
let error_rate_threshold = if success_rate < 0.3 {
0.6 } else {
0.45 };
let confidence = (total as f64 / 10.0).min(1.0);
Ok(LearnedStrategy {
initial_strategy,
maturity_threshold: 5,
error_rate_threshold,
confidence,
session_count: total,
updated_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::learn::episode::{Episode, EpisodeContext, Outcome};
fn make_success_episode(_actions: Vec<&str>) -> Episode {
let context = EpisodeContext::new();
Episode::builder()
.learn_model("test")
.context(context)
.outcome(Outcome::success(1.0))
.build()
}
fn make_failure_episode() -> Episode {
Episode::builder()
.learn_model("test")
.context(EpisodeContext::new())
.outcome(Outcome::failure("test failure"))
.build()
}
#[test]
fn test_dep_graph_learner_empty() {
let learner = DepGraphLearner::new();
let result = learner.learn(&[]);
assert!(result.is_err());
}
#[test]
fn test_dep_graph_learner_no_success() {
let learner = DepGraphLearner::new();
let episodes = vec![make_failure_episode(), make_failure_episode()];
let result = learner.learn(&episodes);
assert!(result.is_err());
}
#[test]
fn test_dep_graph_learner_with_success() {
let learner = DepGraphLearner::new();
let episodes = vec![
make_success_episode(vec!["A", "B", "C"]),
make_success_episode(vec!["A", "B", "C"]),
make_success_episode(vec!["A", "B", "C"]),
];
let result = learner.learn(&episodes);
assert!(result.is_ok());
let learned = result.unwrap();
assert!(learned.confidence > 0.0);
}
#[test]
fn test_exploration_learner() {
let learner = ExplorationLearner::new();
let episodes = vec![
make_success_episode(vec![]),
make_success_episode(vec![]),
make_failure_episode(),
];
let result = learner.learn(&episodes);
assert!(result.is_ok());
let learned = result.unwrap();
assert!(learned.ucb1_c > 0.0);
assert_eq!(learned.session_count, 3);
}
#[test]
fn test_strategy_learner() {
let learner = StrategyLearner::new();
let episodes = vec![make_success_episode(vec![]), make_failure_episode()];
let result = learner.learn(&episodes);
assert!(result.is_ok());
let learned = result.unwrap();
assert!(!learned.initial_strategy.is_empty());
}
#[test]
fn test_extract_order_relations() {
let learner = DepGraphLearner::new().with_min_episodes(1);
let sequences = vec![
vec!["A".to_string(), "B".to_string(), "C".to_string()],
vec!["A".to_string(), "B".to_string(), "C".to_string()],
];
let relations = learner.extract_order_relations(&sequences);
assert_eq!(relations.get(&("A".to_string(), "B".to_string())), Some(&2));
assert_eq!(relations.get(&("A".to_string(), "C".to_string())), Some(&2));
assert_eq!(relations.get(&("B".to_string(), "C".to_string())), Some(&2));
}
}