use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::snapshot::LearningSnapshot;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OfflineModel {
pub version: u32,
pub parameters: OptimalParameters,
pub recommended_paths: Vec<RecommendedPath>,
pub strategy_config: StrategyConfig,
pub analyzed_sessions: usize,
pub updated_at: u64,
#[serde(default)]
pub action_order: Option<LearnedActionOrder>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearnedActionOrder {
pub discover: Vec<String>,
pub not_discover: Vec<String>,
pub action_set_hash: u64,
#[serde(default)]
pub source: ActionOrderSource,
#[serde(default)]
pub lora: Option<crate::types::LoraConfig>,
#[serde(default)]
pub validated_accuracy: Option<f64>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub enum ActionOrderSource {
#[default]
Llm,
Static,
Manual,
}
impl LearnedActionOrder {
pub fn new(discover: Vec<String>, not_discover: Vec<String>, actions: &[String]) -> Self {
Self {
discover,
not_discover,
action_set_hash: Self::compute_hash(actions),
source: ActionOrderSource::Llm,
lora: None,
validated_accuracy: None,
}
}
pub fn with_lora(mut self, lora: crate::types::LoraConfig) -> Self {
self.lora = Some(lora);
self
}
pub fn with_accuracy(mut self, accuracy: f64) -> Self {
self.validated_accuracy = Some(accuracy);
self
}
pub fn with_source(mut self, source: ActionOrderSource) -> Self {
self.source = source;
self
}
pub fn compute_hash(actions: &[String]) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut sorted: Vec<&str> = actions.iter().map(|s| s.as_str()).collect();
sorted.sort();
let mut hasher = DefaultHasher::new();
for action in sorted {
action.hash(&mut hasher);
}
hasher.finish()
}
pub fn is_exact_match(&self, actions: &[String]) -> bool {
self.action_set_hash == Self::compute_hash(actions)
}
#[inline]
pub fn matches_actions(&self, actions: &[String]) -> bool {
self.is_exact_match(actions)
}
pub fn match_rate(&self, actions: &[String]) -> f64 {
use std::collections::HashSet;
let mut self_actions: Vec<String> = self.discover.clone();
self_actions.extend(self.not_discover.clone());
if self_actions.is_empty() && actions.is_empty() {
return 1.0;
}
if self_actions.is_empty() || actions.is_empty() {
return 0.0;
}
let self_set: HashSet<_> = self_actions.iter().collect();
let other_set: HashSet<_> = actions.iter().collect();
let intersection = self_set.intersection(&other_set).count();
let union = self_set.union(&other_set).count();
intersection as f64 / union as f64
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimalParameters {
pub ucb1_c: f64,
pub learning_weight: f64,
pub ngram_weight: f64,
}
impl Default for OptimalParameters {
fn default() -> Self {
Self {
ucb1_c: std::f64::consts::SQRT_2,
learning_weight: 0.3,
ngram_weight: 1.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecommendedPath {
pub actions: Vec<String>,
pub success_rate: f64,
pub observations: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StrategyConfig {
pub maturity_threshold: u32,
pub error_rate_threshold: f64,
pub initial_strategy: String,
}
impl Default for StrategyConfig {
fn default() -> Self {
Self {
maturity_threshold: 10,
error_rate_threshold: 0.3,
initial_strategy: "ucb1".to_string(),
}
}
}
impl Default for OfflineModel {
fn default() -> Self {
Self {
version: 1,
parameters: OptimalParameters::default(),
recommended_paths: Vec::new(),
strategy_config: StrategyConfig::default(),
analyzed_sessions: 0,
updated_at: 0,
action_order: None,
}
}
}
pub struct OfflineAnalyzer<'a> {
snapshots: &'a [LearningSnapshot],
}
impl<'a> OfflineAnalyzer<'a> {
pub fn new(snapshots: &'a [LearningSnapshot]) -> Self {
Self { snapshots }
}
pub fn analyze(&self) -> OfflineModel {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
OfflineModel {
version: 1,
parameters: self.analyze_parameters(),
recommended_paths: self.extract_paths(),
strategy_config: self.analyze_strategy(),
analyzed_sessions: self.snapshots.len(),
updated_at: now,
action_order: None, }
}
pub fn analyze_parameters(&self) -> OptimalParameters {
if self.snapshots.is_empty() {
return OptimalParameters::default();
}
let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
(
acc.0 + s.episode_transitions.success_episodes,
acc.1 + s.episode_transitions.failure_episodes,
)
});
let success_rate = if total_success + total_failure > 0 {
total_success as f64 / (total_success + total_failure) as f64
} else {
0.5
};
let ucb1_c = if success_rate > 0.8 {
1.0 } else if success_rate < 0.5 {
2.0 } else {
std::f64::consts::SQRT_2 };
let ngram_effectiveness = self.evaluate_ngram_effectiveness();
let ngram_weight = if ngram_effectiveness > 0.7 {
1.5 } else if ngram_effectiveness < 0.3 {
0.5 } else {
1.0
};
OptimalParameters {
ucb1_c,
learning_weight: 0.3, ngram_weight,
}
}
fn evaluate_ngram_effectiveness(&self) -> f64 {
let mut all_rates: Vec<f64> = Vec::new();
for snapshot in self.snapshots {
for &(success, failure) in snapshot.ngram_stats.trigrams.values() {
let total = success + failure;
if total >= 3 {
all_rates.push(success as f64 / total as f64);
}
}
}
if all_rates.is_empty() {
return 0.5; }
let mean = all_rates.iter().sum::<f64>() / all_rates.len() as f64;
let variance =
all_rates.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / all_rates.len() as f64;
(variance / 0.25).min(1.0)
}
pub fn extract_paths(&self) -> Vec<RecommendedPath> {
let mut path_stats: HashMap<Vec<String>, (u32, u32)> = HashMap::new();
for snapshot in self.snapshots {
for (key, &(success, failure)) in &snapshot.ngram_stats.trigrams {
let path = vec![key.0.clone(), key.1.clone(), key.2.clone()];
let entry = path_stats.entry(path).or_insert((0, 0));
entry.0 += success;
entry.1 += failure;
}
}
let mut paths: Vec<RecommendedPath> = path_stats
.into_iter()
.filter(|(_, (s, f))| s + f >= 5) .map(|(actions, (success, failure))| {
let total = success + failure;
RecommendedPath {
actions,
success_rate: success as f64 / total as f64,
observations: total,
}
})
.collect();
paths.sort_by(|a, b| {
b.success_rate
.partial_cmp(&a.success_rate)
.unwrap_or(std::cmp::Ordering::Equal)
});
paths.into_iter().take(10).collect() }
pub fn analyze_strategy(&self) -> StrategyConfig {
if self.snapshots.is_empty() {
return StrategyConfig::default();
}
let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
(
acc.0 + s.episode_transitions.success_episodes,
acc.1 + s.episode_transitions.failure_episodes,
)
});
let avg_error_rate = if total_success + total_failure > 0 {
total_failure as f64 / (total_success + total_failure) as f64
} else {
0.3
};
let total_actions: u64 = self
.snapshots
.iter()
.map(|s| s.metadata.total_actions as u64)
.sum();
let avg_actions = total_actions as f64 / self.snapshots.len().max(1) as f64;
let maturity_threshold = ((avg_actions * 0.1) as u32).clamp(5, 50);
let initial_strategy = if avg_error_rate > 0.4 {
"thompson" } else if avg_error_rate < 0.1 {
"greedy" } else {
"ucb1" };
StrategyConfig {
maturity_threshold,
error_rate_threshold: (avg_error_rate * 1.5).min(0.5), initial_strategy: initial_strategy.to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_snapshot(success: u32, failure: u32) -> LearningSnapshot {
let mut snapshot = LearningSnapshot::empty();
snapshot.episode_transitions.success_episodes = success;
snapshot.episode_transitions.failure_episodes = failure;
snapshot.metadata.total_actions = (success + failure) * 5;
snapshot
}
#[test]
fn test_analyzer_empty_snapshots() {
let snapshots: Vec<LearningSnapshot> = vec![];
let analyzer = OfflineAnalyzer::new(&snapshots);
let model = analyzer.analyze();
assert_eq!(model.analyzed_sessions, 0);
assert!((model.parameters.ucb1_c - std::f64::consts::SQRT_2).abs() < 0.01);
}
#[test]
fn test_analyzer_high_success_rate() {
let snapshots = vec![
create_test_snapshot(9, 1),
create_test_snapshot(8, 2),
create_test_snapshot(10, 0),
];
let analyzer = OfflineAnalyzer::new(&snapshots);
let params = analyzer.analyze_parameters();
assert!(params.ucb1_c < std::f64::consts::SQRT_2);
}
#[test]
fn test_analyzer_low_success_rate() {
let snapshots = vec![
create_test_snapshot(3, 7),
create_test_snapshot(4, 6),
create_test_snapshot(2, 8),
];
let analyzer = OfflineAnalyzer::new(&snapshots);
let params = analyzer.analyze_parameters();
assert!(params.ucb1_c > std::f64::consts::SQRT_2);
}
#[test]
fn test_strategy_config_high_error() {
let snapshots = vec![create_test_snapshot(3, 7), create_test_snapshot(4, 6)];
let analyzer = OfflineAnalyzer::new(&snapshots);
let config = analyzer.analyze_strategy();
assert_eq!(config.initial_strategy, "thompson");
}
#[test]
fn test_strategy_config_low_error() {
let snapshots = vec![create_test_snapshot(19, 1), create_test_snapshot(18, 2)];
let analyzer = OfflineAnalyzer::new(&snapshots);
let config = analyzer.analyze_strategy();
assert_eq!(config.initial_strategy, "greedy");
}
}