context_mcp/
server.rs

1//! MCP server implementation using Axum
2//!
3//! Provides HTTP/SSE transport for the context management MCP server.
4
5#[cfg(feature = "server")]
6use axum::{
7    extract::{Json, State},
8    response::{IntoResponse, Sse},
9    routing::{get, post},
10    Router,
11};
12use futures::stream::{self, Stream};
13use serde_json::{json, Value};
14use std::convert::Infallible;
15use std::sync::Arc;
16
17use crate::error::ContextResult;
18use crate::protocol::{
19    CallToolRequest, InitializeResult, JsonRpcError, JsonRpcRequest, JsonRpcResponse, RequestId,
20    ServerCapabilities, ServerInfo, ToolsCapability, MCP_VERSION,
21};
22use crate::rag::{RagConfig, RagProcessor};
23use crate::storage::{ContextStore, StorageConfig};
24use crate::tools::ToolRegistry;
25
26/// Server configuration
27#[derive(Debug, Clone)]
28pub struct ServerConfig {
29    /// Server host
30    pub host: String,
31    /// Server port
32    pub port: u16,
33    /// Storage configuration
34    pub storage: StorageConfig,
35    /// RAG configuration
36    pub rag: RagConfig,
37}
38
39impl Default for ServerConfig {
40    fn default() -> Self {
41        Self {
42            host: "127.0.0.1".to_string(),
43            port: 3000,
44            storage: StorageConfig::default(),
45            rag: RagConfig::default(),
46        }
47    }
48}
49
50/// Shared server state
51#[allow(dead_code)]
52pub struct ServerState {
53    store: Arc<ContextStore>,
54    rag: Arc<RagProcessor>,
55    tools: Arc<ToolRegistry>,
56}
57
58impl ServerState {
59    /// Create new server state
60    pub fn new(config: &ServerConfig) -> ContextResult<Self> {
61        let store = Arc::new(ContextStore::new(config.storage.clone())?);
62        let rag = Arc::new(RagProcessor::new(store.clone(), config.rag.clone()));
63        let tools = Arc::new(ToolRegistry::new(store.clone(), rag.clone()));
64
65        Ok(Self { store, rag, tools })
66    }
67}
68
69/// MCP Server
70pub struct McpServer {
71    config: ServerConfig,
72    state: Arc<ServerState>,
73}
74
75impl McpServer {
76    /// Create a new MCP server
77    pub fn new(config: ServerConfig) -> ContextResult<Self> {
78        let state = Arc::new(ServerState::new(&config)?);
79        Ok(Self { config, state })
80    }
81
82    /// Create with default configuration
83    pub fn with_defaults() -> ContextResult<Self> {
84        Self::new(ServerConfig::default())
85    }
86
87    /// Build the router
88    pub fn router(&self) -> Router {
89        Router::new()
90            .route("/", get(health))
91            .route("/health", get(health))
92            .route("/mcp", post(handle_mcp_request))
93            .route("/sse", get(sse_handler))
94            .with_state(self.state.clone())
95    }
96
97    /// Run the server
98    pub async fn run(&self) -> ContextResult<()> {
99        let addr = format!("{}:{}", self.config.host, self.config.port);
100        let listener = tokio::net::TcpListener::bind(&addr)
101            .await
102            .map_err(crate::error::ContextError::Io)?;
103
104        tracing::info!("MCP Context Server listening on {}", addr);
105
106        axum::serve(listener, self.router())
107            .await
108            .map_err(|e| crate::error::ContextError::Internal(e.to_string()))?;
109
110        Ok(())
111    }
112
113    /// Get server address
114    pub fn address(&self) -> String {
115        format!("{}:{}", self.config.host, self.config.port)
116    }
117}
118
119/// Health check endpoint
120async fn health() -> impl IntoResponse {
121    Json(json!({
122        "status": "ok",
123        "server": "context-mcp",
124        "version": env!("CARGO_PKG_VERSION")
125    }))
126}
127
128/// Handle MCP JSON-RPC request
129async fn handle_mcp_request(
130    State(state): State<Arc<ServerState>>,
131    Json(request): Json<JsonRpcRequest>,
132) -> impl IntoResponse {
133    let response = process_request(&state, request).await;
134    Json(response)
135}
136
137/// Process a single MCP request
138async fn process_request(state: &ServerState, request: JsonRpcRequest) -> JsonRpcResponse {
139    match request.method.as_str() {
140        "initialize" => handle_initialize(request.id),
141        "initialized" => handle_initialized(request.id),
142        "tools/list" => handle_list_tools(request.id, state),
143        "tools/call" => handle_call_tool(request.id, state, request.params).await,
144        "ping" => handle_ping(request.id),
145        method => JsonRpcResponse::error(request.id, JsonRpcError::method_not_found(method)),
146    }
147}
148
149/// Handle initialize request
150fn handle_initialize(id: RequestId) -> JsonRpcResponse {
151    let result = InitializeResult {
152        protocol_version: MCP_VERSION.to_string(),
153        capabilities: ServerCapabilities {
154            tools: Some(ToolsCapability { list_changed: true }),
155            resources: None,
156            prompts: None,
157        },
158        server_info: ServerInfo {
159            name: "context-mcp".to_string(),
160            version: env!("CARGO_PKG_VERSION").to_string(),
161        },
162    };
163
164    JsonRpcResponse::success(id, serde_json::to_value(result).unwrap())
165}
166
167/// Handle initialized notification
168fn handle_initialized(id: RequestId) -> JsonRpcResponse {
169    JsonRpcResponse::success(id, json!({}))
170}
171
172/// Handle tools/list request
173fn handle_list_tools(id: RequestId, state: &ServerState) -> JsonRpcResponse {
174    let tools = state.tools.list_tools();
175    JsonRpcResponse::success(id, json!({ "tools": tools }))
176}
177
178/// Handle tools/call request
179async fn handle_call_tool(
180    id: RequestId,
181    state: &ServerState,
182    params: Option<Value>,
183) -> JsonRpcResponse {
184    let params = match params {
185        Some(p) => p,
186        None => return JsonRpcResponse::error(id, JsonRpcError::invalid_params("Missing params")),
187    };
188
189    let call_request: CallToolRequest = match serde_json::from_value(params) {
190        Ok(r) => r,
191        Err(e) => {
192            return JsonRpcResponse::error(
193                id,
194                JsonRpcError::invalid_params(format!("Invalid params: {}", e)),
195            )
196        }
197    };
198
199    let result = state
200        .tools
201        .execute(&call_request.name, call_request.arguments)
202        .await;
203    JsonRpcResponse::success(id, serde_json::to_value(result).unwrap())
204}
205
206/// Handle ping request
207fn handle_ping(id: RequestId) -> JsonRpcResponse {
208    JsonRpcResponse::success(id, json!({}))
209}
210
211/// SSE handler for streaming updates
212async fn sse_handler(
213    State(_state): State<Arc<ServerState>>,
214) -> Sse<impl Stream<Item = Result<axum::response::sse::Event, Infallible>>> {
215    let stream = stream::iter(vec![Ok(axum::response::sse::Event::default()
216        .event("connected")
217        .data("MCP Context Server connected"))]);
218
219    Sse::new(stream)
220}
221
222/// Stdio transport for MCP
223pub struct StdioTransport {
224    state: Arc<ServerState>,
225}
226
227impl StdioTransport {
228    /// Create a new stdio transport
229    pub fn new(config: ServerConfig) -> ContextResult<Self> {
230        let state = Arc::new(ServerState::new(&config)?);
231        Ok(Self { state })
232    }
233
234    /// Run the stdio transport
235    pub async fn run(&self) -> ContextResult<()> {
236        use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
237
238        let stdin = tokio::io::stdin();
239        let mut stdout = tokio::io::stdout();
240        let mut reader = BufReader::new(stdin);
241
242        loop {
243            let mut line = String::new();
244            match reader.read_line(&mut line).await {
245                Ok(0) => break, // EOF
246                Ok(_) => {
247                    let line = line.trim();
248                    if line.is_empty() {
249                        continue;
250                    }
251
252                    match serde_json::from_str::<JsonRpcRequest>(line) {
253                        Ok(request) => {
254                            let response = process_request(&self.state, request).await;
255                            let response_str = serde_json::to_string(&response).unwrap();
256                            stdout.write_all(response_str.as_bytes()).await.ok();
257                            stdout.write_all(b"\n").await.ok();
258                            stdout.flush().await.ok();
259                        }
260                        Err(_e) => {
261                            let error = JsonRpcResponse::error(
262                                RequestId::Number(0),
263                                JsonRpcError::parse_error(),
264                            );
265                            let error_str = serde_json::to_string(&error).unwrap();
266                            stdout.write_all(error_str.as_bytes()).await.ok();
267                            stdout.write_all(b"\n").await.ok();
268                            stdout.flush().await.ok();
269                        }
270                    }
271                }
272                Err(_) => break,
273            }
274        }
275
276        Ok(())
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[tokio::test]
285    async fn test_health_endpoint() {
286        let _response = health().await;
287        // Basic test that it responds
288    }
289
290    #[test]
291    fn test_server_config_default() {
292        let config = ServerConfig::default();
293        assert_eq!(config.host, "127.0.0.1");
294        assert_eq!(config.port, 3000);
295    }
296}