oak_mcp/
lib.rs

1#![feature(new_range_api)]
2use core::range::Range;
3use oak_lsp::service::LanguageService;
4pub use oak_semantic_search::{NoSemanticSearch, SemanticSearch};
5use serde::{Deserialize, Serialize};
6use serde_json::{Value, json};
7use std::sync::Arc;
8use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
9use tracing::info;
10
11pub mod handlers;
12
13#[derive(Debug, Serialize, Deserialize)]
14pub struct JsonRpcRequest {
15    pub jsonrpc: String,
16    pub id: Value,
17    pub method: String,
18    pub params: Option<Value>,
19}
20
21#[derive(Debug, Serialize, Deserialize)]
22pub struct JsonRpcResponse {
23    pub jsonrpc: String,
24    pub id: Value,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub result: Option<Value>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub error: Option<JsonRpcError>,
29}
30
31#[derive(Debug, Serialize, Deserialize)]
32pub struct JsonRpcError {
33    pub code: i32,
34    pub message: String,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub data: Option<Value>,
37}
38
39#[derive(Debug, Serialize, Deserialize)]
40pub struct JsonRpcNotification {
41    pub jsonrpc: String,
42    pub method: String,
43    pub params: Option<Value>,
44}
45
46/// A generic MCP server that wraps an Oak language service.
47pub struct McpServer<S: LanguageService, E: SemanticSearch = NoSemanticSearch> {
48    service: Arc<S>,
49    searcher: Option<Arc<E>>,
50}
51
52impl<S: LanguageService + 'static> McpServer<S, NoSemanticSearch>
53where
54    S::Vfs: oak_vfs::WritableVfs,
55{
56    pub fn new(service: S) -> Self {
57        Self { service: Arc::new(service), searcher: None }
58    }
59}
60
61impl<S: LanguageService + 'static, E: SemanticSearch + 'static> McpServer<S, E>
62where
63    S::Vfs: oak_vfs::WritableVfs,
64{
65    pub fn with_searcher<NewE: SemanticSearch>(self, searcher: NewE) -> McpServer<S, NewE> {
66        McpServer { service: self.service, searcher: Some(Arc::new(searcher)) }
67    }
68
69    pub async fn run(&self) -> tokio::io::Result<()> {
70        let stdin = tokio::io::stdin();
71        let mut stdout = tokio::io::stdout();
72        let mut reader = BufReader::new(stdin);
73        let mut line = String::new();
74
75        while reader.read_line(&mut line).await? > 0 {
76            let input = line.trim();
77            if input.is_empty() {
78                line.clear();
79                continue;
80            }
81
82            if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(input) {
83                let response = self.handle_request(request).await;
84                let response_json = serde_json::to_string(&response).unwrap();
85                stdout.write_all(response_json.as_bytes()).await?;
86                stdout.write_all(b"\n").await?;
87                stdout.flush().await?;
88            }
89            else if let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(input) {
90                self.handle_notification(notification).await;
91            }
92
93            line.clear();
94        }
95
96        Ok(())
97    }
98
99    pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
100        match request.method.as_str() {
101            "initialize" => JsonRpcResponse {
102                jsonrpc: "2.0".to_string(),
103                id: request.id,
104                result: Some(json!({
105                    "protocolVersion": "2024-11-05",
106                    "capabilities": {
107                        "tools": {
108                            "listChanged": false
109                        }
110                    },
111                    "serverInfo": {
112                        "name": "oak-mcp",
113                        "version": "0.0.1"
114                    }
115                })),
116                error: None,
117            },
118            "tools/list" => {
119                let tools: Value = serde_json::from_str(include_str!("tools.json")).unwrap();
120                JsonRpcResponse {
121                    jsonrpc: "2.0".to_string(),
122                    id: request.id,
123                    result: Some(json!({
124                        "tools": tools
125                    })),
126                    error: None,
127                }
128            }
129            "tools/call" => {
130                let params = request.params.unwrap_or(Value::Null);
131                let name = params.get("name").and_then(|v| v.as_str()).unwrap_or_default();
132                let args = params.get("arguments").cloned().unwrap_or(Value::Null);
133
134                match self.handle_tool_call(name, args).await {
135                    Ok(result) => JsonRpcResponse {
136                        jsonrpc: "2.0".to_string(),
137                        id: request.id,
138                        result: Some(json!({
139                            "content": [
140                                {
141                                    "type": "text",
142                                    "text": serde_json::to_string_pretty(&result).unwrap()
143                                }
144                            ]
145                        })),
146                        error: None,
147                    },
148                    Err(e) => JsonRpcResponse { jsonrpc: "2.0".to_string(), id: request.id, result: None, error: Some(JsonRpcError { code: -32000, message: e, data: None }) },
149                }
150            }
151            _ => JsonRpcResponse { jsonrpc: "2.0".to_string(), id: request.id, result: None, error: Some(JsonRpcError { code: -32601, message: format!("Method not found: {}", request.method), data: None }) },
152        }
153    }
154
155    async fn handle_tool_call(&self, name: &str, args: Value) -> Result<Value, String> {
156        match name {
157            "hover" => {
158                let uri = args.get("uri").and_then(|v| v.as_str()).ok_or("Missing uri")?;
159                let offset = args.get("offset").and_then(|v| v.as_u64()).ok_or("Missing offset")? as usize;
160
161                let hover = self.service.hover(uri, Range { start: offset, end: offset }).await;
162                Ok(json!(hover))
163            }
164            "symbols" => {
165                let uri = args.get("uri").and_then(|v| v.as_str()).ok_or("Missing uri")?;
166                let symbols = self.service.document_symbols(uri).await;
167                Ok(json!(symbols))
168            }
169            "definition" => {
170                let uri = args.get("uri").and_then(|v| v.as_str()).ok_or("Missing uri")?;
171                let offset = args.get("offset").and_then(|v| v.as_u64()).ok_or("Missing offset")? as usize;
172
173                let locs = self.service.definition(uri, Range { start: offset, end: offset }).await;
174                Ok(json!(locs))
175            }
176            "references" => {
177                let uri = args.get("uri").and_then(|v| v.as_str()).ok_or("Missing uri")?;
178                let offset = args.get("offset").and_then(|v| v.as_u64()).ok_or("Missing offset")? as usize;
179                let locs = self.service.references(uri, Range { start: offset, end: offset }).await;
180                Ok(json!(locs))
181            }
182            "diagnostics" => {
183                let uri = args.get("uri").and_then(|v| v.as_str()).ok_or("Missing uri")?;
184                let diagnostics = self.service.diagnostics(uri).await;
185                Ok(json!(diagnostics))
186            }
187            "completion" => {
188                let uri = args.get("uri").and_then(|v| v.as_str()).ok_or("Missing uri")?;
189                let offset = args.get("offset").and_then(|v| v.as_u64()).ok_or("Missing offset")? as usize;
190                let items = self.service.completion(uri, offset).await;
191                Ok(json!(items))
192            }
193            "search" => {
194                let query = args.get("query").and_then(|v| v.as_str()).ok_or("Missing query")?;
195                let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(5) as usize;
196
197                if let Some(searcher) = &self.searcher {
198                    let results = searcher.search(query, limit).await.map_err(|e| e.to_string())?;
199                    Ok(json!(results))
200                }
201                else {
202                    Err("Semantic search is not enabled on this server".to_string())
203                }
204            }
205            "set_file_content" => {
206                let uri = args.get("uri").and_then(|v| v.as_str()).ok_or("Missing uri")?;
207                let content = args.get("content").and_then(|v| v.as_str()).ok_or("Missing content")?;
208
209                use oak_vfs::WritableVfs;
210                let vfs = self.service.vfs();
211                vfs.write_file(uri, content.to_string());
212                Ok(json!({"status": "ok"}))
213            }
214            "semantic_search" => {
215                let query = args.get("query").and_then(|v| v.as_str()).ok_or("Missing query")?;
216                let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(5) as usize;
217
218                if let Some(searcher) = &self.searcher {
219                    let results = searcher.search(query, limit).await.map_err(|e| e.to_string())?;
220                    Ok(json!(results))
221                }
222                else {
223                    Err("Semantic search is not enabled on this server".to_string())
224                }
225            }
226            _ => Err(format!("Unknown tool: {}", name)),
227        }
228    }
229
230    pub async fn handle_notification(&self, notification: JsonRpcNotification) {
231        info!("Received notification: {}", notification.method);
232    }
233}
234
235/// Extension trait for language services to provide MCP integration.
236pub trait OakMcpService: LanguageService + Sized + 'static
237where
238    Self::Vfs: oak_vfs::WritableVfs,
239{
240    /// Convert this service into an Oak MCP server.
241    fn into_mcp_server(self) -> McpServer<Self> {
242        McpServer::new(self)
243    }
244
245    /// Create an Axum router for this MCP service.
246    #[cfg(feature = "axum")]
247    fn into_mcp_axum_router(self) -> axum::Router {
248        crate::handlers::axum_handlers::create_router(self)
249    }
250
251    /// Register this MCP service with an Actix-web config.
252    #[cfg(feature = "actix-web")]
253    fn register_mcp_actix(self, cfg: &mut actix_web::web::ServiceConfig) {
254        crate::handlers::actix_handlers::config(cfg, self)
255    }
256}
257
258impl<S: LanguageService + 'static> OakMcpService for S where S::Vfs: oak_vfs::WritableVfs {}