use serde::{de::DeserializeOwned, Serialize};
use super::episode::Episode;
use super::learn_model::LearnError;
pub trait ComponentLearner: Send + Sync {
type Output: LearnedComponent;
fn name(&self) -> &str;
fn objective(&self) -> &str;
fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError>;
fn update(
&self,
existing: &Self::Output,
new_episodes: &[Episode],
) -> Result<Self::Output, LearnError> {
let mut learned = self.learn(new_episodes)?;
learned.merge(existing);
Ok(learned)
}
}
pub trait LearnedComponent: Send + Sync + Serialize + DeserializeOwned + Clone {
fn component_id() -> &'static str
where
Self: Sized;
fn confidence(&self) -> f64;
fn session_count(&self) -> usize;
fn updated_at(&self) -> u64;
fn merge(&mut self, other: &Self)
where
Self: Sized,
{
if other.confidence() > self.confidence() {
*self = other.clone();
}
}
fn version() -> u32
where
Self: Sized,
{
1
}
}
use crate::exploration::DependencyGraph;
pub use super::offline::RecommendedPath;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LearnedDepGraph {
pub graph: DependencyGraph,
pub action_order: Vec<String>,
#[serde(default)]
pub discover_order: Vec<String>,
#[serde(default)]
pub not_discover_order: Vec<String>,
#[serde(default)]
pub recommended_paths: Vec<RecommendedPath>,
pub confidence: f64,
pub learned_from: Vec<String>,
pub updated_at: u64,
}
impl LearnedDepGraph {
pub fn new(graph: DependencyGraph, action_order: Vec<String>) -> Self {
Self {
graph,
action_order,
discover_order: Vec::new(),
not_discover_order: Vec::new(),
recommended_paths: Vec::new(),
confidence: 0.0,
learned_from: Vec::new(),
updated_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
}
}
pub fn with_orders(
graph: DependencyGraph,
discover_order: Vec<String>,
not_discover_order: Vec<String>,
) -> Self {
let mut all_actions = discover_order.clone();
all_actions.extend(not_discover_order.clone());
Self {
graph,
action_order: all_actions,
discover_order,
not_discover_order,
recommended_paths: Vec::new(),
confidence: 0.0,
learned_from: Vec::new(),
updated_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
}
}
pub fn with_confidence(mut self, confidence: f64) -> Self {
self.confidence = confidence;
self
}
pub fn with_sessions(mut self, session_ids: Vec<String>) -> Self {
self.learned_from = session_ids;
self
}
pub fn with_recommended_paths(mut self, paths: Vec<RecommendedPath>) -> Self {
self.recommended_paths = paths;
self
}
}
impl LearnedComponent for LearnedDepGraph {
fn component_id() -> &'static str {
"dep_graph"
}
fn confidence(&self) -> f64 {
self.confidence
}
fn session_count(&self) -> usize {
self.learned_from.len()
}
fn updated_at(&self) -> u64 {
self.updated_at
}
fn merge(&mut self, other: &Self) {
if other.learned_from.len() > self.learned_from.len() || other.confidence > self.confidence
{
self.graph = other.graph.clone();
self.action_order = other.action_order.clone();
self.confidence = other.confidence;
}
for id in &other.learned_from {
if !self.learned_from.contains(id) {
self.learned_from.push(id.clone());
}
}
for path in &other.recommended_paths {
if !self
.recommended_paths
.iter()
.any(|p| p.actions == path.actions)
{
self.recommended_paths.push(path.clone());
}
}
self.updated_at = other.updated_at.max(self.updated_at);
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LearnedExploration {
pub ucb1_c: f64,
pub learning_weight: f64,
pub ngram_weight: f64,
pub confidence: f64,
pub session_count: usize,
pub updated_at: u64,
}
impl Default for LearnedExploration {
fn default() -> Self {
Self {
ucb1_c: 1.414,
learning_weight: 0.3,
ngram_weight: 1.0,
confidence: 0.0,
session_count: 0,
updated_at: 0,
}
}
}
impl LearnedExploration {
pub fn new(ucb1_c: f64, learning_weight: f64, ngram_weight: f64) -> Self {
Self {
ucb1_c,
learning_weight,
ngram_weight,
confidence: 0.0,
session_count: 0,
updated_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
}
}
}
impl LearnedComponent for LearnedExploration {
fn component_id() -> &'static str {
"exploration"
}
fn confidence(&self) -> f64 {
self.confidence
}
fn session_count(&self) -> usize {
self.session_count
}
fn updated_at(&self) -> u64 {
self.updated_at
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LearnedStrategy {
pub initial_strategy: String,
pub maturity_threshold: usize,
pub error_rate_threshold: f64,
pub confidence: f64,
pub session_count: usize,
pub updated_at: u64,
}
impl Default for LearnedStrategy {
fn default() -> Self {
Self {
initial_strategy: "ucb1".to_string(),
maturity_threshold: 5,
error_rate_threshold: 0.45,
confidence: 0.0,
session_count: 0,
updated_at: 0,
}
}
}
impl LearnedComponent for LearnedStrategy {
fn component_id() -> &'static str {
"strategy"
}
fn confidence(&self) -> f64 {
self.confidence
}
fn session_count(&self) -> usize {
self.session_count
}
fn updated_at(&self) -> u64 {
self.updated_at
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::exploration::DependencyGraph;
#[test]
fn test_learned_dep_graph_creation() {
let graph = DependencyGraph::new();
let learned = LearnedDepGraph::new(graph, vec!["A".to_string(), "B".to_string()])
.with_confidence(0.8)
.with_sessions(vec!["s1".to_string(), "s2".to_string()]);
assert_eq!(learned.confidence(), 0.8);
assert_eq!(learned.session_count(), 2);
assert_eq!(LearnedDepGraph::component_id(), "dep_graph");
}
#[test]
fn test_learned_dep_graph_merge() {
let graph = DependencyGraph::new();
let mut learned1 = LearnedDepGraph::new(graph.clone(), vec!["A".to_string()])
.with_confidence(0.5)
.with_sessions(vec!["s1".to_string()]);
let learned2 = LearnedDepGraph::new(graph, vec!["A".to_string(), "B".to_string()])
.with_confidence(0.8)
.with_sessions(vec!["s2".to_string(), "s3".to_string()]);
learned1.merge(&learned2);
assert_eq!(learned1.confidence, 0.8);
assert_eq!(learned1.action_order.len(), 2);
assert_eq!(learned1.learned_from.len(), 3);
}
#[test]
fn test_learned_exploration_default() {
let exploration = LearnedExploration::default();
assert_eq!(exploration.ucb1_c, 1.414);
assert_eq!(LearnedExploration::component_id(), "exploration");
}
#[test]
fn test_learned_strategy_default() {
let strategy = LearnedStrategy::default();
assert_eq!(strategy.initial_strategy, "ucb1");
assert_eq!(LearnedStrategy::component_id(), "strategy");
}
#[test]
fn test_serialization() {
let exploration = LearnedExploration::new(2.0, 0.5, 1.5);
let json = serde_json::to_string(&exploration).unwrap();
let restored: LearnedExploration = serde_json::from_str(&json).unwrap();
assert_eq!(restored.ucb1_c, 2.0);
}
}