1#[cfg(feature = "server")]
6use axum::{
7 extract::{Json, State},
8 response::{IntoResponse, Sse},
9 routing::{get, post},
10 Router,
11};
12use futures::stream::{self, Stream};
13use serde_json::{json, Value};
14use std::convert::Infallible;
15use std::sync::Arc;
16
17use crate::error::ContextResult;
18use crate::protocol::{
19 CallToolRequest, InitializeResult, JsonRpcError, JsonRpcRequest, JsonRpcResponse, RequestId,
20 ServerCapabilities, ServerInfo, ToolsCapability, MCP_VERSION,
21};
22use crate::rag::{RagConfig, RagProcessor};
23use crate::storage::{ContextStore, StorageConfig};
24use crate::tools::ToolRegistry;
25
26#[derive(Debug, Clone)]
28pub struct ServerConfig {
29 pub host: String,
31 pub port: u16,
33 pub storage: StorageConfig,
35 pub rag: RagConfig,
37}
38
39impl Default for ServerConfig {
40 fn default() -> Self {
41 Self {
42 host: "127.0.0.1".to_string(),
43 port: 3000,
44 storage: StorageConfig::default(),
45 rag: RagConfig::default(),
46 }
47 }
48}
49
50#[allow(dead_code)]
52pub struct ServerState {
53 store: Arc<ContextStore>,
54 rag: Arc<RagProcessor>,
55 tools: Arc<ToolRegistry>,
56}
57
58impl ServerState {
59 pub fn new(config: &ServerConfig) -> ContextResult<Self> {
61 let store = Arc::new(ContextStore::new(config.storage.clone())?);
62 let rag = Arc::new(RagProcessor::new(store.clone(), config.rag.clone()));
63 let tools = Arc::new(ToolRegistry::new(store.clone(), rag.clone()));
64
65 Ok(Self { store, rag, tools })
66 }
67}
68
69pub struct McpServer {
71 config: ServerConfig,
72 state: Arc<ServerState>,
73}
74
75impl McpServer {
76 pub fn new(config: ServerConfig) -> ContextResult<Self> {
78 let state = Arc::new(ServerState::new(&config)?);
79 Ok(Self { config, state })
80 }
81
82 pub fn with_defaults() -> ContextResult<Self> {
84 Self::new(ServerConfig::default())
85 }
86
87 pub fn router(&self) -> Router {
89 Router::new()
90 .route("/", get(health))
91 .route("/health", get(health))
92 .route("/mcp", post(handle_mcp_request))
93 .route("/sse", get(sse_handler))
94 .with_state(self.state.clone())
95 }
96
97 pub async fn run(&self) -> ContextResult<()> {
99 let addr = format!("{}:{}", self.config.host, self.config.port);
100 let listener = tokio::net::TcpListener::bind(&addr)
101 .await
102 .map_err(crate::error::ContextError::Io)?;
103
104 tracing::info!("MCP Context Server listening on {}", addr);
105
106 axum::serve(listener, self.router())
107 .await
108 .map_err(|e| crate::error::ContextError::Internal(e.to_string()))?;
109
110 Ok(())
111 }
112
113 pub fn address(&self) -> String {
115 format!("{}:{}", self.config.host, self.config.port)
116 }
117}
118
119async fn health() -> impl IntoResponse {
121 Json(json!({
122 "status": "ok",
123 "server": "context-mcp",
124 "version": env!("CARGO_PKG_VERSION")
125 }))
126}
127
128async fn handle_mcp_request(
130 State(state): State<Arc<ServerState>>,
131 Json(request): Json<JsonRpcRequest>,
132) -> impl IntoResponse {
133 let response = process_request(&state, request).await;
134 Json(response)
135}
136
137async fn process_request(state: &ServerState, request: JsonRpcRequest) -> JsonRpcResponse {
139 match request.method.as_str() {
140 "initialize" => handle_initialize(request.id),
141 "initialized" => handle_initialized(request.id),
142 "tools/list" => handle_list_tools(request.id, state),
143 "tools/call" => handle_call_tool(request.id, state, request.params).await,
144 "ping" => handle_ping(request.id),
145 method => JsonRpcResponse::error(request.id, JsonRpcError::method_not_found(method)),
146 }
147}
148
149fn handle_initialize(id: RequestId) -> JsonRpcResponse {
151 let result = InitializeResult {
152 protocol_version: MCP_VERSION.to_string(),
153 capabilities: ServerCapabilities {
154 tools: Some(ToolsCapability { list_changed: true }),
155 resources: None,
156 prompts: None,
157 },
158 server_info: ServerInfo {
159 name: "context-mcp".to_string(),
160 version: env!("CARGO_PKG_VERSION").to_string(),
161 },
162 };
163
164 JsonRpcResponse::success(id, serde_json::to_value(result).unwrap())
165}
166
167fn handle_initialized(id: RequestId) -> JsonRpcResponse {
169 JsonRpcResponse::success(id, json!({}))
170}
171
172fn handle_list_tools(id: RequestId, state: &ServerState) -> JsonRpcResponse {
174 let tools = state.tools.list_tools();
175 JsonRpcResponse::success(id, json!({ "tools": tools }))
176}
177
178async fn handle_call_tool(
180 id: RequestId,
181 state: &ServerState,
182 params: Option<Value>,
183) -> JsonRpcResponse {
184 let params = match params {
185 Some(p) => p,
186 None => return JsonRpcResponse::error(id, JsonRpcError::invalid_params("Missing params")),
187 };
188
189 let call_request: CallToolRequest = match serde_json::from_value(params) {
190 Ok(r) => r,
191 Err(e) => {
192 return JsonRpcResponse::error(
193 id,
194 JsonRpcError::invalid_params(format!("Invalid params: {}", e)),
195 )
196 }
197 };
198
199 let result = state
200 .tools
201 .execute(&call_request.name, call_request.arguments)
202 .await;
203 JsonRpcResponse::success(id, serde_json::to_value(result).unwrap())
204}
205
206fn handle_ping(id: RequestId) -> JsonRpcResponse {
208 JsonRpcResponse::success(id, json!({}))
209}
210
211async fn sse_handler(
213 State(_state): State<Arc<ServerState>>,
214) -> Sse<impl Stream<Item = Result<axum::response::sse::Event, Infallible>>> {
215 let stream = stream::iter(vec![Ok(axum::response::sse::Event::default()
216 .event("connected")
217 .data("MCP Context Server connected"))]);
218
219 Sse::new(stream)
220}
221
222pub struct StdioTransport {
224 state: Arc<ServerState>,
225}
226
227impl StdioTransport {
228 pub fn new(config: ServerConfig) -> ContextResult<Self> {
230 let state = Arc::new(ServerState::new(&config)?);
231 Ok(Self { state })
232 }
233
234 pub async fn run(&self) -> ContextResult<()> {
236 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
237
238 let stdin = tokio::io::stdin();
239 let mut stdout = tokio::io::stdout();
240 let mut reader = BufReader::new(stdin);
241
242 loop {
243 let mut line = String::new();
244 match reader.read_line(&mut line).await {
245 Ok(0) => break, Ok(_) => {
247 let line = line.trim();
248 if line.is_empty() {
249 continue;
250 }
251
252 match serde_json::from_str::<JsonRpcRequest>(line) {
253 Ok(request) => {
254 let response = process_request(&self.state, request).await;
255 let response_str = serde_json::to_string(&response).unwrap();
256 stdout.write_all(response_str.as_bytes()).await.ok();
257 stdout.write_all(b"\n").await.ok();
258 stdout.flush().await.ok();
259 }
260 Err(_e) => {
261 let error = JsonRpcResponse::error(
262 RequestId::Number(0),
263 JsonRpcError::parse_error(),
264 );
265 let error_str = serde_json::to_string(&error).unwrap();
266 stdout.write_all(error_str.as_bytes()).await.ok();
267 stdout.write_all(b"\n").await.ok();
268 stdout.flush().await.ok();
269 }
270 }
271 }
272 Err(_) => break,
273 }
274 }
275
276 Ok(())
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[tokio::test]
285 async fn test_health_endpoint() {
286 let _response = health().await;
287 }
289
290 #[test]
291 fn test_server_config_default() {
292 let config = ServerConfig::default();
293 assert_eq!(config.host, "127.0.0.1");
294 assert_eq!(config.port, 3000);
295 }
296}