#[cfg(feature = "server")]
use axum::{
extract::{Json, State},
response::{IntoResponse, Sse},
routing::{get, post},
Router,
};
use futures::stream::{self, Stream};
use serde_json::{json, Value};
use std::convert::Infallible;
use std::sync::Arc;
use crate::error::ContextResult;
use crate::protocol::{
CallToolRequest, InitializeResult, JsonRpcError, JsonRpcRequest, JsonRpcResponse, RequestId,
ServerCapabilities, ServerInfo, ToolsCapability, MCP_VERSION,
};
use crate::rag::{RagConfig, RagProcessor};
use crate::storage::{ContextStore, StorageConfig};
use crate::tools::ToolRegistry;
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub storage: StorageConfig,
pub rag: RagConfig,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 3000,
storage: StorageConfig::default(),
rag: RagConfig::default(),
}
}
}
#[allow(dead_code)]
pub struct ServerState {
store: Arc<ContextStore>,
rag: Arc<RagProcessor>,
tools: Arc<ToolRegistry>,
}
impl ServerState {
pub fn new(config: &ServerConfig) -> ContextResult<Self> {
let store = Arc::new(ContextStore::new(config.storage.clone())?);
let rag = Arc::new(RagProcessor::new(store.clone(), config.rag.clone()));
let tools = Arc::new(ToolRegistry::new(store.clone(), rag.clone()));
Ok(Self { store, rag, tools })
}
}
pub struct McpServer {
config: ServerConfig,
state: Arc<ServerState>,
}
impl McpServer {
pub fn new(config: ServerConfig) -> ContextResult<Self> {
let state = Arc::new(ServerState::new(&config)?);
Ok(Self { config, state })
}
pub fn with_defaults() -> ContextResult<Self> {
Self::new(ServerConfig::default())
}
pub fn router(&self) -> Router {
Router::new()
.route("/", get(health))
.route("/health", get(health))
.route("/mcp", post(handle_mcp_request))
.route("/sse", get(sse_handler))
.with_state(self.state.clone())
}
pub async fn run(&self) -> ContextResult<()> {
let addr = format!("{}:{}", self.config.host, self.config.port);
let listener = tokio::net::TcpListener::bind(&addr)
.await
.map_err(crate::error::ContextError::Io)?;
tracing::info!("MCP Context Server listening on {}", addr);
axum::serve(listener, self.router())
.await
.map_err(|e| crate::error::ContextError::Internal(e.to_string()))?;
Ok(())
}
pub fn address(&self) -> String {
format!("{}:{}", self.config.host, self.config.port)
}
}
async fn health() -> impl IntoResponse {
Json(json!({
"status": "ok",
"server": "context-mcp",
"version": env!("CARGO_PKG_VERSION")
}))
}
async fn handle_mcp_request(
State(state): State<Arc<ServerState>>,
Json(request): Json<JsonRpcRequest>,
) -> impl IntoResponse {
let response = process_request(&state, request).await;
Json(response)
}
async fn process_request(state: &ServerState, request: JsonRpcRequest) -> JsonRpcResponse {
match request.method.as_str() {
"initialize" => handle_initialize(request.id),
"initialized" => handle_initialized(request.id),
"tools/list" => handle_list_tools(request.id, state),
"tools/call" => handle_call_tool(request.id, state, request.params).await,
"ping" => handle_ping(request.id),
method => JsonRpcResponse::error(request.id, JsonRpcError::method_not_found(method)),
}
}
fn handle_initialize(id: RequestId) -> JsonRpcResponse {
let result = InitializeResult {
protocol_version: MCP_VERSION.to_string(),
capabilities: ServerCapabilities {
tools: Some(ToolsCapability { list_changed: true }),
resources: None,
prompts: None,
},
server_info: ServerInfo {
name: "context-mcp".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
},
};
JsonRpcResponse::success(id, serde_json::to_value(result).unwrap())
}
fn handle_initialized(id: RequestId) -> JsonRpcResponse {
JsonRpcResponse::success(id, json!({}))
}
fn handle_list_tools(id: RequestId, state: &ServerState) -> JsonRpcResponse {
let tools = state.tools.list_tools();
JsonRpcResponse::success(id, json!({ "tools": tools }))
}
async fn handle_call_tool(
id: RequestId,
state: &ServerState,
params: Option<Value>,
) -> JsonRpcResponse {
let params = match params {
Some(p) => p,
None => return JsonRpcResponse::error(id, JsonRpcError::invalid_params("Missing params")),
};
let call_request: CallToolRequest = match serde_json::from_value(params) {
Ok(r) => r,
Err(e) => {
return JsonRpcResponse::error(
id,
JsonRpcError::invalid_params(format!("Invalid params: {}", e)),
)
}
};
let result = state
.tools
.execute(&call_request.name, call_request.arguments)
.await;
JsonRpcResponse::success(id, serde_json::to_value(result).unwrap())
}
fn handle_ping(id: RequestId) -> JsonRpcResponse {
JsonRpcResponse::success(id, json!({}))
}
async fn sse_handler(
State(_state): State<Arc<ServerState>>,
) -> Sse<impl Stream<Item = Result<axum::response::sse::Event, Infallible>>> {
let stream = stream::iter(vec![Ok(axum::response::sse::Event::default()
.event("connected")
.data("MCP Context Server connected"))]);
Sse::new(stream)
}
pub struct StdioTransport {
state: Arc<ServerState>,
}
impl StdioTransport {
pub fn new(config: ServerConfig) -> ContextResult<Self> {
let state = Arc::new(ServerState::new(&config)?);
Ok(Self { state })
}
pub async fn run(&self) -> ContextResult<()> {
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
let stdin = tokio::io::stdin();
let mut stdout = tokio::io::stdout();
let mut reader = BufReader::new(stdin);
loop {
let mut line = String::new();
match reader.read_line(&mut line).await {
Ok(0) => break, Ok(_) => {
let line = line.trim();
if line.is_empty() {
continue;
}
match serde_json::from_str::<JsonRpcRequest>(line) {
Ok(request) => {
let response = process_request(&self.state, request).await;
let response_str = serde_json::to_string(&response).unwrap();
stdout.write_all(response_str.as_bytes()).await.ok();
stdout.write_all(b"\n").await.ok();
stdout.flush().await.ok();
}
Err(_e) => {
let error = JsonRpcResponse::error(
RequestId::Number(0),
JsonRpcError::parse_error(),
);
let error_str = serde_json::to_string(&error).unwrap();
stdout.write_all(error_str.as_bytes()).await.ok();
stdout.write_all(b"\n").await.ok();
stdout.flush().await.ok();
}
}
}
Err(_) => break,
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_health_endpoint() {
let _response = health().await;
}
#[test]
fn test_server_config_default() {
let config = ServerConfig::default();
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 3000);
}
}