use crate::rag::RagEngine;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tokio::time;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagPerformanceMetrics {
pub retrieval: RetrievalMetrics,
pub generation: GenerationMetrics,
pub context: ContextMetrics,
pub end_to_end: EndToEndMetrics,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalMetrics {
pub retrieval_time_ms: f64,
pub chunks_retrieved: usize,
pub chunks_used: usize,
pub score_stats: ScoreStats,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerationMetrics {
pub generation_time_ms: f64,
pub tokens_used: u32,
pub tokens_per_second: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextMetrics {
pub context_build_time_ms: f64,
pub context_tokens: usize,
pub token_efficiency: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EndToEndMetrics {
pub total_time_ms: f64,
pub queries_per_second: f64,
pub memory_usage_mb: f64,
pub success_rate: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreStats {
pub mean: f32,
pub std_dev: f32,
pub min: f32,
pub max: f32,
pub median: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceRegression {
pub metric: String,
pub change_absolute: f64,
pub change_percent: f64,
pub p_value: f64,
pub confidence: f64,
pub baseline_value: f64,
pub current_value: f64,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceConfig {
pub alert_threshold: f64,
pub sample_size: usize,
pub confidence_level: f64,
pub history_window: usize,
pub benchmark_queries: Vec<String>,
pub enable_memory_monitoring: bool,
pub monitoring_interval: Duration,
}
impl Default for PerformanceConfig {
fn default() -> Self {
Self {
alert_threshold: 0.05, sample_size: 100,
confidence_level: 0.95,
history_window: 1000,
benchmark_queries: vec![
"What is machine learning?".to_string(),
"Explain neural networks".to_string(),
"How does backpropagation work?".to_string(),
"What are transformers in AI?".to_string(),
"Explain the concept of overfitting".to_string(),
"What is gradient descent?".to_string(),
"How do convolutional neural networks work?".to_string(),
"What is the difference between supervised and unsupervised learning?".to_string(),
"Explain reinforcement learning".to_string(),
"What are attention mechanisms?".to_string(),
],
enable_memory_monitoring: true,
monitoring_interval: Duration::from_secs(300), }
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
struct PerformanceHistory {
metrics: HashMap<i64, RagPerformanceMetrics>,
rolling_stats: HashMap<String, RollingStats>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct RollingStats {
values: Vec<f64>,
mean: f64,
std_dev: f64,
trend_slope: f64,
last_updated: chrono::DateTime<chrono::Utc>,
}
impl RollingStats {
fn new() -> Self {
Self {
values: Vec::new(),
mean: 0.0,
std_dev: 0.0,
trend_slope: 0.0,
last_updated: chrono::Utc::now(),
}
}
fn add_value(&mut self, value: f64) {
self.values.push(value);
self.last_updated = chrono::Utc::now();
self.update_stats();
}
fn update_stats(&mut self) {
if self.values.is_empty() {
return;
}
self.mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
self.std_dev = if self.values.len() > 1 {
let variance = self
.values
.iter()
.map(|v| (v - self.mean).powi(2))
.sum::<f64>()
/ (self.values.len() - 1) as f64;
variance.sqrt()
} else {
0.0
};
self.trend_slope = self.calculate_trend_slope();
}
fn calculate_trend_slope(&self) -> f64 {
if self.values.len() < 2 {
return 0.0;
}
let n = self.values.len() as f64;
let x_mean = (n - 1.0) / 2.0;
let numerator: f64 = self
.values
.iter()
.enumerate()
.map(|(i, &y)| (i as f64 - x_mean) * (y - self.mean))
.sum();
let denominator: f64 = self
.values
.iter()
.enumerate()
.map(|(i, _)| (i as f64 - x_mean).powi(2))
.sum();
if denominator.abs() < f64::EPSILON {
0.0
} else {
numerator / denominator
}
}
}
pub struct RagPerformanceMonitor {
rag_engine: Arc<RagEngine>,
config: PerformanceConfig,
history: Arc<RwLock<PerformanceHistory>>,
}
impl RagPerformanceMonitor {
pub fn new(rag_engine: RagEngine, config: PerformanceConfig) -> Self {
Self {
rag_engine: Arc::new(rag_engine),
config,
history: Arc::new(RwLock::new(PerformanceHistory::default())),
}
}
pub async fn run_benchmark(&self) -> Result<RagPerformanceMetrics> {
let mut retrieval_times = Vec::new();
let mut generation_times = Vec::new();
let mut context_times = Vec::new();
let mut total_times = Vec::new();
let mut scores = Vec::new();
let mut tokens_used = Vec::new();
let mut context_tokens = Vec::new();
let mut success_count = 0;
let start_memory = if self.config.enable_memory_monitoring {
self.get_memory_usage().unwrap_or(0.0)
} else {
0.0
};
for query in &self.config.benchmark_queries {
let _query_start = Instant::now();
match self.benchmark_single_query(query).await {
Ok(metrics) => {
retrieval_times.push(metrics.retrieval.retrieval_time_ms);
generation_times.push(metrics.generation.generation_time_ms);
context_times.push(metrics.context.context_build_time_ms);
total_times.push(metrics.end_to_end.total_time_ms);
tokens_used.push(metrics.generation.tokens_used as f64);
context_tokens.push(metrics.context.context_tokens as f64);
if metrics.retrieval.score_stats.mean > 0.0 {
scores.push(metrics.retrieval.score_stats.mean as f64);
}
success_count += 1;
}
Err(e) => {
tracing::warn!("Benchmark query failed: {} - {}", query, e);
}
}
}
let end_memory = if self.config.enable_memory_monitoring {
self.get_memory_usage().unwrap_or(0.0)
} else {
0.0
};
let avg_retrieval_time = self.average(&retrieval_times);
let avg_generation_time = self.average(&generation_times);
let avg_context_time = self.average(&context_times);
let avg_total_time = self.average(&total_times);
let avg_tokens = self.average(&tokens_used) as u32;
let avg_context_tokens = self.average(&context_tokens) as usize;
let score_stats = if scores.is_empty() {
ScoreStats {
mean: 0.0,
std_dev: 0.0,
min: 0.0,
max: 0.0,
median: 0.0,
}
} else {
self.calculate_score_stats(&scores)
};
let queries_per_second = if avg_total_time > 0.0 {
1000.0 / avg_total_time
} else {
0.0
};
let tokens_per_second = if avg_generation_time > 0.0 {
(avg_tokens as f64 * 1000.0) / avg_generation_time
} else {
0.0
};
let token_efficiency = if avg_context_time > 0.0 {
avg_context_tokens as f64 / avg_context_time
} else {
0.0
};
let success_rate = success_count as f64 / self.config.benchmark_queries.len() as f64;
let metrics = RagPerformanceMetrics {
retrieval: RetrievalMetrics {
retrieval_time_ms: avg_retrieval_time,
chunks_retrieved: self.config.benchmark_queries.len(), chunks_used: ((avg_context_tokens as f64 / 50.0).max(1.0) as usize), score_stats,
},
generation: GenerationMetrics {
generation_time_ms: avg_generation_time,
tokens_used: avg_tokens,
tokens_per_second,
},
context: ContextMetrics {
context_build_time_ms: avg_context_time,
context_tokens: avg_context_tokens,
token_efficiency,
},
end_to_end: EndToEndMetrics {
total_time_ms: avg_total_time,
queries_per_second,
memory_usage_mb: end_memory - start_memory,
success_rate,
},
timestamp: chrono::Utc::now(),
};
Ok(metrics)
}
async fn benchmark_single_query(&self, query: &str) -> Result<RagPerformanceMetrics> {
let total_start = Instant::now();
let retrieval_start = Instant::now();
let results = self.rag_engine.retrieve(query, 10).await?;
let retrieval_time = retrieval_start.elapsed().as_millis() as f64;
let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let score_stats =
self.calculate_score_stats(&scores.iter().map(|&s| s as f64).collect::<Vec<_>>());
let context_start = Instant::now();
let context = self.build_context_for_timing(&results);
let context_time = context_start.elapsed().as_millis() as f64;
let generation_start = Instant::now();
let response = self.rag_engine.query(query).await?;
let generation_time = generation_start.elapsed().as_millis() as f64;
let total_time = total_start.elapsed().as_millis() as f64;
Ok(RagPerformanceMetrics {
retrieval: RetrievalMetrics {
retrieval_time_ms: retrieval_time,
chunks_retrieved: results.len(),
chunks_used: results.len(),
score_stats,
},
generation: GenerationMetrics {
generation_time_ms: generation_time,
tokens_used: response.tokens_used.unwrap_or(100), tokens_per_second: if generation_time > 0.0 {
(response.tokens_used.unwrap_or(100) as f64 * 1000.0) / generation_time
} else {
0.0
},
},
context: ContextMetrics {
context_build_time_ms: context_time,
context_tokens: context.len() / 4, token_efficiency: (context.len() / 4) as f64 / context_time.max(1.0),
},
end_to_end: EndToEndMetrics {
total_time_ms: total_time,
queries_per_second: if total_time > 0.0 {
1000.0 / total_time
} else {
0.0
},
memory_usage_mb: 0.0, success_rate: 1.0,
},
timestamp: chrono::Utc::now(),
})
}
pub async fn start_continuous_monitoring(self: Arc<Self>) -> Result<()> {
let monitor = Arc::clone(&self);
tokio::spawn(async move {
let mut interval = time::interval(monitor.config.monitoring_interval);
loop {
interval.tick().await;
match monitor.run_benchmark().await {
Ok(metrics) => {
monitor.store_metrics(metrics.clone()).await;
match monitor.detect_regressions().await {
Ok(regressions) => {
for regression in regressions {
monitor.alert_regression(®ression).await;
}
}
Err(e) => {
tracing::error!("Failed to detect regressions: {}", e);
}
}
tracing::info!("Performance benchmark completed at {}", metrics.timestamp);
}
Err(e) => {
tracing::error!("Performance benchmark failed: {}", e);
}
}
}
});
Ok(())
}
pub async fn detect_regressions(&self) -> Result<Vec<PerformanceRegression>> {
let history = self.history.read().await;
let mut regressions = Vec::new();
let metrics_to_check = vec![
("retrieval_time_ms", "Retrieval Time"),
("generation_time_ms", "Generation Time"),
("total_time_ms", "Total Query Time"),
("tokens_per_second", "Token Generation Rate"),
];
for (metric_key, display_name) in metrics_to_check {
if let Some(stats) = history.rolling_stats.get(metric_key) {
if stats.values.len() < 10 {
continue; }
let current_value = stats.values.last().copied().unwrap_or(0.0);
let baseline_value = stats.mean;
let is_latency_metric = metric_key.contains("time_ms");
let change = if is_latency_metric {
current_value - baseline_value
} else {
baseline_value - current_value
};
let change_percent = if baseline_value > 0.0 {
change / baseline_value
} else {
0.0
};
if change_percent.abs() > self.config.alert_threshold {
let (p_value, confidence) = self.statistical_test(
&stats.values,
current_value,
self.config.sample_size,
);
if p_value < (1.0 - self.config.confidence_level) {
regressions.push(PerformanceRegression {
metric: display_name.to_string(),
change_absolute: change,
change_percent,
p_value,
confidence,
baseline_value,
current_value,
timestamp: chrono::Utc::now(),
});
}
}
}
}
Ok(regressions)
}
async fn store_metrics(&self, metrics: RagPerformanceMetrics) {
let mut history = self.history.write().await;
let timestamp_key = metrics.timestamp.timestamp();
history.metrics.insert(timestamp_key, metrics.clone());
self.update_rolling_stats(
&mut history,
"retrieval_time_ms",
metrics.retrieval.retrieval_time_ms,
);
self.update_rolling_stats(
&mut history,
"generation_time_ms",
metrics.generation.generation_time_ms,
);
self.update_rolling_stats(
&mut history,
"total_time_ms",
metrics.end_to_end.total_time_ms,
);
self.update_rolling_stats(
&mut history,
"tokens_per_second",
metrics.generation.tokens_per_second,
);
self.update_rolling_stats(
&mut history,
"queries_per_second",
metrics.end_to_end.queries_per_second,
);
self.update_rolling_stats(
&mut history,
"memory_usage_mb",
metrics.end_to_end.memory_usage_mb,
);
self.trim_history(&mut history);
}
fn update_rolling_stats(&self, history: &mut PerformanceHistory, metric: &str, value: f64) {
history
.rolling_stats
.entry(metric.to_string())
.or_insert_with(RollingStats::new)
.add_value(value);
}
fn trim_history(&self, history: &mut PerformanceHistory) {
let mut timestamps: Vec<_> = history.metrics.keys().cloned().collect();
timestamps.sort_by(|a, b| b.cmp(a));
if timestamps.len() > self.config.history_window {
let to_remove: Vec<_> = timestamps
.iter()
.skip(self.config.history_window)
.cloned()
.collect();
for ts in to_remove {
history.metrics.remove(&ts);
}
}
for stats in history.rolling_stats.values_mut() {
if stats.values.len() > self.config.history_window {
stats.values = stats
.values
.iter()
.rev()
.take(self.config.history_window)
.cloned()
.collect();
stats.values.reverse(); stats.update_stats();
}
}
}
async fn alert_regression(&self, regression: &PerformanceRegression) {
tracing::error!(
"🚨 PERFORMANCE REGRESSION DETECTED 🚨\n\
Metric: {}\n\
Change: {:.2}% ({:.2} absolute)\n\
Baseline: {:.2}, Current: {:.2}\n\
Confidence: {:.1}%, p-value: {:.4}\n\
Timestamp: {}",
regression.metric,
regression.change_percent * 100.0,
regression.change_absolute,
regression.baseline_value,
regression.current_value,
regression.confidence * 100.0,
regression.p_value,
regression.timestamp
);
}
fn statistical_test(
&self,
historical_values: &[f64],
current_value: f64,
sample_size: usize,
) -> (f64, f64) {
if historical_values.len() < 2 {
return (1.0, 0.0); }
let baseline_values: Vec<f64> = historical_values
.iter()
.rev()
.take(sample_size.min(historical_values.len()))
.cloned()
.collect();
let baseline_mean = self.average(&baseline_values);
let baseline_std = self.standard_deviation(&baseline_values, baseline_mean);
if baseline_std < f64::EPSILON {
return (1.0, 0.0); }
let z_score = (current_value - baseline_mean) / baseline_std;
let p_value = 2.0 * (1.0 - self.normal_cdf(z_score.abs()));
let confidence = 1.0 - p_value;
(p_value, confidence)
}
fn average(&self, values: &[f64]) -> f64 {
if values.is_empty() {
0.0
} else {
values.iter().sum::<f64>() / values.len() as f64
}
}
fn standard_deviation(&self, values: &[f64], mean: f64) -> f64 {
if values.len() < 2 {
0.0
} else {
let variance =
values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
variance.sqrt()
}
}
fn calculate_score_stats(&self, scores: &[f64]) -> ScoreStats {
if scores.is_empty() {
return ScoreStats {
mean: 0.0,
std_dev: 0.0,
min: 0.0,
max: 0.0,
median: 0.0,
};
}
let mean = self.average(scores);
let std_dev = self.standard_deviation(scores, mean);
let min = scores.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max = scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let mut sorted_scores = scores.to_vec();
sorted_scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
let median = if sorted_scores.len() % 2 == 0 {
(sorted_scores[sorted_scores.len() / 2 - 1] + sorted_scores[sorted_scores.len() / 2])
/ 2.0
} else {
sorted_scores[sorted_scores.len() / 2]
};
ScoreStats {
mean: mean as f32,
std_dev: std_dev as f32,
min: min as f32,
max: max as f32,
median: median as f32,
}
}
fn normal_cdf(&self, x: f64) -> f64 {
let a1 = 0.254829592;
let a2 = -0.284496736;
let a3 = 1.421413741;
let a4 = -1.453152027;
let a5 = 1.061405429;
let p = 0.3275911;
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs() / 2.0_f64.sqrt();
let t = 1.0 / (1.0 + p * x);
let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
0.5 * (1.0 + sign * y)
}
fn get_memory_usage(&self) -> Result<f64> {
Ok(0.0) }
fn build_context_for_timing(
&self,
results: &[reasonkit_mem::retrieval::HybridResult],
) -> String {
let mut context = String::new();
for result in results {
context.push_str(&result.text);
context.push_str("\n\n");
}
context
}
pub async fn get_history_summary(&self) -> Result<serde_json::Value> {
let history = self.history.read().await;
let summary = serde_json::json!({
"total_measurements": history.metrics.len(),
"metrics_tracked": history.rolling_stats.len(),
"rolling_stats": history.rolling_stats.iter()
.map(|(k, v)| (k.clone(), serde_json::json!({
"count": v.values.len(),
"mean": v.mean,
"std_dev": v.std_dev,
"trend_slope": v.trend_slope,
"latest_value": v.values.last().copied().unwrap_or(0.0),
"last_updated": v.last_updated
})))
.collect::<serde_json::Map<String, serde_json::Value>>(),
"config": self.config
});
Ok(summary)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rag::RagEngine;
#[tokio::test]
async fn test_performance_monitor_creation() {
let engine = RagEngine::in_memory().expect("Failed to create RAG engine");
let config = PerformanceConfig::default();
let monitor = RagPerformanceMonitor::new(engine, config);
assert!(monitor.config.benchmark_queries.len() > 0);
assert_eq!(monitor.config.alert_threshold, 0.05);
}
#[tokio::test]
async fn test_rolling_stats() {
let mut stats = RollingStats::new();
stats.add_value(100.0);
stats.add_value(105.0);
stats.add_value(95.0);
assert_eq!(stats.values.len(), 3);
assert!((stats.mean - 100.0).abs() < 0.1);
assert!(stats.std_dev > 0.0);
}
#[test]
fn test_score_stats_calculation() {
let monitor = RagPerformanceMonitor::new(
RagEngine::in_memory().unwrap(),
PerformanceConfig::default(),
);
let scores = vec![0.8, 0.9, 0.7, 0.85, 0.95];
let stats = monitor.calculate_score_stats(&scores);
assert!((stats.mean - 0.85).abs() < 0.1);
assert!(stats.std_dev > 0.0);
assert_eq!(stats.min, 0.7);
assert_eq!(stats.max, 0.95);
}
#[test]
fn test_statistical_test() {
let monitor = RagPerformanceMonitor::new(
RagEngine::in_memory().unwrap(),
PerformanceConfig::default(),
);
let historical = vec![100.0, 102.0, 98.0, 101.0, 99.0];
let current = 120.0;
let (p_value, confidence) = monitor.statistical_test(&historical, current, 5);
assert!(p_value < 0.05); assert!(confidence > 0.8);
}
#[tokio::test]
async fn test_performance_monitor_integration() {
let rag_engine = RagEngine::in_memory().expect("Failed to create RAG engine");
let config = PerformanceConfig {
alert_threshold: 0.05,
sample_size: 5,
benchmark_queries: vec!["Test query 1".to_string(), "Test query 2".to_string()],
..Default::default()
};
let monitor = RagPerformanceMonitor::new(rag_engine, config);
let metrics = monitor.run_benchmark().await.expect("Benchmark failed");
assert!(metrics.retrieval.retrieval_time_ms >= 0.0);
assert!(metrics.generation.generation_time_ms >= 0.0);
assert!(metrics.end_to_end.total_time_ms >= 0.0);
assert!(metrics.end_to_end.success_rate >= 0.0);
assert!(metrics.end_to_end.success_rate <= 1.0);
let history = monitor.get_history_summary().await.expect("History failed");
assert!(history.is_object());
}
#[test]
fn test_performance_config_defaults() {
let config = PerformanceConfig::default();
assert_eq!(config.alert_threshold, 0.05);
assert_eq!(config.sample_size, 100);
assert_eq!(config.confidence_level, 0.95);
assert_eq!(config.history_window, 1000);
assert!(!config.benchmark_queries.is_empty());
assert!(config.enable_memory_monitoring);
}
}