Skip to main content

aether_lspd/
protocol.rs

1use crate::language_catalog::LanguageId;
2use lsp_types::Uri;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::io;
6use std::path::PathBuf;
7
8#[doc = include_str!("docs/protocol.md")]
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub enum DaemonRequest {
11    Initialize(InitializeRequest),
12    LspCall {
13        client_id: i64,
14        method: String,
15        params: Value,
16    },
17    GetDiagnostics {
18        client_id: i64,
19        /// If None, return all cached diagnostics for the workspace
20        uri: Option<Uri>,
21    },
22    QueueDiagnosticRefresh {
23        client_id: i64,
24        uri: Uri,
25    },
26    Disconnect,
27    Ping,
28}
29
30/// Initialize request to set up LSP for a workspace
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct InitializeRequest {
33    pub workspace_root: PathBuf,
34    pub language: LanguageId,
35}
36
37/// LSP notification from client to server
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct LspNotification {
40    pub method: String,
41    pub params: Value,
42}
43
44/// Top-level daemon response
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub enum DaemonResponse {
47    Initialized,
48    Pong,
49    LspResult { client_id: i64, result: Result<Value, LspErrorResponse> },
50    Error(ProtocolError),
51}
52
53/// LSP error response
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct LspErrorResponse {
56    pub code: i32,
57    pub message: String,
58}
59
60/// Protocol-level error (not LSP error)
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct ProtocolError {
63    pub message: String,
64    /// Optional `client_id` for correlating errors back to LSP requests
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub client_id: Option<i64>,
67}
68
69impl ProtocolError {
70    pub fn new(message: impl Into<String>) -> Self {
71        Self { message: message.into(), client_id: None }
72    }
73
74    pub fn with_client_id(message: impl Into<String>, client_id: i64) -> Self {
75        Self { message: message.into(), client_id: Some(client_id) }
76    }
77}
78
79/// Extract the document URI from an LSP request's params by method name.
80///
81/// Used by the daemon for auto-open: if the request targets a specific file,
82/// the daemon ensures the file is opened before forwarding the request.
83pub fn extract_document_uri(method: &str, params: &Value) -> Option<Uri> {
84    if !method.starts_with("textDocument/") {
85        return None;
86    }
87    params.pointer("/textDocument/uri").and_then(|v| v.as_str()).and_then(|s| s.parse().ok())
88}
89
90/// Maximum message size (16 MB)
91pub const MAX_MESSAGE_SIZE: u32 = 16 * 1024 * 1024;
92
93/// Read a length-prefixed frame from an async reader
94pub(crate) async fn read_frame<R, T>(reader: &mut R) -> io::Result<Option<T>>
95where
96    R: tokio::io::AsyncReadExt + Unpin,
97    T: for<'de> Deserialize<'de>,
98{
99    let mut len_buf = [0u8; 4];
100    match reader.read_exact(&mut len_buf).await {
101        Ok(_) => {}
102        Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
103        Err(e) => return Err(e),
104    }
105
106    let len = u32::from_be_bytes(len_buf);
107
108    if len > MAX_MESSAGE_SIZE {
109        return Err(io::Error::new(io::ErrorKind::InvalidData, format!("Message too large: {len} bytes")));
110    }
111
112    let mut buf = vec![0u8; len as usize];
113    reader.read_exact(&mut buf).await?;
114
115    serde_json::from_slice(&buf).map(Some).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
116}
117
118/// Write a length-prefixed frame to an async writer
119pub(crate) async fn write_frame<W, T>(writer: &mut W, message: &T) -> io::Result<()>
120where
121    W: tokio::io::AsyncWriteExt + Unpin,
122    T: Serialize,
123{
124    let json = serde_json::to_vec(message)?;
125
126    if json.len() > MAX_MESSAGE_SIZE as usize {
127        return Err(io::Error::new(io::ErrorKind::InvalidData, format!("Message too large: {} bytes", json.len())));
128    }
129
130    let len = u32::try_from(json.len()).unwrap_or(u32::MAX);
131    writer.write_all(&len.to_be_bytes()).await?;
132    writer.write_all(&json).await?;
133    writer.flush().await
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn test_protocol_error_new() {
142        let err = ProtocolError::new("test error");
143        assert_eq!(err.message, "test error");
144    }
145
146    #[test]
147    fn test_daemon_request_lsp_call_roundtrip() {
148        let req = DaemonRequest::LspCall {
149            client_id: 42,
150            method: "textDocument/definition".to_string(),
151            params: serde_json::json!({
152                "textDocument": { "uri": "file:///test.rs" },
153                "position": { "line": 0, "character": 0 }
154            }),
155        };
156        let json = serde_json::to_string(&req).unwrap();
157        let decoded: DaemonRequest = serde_json::from_str(&json).unwrap();
158        match decoded {
159            DaemonRequest::LspCall { client_id, method, .. } => {
160                assert_eq!(client_id, 42);
161                assert_eq!(method, "textDocument/definition");
162            }
163            _ => panic!("Wrong variant"),
164        }
165    }
166
167    #[test]
168    fn test_extract_document_uri_definition() {
169        let params = serde_json::json!({
170            "textDocument": { "uri": "file:///src/main.rs" },
171            "position": { "line": 10, "character": 5 }
172        });
173        let uri = extract_document_uri("textDocument/definition", &params);
174        assert!(uri.is_some());
175        assert_eq!(uri.unwrap().as_str(), "file:///src/main.rs");
176    }
177
178    #[test]
179    fn test_extract_document_uri_references() {
180        let params = serde_json::json!({
181            "textDocument": { "uri": "file:///src/lib.rs" },
182            "position": { "line": 5, "character": 3 },
183            "context": { "includeDeclaration": true }
184        });
185        let uri = extract_document_uri("textDocument/references", &params);
186        assert!(uri.is_some());
187        assert_eq!(uri.unwrap().as_str(), "file:///src/lib.rs");
188    }
189
190    #[test]
191    fn test_extract_document_uri_document_symbol() {
192        let params = serde_json::json!({
193            "textDocument": { "uri": "file:///src/foo.rs" }
194        });
195        let uri = extract_document_uri("textDocument/documentSymbol", &params);
196        assert!(uri.is_some());
197        assert_eq!(uri.unwrap().as_str(), "file:///src/foo.rs");
198    }
199
200    #[test]
201    fn test_extract_document_uri_workspace_symbol_returns_none() {
202        let params = serde_json::json!({ "query": "Foo" });
203        let uri = extract_document_uri("workspace/symbol", &params);
204        assert!(uri.is_none());
205    }
206
207    #[test]
208    fn test_extract_document_uri_unknown_method_returns_none() {
209        let params = serde_json::json!({});
210        let uri = extract_document_uri("textDocument/unknown", &params);
211        assert!(uri.is_none());
212    }
213}