use std::collections::HashMap;
use std::sync::Arc;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::types::Timestamp;
pub use turbomcp_core::context::{RequestContext, TransportType};
#[derive(Debug, Clone)]
pub struct ResponseContext {
pub request_id: String,
pub timestamp: Timestamp,
pub duration: std::time::Duration,
pub status: ResponseStatus,
pub metadata: Arc<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ResponseStatus {
Success,
Error {
code: i32,
message: String,
},
Partial,
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestInfo {
pub timestamp: DateTime<Utc>,
pub client_id: String,
pub method_name: String,
pub parameters: serde_json::Value,
pub response_time_ms: Option<u64>,
pub success: bool,
pub error_message: Option<String>,
pub status_code: Option<u16>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl ResponseContext {
pub fn success(request_id: impl Into<String>, duration: std::time::Duration) -> Self {
Self {
request_id: request_id.into(),
timestamp: Timestamp::now(),
duration,
status: ResponseStatus::Success,
metadata: Arc::new(HashMap::new()),
}
}
pub fn error(
request_id: impl Into<String>,
duration: std::time::Duration,
code: i32,
message: impl Into<String>,
) -> Self {
Self {
request_id: request_id.into(),
timestamp: Timestamp::now(),
duration,
status: ResponseStatus::Error {
code,
message: message.into(),
},
metadata: Arc::new(HashMap::new()),
}
}
}
impl RequestInfo {
#[must_use]
pub fn new(client_id: String, method_name: String, parameters: serde_json::Value) -> Self {
Self {
timestamp: Utc::now(),
client_id,
method_name,
parameters,
response_time_ms: None,
success: false,
error_message: None,
status_code: None,
metadata: HashMap::new(),
}
}
#[must_use]
pub const fn complete_success(mut self, response_time_ms: u64) -> Self {
self.response_time_ms = Some(response_time_ms);
self.success = true;
self.status_code = Some(200);
self
}
#[must_use]
pub fn complete_error(mut self, response_time_ms: u64, error: String) -> Self {
self.response_time_ms = Some(response_time_ms);
self.success = false;
self.error_message = Some(error);
self.status_code = Some(500);
self
}
#[must_use]
pub const fn with_status_code(mut self, code: u16) -> Self {
self.status_code = Some(code);
self
}
#[must_use]
pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
self.metadata.insert(key, value);
self
}
}
pub trait RequestContextExt {
#[must_use]
fn with_enhanced_client_id(self, client_id: super::client::ClientId) -> Self;
#[must_use]
fn extract_client_id(
self,
extractor: &super::client::ClientIdExtractor,
headers: Option<&HashMap<String, String>>,
query_params: Option<&HashMap<String, String>>,
) -> Self;
fn get_enhanced_client_id(&self) -> Option<super::client::ClientId>;
}
impl RequestContextExt for RequestContext {
fn with_enhanced_client_id(self, client_id: super::client::ClientId) -> Self {
self.with_client_id(client_id.as_str())
.with_metadata(
"client_id_method",
serde_json::Value::String(client_id.auth_method().to_string()),
)
.with_metadata(
"client_authenticated",
serde_json::Value::Bool(client_id.is_authenticated()),
)
}
fn extract_client_id(
self,
extractor: &super::client::ClientIdExtractor,
headers: Option<&HashMap<String, String>>,
query_params: Option<&HashMap<String, String>>,
) -> Self {
let client_id = extractor.extract_client_id(headers, query_params);
self.with_enhanced_client_id(client_id)
}
fn get_enhanced_client_id(&self) -> Option<super::client::ClientId> {
self.client_id.as_ref().map(|id| {
let method = self
.get_metadata("client_id_method")
.and_then(|v| v.as_str())
.unwrap_or("header");
match method {
"bearer_token" => super::client::ClientId::Token(id.clone()),
"session_cookie" => super::client::ClientId::Session(id.clone()),
"query_param" => super::client::ClientId::QueryParam(id.clone()),
_ => super::client::ClientId::Header(id.clone()),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn response_context_builders() {
let success = ResponseContext::success("req-1", std::time::Duration::from_millis(10));
assert_eq!(success.request_id, "req-1");
assert_eq!(success.status, ResponseStatus::Success);
let err =
ResponseContext::error("req-2", std::time::Duration::from_millis(5), -32000, "boom");
assert!(matches!(err.status, ResponseStatus::Error { .. }));
}
#[test]
fn request_info_lifecycle() {
let info = RequestInfo::new(
"client-1".into(),
"tools/list".into(),
serde_json::json!({}),
)
.complete_success(42)
.with_status_code(200)
.with_metadata("foo".into(), serde_json::json!("bar"));
assert!(info.success);
assert_eq!(info.response_time_ms, Some(42));
assert_eq!(info.metadata.get("foo"), Some(&serde_json::json!("bar")));
}
}