use crate::{TaskId, TaskResult};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
const MAX_EXECUTION_TIMES: usize = 10_000;
#[derive(Debug, Clone, Default)]
pub struct TaskCounts {
pub total: u64,
pub successful: u64,
pub failed: u64,
pub pending: u64,
pub running: u64,
}
#[derive(Debug, Clone)]
pub struct WorkerStats {
pub tasks_processed: u64,
pub average_execution_time: Duration,
pub idle_time: Duration,
}
impl Default for WorkerStats {
fn default() -> Self {
Self {
tasks_processed: 0,
average_execution_time: Duration::ZERO,
idle_time: Duration::ZERO,
}
}
}
#[derive(Debug, Clone)]
pub struct MetricsSnapshot {
pub task_counts: TaskCounts,
pub average_execution_time: Duration,
pub p50_execution_time: Duration,
pub p95_execution_time: Duration,
pub p99_execution_time: Duration,
pub queue_depths: HashMap<String, usize>,
pub worker_stats: HashMap<String, WorkerStats>,
}
#[derive(Clone)]
pub struct TaskMetrics {
task_counts: Arc<RwLock<TaskCounts>>,
execution_times: Arc<RwLock<VecDeque<Duration>>>,
queue_depths: Arc<RwLock<HashMap<String, usize>>>,
worker_stats: Arc<RwLock<HashMap<String, WorkerStats>>>,
}
impl TaskMetrics {
pub fn new() -> Self {
Self {
task_counts: Arc::new(RwLock::new(TaskCounts::default())),
execution_times: Arc::new(RwLock::new(VecDeque::new())),
queue_depths: Arc::new(RwLock::new(HashMap::new())),
worker_stats: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn record_task_start(&self, _task_id: &TaskId) -> TaskResult<()> {
let mut counts = self.task_counts.write().await;
counts.running += 1;
counts.total += 1;
Ok(())
}
pub async fn record_task_success(
&self,
_task_id: &TaskId,
duration: Duration,
) -> TaskResult<()> {
let mut counts = self.task_counts.write().await;
counts.successful += 1;
counts.running = counts.running.saturating_sub(1);
let mut times = self.execution_times.write().await;
if times.len() >= MAX_EXECUTION_TIMES {
times.pop_front();
}
times.push_back(duration);
Ok(())
}
pub async fn record_task_failure(
&self,
_task_id: &TaskId,
duration: Duration,
) -> TaskResult<()> {
let mut counts = self.task_counts.write().await;
counts.failed += 1;
counts.running = counts.running.saturating_sub(1);
let mut times = self.execution_times.write().await;
if times.len() >= MAX_EXECUTION_TIMES {
times.pop_front();
}
times.push_back(duration);
Ok(())
}
pub async fn record_queue_depth(&self, queue_name: String, depth: usize) -> TaskResult<()> {
let mut depths = self.queue_depths.write().await;
depths.insert(queue_name, depth);
Ok(())
}
pub async fn record_worker_stats(
&self,
worker_id: String,
stats: WorkerStats,
) -> TaskResult<()> {
let mut worker_stats = self.worker_stats.write().await;
worker_stats.insert(worker_id, stats);
Ok(())
}
pub async fn snapshot(&self) -> MetricsSnapshot {
let counts = self.task_counts.read().await.clone();
let times_deque = self.execution_times.read().await.clone();
let depths = self.queue_depths.read().await.clone();
let workers = self.worker_stats.read().await.clone();
let times: Vec<Duration> = times_deque.into_iter().collect();
let (average, p50, p95, p99) = Self::calculate_percentiles(×);
MetricsSnapshot {
task_counts: counts,
average_execution_time: average,
p50_execution_time: p50,
p95_execution_time: p95,
p99_execution_time: p99,
queue_depths: depths,
worker_stats: workers,
}
}
pub async fn reset(&self) -> TaskResult<()> {
let mut counts = self.task_counts.write().await;
*counts = TaskCounts::default();
let mut times = self.execution_times.write().await;
times.clear();
let mut depths = self.queue_depths.write().await;
depths.clear();
let mut workers = self.worker_stats.write().await;
workers.clear();
Ok(())
}
fn calculate_percentiles(times: &[Duration]) -> (Duration, Duration, Duration, Duration) {
if times.is_empty() {
return (
Duration::ZERO,
Duration::ZERO,
Duration::ZERO,
Duration::ZERO,
);
}
let mut sorted = times.to_vec();
sorted.sort();
let total: Duration = sorted.iter().sum();
let average = total / times.len() as u32;
let p50 = Self::percentile(&sorted, 0.50);
let p95 = Self::percentile(&sorted, 0.95);
let p99 = Self::percentile(&sorted, 0.99);
(average, p50, p95, p99)
}
fn percentile(sorted: &[Duration], percentile: f64) -> Duration {
if sorted.is_empty() {
return Duration::ZERO;
}
let index = ((sorted.len() as f64 - 1.0) * percentile) as usize;
sorted[index.min(sorted.len() - 1)]
}
}
impl Default for TaskMetrics {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_record_task_start() {
let metrics = TaskMetrics::new();
let task_id = TaskId::new();
metrics.record_task_start(&task_id).await.unwrap();
let snapshot = metrics.snapshot().await;
assert_eq!(snapshot.task_counts.total, 1);
assert_eq!(snapshot.task_counts.running, 1);
}
#[tokio::test]
async fn test_record_task_success() {
let metrics = TaskMetrics::new();
let task_id = TaskId::new();
metrics.record_task_start(&task_id).await.unwrap();
metrics
.record_task_success(&task_id, Duration::from_millis(100))
.await
.unwrap();
let snapshot = metrics.snapshot().await;
assert_eq!(snapshot.task_counts.total, 1);
assert_eq!(snapshot.task_counts.successful, 1);
assert_eq!(snapshot.task_counts.running, 0);
assert_eq!(snapshot.average_execution_time, Duration::from_millis(100));
}
#[tokio::test]
async fn test_record_task_failure() {
let metrics = TaskMetrics::new();
let task_id = TaskId::new();
metrics.record_task_start(&task_id).await.unwrap();
metrics
.record_task_failure(&task_id, Duration::from_millis(50))
.await
.unwrap();
let snapshot = metrics.snapshot().await;
assert_eq!(snapshot.task_counts.total, 1);
assert_eq!(snapshot.task_counts.failed, 1);
assert_eq!(snapshot.task_counts.running, 0);
}
#[tokio::test]
async fn test_record_queue_depth() {
let metrics = TaskMetrics::new();
metrics
.record_queue_depth("default".to_string(), 42)
.await
.unwrap();
metrics
.record_queue_depth("priority".to_string(), 10)
.await
.unwrap();
let snapshot = metrics.snapshot().await;
assert_eq!(snapshot.queue_depths.get("default"), Some(&42));
assert_eq!(snapshot.queue_depths.get("priority"), Some(&10));
}
#[tokio::test]
async fn test_record_worker_stats() {
let metrics = TaskMetrics::new();
let stats = WorkerStats {
tasks_processed: 10,
average_execution_time: Duration::from_millis(100),
idle_time: Duration::from_secs(5),
};
metrics
.record_worker_stats("worker-1".to_string(), stats.clone())
.await
.unwrap();
let snapshot = metrics.snapshot().await;
let worker = snapshot.worker_stats.get("worker-1").unwrap();
assert_eq!(worker.tasks_processed, 10);
assert_eq!(worker.average_execution_time, Duration::from_millis(100));
assert_eq!(worker.idle_time, Duration::from_secs(5));
}
#[tokio::test]
async fn test_percentile_calculation() {
let metrics = TaskMetrics::new();
let task_id = TaskId::new();
for i in 1..=100 {
metrics.record_task_start(&task_id).await.unwrap();
metrics
.record_task_success(&task_id, Duration::from_millis(i))
.await
.unwrap();
}
let snapshot = metrics.snapshot().await;
let avg = snapshot.average_execution_time;
assert!(
avg >= Duration::from_millis(50) && avg <= Duration::from_millis(51),
"Expected average around 50-51ms, got {:?}",
avg
);
assert_eq!(snapshot.p50_execution_time, Duration::from_millis(50));
assert_eq!(snapshot.p95_execution_time, Duration::from_millis(95));
assert_eq!(snapshot.p99_execution_time, Duration::from_millis(99));
}
#[tokio::test]
async fn test_reset() {
let metrics = TaskMetrics::new();
let task_id = TaskId::new();
metrics.record_task_start(&task_id).await.unwrap();
metrics
.record_task_success(&task_id, Duration::from_millis(100))
.await
.unwrap();
metrics
.record_queue_depth("default".to_string(), 42)
.await
.unwrap();
metrics.reset().await.unwrap();
let snapshot = metrics.snapshot().await;
assert_eq!(snapshot.task_counts.total, 0);
assert_eq!(snapshot.task_counts.successful, 0);
assert_eq!(snapshot.queue_depths.len(), 0);
assert_eq!(snapshot.average_execution_time, Duration::ZERO);
}
#[tokio::test]
async fn test_concurrent_access() {
let metrics = Arc::new(TaskMetrics::new());
let mut handles = vec![];
for i in 0..10 {
let metrics = Arc::clone(&metrics);
let handle = tokio::spawn(async move {
let task_id = TaskId::new();
metrics.record_task_start(&task_id).await.unwrap();
metrics
.record_task_success(&task_id, Duration::from_millis(i * 10))
.await
.unwrap();
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
let snapshot = metrics.snapshot().await;
assert_eq!(snapshot.task_counts.total, 10);
assert_eq!(snapshot.task_counts.successful, 10);
}
#[tokio::test]
async fn test_empty_percentiles() {
let metrics = TaskMetrics::new();
let snapshot = metrics.snapshot().await;
assert_eq!(snapshot.average_execution_time, Duration::ZERO);
assert_eq!(snapshot.p50_execution_time, Duration::ZERO);
assert_eq!(snapshot.p95_execution_time, Duration::ZERO);
assert_eq!(snapshot.p99_execution_time, Duration::ZERO);
}
#[tokio::test]
async fn test_single_value_percentiles() {
let metrics = TaskMetrics::new();
let task_id = TaskId::new();
metrics.record_task_start(&task_id).await.unwrap();
metrics
.record_task_success(&task_id, Duration::from_millis(100))
.await
.unwrap();
let snapshot = metrics.snapshot().await;
assert_eq!(snapshot.average_execution_time, Duration::from_millis(100));
assert_eq!(snapshot.p50_execution_time, Duration::from_millis(100));
assert_eq!(snapshot.p95_execution_time, Duration::from_millis(100));
assert_eq!(snapshot.p99_execution_time, Duration::from_millis(100));
}
}