context_mcp/
server.rs

1//! MCP server implementation using Axum
2//!
3//! Provides HTTP/SSE transport for the context management MCP server.
4
5#[cfg(feature = "server")]
6use axum::{
7    extract::{Json, State},
8    response::{IntoResponse, Sse},
9    routing::{get, post},
10    Router,
11};
12use futures::stream::{self, Stream};
13use serde_json::{json, Value};
14use std::convert::Infallible;
15use std::sync::Arc;
16
17use crate::error::ContextResult;
18use crate::protocol::{
19    CallToolRequest, InitializeResult, JsonRpcError, JsonRpcRequest,
20    JsonRpcResponse, MCP_VERSION, RequestId, ServerCapabilities, ServerInfo,
21    ToolsCapability,
22};
23use crate::rag::{RagConfig, RagProcessor};
24use crate::storage::{ContextStore, StorageConfig};
25use crate::tools::ToolRegistry;
26
27/// Server configuration
28#[derive(Debug, Clone)]
29pub struct ServerConfig {
30    /// Server host
31    pub host: String,
32    /// Server port
33    pub port: u16,
34    /// Storage configuration
35    pub storage: StorageConfig,
36    /// RAG configuration
37    pub rag: RagConfig,
38}
39
40impl Default for ServerConfig {
41    fn default() -> Self {
42        Self {
43            host: "127.0.0.1".to_string(),
44            port: 3000,
45            storage: StorageConfig::default(),
46            rag: RagConfig::default(),
47        }
48    }
49}
50
51/// Shared server state
52#[allow(dead_code)]
53pub struct ServerState {
54    store: Arc<ContextStore>,
55    rag: Arc<RagProcessor>,
56    tools: Arc<ToolRegistry>,
57}
58
59impl ServerState {
60    /// Create new server state
61    pub fn new(config: &ServerConfig) -> ContextResult<Self> {
62        let store = Arc::new(ContextStore::new(config.storage.clone())?);
63        let rag = Arc::new(RagProcessor::new(store.clone(), config.rag.clone()));
64        let tools = Arc::new(ToolRegistry::new(store.clone(), rag.clone()));
65
66        Ok(Self { store, rag, tools })
67    }
68}
69
70/// MCP Server
71pub struct McpServer {
72    config: ServerConfig,
73    state: Arc<ServerState>,
74}
75
76impl McpServer {
77    /// Create a new MCP server
78    pub fn new(config: ServerConfig) -> ContextResult<Self> {
79        let state = Arc::new(ServerState::new(&config)?);
80        Ok(Self { config, state })
81    }
82
83    /// Create with default configuration
84    pub fn with_defaults() -> ContextResult<Self> {
85        Self::new(ServerConfig::default())
86    }
87
88    /// Build the router
89    pub fn router(&self) -> Router {
90        Router::new()
91            .route("/", get(health))
92            .route("/health", get(health))
93            .route("/mcp", post(handle_mcp_request))
94            .route("/sse", get(sse_handler))
95            .with_state(self.state.clone())
96    }
97
98    /// Run the server
99    pub async fn run(&self) -> ContextResult<()> {
100        let addr = format!("{}:{}", self.config.host, self.config.port);
101        let listener = tokio::net::TcpListener::bind(&addr)
102            .await
103            .map_err(|e| crate::error::ContextError::Io(e))?;
104
105        tracing::info!("MCP Context Server listening on {}", addr);
106
107        axum::serve(listener, self.router())
108            .await
109            .map_err(|e| crate::error::ContextError::Internal(e.to_string()))?;
110
111        Ok(())
112    }
113
114    /// Get server address
115    pub fn address(&self) -> String {
116        format!("{}:{}", self.config.host, self.config.port)
117    }
118}
119
120/// Health check endpoint
121async fn health() -> impl IntoResponse {
122    Json(json!({
123        "status": "ok",
124        "server": "context-mcp",
125        "version": env!("CARGO_PKG_VERSION")
126    }))
127}
128
129/// Handle MCP JSON-RPC request
130async fn handle_mcp_request(
131    State(state): State<Arc<ServerState>>,
132    Json(request): Json<JsonRpcRequest>,
133) -> impl IntoResponse {
134    let response = process_request(&state, request).await;
135    Json(response)
136}
137
138/// Process a single MCP request
139async fn process_request(state: &ServerState, request: JsonRpcRequest) -> JsonRpcResponse {
140    match request.method.as_str() {
141        "initialize" => handle_initialize(request.id),
142        "initialized" => handle_initialized(request.id),
143        "tools/list" => handle_list_tools(request.id, state),
144        "tools/call" => handle_call_tool(request.id, state, request.params).await,
145        "ping" => handle_ping(request.id),
146        method => JsonRpcResponse::error(
147            request.id,
148            JsonRpcError::method_not_found(method),
149        ),
150    }
151}
152
153/// Handle initialize request
154fn handle_initialize(id: RequestId) -> JsonRpcResponse {
155    let result = InitializeResult {
156        protocol_version: MCP_VERSION.to_string(),
157        capabilities: ServerCapabilities {
158            tools: Some(ToolsCapability { list_changed: true }),
159            resources: None,
160            prompts: None,
161        },
162        server_info: ServerInfo {
163            name: "context-mcp".to_string(),
164            version: env!("CARGO_PKG_VERSION").to_string(),
165        },
166    };
167
168    JsonRpcResponse::success(id, serde_json::to_value(result).unwrap())
169}
170
171/// Handle initialized notification
172fn handle_initialized(id: RequestId) -> JsonRpcResponse {
173    JsonRpcResponse::success(id, json!({}))
174}
175
176/// Handle tools/list request
177fn handle_list_tools(id: RequestId, state: &ServerState) -> JsonRpcResponse {
178    let tools = state.tools.list_tools();
179    JsonRpcResponse::success(id, json!({ "tools": tools }))
180}
181
182/// Handle tools/call request
183async fn handle_call_tool(
184    id: RequestId,
185    state: &ServerState,
186    params: Option<Value>,
187) -> JsonRpcResponse {
188    let params = match params {
189        Some(p) => p,
190        None => {
191            return JsonRpcResponse::error(id, JsonRpcError::invalid_params("Missing params"))
192        }
193    };
194
195    let call_request: CallToolRequest = match serde_json::from_value(params) {
196        Ok(r) => r,
197        Err(e) => {
198            return JsonRpcResponse::error(
199                id,
200                JsonRpcError::invalid_params(format!("Invalid params: {}", e)),
201            )
202        }
203    };
204
205    let result = state.tools.execute(&call_request.name, call_request.arguments).await;
206    JsonRpcResponse::success(id, serde_json::to_value(result).unwrap())
207}
208
209/// Handle ping request
210fn handle_ping(id: RequestId) -> JsonRpcResponse {
211    JsonRpcResponse::success(id, json!({}))
212}
213
214/// SSE handler for streaming updates
215async fn sse_handler(
216    State(_state): State<Arc<ServerState>>,
217) -> Sse<impl Stream<Item = Result<axum::response::sse::Event, Infallible>>> {
218    let stream = stream::iter(vec![
219        Ok(axum::response::sse::Event::default()
220            .event("connected")
221            .data("MCP Context Server connected")),
222    ]);
223
224    Sse::new(stream)
225}
226
227/// Stdio transport for MCP
228pub struct StdioTransport {
229    state: Arc<ServerState>,
230}
231
232impl StdioTransport {
233    /// Create a new stdio transport
234    pub fn new(config: ServerConfig) -> ContextResult<Self> {
235        let state = Arc::new(ServerState::new(&config)?);
236        Ok(Self { state })
237    }
238
239    /// Run the stdio transport
240    pub async fn run(&self) -> ContextResult<()> {
241        use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
242
243        let stdin = tokio::io::stdin();
244        let mut stdout = tokio::io::stdout();
245        let mut reader = BufReader::new(stdin);
246
247        loop {
248            let mut line = String::new();
249            match reader.read_line(&mut line).await {
250                Ok(0) => break, // EOF
251                Ok(_) => {
252                    let line = line.trim();
253                    if line.is_empty() {
254                        continue;
255                    }
256
257                    match serde_json::from_str::<JsonRpcRequest>(line) {
258                        Ok(request) => {
259                            let response = process_request(&self.state, request).await;
260                            let response_str = serde_json::to_string(&response).unwrap();
261                            stdout.write_all(response_str.as_bytes()).await.ok();
262                            stdout.write_all(b"\n").await.ok();
263                            stdout.flush().await.ok();
264                        }
265                        Err(_e) => {
266                            let error = JsonRpcResponse::error(
267                                RequestId::Number(0),
268                                JsonRpcError::parse_error(),
269                            );
270                            let error_str = serde_json::to_string(&error).unwrap();
271                            stdout.write_all(error_str.as_bytes()).await.ok();
272                            stdout.write_all(b"\n").await.ok();
273                            stdout.flush().await.ok();
274                        }
275                    }
276                }
277                Err(_) => break,
278            }
279        }
280
281        Ok(())
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[tokio::test]
290    async fn test_health_endpoint() {
291        let _response = health().await;
292        // Basic test that it responds
293    }
294
295    #[test]
296    fn test_server_config_default() {
297        let config = ServerConfig::default();
298        assert_eq!(config.host, "127.0.0.1");
299        assert_eq!(config.port, 3000);
300    }
301}