1use crate::language_catalog::LanguageId;
2use lsp_types::Uri;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::io;
6use std::path::PathBuf;
7
8#[doc = include_str!("docs/protocol.md")]
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub enum DaemonRequest {
11 Initialize(InitializeRequest),
12 LspCall {
13 client_id: i64,
14 method: String,
15 params: Value,
16 },
17 GetDiagnostics {
18 client_id: i64,
19 uri: Option<Uri>,
21 },
22 QueueDiagnosticRefresh {
23 client_id: i64,
24 uri: Uri,
25 },
26 Disconnect,
27 Ping,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct InitializeRequest {
33 pub workspace_root: PathBuf,
34 pub language: LanguageId,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct LspNotification {
40 pub method: String,
41 pub params: Value,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub enum DaemonResponse {
47 Initialized,
48 Pong,
49 LspResult { client_id: i64, result: Result<Value, LspErrorResponse> },
50 Error(ProtocolError),
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct LspErrorResponse {
56 pub code: i32,
57 pub message: String,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct ProtocolError {
63 pub message: String,
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub client_id: Option<i64>,
67}
68
69impl ProtocolError {
70 pub fn new(message: impl Into<String>) -> Self {
71 Self { message: message.into(), client_id: None }
72 }
73
74 pub fn with_client_id(message: impl Into<String>, client_id: i64) -> Self {
75 Self { message: message.into(), client_id: Some(client_id) }
76 }
77}
78
79pub fn extract_document_uri(method: &str, params: &Value) -> Option<Uri> {
84 if !method.starts_with("textDocument/") {
85 return None;
86 }
87 params.pointer("/textDocument/uri").and_then(|v| v.as_str()).and_then(|s| s.parse().ok())
88}
89
90pub const MAX_MESSAGE_SIZE: u32 = 16 * 1024 * 1024;
92
93pub(crate) async fn read_frame<R, T>(reader: &mut R) -> io::Result<Option<T>>
95where
96 R: tokio::io::AsyncReadExt + Unpin,
97 T: for<'de> Deserialize<'de>,
98{
99 let mut len_buf = [0u8; 4];
100 match reader.read_exact(&mut len_buf).await {
101 Ok(_) => {}
102 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
103 Err(e) => return Err(e),
104 }
105
106 let len = u32::from_be_bytes(len_buf);
107
108 if len > MAX_MESSAGE_SIZE {
109 return Err(io::Error::new(io::ErrorKind::InvalidData, format!("Message too large: {len} bytes")));
110 }
111
112 let mut buf = vec![0u8; len as usize];
113 reader.read_exact(&mut buf).await?;
114
115 serde_json::from_slice(&buf).map(Some).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
116}
117
118pub(crate) async fn write_frame<W, T>(writer: &mut W, message: &T) -> io::Result<()>
120where
121 W: tokio::io::AsyncWriteExt + Unpin,
122 T: Serialize,
123{
124 let json = serde_json::to_vec(message)?;
125
126 if json.len() > MAX_MESSAGE_SIZE as usize {
127 return Err(io::Error::new(io::ErrorKind::InvalidData, format!("Message too large: {} bytes", json.len())));
128 }
129
130 let len = u32::try_from(json.len()).unwrap_or(u32::MAX);
131 writer.write_all(&len.to_be_bytes()).await?;
132 writer.write_all(&json).await?;
133 writer.flush().await
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_protocol_error_new() {
142 let err = ProtocolError::new("test error");
143 assert_eq!(err.message, "test error");
144 }
145
146 #[test]
147 fn test_daemon_request_lsp_call_roundtrip() {
148 let req = DaemonRequest::LspCall {
149 client_id: 42,
150 method: "textDocument/definition".to_string(),
151 params: serde_json::json!({
152 "textDocument": { "uri": "file:///test.rs" },
153 "position": { "line": 0, "character": 0 }
154 }),
155 };
156 let json = serde_json::to_string(&req).unwrap();
157 let decoded: DaemonRequest = serde_json::from_str(&json).unwrap();
158 match decoded {
159 DaemonRequest::LspCall { client_id, method, .. } => {
160 assert_eq!(client_id, 42);
161 assert_eq!(method, "textDocument/definition");
162 }
163 _ => panic!("Wrong variant"),
164 }
165 }
166
167 #[test]
168 fn test_extract_document_uri_definition() {
169 let params = serde_json::json!({
170 "textDocument": { "uri": "file:///src/main.rs" },
171 "position": { "line": 10, "character": 5 }
172 });
173 let uri = extract_document_uri("textDocument/definition", ¶ms);
174 assert!(uri.is_some());
175 assert_eq!(uri.unwrap().as_str(), "file:///src/main.rs");
176 }
177
178 #[test]
179 fn test_extract_document_uri_references() {
180 let params = serde_json::json!({
181 "textDocument": { "uri": "file:///src/lib.rs" },
182 "position": { "line": 5, "character": 3 },
183 "context": { "includeDeclaration": true }
184 });
185 let uri = extract_document_uri("textDocument/references", ¶ms);
186 assert!(uri.is_some());
187 assert_eq!(uri.unwrap().as_str(), "file:///src/lib.rs");
188 }
189
190 #[test]
191 fn test_extract_document_uri_document_symbol() {
192 let params = serde_json::json!({
193 "textDocument": { "uri": "file:///src/foo.rs" }
194 });
195 let uri = extract_document_uri("textDocument/documentSymbol", ¶ms);
196 assert!(uri.is_some());
197 assert_eq!(uri.unwrap().as_str(), "file:///src/foo.rs");
198 }
199
200 #[test]
201 fn test_extract_document_uri_workspace_symbol_returns_none() {
202 let params = serde_json::json!({ "query": "Foo" });
203 let uri = extract_document_uri("workspace/symbol", ¶ms);
204 assert!(uri.is_none());
205 }
206
207 #[test]
208 fn test_extract_document_uri_unknown_method_returns_none() {
209 let params = serde_json::json!({});
210 let uri = extract_document_uri("textDocument/unknown", ¶ms);
211 assert!(uri.is_none());
212 }
213}