1#[cfg(feature = "server")]
6use axum::{
7 extract::{Json, State},
8 response::{IntoResponse, Sse},
9 routing::{get, post},
10 Router,
11};
12use futures::stream::{self, Stream};
13use serde_json::{json, Value};
14use std::convert::Infallible;
15use std::sync::Arc;
16
17use crate::error::ContextResult;
18use crate::protocol::{
19 CallToolRequest, InitializeResult, JsonRpcError, JsonRpcRequest,
20 JsonRpcResponse, MCP_VERSION, RequestId, ServerCapabilities, ServerInfo,
21 ToolsCapability,
22};
23use crate::rag::{RagConfig, RagProcessor};
24use crate::storage::{ContextStore, StorageConfig};
25use crate::tools::ToolRegistry;
26
27#[derive(Debug, Clone)]
29pub struct ServerConfig {
30 pub host: String,
32 pub port: u16,
34 pub storage: StorageConfig,
36 pub rag: RagConfig,
38}
39
40impl Default for ServerConfig {
41 fn default() -> Self {
42 Self {
43 host: "127.0.0.1".to_string(),
44 port: 3000,
45 storage: StorageConfig::default(),
46 rag: RagConfig::default(),
47 }
48 }
49}
50
51#[allow(dead_code)]
53pub struct ServerState {
54 store: Arc<ContextStore>,
55 rag: Arc<RagProcessor>,
56 tools: Arc<ToolRegistry>,
57}
58
59impl ServerState {
60 pub fn new(config: &ServerConfig) -> ContextResult<Self> {
62 let store = Arc::new(ContextStore::new(config.storage.clone())?);
63 let rag = Arc::new(RagProcessor::new(store.clone(), config.rag.clone()));
64 let tools = Arc::new(ToolRegistry::new(store.clone(), rag.clone()));
65
66 Ok(Self { store, rag, tools })
67 }
68}
69
70pub struct McpServer {
72 config: ServerConfig,
73 state: Arc<ServerState>,
74}
75
76impl McpServer {
77 pub fn new(config: ServerConfig) -> ContextResult<Self> {
79 let state = Arc::new(ServerState::new(&config)?);
80 Ok(Self { config, state })
81 }
82
83 pub fn with_defaults() -> ContextResult<Self> {
85 Self::new(ServerConfig::default())
86 }
87
88 pub fn router(&self) -> Router {
90 Router::new()
91 .route("/", get(health))
92 .route("/health", get(health))
93 .route("/mcp", post(handle_mcp_request))
94 .route("/sse", get(sse_handler))
95 .with_state(self.state.clone())
96 }
97
98 pub async fn run(&self) -> ContextResult<()> {
100 let addr = format!("{}:{}", self.config.host, self.config.port);
101 let listener = tokio::net::TcpListener::bind(&addr)
102 .await
103 .map_err(|e| crate::error::ContextError::Io(e))?;
104
105 tracing::info!("MCP Context Server listening on {}", addr);
106
107 axum::serve(listener, self.router())
108 .await
109 .map_err(|e| crate::error::ContextError::Internal(e.to_string()))?;
110
111 Ok(())
112 }
113
114 pub fn address(&self) -> String {
116 format!("{}:{}", self.config.host, self.config.port)
117 }
118}
119
120async fn health() -> impl IntoResponse {
122 Json(json!({
123 "status": "ok",
124 "server": "context-mcp",
125 "version": env!("CARGO_PKG_VERSION")
126 }))
127}
128
129async fn handle_mcp_request(
131 State(state): State<Arc<ServerState>>,
132 Json(request): Json<JsonRpcRequest>,
133) -> impl IntoResponse {
134 let response = process_request(&state, request).await;
135 Json(response)
136}
137
138async fn process_request(state: &ServerState, request: JsonRpcRequest) -> JsonRpcResponse {
140 match request.method.as_str() {
141 "initialize" => handle_initialize(request.id),
142 "initialized" => handle_initialized(request.id),
143 "tools/list" => handle_list_tools(request.id, state),
144 "tools/call" => handle_call_tool(request.id, state, request.params).await,
145 "ping" => handle_ping(request.id),
146 method => JsonRpcResponse::error(
147 request.id,
148 JsonRpcError::method_not_found(method),
149 ),
150 }
151}
152
153fn handle_initialize(id: RequestId) -> JsonRpcResponse {
155 let result = InitializeResult {
156 protocol_version: MCP_VERSION.to_string(),
157 capabilities: ServerCapabilities {
158 tools: Some(ToolsCapability { list_changed: true }),
159 resources: None,
160 prompts: None,
161 },
162 server_info: ServerInfo {
163 name: "context-mcp".to_string(),
164 version: env!("CARGO_PKG_VERSION").to_string(),
165 },
166 };
167
168 JsonRpcResponse::success(id, serde_json::to_value(result).unwrap())
169}
170
171fn handle_initialized(id: RequestId) -> JsonRpcResponse {
173 JsonRpcResponse::success(id, json!({}))
174}
175
176fn handle_list_tools(id: RequestId, state: &ServerState) -> JsonRpcResponse {
178 let tools = state.tools.list_tools();
179 JsonRpcResponse::success(id, json!({ "tools": tools }))
180}
181
182async fn handle_call_tool(
184 id: RequestId,
185 state: &ServerState,
186 params: Option<Value>,
187) -> JsonRpcResponse {
188 let params = match params {
189 Some(p) => p,
190 None => {
191 return JsonRpcResponse::error(id, JsonRpcError::invalid_params("Missing params"))
192 }
193 };
194
195 let call_request: CallToolRequest = match serde_json::from_value(params) {
196 Ok(r) => r,
197 Err(e) => {
198 return JsonRpcResponse::error(
199 id,
200 JsonRpcError::invalid_params(format!("Invalid params: {}", e)),
201 )
202 }
203 };
204
205 let result = state.tools.execute(&call_request.name, call_request.arguments).await;
206 JsonRpcResponse::success(id, serde_json::to_value(result).unwrap())
207}
208
209fn handle_ping(id: RequestId) -> JsonRpcResponse {
211 JsonRpcResponse::success(id, json!({}))
212}
213
214async fn sse_handler(
216 State(_state): State<Arc<ServerState>>,
217) -> Sse<impl Stream<Item = Result<axum::response::sse::Event, Infallible>>> {
218 let stream = stream::iter(vec![
219 Ok(axum::response::sse::Event::default()
220 .event("connected")
221 .data("MCP Context Server connected")),
222 ]);
223
224 Sse::new(stream)
225}
226
227pub struct StdioTransport {
229 state: Arc<ServerState>,
230}
231
232impl StdioTransport {
233 pub fn new(config: ServerConfig) -> ContextResult<Self> {
235 let state = Arc::new(ServerState::new(&config)?);
236 Ok(Self { state })
237 }
238
239 pub async fn run(&self) -> ContextResult<()> {
241 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
242
243 let stdin = tokio::io::stdin();
244 let mut stdout = tokio::io::stdout();
245 let mut reader = BufReader::new(stdin);
246
247 loop {
248 let mut line = String::new();
249 match reader.read_line(&mut line).await {
250 Ok(0) => break, Ok(_) => {
252 let line = line.trim();
253 if line.is_empty() {
254 continue;
255 }
256
257 match serde_json::from_str::<JsonRpcRequest>(line) {
258 Ok(request) => {
259 let response = process_request(&self.state, request).await;
260 let response_str = serde_json::to_string(&response).unwrap();
261 stdout.write_all(response_str.as_bytes()).await.ok();
262 stdout.write_all(b"\n").await.ok();
263 stdout.flush().await.ok();
264 }
265 Err(_e) => {
266 let error = JsonRpcResponse::error(
267 RequestId::Number(0),
268 JsonRpcError::parse_error(),
269 );
270 let error_str = serde_json::to_string(&error).unwrap();
271 stdout.write_all(error_str.as_bytes()).await.ok();
272 stdout.write_all(b"\n").await.ok();
273 stdout.flush().await.ok();
274 }
275 }
276 }
277 Err(_) => break,
278 }
279 }
280
281 Ok(())
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[tokio::test]
290 async fn test_health_endpoint() {
291 let _response = health().await;
292 }
294
295 #[test]
296 fn test_server_config_default() {
297 let config = ServerConfig::default();
298 assert_eq!(config.host, "127.0.0.1");
299 assert_eq!(config.port, 3000);
300 }
301}