1use std::path::PathBuf;
2
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6pub(crate) const MAX_FRAME_SIZE: u32 = 10 * 1024 * 1024;
8
9#[derive(Debug, Serialize, Deserialize, PartialEq)]
10#[serde(tag = "type", rename_all = "snake_case")]
11pub enum Request {
12 Status,
13 DaemonStop,
14 Init,
15 Check {
16 path: Option<PathBuf>,
17 errors_only: bool,
19 },
20 FindSymbol {
21 name: String,
22 #[serde(default, skip_serializing_if = "Option::is_none")]
24 path_filter: Option<String>,
25 #[serde(default)]
27 src_only: bool,
28 #[serde(default)]
30 include_body: bool,
31 },
32 FindImpl {
33 name: String,
34 },
35 FindRefs {
36 name: String,
37 #[serde(default)]
39 with_symbol: bool,
40 },
41 ListSymbols {
42 path: PathBuf,
43 depth: u8,
44 },
45 ReadFile {
46 path: PathBuf,
47 from: Option<u32>,
48 to: Option<u32>,
49 max_lines: Option<u32>,
50 },
51 ReadSymbol {
52 name: String,
53 signature_only: bool,
54 max_lines: Option<u32>,
55 #[serde(default, skip_serializing_if = "Option::is_none")]
57 path_filter: Option<String>,
58 #[serde(default)]
60 has_body: bool,
61 },
62 EditReplace {
63 symbol: String,
64 code: String,
65 },
66 EditInsertAfter {
67 symbol: String,
68 code: String,
69 },
70 EditInsertBefore {
71 symbol: String,
72 code: String,
73 },
74 Hover {
75 name: String,
76 },
77 Format {
78 path: PathBuf,
79 },
80 Rename {
81 name: String,
82 new_name: String,
83 },
84 Fix {
85 path: Option<PathBuf>,
86 },
87 ServerStatus,
89 ServerRestart {
91 language: String,
92 },
93}
94
95#[derive(Debug, Serialize, Deserialize, PartialEq)]
96pub struct Response {
97 pub success: bool,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub data: Option<serde_json::Value>,
100 #[serde(skip_serializing_if = "Option::is_none")]
101 pub error: Option<ErrorPayload>,
102}
103
104#[derive(Debug, Serialize, Deserialize, PartialEq)]
105pub struct ErrorPayload {
106 pub code: String,
107 pub message: String,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 pub advice: Option<String>,
110}
111
112impl Response {
113 #[must_use]
114 pub fn ok(data: serde_json::Value) -> Self {
115 Self {
116 success: true,
117 data: Some(data),
118 error: None,
119 }
120 }
121
122 pub fn err(code: impl Into<String>, message: impl Into<String>) -> Self {
123 Self {
124 success: false,
125 data: None,
126 error: Some(ErrorPayload {
127 code: code.into(),
128 message: message.into(),
129 advice: None,
130 }),
131 }
132 }
133
134 pub fn err_with_advice(
135 code: impl Into<String>,
136 message: impl Into<String>,
137 advice: impl Into<String>,
138 ) -> Self {
139 Self {
140 success: false,
141 data: None,
142 error: Some(ErrorPayload {
143 code: code.into(),
144 message: message.into(),
145 advice: Some(advice.into()),
146 }),
147 }
148 }
149
150 #[must_use]
151 pub fn not_implemented() -> Self {
152 Self::err("not_implemented", "This command is not yet implemented")
153 }
154}
155
156#[derive(Debug, Error)]
157pub enum FrameError {
158 #[error("Frame exceeds maximum size of {MAX_FRAME_SIZE} bytes (got {size})")]
159 Oversized { size: u32 },
160 #[error("Incomplete frame: expected {expected} bytes, got {got}")]
161 Incomplete { expected: u32, got: usize },
162 #[error("IO error: {0}")]
163 Io(#[from] std::io::Error),
164 #[error("JSON error: {0}")]
165 Json(#[from] serde_json::Error),
166}
167
168pub fn encode_frame<T: Serialize>(value: &T) -> Result<Vec<u8>, FrameError> {
175 let json = serde_json::to_vec(value)?;
176
177 let len = u32::try_from(json.len()).map_err(|_| FrameError::Oversized { size: u32::MAX })?;
178
179 if len > MAX_FRAME_SIZE {
180 return Err(FrameError::Oversized { size: len });
181 }
182
183 let mut buf = Vec::with_capacity(4 + json.len());
184 buf.extend_from_slice(&len.to_be_bytes());
185 buf.extend_from_slice(&json);
186 Ok(buf)
187}
188
189pub fn decode_frame<T: for<'de> Deserialize<'de>>(buf: &[u8]) -> Result<(T, usize), FrameError> {
197 if buf.len() < 4 {
198 return Err(FrameError::Incomplete {
199 expected: 4,
200 got: buf.len(),
201 });
202 }
203
204 let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
205
206 if len > MAX_FRAME_SIZE {
207 return Err(FrameError::Oversized { size: len });
208 }
209
210 let total = 4 + len as usize;
211 if buf.len() < total {
212 return Err(FrameError::Incomplete {
213 expected: len,
214 got: buf.len() - 4,
215 });
216 }
217
218 let value = serde_json::from_slice(&buf[4..total])?;
219 Ok((value, total))
220}
221
222#[cfg(test)]
223mod tests {
224 use serde_json::json;
225
226 use super::*;
227
228 #[test]
229 fn request_roundtrip_serialization() {
230 let requests = vec![
231 Request::Status,
232 Request::DaemonStop,
233 Request::Init {},
234 Request::Check {
235 path: None,
236 errors_only: false,
237 },
238 Request::Check {
239 path: Some(PathBuf::from("src/lib.rs")),
240 errors_only: true,
241 },
242 Request::FindSymbol {
243 name: "MyStruct".into(),
244 path_filter: None,
245 src_only: false,
246 include_body: false,
247 },
248 Request::FindSymbol {
249 name: "Foo".into(),
250 path_filter: Some("packages/core".into()),
251 src_only: true,
252 include_body: false,
253 },
254 Request::FindRefs {
255 name: "my_func".into(),
256 with_symbol: false,
257 },
258 Request::FindRefs {
259 name: "createStep".into(),
260 with_symbol: true,
261 },
262 Request::ListSymbols {
263 path: PathBuf::from("src/lib.rs"),
264 depth: 1,
265 },
266 Request::ReadFile {
267 path: PathBuf::from("src/main.rs"),
268 from: Some(5),
269 to: Some(10),
270 max_lines: None,
271 },
272 Request::ReadSymbol {
273 name: "Config".into(),
274 signature_only: true,
275 max_lines: Some(20),
276 path_filter: None,
277 has_body: false,
278 },
279 Request::ReadSymbol {
280 name: "CreatePromotionDTO".into(),
281 signature_only: false,
282 max_lines: None,
283 path_filter: Some("packages/core".into()),
284 has_body: true,
285 },
286 Request::EditReplace {
287 symbol: "greet".into(),
288 code: "fn greet() {}".into(),
289 },
290 Request::EditInsertAfter {
291 symbol: "greet".into(),
292 code: "fn helper() {}".into(),
293 },
294 Request::EditInsertBefore {
295 symbol: "greet".into(),
296 code: "#[test]".into(),
297 },
298 Request::ServerStatus,
299 Request::ServerRestart {
300 language: "rust".into(),
301 },
302 ];
303
304 for req in &requests {
305 let json = serde_json::to_string(req).unwrap();
306 let decoded: Request = serde_json::from_str(&json).unwrap();
307 assert_eq!(*req, decoded, "roundtrip failed for {json}");
308 }
309 }
310
311 #[test]
312 fn response_roundtrip_serialization() {
313 let responses = vec![
314 Response::ok(json!({"pid": 1234})),
315 Response::err("not_found", "Symbol not found"),
316 Response::err_with_advice("lsp_not_found", "LSP not detected", "Install rust-analyzer"),
317 Response::not_implemented(),
318 ];
319
320 for resp in &responses {
321 let json = serde_json::to_string(resp).unwrap();
322 let decoded: Response = serde_json::from_str(&json).unwrap();
323 assert_eq!(*resp, decoded, "roundtrip failed for {json}");
324 }
325 }
326
327 #[test]
328 fn frame_encode_decode() {
329 let req = Request::FindSymbol {
330 name: "Foo".into(),
331 path_filter: None,
332 src_only: false,
333 include_body: false,
334 };
335 let frame = encode_frame(&req).unwrap();
336 let (decoded, consumed): (Request, usize) = decode_frame(&frame).unwrap();
337 assert_eq!(decoded, req);
338 assert_eq!(consumed, frame.len());
339 }
340
341 #[test]
342 fn frame_empty_payload() {
343 let req = Request::Status;
344 let frame = encode_frame(&req).unwrap();
345 let (decoded, _): (Request, usize) = decode_frame(&frame).unwrap();
346 assert_eq!(decoded, req);
347 }
348
349 #[test]
350 fn frame_large_payload() {
351 let big_code = "x".repeat(1_000_000);
352 let req = Request::EditReplace {
353 symbol: "f".into(),
354 code: big_code.clone(),
355 };
356 let frame = encode_frame(&req).unwrap();
357 let (decoded, _): (Request, usize) = decode_frame(&frame).unwrap();
358 assert_eq!(
359 decoded,
360 Request::EditReplace {
361 symbol: "f".into(),
362 code: big_code,
363 }
364 );
365 }
366
367 #[test]
368 fn frame_rejects_oversized() {
369 let huge = "x".repeat(11_000_000);
370 let req = Request::EditReplace {
371 symbol: "f".into(),
372 code: huge,
373 };
374 let result = encode_frame(&req);
375 assert!(result.is_err());
376 assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
377 }
378
379 #[test]
380 fn frame_decode_incomplete_header() {
381 let result: Result<(Request, usize), _> = decode_frame(&[0, 1]);
382 assert!(result.is_err());
383 }
384
385 #[test]
386 fn frame_decode_incomplete_payload() {
387 let frame = encode_frame(&Request::Status).unwrap();
388 let truncated = &frame[..frame.len() - 1];
389 let result: Result<(Request, usize), _> = decode_frame(truncated);
390 assert!(result.is_err());
391 }
392}