use actix_web::{App, HttpMessage, HttpServer, dev::ServiceRequest, middleware};
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
use rmcp::{
ErrorData as McpError, RoleServer, ServerHandler, handler::server::router::tool::ToolRouter,
model::*, service::RequestContext, tool, tool_handler, tool_router,
};
use rmcp_actix_web::transport::StreamableHttpService;
use serde_json::json;
use std::sync::Arc;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[derive(Clone, Debug)]
pub struct UserClaims {
pub user_id: String,
pub role: String,
}
fn extract_claims_from_headers(req: &ServiceRequest) {
let user_id = req
.headers()
.get("X-User-Id")
.and_then(|v| v.to_str().ok())
.map(String::from);
let role = req
.headers()
.get("X-User-Role")
.and_then(|v| v.to_str().ok())
.map(String::from)
.unwrap_or_else(|| "guest".to_string());
if let Some(user_id) = user_id {
let claims = UserClaims { user_id, role };
tracing::info!(?claims, "Extracted user claims from headers");
req.extensions_mut().insert(claims);
}
}
use actix_web::dev::{Service, ServiceResponse, Transform};
use std::future::{Future, Ready, ready};
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct ClaimsExtractor;
impl<S, B> Transform<S, ServiceRequest> for ClaimsExtractor
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type InitError = ();
type Transform = ClaimsExtractorMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(ClaimsExtractorMiddleware { service }))
}
}
pub struct ClaimsExtractorMiddleware<S> {
service: S,
}
impl<S, B> Service<ServiceRequest> for ClaimsExtractorMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&self, req: ServiceRequest) -> Self::Future {
extract_claims_from_headers(&req);
Box::pin(self.service.call(req))
}
}
#[derive(Clone)]
struct UserAwareService {
tool_router: ToolRouter<UserAwareService>,
}
#[tool_router]
impl UserAwareService {
fn new() -> Self {
Self {
tool_router: Self::tool_router(),
}
}
#[tool(description = "Get current user info from propagated claims")]
async fn whoami(
&self,
context: RequestContext<RoleServer>,
) -> Result<CallToolResult, McpError> {
if let Some(claims) = context.extensions.get::<UserClaims>() {
let info = json!({
"user_id": claims.user_id,
"role": claims.role,
"message": "Claims successfully propagated from middleware via on_request hook!"
});
Ok(CallToolResult::success(vec![Content::text(
info.to_string(),
)]))
} else {
let info = json!({
"message": "No user claims found. Try adding X-User-Id header.",
"hint": "curl -H 'X-User-Id: alice' -H 'X-User-Role: admin' ..."
});
Ok(CallToolResult::success(vec![Content::text(
info.to_string(),
)]))
}
}
#[tool(description = "Admin-only operation - requires admin role")]
async fn admin_action(
&self,
context: RequestContext<RoleServer>,
) -> Result<CallToolResult, McpError> {
if let Some(claims) = context.extensions.get::<UserClaims>() {
if claims.role == "admin" {
Ok(CallToolResult::success(vec![Content::text(format!(
"Admin action executed by user: {}",
claims.user_id
))]))
} else {
Err(McpError::invalid_request(
format!(
"Access denied. User '{}' has role '{}', but 'admin' is required.",
claims.user_id, claims.role
),
None,
))
}
} else {
Err(McpError::invalid_request(
"Authentication required. No user claims found.",
None,
))
}
}
}
#[tool_handler]
impl ServerHandler for UserAwareService {
fn get_info(&self) -> ServerInfo {
ServerInfo {
instructions: Some("User-aware service demonstrating on_request hook".into()),
capabilities: ServerCapabilities::builder().enable_tools().build(),
..Default::default()
}
}
}
const BIND_ADDRESS: &str = "127.0.0.1:8080";
#[actix_web::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "info".to_string().into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
println!("\n🚀 on_request Hook Example running at http://{BIND_ADDRESS}");
println!(
"\n📖 This example demonstrates propagating typed data from middleware to MCP handlers."
);
println!("\n🔧 Available tools:");
println!(" - whoami: Returns current user info from propagated claims");
println!(" - admin_action: Admin-only operation (requires role=admin)");
println!("\n💡 Add headers to simulate authentication:");
println!(" X-User-Id: <user-id>");
println!(" X-User-Role: <role>");
println!("\nPress Ctrl+C to stop the server\n");
let http_service = StreamableHttpService::builder()
.service_factory(Arc::new(|| Ok(UserAwareService::new())))
.session_manager(Arc::new(LocalSessionManager::default()))
.stateful_mode(true)
.on_request_fn(|http_req, mcp_ext| {
if let Some(claims) = http_req.extensions().get::<UserClaims>() {
tracing::debug!(?claims, "Propagating claims to MCP context");
mcp_ext.insert(claims.clone());
}
})
.build();
HttpServer::new(move || {
App::new()
.wrap(middleware::Logger::default())
.wrap(ClaimsExtractor)
.service(http_service.clone().scope())
})
.bind(BIND_ADDRESS)?
.run()
.await?;
Ok(())
}