Skip to main content

krait/
protocol.rs

1use std::path::PathBuf;
2
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6/// Max frame size: 10 MB
7pub(crate) const MAX_FRAME_SIZE: u32 = 10 * 1024 * 1024;
8
9#[derive(Debug, Serialize, Deserialize, PartialEq)]
10#[serde(tag = "type", rename_all = "snake_case")]
11pub enum Request {
12    Status,
13    DaemonStop,
14    Init,
15    Check {
16        path: Option<PathBuf>,
17        /// If true, suppress warnings and hints.
18        errors_only: bool,
19    },
20    FindSymbol {
21        name: String,
22        /// Substring filter applied to result paths (for disambiguation)
23        #[serde(default, skip_serializing_if = "Option::is_none")]
24        path_filter: Option<String>,
25        /// Exclude noise paths (www/, dist/, `node_modules`/, .d.ts, .mdx)
26        #[serde(default)]
27        src_only: bool,
28        /// Include full symbol body in each result
29        #[serde(default)]
30        include_body: bool,
31    },
32    FindImpl {
33        name: String,
34    },
35    FindRefs {
36        name: String,
37        /// Enrich each reference with its containing symbol
38        #[serde(default)]
39        with_symbol: bool,
40    },
41    ListSymbols {
42        path: PathBuf,
43        depth: u8,
44    },
45    ReadFile {
46        path: PathBuf,
47        from: Option<u32>,
48        to: Option<u32>,
49        max_lines: Option<u32>,
50    },
51    ReadSymbol {
52        name: String,
53        signature_only: bool,
54        max_lines: Option<u32>,
55        /// Substring filter to select the right definition when multiple exist
56        #[serde(default, skip_serializing_if = "Option::is_none")]
57        path_filter: Option<String>,
58        /// Skip overload stubs and return the real implementation body
59        #[serde(default)]
60        has_body: bool,
61    },
62    EditReplace {
63        symbol: String,
64        code: String,
65    },
66    EditInsertAfter {
67        symbol: String,
68        code: String,
69    },
70    EditInsertBefore {
71        symbol: String,
72        code: String,
73    },
74    Hover {
75        name: String,
76    },
77    Format {
78        path: PathBuf,
79    },
80    Rename {
81        name: String,
82        new_name: String,
83    },
84    Fix {
85        path: Option<PathBuf>,
86    },
87    /// Get running LSP server status from daemon
88    ServerStatus,
89    /// Restart a language server in the daemon
90    ServerRestart {
91        language: String,
92    },
93}
94
95#[derive(Debug, Serialize, Deserialize, PartialEq)]
96pub struct Response {
97    pub success: bool,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub data: Option<serde_json::Value>,
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub error: Option<ErrorPayload>,
102}
103
104#[derive(Debug, Serialize, Deserialize, PartialEq)]
105pub struct ErrorPayload {
106    pub code: String,
107    pub message: String,
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub advice: Option<String>,
110}
111
112impl Response {
113    #[must_use]
114    pub fn ok(data: serde_json::Value) -> Self {
115        Self {
116            success: true,
117            data: Some(data),
118            error: None,
119        }
120    }
121
122    pub fn err(code: impl Into<String>, message: impl Into<String>) -> Self {
123        Self {
124            success: false,
125            data: None,
126            error: Some(ErrorPayload {
127                code: code.into(),
128                message: message.into(),
129                advice: None,
130            }),
131        }
132    }
133
134    pub fn err_with_advice(
135        code: impl Into<String>,
136        message: impl Into<String>,
137        advice: impl Into<String>,
138    ) -> Self {
139        Self {
140            success: false,
141            data: None,
142            error: Some(ErrorPayload {
143                code: code.into(),
144                message: message.into(),
145                advice: Some(advice.into()),
146            }),
147        }
148    }
149
150    #[must_use]
151    pub fn not_implemented() -> Self {
152        Self::err("not_implemented", "This command is not yet implemented")
153    }
154}
155
156#[derive(Debug, Error)]
157pub enum FrameError {
158    #[error("Frame exceeds maximum size of {MAX_FRAME_SIZE} bytes (got {size})")]
159    Oversized { size: u32 },
160    #[error("Incomplete frame: expected {expected} bytes, got {got}")]
161    Incomplete { expected: u32, got: usize },
162    #[error("IO error: {0}")]
163    Io(#[from] std::io::Error),
164    #[error("JSON error: {0}")]
165    Json(#[from] serde_json::Error),
166}
167
168/// Encode a serializable value into a length-prefixed frame.
169/// Format: `[4-byte big-endian length][JSON payload]`
170///
171/// # Errors
172/// Returns `FrameError::Oversized` if the payload exceeds 10 MB,
173/// or `FrameError::Json` if serialization fails.
174pub fn encode_frame<T: Serialize>(value: &T) -> Result<Vec<u8>, FrameError> {
175    let json = serde_json::to_vec(value)?;
176
177    let len = u32::try_from(json.len()).map_err(|_| FrameError::Oversized { size: u32::MAX })?;
178
179    if len > MAX_FRAME_SIZE {
180        return Err(FrameError::Oversized { size: len });
181    }
182
183    let mut buf = Vec::with_capacity(4 + json.len());
184    buf.extend_from_slice(&len.to_be_bytes());
185    buf.extend_from_slice(&json);
186    Ok(buf)
187}
188
189/// Decode a length-prefixed frame into a deserialized value.
190/// Returns the value and the number of bytes consumed.
191///
192/// # Errors
193/// Returns `FrameError::Incomplete` if the buffer is too short,
194/// `FrameError::Oversized` if the declared length exceeds 10 MB,
195/// or `FrameError::Json` if deserialization fails.
196pub fn decode_frame<T: for<'de> Deserialize<'de>>(buf: &[u8]) -> Result<(T, usize), FrameError> {
197    if buf.len() < 4 {
198        return Err(FrameError::Incomplete {
199            expected: 4,
200            got: buf.len(),
201        });
202    }
203
204    let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
205
206    if len > MAX_FRAME_SIZE {
207        return Err(FrameError::Oversized { size: len });
208    }
209
210    let total = 4 + len as usize;
211    if buf.len() < total {
212        return Err(FrameError::Incomplete {
213            expected: len,
214            got: buf.len() - 4,
215        });
216    }
217
218    let value = serde_json::from_slice(&buf[4..total])?;
219    Ok((value, total))
220}
221
222#[cfg(test)]
223mod tests {
224    use serde_json::json;
225
226    use super::*;
227
228    #[test]
229    fn request_roundtrip_serialization() {
230        let requests = vec![
231            Request::Status,
232            Request::DaemonStop,
233            Request::Init {},
234            Request::Check {
235                path: None,
236                errors_only: false,
237            },
238            Request::Check {
239                path: Some(PathBuf::from("src/lib.rs")),
240                errors_only: true,
241            },
242            Request::FindSymbol {
243                name: "MyStruct".into(),
244                path_filter: None,
245                src_only: false,
246                include_body: false,
247            },
248            Request::FindSymbol {
249                name: "Foo".into(),
250                path_filter: Some("packages/core".into()),
251                src_only: true,
252                include_body: false,
253            },
254            Request::FindRefs {
255                name: "my_func".into(),
256                with_symbol: false,
257            },
258            Request::FindRefs {
259                name: "createStep".into(),
260                with_symbol: true,
261            },
262            Request::ListSymbols {
263                path: PathBuf::from("src/lib.rs"),
264                depth: 1,
265            },
266            Request::ReadFile {
267                path: PathBuf::from("src/main.rs"),
268                from: Some(5),
269                to: Some(10),
270                max_lines: None,
271            },
272            Request::ReadSymbol {
273                name: "Config".into(),
274                signature_only: true,
275                max_lines: Some(20),
276                path_filter: None,
277                has_body: false,
278            },
279            Request::ReadSymbol {
280                name: "CreatePromotionDTO".into(),
281                signature_only: false,
282                max_lines: None,
283                path_filter: Some("packages/core".into()),
284                has_body: true,
285            },
286            Request::EditReplace {
287                symbol: "greet".into(),
288                code: "fn greet() {}".into(),
289            },
290            Request::EditInsertAfter {
291                symbol: "greet".into(),
292                code: "fn helper() {}".into(),
293            },
294            Request::EditInsertBefore {
295                symbol: "greet".into(),
296                code: "#[test]".into(),
297            },
298            Request::ServerStatus,
299            Request::ServerRestart {
300                language: "rust".into(),
301            },
302        ];
303
304        for req in &requests {
305            let json = serde_json::to_string(req).unwrap();
306            let decoded: Request = serde_json::from_str(&json).unwrap();
307            assert_eq!(*req, decoded, "roundtrip failed for {json}");
308        }
309    }
310
311    #[test]
312    fn response_roundtrip_serialization() {
313        let responses = vec![
314            Response::ok(json!({"pid": 1234})),
315            Response::err("not_found", "Symbol not found"),
316            Response::err_with_advice("lsp_not_found", "LSP not detected", "Install rust-analyzer"),
317            Response::not_implemented(),
318        ];
319
320        for resp in &responses {
321            let json = serde_json::to_string(resp).unwrap();
322            let decoded: Response = serde_json::from_str(&json).unwrap();
323            assert_eq!(*resp, decoded, "roundtrip failed for {json}");
324        }
325    }
326
327    #[test]
328    fn frame_encode_decode() {
329        let req = Request::FindSymbol {
330            name: "Foo".into(),
331            path_filter: None,
332            src_only: false,
333            include_body: false,
334        };
335        let frame = encode_frame(&req).unwrap();
336        let (decoded, consumed): (Request, usize) = decode_frame(&frame).unwrap();
337        assert_eq!(decoded, req);
338        assert_eq!(consumed, frame.len());
339    }
340
341    #[test]
342    fn frame_empty_payload() {
343        let req = Request::Status;
344        let frame = encode_frame(&req).unwrap();
345        let (decoded, _): (Request, usize) = decode_frame(&frame).unwrap();
346        assert_eq!(decoded, req);
347    }
348
349    #[test]
350    fn frame_large_payload() {
351        let big_code = "x".repeat(1_000_000);
352        let req = Request::EditReplace {
353            symbol: "f".into(),
354            code: big_code.clone(),
355        };
356        let frame = encode_frame(&req).unwrap();
357        let (decoded, _): (Request, usize) = decode_frame(&frame).unwrap();
358        assert_eq!(
359            decoded,
360            Request::EditReplace {
361                symbol: "f".into(),
362                code: big_code,
363            }
364        );
365    }
366
367    #[test]
368    fn frame_rejects_oversized() {
369        let huge = "x".repeat(11_000_000);
370        let req = Request::EditReplace {
371            symbol: "f".into(),
372            code: huge,
373        };
374        let result = encode_frame(&req);
375        assert!(result.is_err());
376        assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
377    }
378
379    #[test]
380    fn frame_decode_incomplete_header() {
381        let result: Result<(Request, usize), _> = decode_frame(&[0, 1]);
382        assert!(result.is_err());
383    }
384
385    #[test]
386    fn frame_decode_incomplete_payload() {
387        let frame = encode_frame(&Request::Status).unwrap();
388        let truncated = &frame[..frame.len() - 1];
389        let result: Result<(Request, usize), _> = decode_frame(truncated);
390        assert!(result.is_err());
391    }
392}