oak_lsp/
server.rs

1use crate::{
2    service::LanguageService,
3    types::{InitializeParams, Range},
4};
5use oak_vfs::{LineMap, Vfs, WritableVfs};
6use serde_json::{Value, json};
7use std::sync::Arc;
8use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
9
10pub struct LspServer<S: LanguageService> {
11    service: Arc<S>,
12}
13
14impl<S: LanguageService + 'static> LspServer<S> {
15    pub fn new(service: Arc<S>) -> Self {
16        Self { service }
17    }
18
19    pub async fn run<R, W>(&self, reader: R, mut writer: W) -> anyhow::Result<()>
20    where
21        R: AsyncRead + Unpin,
22        W: AsyncWrite + Unpin,
23    {
24        let mut reader = BufReader::new(reader);
25
26        loop {
27            let mut content_length = 0;
28
29            // Read headers
30            loop {
31                let mut line = String::new();
32                if reader.read_line(&mut line).await? == 0 {
33                    return Ok(());
34                }
35                if line.trim().is_empty() {
36                    break;
37                }
38                if line.to_lowercase().starts_with("content-length:") {
39                    content_length = line["content-length:".len()..].trim().parse::<usize>()?;
40                }
41            }
42
43            if content_length == 0 {
44                continue;
45            }
46
47            // Read body
48            let mut body = vec![0u8; content_length];
49            reader.read_exact(&mut body).await?;
50            let body_str = String::from_utf8(body)?;
51            let request: Value = serde_json::from_str(&body_str)?;
52
53            // Handle request
54            if let Some(id) = request.get("id") {
55                // Method call
56                let method = request.get("method").and_then(|m| m.as_str()).unwrap_or("");
57                let params = request.get("params").cloned().unwrap_or(json!({}));
58
59                let response = self.handle_request(id.clone(), method, params).await;
60                self.send_payload(&mut writer, response).await?;
61            }
62            else {
63                // Notification
64                let method = request.get("method").and_then(|m| m.as_str()).unwrap_or("");
65                let params = request.get("params").cloned().unwrap_or(json!({}));
66                self.handle_notification(method, params, &mut writer).await?;
67            }
68        }
69    }
70
71    async fn handle_request(&self, id: Value, method: &str, params: Value) -> Value {
72        match method {
73            "initialize" => {
74                let params: InitializeParams = serde_json::from_value(params).unwrap_or_default();
75                self.service.initialize(params).await;
76                json!({
77                    "jsonrpc": "2.0",
78                    "id": id,
79                    "result": {
80                        "capabilities": {
81                            "textDocumentSync": 1,
82                            "hoverProvider": true,
83                            "completionProvider": {
84                                "resolveProvider": false,
85                                "triggerCharacters": [".", "<", "@", ":"]
86                            },
87                            "definitionProvider": true,
88                            "referencesProvider": true,
89                            "documentSymbolProvider": true,
90                            "workspaceSymbolProvider": true,
91                            "renameProvider": true
92                        }
93                    }
94                })
95            }
96            "shutdown" => {
97                self.service.shutdown().await;
98                json!({
99                    "jsonrpc": "2.0",
100                    "id": id,
101                    "result": null
102                })
103            }
104            "textDocument/hover" => {
105                if let (Some(uri), Some(pos)) = (params.get("textDocument").and_then(|d| d.get("uri")).and_then(|u| u.as_str()), params.get("position")) {
106                    if let Some(source) = self.service.vfs().get_source(uri) {
107                        let line = pos.get("line").and_then(|l| l.as_u64()).unwrap_or(0) as u32;
108                        let character = pos.get("character").and_then(|c| c.as_u64()).unwrap_or(0) as u32;
109                        let line_map = LineMap::from_source(&source);
110                        let offset = line_map.line_col_utf16_to_offset(&source, line, character);
111
112                        if let Some(hover) = self.service.hover(uri, Range { start: offset, end: offset }).await {
113                            return json!({
114                                "jsonrpc": "2.0",
115                                "id": id,
116                                "result": {
117                                    "contents": { "kind": "markdown", "value": hover.contents }
118                                }
119                            });
120                        }
121                    }
122                }
123                json!({ "jsonrpc": "2.0", "id": id, "result": null })
124            }
125            "textDocument/completion" => {
126                if let (Some(uri), Some(pos)) = (params.get("textDocument").and_then(|d| d.get("uri")).and_then(|u| u.as_str()), params.get("position")) {
127                    if let Some(source) = self.service.vfs().get_source(uri) {
128                        let line = pos.get("line").and_then(|l| l.as_u64()).unwrap_or(0) as u32;
129                        let character = pos.get("character").and_then(|c| c.as_u64()).unwrap_or(0) as u32;
130                        let line_map = LineMap::from_source(&source);
131                        let offset = line_map.line_col_utf16_to_offset(&source, line, character);
132
133                        let items = self.service.completion(uri, offset).await;
134                        return json!({
135                            "jsonrpc": "2.0",
136                            "id": id,
137                            "result": items
138                        });
139                    }
140                }
141                json!({ "jsonrpc": "2.0", "id": id, "result": [] })
142            }
143            "textDocument/definition" => {
144                if let (Some(uri), Some(pos)) = (params.get("textDocument").and_then(|d| d.get("uri")).and_then(|u| u.as_str()), params.get("position")) {
145                    if let Some(source) = self.service.vfs().get_source(uri) {
146                        let line = pos.get("line").and_then(|l| l.as_u64()).unwrap_or(0) as u32;
147                        let character = pos.get("character").and_then(|c| c.as_u64()).unwrap_or(0) as u32;
148                        let line_map = LineMap::from_source(&source);
149                        let offset = line_map.line_col_utf16_to_offset(&source, line, character);
150
151                        let locations = self.service.definition(uri, Range { start: offset, end: offset }).await;
152                        let mut result = Vec::new();
153                        for loc in locations {
154                            if let Some(target_source) = self.service.vfs().get_source(&loc.uri) {
155                                let target_line_map = LineMap::from_source(&target_source);
156                                let (start_line, start_character) = target_line_map.offset_to_line_col_utf16(&target_source, loc.range.start);
157                                let (end_line, end_character) = target_line_map.offset_to_line_col_utf16(&target_source, loc.range.end);
158                                result.push(json!({
159                                    "uri": loc.uri,
160                                    "range": {
161                                        "start": { "line": start_line, "character": start_character },
162                                        "end": { "line": end_line, "character": end_character }
163                                    }
164                                }));
165                            }
166                        }
167                        return json!({ "jsonrpc": "2.0", "id": id, "result": result });
168                    }
169                }
170                json!({ "jsonrpc": "2.0", "id": id, "result": null })
171            }
172            _ => {
173                json!({
174                    "jsonrpc": "2.0",
175                    "id": id,
176                    "error": {
177                        "code": -32601,
178                        "message": format!("Method not found: {}", method)
179                    }
180                })
181            }
182        }
183    }
184
185    async fn handle_notification<W: AsyncWrite + Unpin>(&self, method: &str, params: Value, writer: &mut W) -> anyhow::Result<()> {
186        match method {
187            "initialized" => {
188                self.service.initialized().await;
189            }
190            "exit" => {
191                std::process::exit(0);
192            }
193            "textDocument/didOpen" => {
194                if let Some(doc) = params.get("textDocument") {
195                    if let (Some(uri), Some(text)) = (doc.get("uri").and_then(|u| u.as_str()), doc.get("text").and_then(|t| t.as_str())) {
196                        self.service.vfs().write_file(uri, text.to_string());
197                        self.publish_diagnostics(uri, writer).await?;
198                    }
199                }
200            }
201            "textDocument/didChange" => {
202                if let (Some(uri), Some(changes)) = (params.get("textDocument").and_then(|d| d.get("uri")).and_then(|u| u.as_str()), params.get("contentChanges").and_then(|c| c.as_array())) {
203                    if let Some(change) = changes.first() {
204                        if let Some(text) = change.get("text").and_then(|t| t.as_str()) {
205                            self.service.vfs().write_file(uri, text.to_string());
206                            self.publish_diagnostics(uri, writer).await?;
207                        }
208                    }
209                }
210            }
211            "textDocument/didSave" => {
212                if let Some(uri) = params.get("textDocument").and_then(|d| d.get("uri")).and_then(|u| u.as_str()) {
213                    self.service.did_save(uri).await;
214                }
215            }
216            "textDocument/didClose" => {
217                if let Some(uri) = params.get("textDocument").and_then(|d| d.get("uri")).and_then(|u| u.as_str()) {
218                    self.service.did_close(uri).await;
219                }
220            }
221            _ => {}
222        }
223        Ok(())
224    }
225
226    async fn publish_diagnostics<W: AsyncWrite + Unpin>(&self, uri: &str, writer: &mut W) -> anyhow::Result<()> {
227        use oak_vfs::LineMap;
228        let diags = self.service.diagnostics(uri).await;
229        if let Some(source) = self.service.vfs().get_source(uri) {
230            let line_map = LineMap::from_source(&source);
231            let mut result = Vec::new();
232            for diag in diags {
233                let (start_line, start_character) = line_map.offset_to_line_col_utf16(&source, diag.range.start);
234                let (end_line, end_character) = line_map.offset_to_line_col_utf16(&source, diag.range.end);
235                result.push(json!({
236                    "range": {
237                        "start": { "line": start_line, "character": start_character },
238                        "end": { "line": end_line, "character": end_character }
239                    },
240                    "severity": diag.severity.map(|s| s as u32).unwrap_or(1),
241                    "message": diag.message,
242                    "source": diag.source
243                }));
244            }
245            let payload = json!({
246                "jsonrpc": "2.0",
247                "method": "textDocument/publishDiagnostics",
248                "params": {
249                    "uri": uri,
250                    "diagnostics": result
251                }
252            });
253            self.send_payload(writer, payload).await?;
254        }
255        Ok(())
256    }
257
258    async fn send_payload<W: AsyncWrite + Unpin>(&self, writer: &mut W, payload: Value) -> anyhow::Result<()> {
259        let body = serde_json::to_string(&payload)?;
260        let header = format!("Content-Length: {}\r\n\r\n", body.len());
261        writer.write_all(header.as_bytes()).await?;
262        writer.write_all(body.as_bytes()).await?;
263        writer.flush().await?;
264        Ok(())
265    }
266}
267
268/// Creates an Axum router for the language server.
269#[cfg(feature = "axum")]
270pub fn axum_router<S: LanguageService + 'static>(service: Arc<S>) -> axum::Router {
271    use axum::{Json, Router, extract::State, routing::post};
272
273    Router::new()
274        .route(
275            "/lsp",
276            post(|State(service): State<Arc<S>>, Json(request): Json<Value>| async move {
277                let server = LspServer::new(service);
278                if let Some(id) = request.get("id") {
279                    let method = request.get("method").and_then(|m| m.as_str()).unwrap_or("");
280                    let params = request.get("params").cloned().unwrap_or(json!({}));
281                    let response = server.handle_request(id.clone(), method, params).await;
282                    Json(response)
283                }
284                else {
285                    // Notifications in HTTP are tricky, just return null
286                    Json(json!(null))
287                }
288            }),
289        )
290        .with_state(service)
291}