use crate::error::{McpError, McpResult};
use crate::oauth::OAuthManager;
use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
use crate::server::McpServer;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
#[derive(Debug, Clone)]
pub struct Session {
pub id: String,
pub created_at: std::time::SystemTime,
pub last_used: std::time::SystemTime,
pub metadata: HashMap<String, String>,
}
impl Session {
pub fn new(id: String) -> Self {
let now = std::time::SystemTime::now();
Self {
id,
created_at: now,
last_used: now,
metadata: HashMap::new(),
}
}
pub fn touch(&mut self) {
self.last_used = std::time::SystemTime::now();
}
}
pub struct SessionManager {
sessions: RwLock<HashMap<String, Arc<Mutex<Session>>>>,
ttl_seconds: u64,
}
impl SessionManager {
pub fn new(ttl_seconds: u64) -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
ttl_seconds,
}
}
pub async fn get_or_create(&self, session_id: Option<String>) -> Arc<Mutex<Session>> {
self.cleanup_expired().await;
let session_id = session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get(&session_id) {
let mut sess = session.lock().await;
sess.touch();
drop(sess);
return session.clone();
}
let session = Arc::new(Mutex::new(Session::new(session_id.clone())));
sessions.insert(session_id, session.clone());
session
}
async fn cleanup_expired(&self) {
let now = std::time::SystemTime::now();
let ttl = std::time::Duration::from_secs(self.ttl_seconds);
let sessions = self.sessions.read().await;
let mut to_remove = Vec::new();
for (id, session) in sessions.iter() {
let session = session.lock().await;
if let Ok(elapsed) = now.duration_since(session.last_used) {
if elapsed >= ttl {
to_remove.push(id.clone());
}
}
}
drop(sessions);
if !to_remove.is_empty() {
let mut sessions = self.sessions.write().await;
for id in to_remove {
sessions.remove(&id);
}
}
}
pub async fn count(&self) -> usize {
self.sessions.read().await.len()
}
}
#[derive(Debug, Clone)]
pub struct HttpServerConfig {
pub host: String,
pub port: u16,
pub session_ttl_seconds: u64,
pub enable_oauth: bool,
pub oauth_secret_key: Option<String>,
pub oauth_issuer: Option<String>,
pub oauth_audience: Option<String>,
}
impl Default for HttpServerConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 3000,
session_ttl_seconds: 3600, enable_oauth: false,
oauth_secret_key: None,
oauth_issuer: None,
oauth_audience: None,
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct McpHttpRequest {
#[serde(flatten)]
pub rpc: JsonRpcRequest,
}
#[derive(Debug, Serialize)]
pub struct McpHttpResponse {
#[serde(flatten)]
pub rpc: JsonRpcResponse,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
}
pub struct McpHttpServer {
config: HttpServerConfig,
session_manager: Arc<SessionManager>,
oauth_manager: Option<Arc<OAuthManager>>,
#[allow(dead_code)]
mcp_server: Arc<Mutex<McpServer>>,
}
impl McpHttpServer {
pub fn new(config: HttpServerConfig, mcp_server: Arc<Mutex<McpServer>>) -> McpResult<Self> {
let session_manager = Arc::new(SessionManager::new(config.session_ttl_seconds));
let oauth_manager = if config.enable_oauth {
let secret =
config
.oauth_secret_key
.as_ref()
.ok_or_else(|| McpError::ValidationError {
field: "oauth_secret_key".to_string(),
message: "OAuth secret key is required when OAuth is enabled".to_string(),
})?;
let issuer = config
.oauth_issuer
.as_ref()
.ok_or_else(|| McpError::ValidationError {
field: "oauth_issuer".to_string(),
message: "OAuth issuer is required when OAuth is enabled".to_string(),
})?;
let audience =
config
.oauth_audience
.as_ref()
.ok_or_else(|| McpError::ValidationError {
field: "oauth_audience".to_string(),
message: "OAuth audience is required when OAuth is enabled".to_string(),
})?;
Some(Arc::new(OAuthManager::new(
secret,
issuer.clone(),
audience.clone(),
)))
} else {
None
};
Ok(Self {
config,
session_manager,
oauth_manager,
mcp_server,
})
}
fn verify_oauth_token(&self, auth_header: Option<&str>) -> McpResult<()> {
if !self.config.enable_oauth {
return Ok(()); }
let oauth = self
.oauth_manager
.as_ref()
.ok_or_else(|| McpError::InternalError {
message: "OAuth manager not initialized".to_string(),
})?;
let auth_header = auth_header.ok_or_else(|| McpError::AuthenticationError {
message: "Missing Authorization header".to_string(),
})?;
let token =
auth_header
.strip_prefix("Bearer ")
.ok_or_else(|| McpError::AuthenticationError {
message: "Invalid Authorization header format (expected 'Bearer <token>')"
.to_string(),
})?;
oauth.validate_token(token)?;
Ok(())
}
pub fn oauth_manager(&self) -> Option<Arc<OAuthManager>> {
self.oauth_manager.clone()
}
pub async fn handle_request(
&self,
session_id: Option<String>,
auth_header: Option<&str>,
request: McpHttpRequest,
) -> McpResult<McpHttpResponse> {
self.verify_oauth_token(auth_header)?;
let session = self.session_manager.get_or_create(session_id).await;
let session_id = {
let sess = session.lock().await;
sess.id.clone()
};
let response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.rpc.id.clone(),
result: Some(serde_json::json!({
"message": "Not yet implemented"
})),
error: None,
};
Ok(McpHttpResponse {
rpc: response,
session_id: Some(session_id),
})
}
pub fn address(&self) -> String {
format!("{}:{}", self.config.host, self.config.port)
}
pub async fn session_count(&self) -> usize {
self.session_manager.count().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_session_creation() {
let manager = SessionManager::new(3600);
let session1 = manager.get_or_create(None).await;
let session2 = manager.get_or_create(None).await;
let id1 = session1.lock().await.id.clone();
let id2 = session2.lock().await.id.clone();
assert_ne!(id1, id2);
assert_eq!(manager.count().await, 2);
}
#[tokio::test]
async fn test_session_reuse() {
let manager = SessionManager::new(3600);
let session1 = manager
.get_or_create(Some("test-session".to_string()))
.await;
let session2 = manager
.get_or_create(Some("test-session".to_string()))
.await;
let id1 = session1.lock().await.id.clone();
let id2 = session2.lock().await.id.clone();
assert_eq!(id1, id2);
assert_eq!(manager.count().await, 1);
}
#[tokio::test]
async fn test_session_expiry() {
let manager = SessionManager::new(1);
let _session = manager
.get_or_create(Some("test-session".to_string()))
.await;
assert_eq!(manager.count().await, 1);
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
manager.cleanup_expired().await;
assert_eq!(manager.count().await, 0);
}
#[tokio::test]
async fn test_oauth_integration() {
use crate::oauth::ClientCredentials;
use std::sync::Arc;
use tokio::sync::Mutex;
let config = HttpServerConfig {
enable_oauth: true,
oauth_secret_key: Some("test-secret-key".to_string()),
oauth_issuer: Some("test-issuer".to_string()),
oauth_audience: Some("test-audience".to_string()),
..Default::default()
};
let mcp_server = Arc::new(Mutex::new(crate::server::McpServer::new().await.unwrap()));
let http_server = McpHttpServer::new(config, mcp_server).unwrap();
let oauth = http_server.oauth_manager().unwrap();
let credentials = ClientCredentials::new(
"test-client".to_string(),
"test-secret",
"Test Client".to_string(),
vec!["read".to_string(), "write".to_string()],
);
oauth.register_client(credentials).await.unwrap();
let token_response = oauth
.client_credentials_grant("test-client", "test-secret", vec!["read".to_string()])
.await
.unwrap();
let request = McpHttpRequest {
rpc: crate::protocol::JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: serde_json::json!(1),
method: "test".to_string(),
params: serde_json::json!({}),
},
};
let auth_header = format!("Bearer {}", token_response.access_token);
let result = http_server
.handle_request(None, Some(&auth_header), request.clone())
.await;
assert!(result.is_ok());
let result = http_server
.handle_request(None, None, request.clone())
.await;
assert!(result.is_err());
let result = http_server
.handle_request(None, Some("Bearer invalid-token"), request)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_oauth_disabled() {
use std::sync::Arc;
use tokio::sync::Mutex;
let config = HttpServerConfig::default(); let mcp_server = Arc::new(Mutex::new(crate::server::McpServer::new().await.unwrap()));
let http_server = McpHttpServer::new(config, mcp_server).unwrap();
let request = McpHttpRequest {
rpc: crate::protocol::JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: serde_json::json!(1),
method: "test".to_string(),
params: serde_json::json!({}),
},
};
let result = http_server.handle_request(None, None, request).await;
assert!(result.is_ok());
}
}