use std::collections::HashMap;
use std::time::Duration;
use hdrhistogram::Histogram;
use serde::Serialize;
#[derive(Debug, Serialize)]
pub struct PerfMetrics {
pub total_requests: usize,
pub successful_requests: usize,
pub failed_requests: usize,
pub total_duration_ms: f64,
pub latency_min_ms: f64,
pub latency_max_ms: f64,
pub latency_avg_ms: f64,
pub latency_p50_ms: f64,
pub latency_p95_ms: f64,
pub latency_p99_ms: f64,
pub requests_per_second: f64,
pub error_rate_percent: f64,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub endpoints: HashMap<String, PerfMetrics>,
}
struct StatsBucket {
histogram: Histogram<u64>,
successful: usize,
failed: usize,
}
impl StatsBucket {
fn new() -> Self {
let histogram = Histogram::new_with_bounds(1, 60_000_000, 3)
.expect("Failed to create histogram");
Self {
histogram,
successful: 0,
failed: 0,
}
}
fn record_success(&mut self, duration: Duration) {
let micros = duration.as_micros() as u64;
let micros = micros.min(self.histogram.high());
let _ = self.histogram.record(micros);
self.successful += 1;
}
fn record_failure(&mut self, duration: Duration) {
let micros = duration.as_micros() as u64;
let micros = micros.min(self.histogram.high());
let _ = self.histogram.record(micros);
self.failed += 1;
}
fn compute_metrics(&self, total_duration: Duration) -> PerfMetrics {
let total = self.successful + self.failed;
let total_duration_ms = total_duration.as_secs_f64() * 1000.0;
let requests_per_second = if total_duration.as_secs_f64() > 0.0 {
total as f64 / total_duration.as_secs_f64()
} else {
0.0
};
let error_rate = if total > 0 {
(self.failed as f64 / total as f64) * 100.0
} else {
0.0
};
let to_ms = |micros: u64| micros as f64 / 1000.0;
PerfMetrics {
total_requests: total,
successful_requests: self.successful,
failed_requests: self.failed,
total_duration_ms,
latency_min_ms: to_ms(self.histogram.min()),
latency_max_ms: to_ms(self.histogram.max()),
latency_avg_ms: to_ms(self.histogram.mean() as u64),
latency_p50_ms: to_ms(self.histogram.value_at_percentile(50.0)),
latency_p95_ms: to_ms(self.histogram.value_at_percentile(95.0)),
latency_p99_ms: to_ms(self.histogram.value_at_percentile(99.0)),
requests_per_second,
error_rate_percent: error_rate,
endpoints: HashMap::new(), }
}
}
pub struct MetricsCollector {
global: StatsBucket,
endpoints: HashMap<String, StatsBucket>,
start_time: Option<std::time::Instant>,
end_time: Option<std::time::Instant>,
}
impl MetricsCollector {
pub fn new() -> Self {
Self {
global: StatsBucket::new(),
endpoints: HashMap::new(),
start_time: None,
end_time: None,
}
}
pub fn start(&mut self) {
self.start_time = Some(std::time::Instant::now());
}
pub fn finish(&mut self) {
self.end_time = Some(std::time::Instant::now());
}
pub fn record_success(&mut self, duration: Duration, label: Option<&str>) {
self.global.record_success(duration);
if let Some(lbl) = label {
self.endpoints
.entry(lbl.to_string())
.or_insert_with(StatsBucket::new)
.record_success(duration);
}
}
pub fn record_failure(&mut self, duration: Duration, label: Option<&str>) {
self.global.record_failure(duration);
if let Some(lbl) = label {
self.endpoints
.entry(lbl.to_string())
.or_insert_with(StatsBucket::new)
.record_failure(duration);
}
}
pub fn compute_metrics(&self) -> PerfMetrics {
let total_duration = match (self.start_time, self.end_time) {
(Some(start), Some(end)) => end.duration_since(start),
_ => Duration::ZERO,
};
let mut metrics = self.global.compute_metrics(total_duration);
let endpoint_metrics: HashMap<String, PerfMetrics> = self.endpoints
.iter()
.map(|(k, v)| (k.clone(), v.compute_metrics(total_duration)))
.collect();
metrics.endpoints = endpoint_metrics;
metrics
}
}
impl Default for MetricsCollector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_collector() {
let collector = MetricsCollector::new();
let metrics = collector.compute_metrics();
assert_eq!(metrics.total_requests, 0);
assert!(metrics.endpoints.is_empty());
}
#[test]
fn test_record_success_global() {
let mut collector = MetricsCollector::new();
collector.record_success(Duration::from_millis(100), None);
collector.record_success(Duration::from_millis(200), None);
let metrics = collector.compute_metrics();
assert_eq!(metrics.successful_requests, 2);
assert_eq!(metrics.failed_requests, 0);
}
#[test]
fn test_record_failure_global() {
let mut collector = MetricsCollector::new();
collector.record_failure(Duration::from_millis(100), None);
let metrics = collector.compute_metrics();
assert_eq!(metrics.failed_requests, 1);
}
#[test]
fn test_record_with_endpoints() {
let mut collector = MetricsCollector::new();
collector.record_success(Duration::from_millis(100), Some("GET /api"));
collector.record_success(Duration::from_millis(200), Some("GET /api"));
collector.record_failure(Duration::from_millis(50), Some("POST /login"));
let metrics = collector.compute_metrics();
assert_eq!(metrics.total_requests, 3);
assert_eq!(metrics.successful_requests, 2);
assert_eq!(metrics.failed_requests, 1);
assert_eq!(metrics.endpoints.len(), 2);
let api_metrics = metrics.endpoints.get("GET /api").unwrap();
assert_eq!(api_metrics.total_requests, 2);
assert_eq!(api_metrics.successful_requests, 2);
let login_metrics = metrics.endpoints.get("POST /login").unwrap();
assert_eq!(login_metrics.total_requests, 1);
assert_eq!(login_metrics.failed_requests, 1);
}
}