use axum::{body::Body, extract::Request, middleware::Next, response::Response};
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Instant;
use tower::{Layer, Service};
use crate::AvxMetrics;
#[derive(Clone, Debug)]
pub struct LatencyCollector {
service_name: String,
latencies: Arc<Mutex<Vec<f64>>>,
max_size: usize,
metrics: Arc<AvxMetrics>,
}
impl LatencyCollector {
pub fn new(service_name: impl Into<String>) -> Self {
Self::with_capacity(service_name, 1000)
}
pub fn with_capacity(service_name: impl Into<String>, max_size: usize) -> Self {
Self {
service_name: service_name.into(),
latencies: Arc::new(Mutex::new(Vec::with_capacity(max_size))),
max_size,
metrics: Arc::new(AvxMetrics::new()),
}
}
pub fn record(&self, latency_ms: f64) {
if let Ok(mut latencies) = self.latencies.lock() {
latencies.push(latency_ms);
if latencies.len() > self.max_size {
let drain_count = self.max_size / 10; latencies.drain(0..drain_count);
}
if latency_ms > 1000.0 {
tracing::warn!(
service = %self.service_name,
latency_ms = latency_ms,
"High latency detected"
);
}
}
}
pub fn snapshot(&self) -> Vec<f64> {
self.latencies.lock().map(|l| l.clone()).unwrap_or_default()
}
pub fn statistics(&self) -> LatencyStatistics {
let latencies = self.snapshot();
if latencies.is_empty() {
return LatencyStatistics::default();
}
let mut sorted = latencies.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let sum: f64 = latencies.iter().sum();
let mean = sum / latencies.len() as f64;
let variance: f64 =
latencies.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / latencies.len() as f64;
let std_dev = variance.sqrt();
LatencyStatistics {
count: latencies.len(),
mean_ms: mean,
std_dev_ms: std_dev,
min_ms: *sorted.first().unwrap(),
max_ms: *sorted.last().unwrap(),
p50_ms: percentile(&sorted, 0.50),
p95_ms: percentile(&sorted, 0.95),
p99_ms: percentile(&sorted, 0.99),
}
}
pub fn detect_anomalies(&self) -> Result<Vec<crate::Anomaly>, crate::TelemetryError> {
let latencies = self.snapshot();
if latencies.is_empty() {
return Ok(vec![]);
}
self.metrics.track_latencies(latencies)
}
pub fn clear(&self) {
if let Ok(mut latencies) = self.latencies.lock() {
latencies.clear();
}
}
pub fn service_name(&self) -> &str {
&self.service_name
}
}
#[derive(Debug, Clone, Default)]
pub struct LatencyStatistics {
pub count: usize,
pub mean_ms: f64,
pub std_dev_ms: f64,
pub min_ms: f64,
pub max_ms: f64,
pub p50_ms: f64,
pub p95_ms: f64,
pub p99_ms: f64,
}
impl LatencyStatistics {
pub fn is_healthy(&self, p99_threshold_ms: f64) -> bool {
self.p99_ms <= p99_threshold_ms
}
pub fn to_json(&self) -> serde_json::Value {
serde_json::json!({
"count": self.count,
"mean_ms": self.mean_ms,
"std_dev_ms": self.std_dev_ms,
"min_ms": self.min_ms,
"max_ms": self.max_ms,
"percentiles": {
"p50": self.p50_ms,
"p95": self.p95_ms,
"p99": self.p99_ms,
}
})
}
}
#[derive(Clone)]
pub struct LatencyMiddleware {
collector: LatencyCollector,
}
impl LatencyMiddleware {
pub fn new(service_name: impl Into<String>) -> Self {
Self {
collector: LatencyCollector::new(service_name),
}
}
pub fn with_collector(collector: LatencyCollector) -> Self {
Self { collector }
}
pub fn collector(&self) -> &LatencyCollector {
&self.collector
}
pub fn into_layer(self) -> LatencyLayer {
LatencyLayer {
collector: self.collector,
}
}
}
#[derive(Clone)]
pub struct LatencyLayer {
collector: LatencyCollector,
}
impl<S> Layer<S> for LatencyLayer {
type Service = LatencyService<S>;
fn layer(&self, inner: S) -> Self::Service {
LatencyService {
inner,
collector: self.collector.clone(),
}
}
}
#[derive(Clone)]
pub struct LatencyService<S> {
inner: S,
collector: LatencyCollector,
}
impl<S> Service<Request<Body>> for LatencyService<S>
where
S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let collector = self.collector.clone();
let mut inner = self.inner.clone();
let start = Instant::now();
let method = request.method().clone();
let uri = request.uri().clone();
Box::pin(async move {
let response = inner.call(request).await?;
let duration = start.elapsed();
let latency_ms = duration.as_secs_f64() * 1000.0;
collector.record(latency_ms);
tracing::debug!(
service = %collector.service_name(),
method = %method,
uri = %uri,
status = response.status().as_u16(),
latency_ms = latency_ms,
"Request completed"
);
Ok(response)
})
}
}
fn percentile(sorted_data: &[f64], p: f64) -> f64 {
if sorted_data.is_empty() {
return 0.0;
}
let index = (p * (sorted_data.len() - 1) as f64).round() as usize;
sorted_data[index.min(sorted_data.len() - 1)]
}
pub async fn latency_middleware(
request: Request,
next: Next,
collector: Arc<LatencyCollector>,
) -> Response {
let start = Instant::now();
let method = request.method().clone();
let uri = request.uri().clone();
let response = next.run(request).await;
let duration = start.elapsed();
let latency_ms = duration.as_secs_f64() * 1000.0;
collector.record(latency_ms);
tracing::debug!(
service = %collector.service_name(),
method = %method,
uri = %uri,
status = response.status().as_u16(),
latency_ms = latency_ms,
"Request completed"
);
response
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_latency_collector_creation() {
let collector = LatencyCollector::new("test-service");
assert_eq!(collector.service_name(), "test-service");
assert_eq!(collector.snapshot().len(), 0);
}
#[test]
fn test_latency_recording() {
let collector = LatencyCollector::new("test");
collector.record(10.0);
collector.record(20.0);
collector.record(30.0);
let snapshot = collector.snapshot();
assert_eq!(snapshot.len(), 3);
assert_eq!(snapshot, vec![10.0, 20.0, 30.0]);
}
#[test]
fn test_buffer_circular_behavior() {
let collector = LatencyCollector::with_capacity("test", 10);
for i in 0..15 {
collector.record(i as f64);
}
let snapshot = collector.snapshot();
assert!(snapshot.len() <= 10);
}
#[test]
fn test_latency_statistics() {
let collector = LatencyCollector::new("test");
let latencies = vec![10.0, 20.0, 30.0, 40.0, 50.0];
for latency in &latencies {
collector.record(*latency);
}
let stats = collector.statistics();
assert_eq!(stats.count, 5);
assert_eq!(stats.mean_ms, 30.0);
assert_eq!(stats.min_ms, 10.0);
assert_eq!(stats.max_ms, 50.0);
assert_eq!(stats.p50_ms, 30.0);
}
#[test]
fn test_statistics_empty_collector() {
let collector = LatencyCollector::new("test");
let stats = collector.statistics();
assert_eq!(stats.count, 0);
assert_eq!(stats.mean_ms, 0.0);
}
#[test]
fn test_statistics_is_healthy() {
let collector = LatencyCollector::new("test");
for i in 0..100 {
collector.record((i % 20) as f64); }
let stats = collector.statistics();
assert!(stats.is_healthy(50.0));
assert!(stats.p99_ms < 50.0);
}
#[test]
fn test_clear_collector() {
let collector = LatencyCollector::new("test");
collector.record(10.0);
collector.record(20.0);
assert_eq!(collector.snapshot().len(), 2);
collector.clear();
assert_eq!(collector.snapshot().len(), 0);
}
#[test]
fn test_percentile_calculation() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
assert_eq!(percentile(&data, 0.50), 5.0);
assert_eq!(percentile(&data, 0.95), 10.0);
assert_eq!(percentile(&data, 0.00), 1.0);
assert_eq!(percentile(&data, 1.00), 10.0);
}
#[test]
fn test_statistics_to_json() {
let stats = LatencyStatistics {
count: 100,
mean_ms: 15.5,
std_dev_ms: 3.2,
min_ms: 10.0,
max_ms: 25.0,
p50_ms: 15.0,
p95_ms: 22.0,
p99_ms: 24.0,
};
let json = stats.to_json();
assert_eq!(json["count"], 100);
assert_eq!(json["mean_ms"], 15.5);
assert_eq!(json["percentiles"]["p50"], 15.0);
assert_eq!(json["percentiles"]["p99"], 24.0);
}
#[tokio::test]
async fn test_detect_anomalies() {
let collector = LatencyCollector::new("test");
for _ in 0..100 {
collector.record(10.0);
}
collector.record(1000.0);
let anomalies = collector.detect_anomalies().unwrap();
assert!(!anomalies.is_empty());
}
#[test]
fn test_middleware_creation() {
let middleware = LatencyMiddleware::new("test-service");
assert_eq!(middleware.collector().service_name(), "test-service");
}
}