use alloc::string::{String, ToString};
use alloc::sync::Arc;
use alloc::vec::Vec;
use hashbrown::HashMap as HashbrownMap;
use serde_json::Value;
use crate::auth::Principal;
use crate::error::{McpError, McpResult};
use crate::session::McpSession;
#[cfg(feature = "std")]
use crate::session::Cancellable;
#[cfg(feature = "std")]
use std::time::Instant;
use turbomcp_types::{CreateMessageRequest, CreateMessageResult, ElicitResult};
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Default, serde::Serialize, serde::Deserialize,
)]
#[serde(rename_all = "lowercase")]
#[non_exhaustive]
pub enum TransportType {
#[default]
Stdio,
Http,
WebSocket,
Tcp,
Unix,
Wasm,
Channel,
Unknown,
}
impl TransportType {
#[inline]
pub fn is_network(&self) -> bool {
matches!(self, Self::Http | Self::WebSocket | Self::Tcp)
}
#[inline]
pub fn is_local(&self) -> bool {
matches!(self, Self::Stdio | Self::Unix | Self::Channel)
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Stdio => "stdio",
Self::Http => "http",
Self::WebSocket => "websocket",
Self::Tcp => "tcp",
Self::Unix => "unix",
Self::Wasm => "wasm",
Self::Channel => "channel",
Self::Unknown => "unknown",
}
}
}
impl core::fmt::Display for TransportType {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Default)]
pub struct RequestContext {
pub request_id: String,
pub transport: TransportType,
pub user_id: Option<String>,
pub session_id: Option<String>,
pub client_id: Option<String>,
pub metadata: HashbrownMap<String, Value>,
pub principal: Option<Principal>,
pub session: Option<Arc<dyn McpSession>>,
pub headers: Option<HashbrownMap<String, String>>,
#[cfg(feature = "std")]
pub start_time: Option<Instant>,
#[cfg(feature = "std")]
pub cancellation_token: Option<Arc<dyn Cancellable>>,
}
impl RequestContext {
pub fn new() -> Self {
#[cfg(feature = "std")]
{
Self {
request_id: uuid::Uuid::new_v4().to_string(),
..Default::default()
}
}
#[cfg(not(feature = "std"))]
{
Self::default()
}
}
pub fn with_id_and_transport(request_id: impl Into<String>, transport: TransportType) -> Self {
Self {
request_id: request_id.into(),
transport,
..Default::default()
}
}
pub fn with_id(request_id: impl Into<String>) -> Self {
Self {
request_id: request_id.into(),
..Default::default()
}
}
#[inline]
pub fn stdio() -> Self {
Self::new().with_transport(TransportType::Stdio)
}
#[inline]
pub fn http() -> Self {
Self::new().with_transport(TransportType::Http)
}
#[inline]
pub fn websocket() -> Self {
Self::new().with_transport(TransportType::WebSocket)
}
#[inline]
pub fn tcp() -> Self {
Self::new().with_transport(TransportType::Tcp)
}
#[inline]
pub fn unix() -> Self {
Self::new().with_transport(TransportType::Unix)
}
#[inline]
pub fn wasm() -> Self {
Self::new().with_transport(TransportType::Wasm)
}
#[inline]
pub fn channel() -> Self {
Self::new().with_transport(TransportType::Channel)
}
}
impl RequestContext {
#[must_use]
pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
self.request_id = id.into();
self
}
#[must_use]
pub fn with_transport(mut self, transport: TransportType) -> Self {
self.transport = transport;
self
}
#[must_use]
pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
self.user_id = Some(user_id.into());
self
}
#[must_use]
pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
self.session_id = Some(session_id.into());
self
}
#[must_use]
pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
self.client_id = Some(client_id.into());
self
}
#[must_use]
pub fn with_principal(mut self, principal: Principal) -> Self {
self.principal = Some(principal);
self
}
#[must_use]
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
#[must_use]
pub fn with_session(mut self, session: Arc<dyn McpSession>) -> Self {
self.session = Some(session);
self
}
#[must_use]
pub fn with_headers(mut self, headers: HashbrownMap<String, String>) -> Self {
self.headers = Some(headers);
self
}
#[cfg(feature = "std")]
#[must_use]
pub fn with_start_time(mut self, start: Instant) -> Self {
self.start_time = Some(start);
self
}
#[cfg(feature = "std")]
#[must_use]
pub fn with_cancellation_token(mut self, token: Arc<dyn Cancellable>) -> Self {
self.cancellation_token = Some(token);
self
}
}
impl RequestContext {
pub fn insert_metadata(&mut self, key: impl Into<String>, value: impl Into<Value>) {
self.metadata.insert(key.into(), value.into());
}
pub fn set_principal(&mut self, principal: Principal) {
self.principal = Some(principal);
}
pub fn clear_principal(&mut self) {
self.principal = None;
}
pub fn set_session(&mut self, session: Arc<dyn McpSession>) {
self.session = Some(session);
}
}
impl RequestContext {
#[inline]
pub fn request_id(&self) -> &str {
&self.request_id
}
#[inline]
pub fn has_request_id(&self) -> bool {
!self.request_id.is_empty()
}
#[inline]
pub fn transport(&self) -> TransportType {
self.transport
}
#[inline]
pub fn user_id(&self) -> Option<&str> {
self.user_id.as_deref()
}
#[inline]
pub fn session_id(&self) -> Option<&str> {
self.session_id.as_deref()
}
#[inline]
pub fn client_id(&self) -> Option<&str> {
self.client_id.as_deref()
}
#[inline]
pub fn get_metadata(&self, key: &str) -> Option<&Value> {
self.metadata.get(key)
}
pub fn get_metadata_str(&self, key: &str) -> Option<&str> {
self.metadata.get(key).and_then(|v| v.as_str())
}
#[inline]
pub fn has_metadata(&self, key: &str) -> bool {
self.metadata.contains_key(key)
}
#[inline]
pub fn principal(&self) -> Option<&Principal> {
self.principal.as_ref()
}
pub fn is_authenticated(&self) -> bool {
self.principal.is_some() || self.user_id.is_some()
}
pub fn subject(&self) -> Option<&str> {
self.principal
.as_ref()
.map(|p| p.subject.as_str())
.or(self.user_id.as_deref())
}
#[inline]
pub fn session(&self) -> Option<&Arc<dyn McpSession>> {
self.session.as_ref()
}
#[inline]
pub fn has_session(&self) -> bool {
self.session.is_some()
}
#[inline]
pub fn headers(&self) -> Option<&HashbrownMap<String, String>> {
self.headers.as_ref()
}
pub fn header(&self, name: &str) -> Option<&str> {
let headers = self.headers.as_ref()?;
headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_str())
}
#[cfg(feature = "std")]
pub fn elapsed(&self) -> Option<core::time::Duration> {
self.start_time.map(|t| t.elapsed())
}
#[cfg(feature = "std")]
pub fn is_cancelled(&self) -> bool {
self.cancellation_token
.as_ref()
.is_some_and(|c| c.is_cancelled())
}
pub fn roles(&self) -> Vec<String> {
if let Some(p) = &self.principal
&& !p.roles.is_empty()
{
return p.roles.to_vec();
}
self.metadata
.get("auth")
.and_then(|auth| auth.get("roles"))
.and_then(|r| r.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(ToString::to_string))
.collect()
})
.unwrap_or_default()
}
pub fn has_any_role<S: AsRef<str>>(&self, required: &[S]) -> bool {
if required.is_empty() {
return true;
}
let roles = self.roles();
required
.iter()
.any(|need| roles.iter().any(|have| have == need.as_ref()))
}
}
impl RequestContext {
pub async fn sample(&self, request: CreateMessageRequest) -> McpResult<CreateMessageResult> {
let session = self.require_session("sampling/createMessage")?;
let params = serde_json::to_value(request).map_err(|e| {
McpError::invalid_params(alloc::format!("Failed to serialize sampling request: {e}"))
})?;
let result = session.call("sampling/createMessage", params).await?;
serde_json::from_value(result)
.map_err(|e| McpError::internal(alloc::format!("Failed to parse sampling result: {e}")))
}
pub async fn elicit_form(
&self,
message: impl Into<String>,
schema: Value,
) -> McpResult<ElicitResult> {
let session = self.require_session("elicitation/create")?;
let params = serde_json::json!({
"mode": "form",
"message": message.into(),
"requestedSchema": schema,
});
let result = session.call("elicitation/create", params).await?;
serde_json::from_value(result).map_err(|e| {
McpError::internal(alloc::format!("Failed to parse elicitation result: {e}"))
})
}
pub async fn elicit_url(
&self,
message: impl Into<String>,
url: impl Into<String>,
elicitation_id: impl Into<String>,
) -> McpResult<ElicitResult> {
let session = self.require_session("elicitation/create")?;
let params = serde_json::json!({
"mode": "url",
"message": message.into(),
"url": url.into(),
"elicitationId": elicitation_id.into(),
});
let result = session.call("elicitation/create", params).await?;
serde_json::from_value(result).map_err(|e| {
McpError::internal(alloc::format!("Failed to parse elicitation result: {e}"))
})
}
pub async fn notify_client(&self, method: impl AsRef<str>, params: Value) -> McpResult<()> {
let session = self.require_session(method.as_ref())?;
session.notify(method.as_ref(), params).await
}
fn require_session(&self, op: &str) -> McpResult<&Arc<dyn McpSession>> {
self.session.as_ref().ok_or_else(|| {
McpError::capability_not_supported(alloc::format!(
"Bidirectional session required for {op} but transport does not support it"
))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transport_type_display() {
assert_eq!(TransportType::Stdio.to_string(), "stdio");
assert_eq!(TransportType::Http.to_string(), "http");
assert_eq!(TransportType::WebSocket.to_string(), "websocket");
assert_eq!(TransportType::Tcp.to_string(), "tcp");
assert_eq!(TransportType::Unix.to_string(), "unix");
assert_eq!(TransportType::Wasm.to_string(), "wasm");
assert_eq!(TransportType::Channel.to_string(), "channel");
assert_eq!(TransportType::Unknown.to_string(), "unknown");
}
#[test]
fn test_transport_type_classification() {
assert!(TransportType::Http.is_network());
assert!(TransportType::WebSocket.is_network());
assert!(TransportType::Tcp.is_network());
assert!(!TransportType::Stdio.is_network());
assert!(TransportType::Stdio.is_local());
assert!(TransportType::Unix.is_local());
assert!(TransportType::Channel.is_local());
assert!(!TransportType::Http.is_local());
}
#[test]
fn test_request_context_new() {
let ctx = RequestContext::with_id_and_transport("test-123", TransportType::Http);
assert_eq!(ctx.request_id(), "test-123");
assert_eq!(ctx.transport(), TransportType::Http);
assert!(ctx.metadata.is_empty());
assert!(!ctx.has_session());
}
#[test]
fn test_request_context_factory_methods() {
assert_eq!(RequestContext::stdio().transport(), TransportType::Stdio);
assert_eq!(RequestContext::http().transport(), TransportType::Http);
assert_eq!(
RequestContext::websocket().transport(),
TransportType::WebSocket
);
assert_eq!(RequestContext::tcp().transport(), TransportType::Tcp);
assert_eq!(RequestContext::unix().transport(), TransportType::Unix);
assert_eq!(RequestContext::wasm().transport(), TransportType::Wasm);
assert_eq!(
RequestContext::channel().transport(),
TransportType::Channel
);
}
#[test]
fn test_request_context_metadata() {
let ctx = RequestContext::with_id_and_transport("1", TransportType::Http)
.with_metadata("key1", "value1")
.with_metadata("count", 42);
assert_eq!(ctx.get_metadata_str("key1"), Some("value1"));
assert_eq!(ctx.get_metadata("count"), Some(&serde_json::json!(42)));
assert_eq!(ctx.get_metadata("key3"), None);
assert!(ctx.has_metadata("key1"));
assert!(!ctx.has_metadata("key3"));
}
#[test]
fn test_request_context_ids() {
let ctx = RequestContext::with_id_and_transport("r", TransportType::Http)
.with_user_id("u")
.with_session_id("s")
.with_client_id("c");
assert_eq!(ctx.user_id(), Some("u"));
assert_eq!(ctx.session_id(), Some("s"));
assert_eq!(ctx.client_id(), Some("c"));
assert!(ctx.is_authenticated());
}
#[test]
fn test_request_context_principal() {
let ctx = RequestContext::with_id_and_transport("1", TransportType::Http);
assert!(!ctx.is_authenticated());
assert!(ctx.principal().is_none());
assert!(ctx.subject().is_none());
let principal = Principal::new("user-123")
.with_email("user@example.com")
.with_role("admin");
let ctx = ctx.with_principal(principal);
assert!(ctx.is_authenticated());
assert_eq!(ctx.subject(), Some("user-123"));
assert!(ctx.principal().unwrap().has_role("admin"));
assert_eq!(ctx.roles(), alloc::vec![String::from("admin")]);
assert!(ctx.has_any_role(&["admin"]));
assert!(!ctx.has_any_role(&["root"]));
}
#[test]
fn test_request_context_default() {
let ctx = RequestContext::default();
assert!(ctx.request_id.is_empty());
assert_eq!(ctx.transport, TransportType::Stdio);
assert!(ctx.metadata.is_empty());
assert!(!ctx.has_session());
}
#[test]
fn test_request_context_headers() {
let mut headers: HashbrownMap<String, String> = HashbrownMap::new();
headers.insert("User-Agent".into(), "Test/1.0".into());
let ctx =
RequestContext::with_id_and_transport("1", TransportType::Http).with_headers(headers);
assert_eq!(ctx.header("user-agent"), Some("Test/1.0"));
assert_eq!(ctx.header("USER-AGENT"), Some("Test/1.0"));
assert_eq!(ctx.header("missing"), None);
}
#[cfg(feature = "std")]
#[tokio::test]
async fn test_sampling_without_session_fails() {
use turbomcp_types::CreateMessageRequest;
let ctx = RequestContext::stdio();
let err = ctx
.sample(CreateMessageRequest::default())
.await
.unwrap_err();
assert_eq!(err.kind, crate::error::ErrorKind::CapabilityNotSupported);
}
}