use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DomainStatistics {
pub domain: String,
pub episode_count: usize,
pub avg_duration_secs: f32,
pub p50_duration_secs: f32,
pub p90_duration_secs: f32,
pub avg_step_count: f32,
pub p50_step_count: usize,
pub p90_step_count: usize,
pub avg_reward: f32,
pub p50_reward: f32,
pub reward_std_dev: f32,
pub last_updated: DateTime<Utc>,
pub success_count: usize,
}
impl DomainStatistics {
#[must_use]
pub fn new(domain: String) -> Self {
Self {
domain,
episode_count: 0,
avg_duration_secs: 0.0,
p50_duration_secs: 0.0,
p90_duration_secs: 0.0,
avg_step_count: 0.0,
p50_step_count: 0,
p90_step_count: 0,
avg_reward: 0.0,
p50_reward: 0.0,
reward_std_dev: 0.0,
last_updated: Utc::now(),
success_count: 0,
}
}
fn percentile(sorted_values: &[f32], percentile: f32) -> f32 {
if sorted_values.is_empty() {
return 0.0;
}
let index = ((sorted_values.len() - 1) as f32 * percentile).round() as usize;
sorted_values[index]
}
fn percentile_usize(sorted_values: &[usize], percentile: f32) -> usize {
if sorted_values.is_empty() {
return 0;
}
let index = ((sorted_values.len() - 1) as f32 * percentile).round() as usize;
sorted_values[index]
}
#[must_use]
pub fn success_rate(&self) -> f32 {
if self.episode_count == 0 {
0.0
} else {
self.success_count as f32 / self.episode_count as f32
}
}
#[must_use]
pub fn is_reliable(&self) -> bool {
self.episode_count >= 5
}
#[must_use]
pub fn is_stale(&self) -> bool {
let age = Utc::now() - self.last_updated;
age.num_days() > 7
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct DomainStatisticsCache {
pub stats: HashMap<String, DomainStatistics>,
}
impl DomainStatisticsCache {
#[must_use]
pub fn new() -> Self {
Self {
stats: HashMap::new(),
}
}
#[must_use]
pub fn get(&self, domain: &str) -> Option<&DomainStatistics> {
self.stats.get(domain)
}
pub fn get_or_create(&mut self, domain: String) -> &mut DomainStatistics {
self.stats
.entry(domain.clone())
.or_insert_with(|| DomainStatistics::new(domain))
}
pub fn calculate_from_episodes(
&mut self,
domain: String,
episodes: &[crate::episode::Episode],
) {
use crate::types::TaskOutcome;
if episodes.is_empty() {
return;
}
let mut durations: Vec<f32> = Vec::new();
let mut step_counts: Vec<usize> = Vec::new();
let mut rewards: Vec<f32> = Vec::new();
let mut success_count = 0;
for episode in episodes {
if !episode.is_complete() || episode.context.domain != domain {
continue;
}
if let Some(duration) = episode.duration() {
durations.push(duration.num_seconds() as f32);
}
step_counts.push(episode.steps.len());
if let Some(reward) = &episode.reward {
rewards.push(reward.total);
}
if matches!(episode.outcome, Some(TaskOutcome::Success { .. })) {
success_count += 1;
}
}
if durations.is_empty() {
return;
}
durations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
step_counts.sort_unstable();
rewards.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let avg_duration = durations.iter().sum::<f32>() / durations.len() as f32;
let avg_steps = step_counts.iter().sum::<usize>() as f32 / step_counts.len() as f32;
let avg_reward = if !rewards.is_empty() {
rewards.iter().sum::<f32>() / rewards.len() as f32
} else {
0.0
};
let reward_variance = if !rewards.is_empty() {
let sum_sq_diff: f32 = rewards.iter().map(|r| (r - avg_reward).powi(2)).sum();
sum_sq_diff / rewards.len() as f32
} else {
0.0
};
let reward_std_dev = reward_variance.sqrt();
let stats = self.get_or_create(domain);
stats.episode_count = durations.len();
stats.avg_duration_secs = avg_duration;
stats.p50_duration_secs = DomainStatistics::percentile(&durations, 0.5);
stats.p90_duration_secs = DomainStatistics::percentile(&durations, 0.9);
stats.avg_step_count = avg_steps;
stats.p50_step_count = DomainStatistics::percentile_usize(&step_counts, 0.5);
stats.p90_step_count = DomainStatistics::percentile_usize(&step_counts, 0.9);
stats.avg_reward = avg_reward;
stats.p50_reward = if !rewards.is_empty() {
DomainStatistics::percentile(&rewards, 0.5)
} else {
0.0
};
stats.reward_std_dev = reward_std_dev;
stats.last_updated = Utc::now();
stats.success_count = success_count;
}
pub fn update_incremental(
&mut self,
domain: &str,
duration_secs: f32,
step_count: usize,
reward: f32,
is_success: bool,
) {
let stats = self.get_or_create(domain.to_string());
let n = stats.episode_count as f32;
let new_n = n + 1.0;
stats.avg_duration_secs = (stats.avg_duration_secs * n + duration_secs) / new_n;
stats.avg_step_count = (stats.avg_step_count * n + step_count as f32) / new_n;
let old_mean = stats.avg_reward;
stats.avg_reward = (old_mean * n + reward) / new_n;
if n > 0.0 {
let old_variance = stats.reward_std_dev.powi(2);
let new_variance =
((n - 1.0) * old_variance + (reward - old_mean) * (reward - stats.avg_reward)) / n;
stats.reward_std_dev = new_variance.sqrt();
}
stats.episode_count += 1;
if is_success {
stats.success_count += 1;
}
stats.last_updated = Utc::now();
}
pub fn prune_stale(&mut self, max_age_days: i64) {
let cutoff = Utc::now() - chrono::Duration::days(max_age_days);
self.stats.retain(|_, stats| stats.last_updated > cutoff);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_domain_statistics_creation() {
let stats = DomainStatistics::new("web-api".to_string());
assert_eq!(stats.domain, "web-api");
assert_eq!(stats.episode_count, 0);
assert_eq!(stats.success_rate(), 0.0);
assert!(!stats.is_reliable());
}
#[test]
fn test_percentile_calculation() {
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
assert_eq!(DomainStatistics::percentile(&values, 0.5), 3.0); assert_eq!(DomainStatistics::percentile(&values, 0.0), 1.0); assert_eq!(DomainStatistics::percentile(&values, 1.0), 5.0); }
#[test]
fn test_incremental_update() {
let mut cache = DomainStatisticsCache::new();
cache.update_incremental("web-api", 60.0, 10, 0.8, true);
let stats = cache.get("web-api").unwrap();
assert_eq!(stats.episode_count, 1);
assert_eq!(stats.avg_duration_secs, 60.0);
assert_eq!(stats.avg_step_count, 10.0);
assert_eq!(stats.avg_reward, 0.8);
assert_eq!(stats.success_count, 1);
cache.update_incremental("web-api", 120.0, 15, 0.9, true);
let stats = cache.get("web-api").unwrap();
assert_eq!(stats.episode_count, 2);
assert_eq!(stats.avg_duration_secs, 90.0); assert_eq!(stats.avg_step_count, 12.5); assert!((stats.avg_reward - 0.85).abs() < 0.01); }
#[test]
fn test_success_rate() {
let mut cache = DomainStatisticsCache::new();
cache.update_incremental("test", 60.0, 10, 0.8, true);
cache.update_incremental("test", 60.0, 10, 0.8, true);
cache.update_incremental("test", 60.0, 10, 0.3, false);
let stats = cache.get("test").unwrap();
assert_eq!(stats.success_rate(), 2.0 / 3.0);
}
#[test]
fn test_reliability() {
let mut stats = DomainStatistics::new("test".to_string());
assert!(!stats.is_reliable());
stats.episode_count = 5;
assert!(stats.is_reliable());
}
}