use super::request::{McpRequest, McpResponse};
use futures_util::future::BoxFuture;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use tower_layer::Layer;
use tower_service::Service;
use turbomcp_protocol::McpError;
#[derive(Debug)]
pub struct Metrics {
total_requests: AtomicU64,
successful_responses: AtomicU64,
error_responses: AtomicU64,
response_times: RwLock<ResponseTimeStats>,
method_metrics: RwLock<HashMap<String, MethodMetrics>>,
start_time: Instant,
}
#[derive(Debug, Default)]
struct ResponseTimeStats {
total_ms: u64,
count: u64,
min_ms: Option<u64>,
max_ms: u64,
recent: Vec<u64>,
}
#[derive(Debug, Clone, Default)]
pub struct MethodMetrics {
pub count: u64,
pub avg_duration_ms: f64,
pub success_count: u64,
pub error_count: u64,
}
#[derive(Debug, Clone)]
pub struct MetricsSnapshot {
pub total_requests: u64,
pub successful_responses: u64,
pub error_responses: u64,
pub avg_response_time_ms: f64,
pub min_response_time_ms: Option<u64>,
pub max_response_time_ms: u64,
pub requests_per_second: f64,
pub method_metrics: HashMap<String, MethodMetrics>,
pub uptime: Duration,
}
impl Metrics {
#[must_use]
pub fn new() -> Self {
Self {
total_requests: AtomicU64::new(0),
successful_responses: AtomicU64::new(0),
error_responses: AtomicU64::new(0),
response_times: RwLock::new(ResponseTimeStats::default()),
method_metrics: RwLock::new(HashMap::new()),
start_time: Instant::now(),
}
}
pub fn record_request(&self) {
self.total_requests.fetch_add(1, Ordering::Relaxed);
}
pub fn record_response(&self, method: &str, duration: Duration, is_success: bool) {
let duration_ms = duration.as_millis() as u64;
if is_success {
self.successful_responses.fetch_add(1, Ordering::Relaxed);
} else {
self.error_responses.fetch_add(1, Ordering::Relaxed);
}
{
let mut stats = self.response_times.write();
stats.total_ms += duration_ms;
stats.count += 1;
stats.max_ms = stats.max_ms.max(duration_ms);
stats.min_ms = Some(stats.min_ms.map_or(duration_ms, |min| min.min(duration_ms)));
if stats.recent.len() >= 1000 {
stats.recent.remove(0);
}
stats.recent.push(duration_ms);
}
{
let mut methods = self.method_metrics.write();
let entry = methods.entry(method.to_string()).or_default();
entry.count += 1;
if is_success {
entry.success_count += 1;
} else {
entry.error_count += 1;
}
entry.avg_duration_ms = (entry.avg_duration_ms * (entry.count - 1) as f64
+ duration_ms as f64)
/ entry.count as f64;
}
}
#[must_use]
pub fn snapshot(&self) -> MetricsSnapshot {
let total = self.total_requests.load(Ordering::Relaxed);
let successful = self.successful_responses.load(Ordering::Relaxed);
let errors = self.error_responses.load(Ordering::Relaxed);
let uptime = self.start_time.elapsed();
let (avg_ms, min_ms, max_ms) = {
let stats = self.response_times.read();
let avg = if stats.count > 0 {
stats.total_ms as f64 / stats.count as f64
} else {
0.0
};
(avg, stats.min_ms, stats.max_ms)
};
let method_metrics = self.method_metrics.read().clone();
MetricsSnapshot {
total_requests: total,
successful_responses: successful,
error_responses: errors,
avg_response_time_ms: avg_ms,
min_response_time_ms: min_ms,
max_response_time_ms: max_ms,
requests_per_second: if uptime.as_secs() > 0 {
total as f64 / uptime.as_secs_f64()
} else {
total as f64
},
method_metrics,
uptime,
}
}
pub fn reset(&self) {
self.total_requests.store(0, Ordering::Relaxed);
self.successful_responses.store(0, Ordering::Relaxed);
self.error_responses.store(0, Ordering::Relaxed);
*self.response_times.write() = ResponseTimeStats::default();
self.method_metrics.write().clear();
}
}
impl Default for Metrics {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct MetricsLayer {
metrics: Arc<Metrics>,
}
impl MetricsLayer {
#[must_use]
pub fn new(metrics: Arc<Metrics>) -> Self {
Self { metrics }
}
#[must_use]
pub fn with_internal_metrics() -> Self {
Self {
metrics: Arc::new(Metrics::new()),
}
}
#[must_use]
pub fn metrics(&self) -> &Arc<Metrics> {
&self.metrics
}
}
impl<S> Layer<S> for MetricsLayer {
type Service = MetricsService<S>;
fn layer(&self, inner: S) -> Self::Service {
MetricsService {
inner,
metrics: Arc::clone(&self.metrics),
}
}
}
#[derive(Debug, Clone)]
pub struct MetricsService<S> {
inner: S,
metrics: Arc<Metrics>,
}
impl<S> MetricsService<S> {
pub fn inner(&self) -> &S {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}
pub fn metrics(&self) -> &Arc<Metrics> {
&self.metrics
}
}
impl<S> Service<McpRequest> for MetricsService<S>
where
S: Service<McpRequest, Response = McpResponse> + Clone + Send + 'static,
S::Future: Send,
S::Error: Into<McpError>,
{
type Response = McpResponse;
type Error = McpError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: McpRequest) -> Self::Future {
let method = req.method().to_string();
let metrics = Arc::clone(&self.metrics);
let start = Instant::now();
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
metrics.record_request();
Box::pin(async move {
let result = inner.call(req).await.map_err(Into::into);
let duration = start.elapsed();
match &result {
Ok(response) => {
metrics.record_response(&method, duration, response.is_success());
}
Err(_) => {
metrics.record_response(&method, duration, false);
}
}
result
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use turbomcp_protocol::MessageId;
use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
#[test]
fn test_metrics_creation() {
let metrics = Metrics::new();
let snapshot = metrics.snapshot();
assert_eq!(snapshot.total_requests, 0);
assert_eq!(snapshot.successful_responses, 0);
assert_eq!(snapshot.error_responses, 0);
}
#[test]
fn test_metrics_recording() {
let metrics = Metrics::new();
metrics.record_request();
metrics.record_request();
metrics.record_response("test/method", Duration::from_millis(100), true);
metrics.record_response("test/method", Duration::from_millis(200), false);
let snapshot = metrics.snapshot();
assert_eq!(snapshot.total_requests, 2);
assert_eq!(snapshot.successful_responses, 1);
assert_eq!(snapshot.error_responses, 1);
assert_eq!(snapshot.min_response_time_ms, Some(100));
assert_eq!(snapshot.max_response_time_ms, 200);
}
#[test]
fn test_method_metrics() {
let metrics = Metrics::new();
metrics.record_response("tools/call", Duration::from_millis(50), true);
metrics.record_response("tools/call", Duration::from_millis(100), true);
metrics.record_response("resources/read", Duration::from_millis(75), false);
let snapshot = metrics.snapshot();
let tool_metrics = snapshot.method_metrics.get("tools/call").unwrap();
assert_eq!(tool_metrics.count, 2);
assert_eq!(tool_metrics.success_count, 2);
assert_eq!(tool_metrics.error_count, 0);
assert_eq!(tool_metrics.avg_duration_ms, 75.0);
let resource_metrics = snapshot.method_metrics.get("resources/read").unwrap();
assert_eq!(resource_metrics.count, 1);
assert_eq!(resource_metrics.success_count, 0);
assert_eq!(resource_metrics.error_count, 1);
}
#[test]
fn test_metrics_reset() {
let metrics = Metrics::new();
metrics.record_request();
metrics.record_response("test", Duration::from_millis(100), true);
metrics.reset();
let snapshot = metrics.snapshot();
assert_eq!(snapshot.total_requests, 0);
assert!(snapshot.method_metrics.is_empty());
}
#[test]
fn test_metrics_layer_creation() {
let metrics = Arc::new(Metrics::new());
let layer = MetricsLayer::new(Arc::clone(&metrics));
assert!(Arc::ptr_eq(&metrics, layer.metrics()));
}
#[tokio::test]
async fn test_metrics_service() {
use tower::ServiceExt;
let metrics = Arc::new(Metrics::new());
let mock_service = tower::service_fn(|_req: McpRequest| async {
Ok::<_, McpError>(McpResponse::success(
json!({"result": "ok"}),
Duration::from_millis(10),
))
});
let mut service = MetricsLayer::new(Arc::clone(&metrics)).layer(mock_service);
let request = McpRequest::new(JsonRpcRequest {
jsonrpc: JsonRpcVersion,
id: MessageId::from("test-1"),
method: "test/method".to_string(),
params: None,
});
let _ = service.ready().await.unwrap().call(request).await.unwrap();
let snapshot = metrics.snapshot();
assert_eq!(snapshot.total_requests, 1);
assert_eq!(snapshot.successful_responses, 1);
}
}