lsp_client_rs/
protocol.rs

1use anyhow::{bail, Result};
2use serde::{Deserialize, Serialize};
3
4#[derive(Serialize, Deserialize, Debug)]
5pub struct BaseMessage {
6    pub jsonrpc: String,
7}
8
9#[derive(Serialize, Deserialize, Debug)]
10pub struct RequestMessage {
11    #[serde(flatten)]
12    pub base_message: BaseMessage,
13    pub id: serde_json::Value,
14    pub method: String,
15    pub params: serde_json::Value,
16}
17
18#[derive(Serialize, Deserialize, Debug)]
19pub struct ResponseMessage {
20    #[serde(flatten)]
21    pub base_message: BaseMessage,
22    pub id: serde_json::Value,
23    pub result: serde_json::Value,
24    pub error: Option<serde_json::Value>,
25}
26
27#[derive(Serialize, Deserialize, Debug)]
28pub struct NotificationMessage {
29    #[serde(flatten)]
30    pub base_message: BaseMessage,
31    pub method: String,
32    pub params: serde_json::Value,
33}
34
35#[derive(Serialize, Deserialize, Debug)]
36pub struct InitializeParams {
37    #[serde(rename = "processId")]
38    pub process_id: u32,
39    #[serde(rename = "rootUri")]
40    pub root_uri: String,
41    #[serde(rename = "clientInfo")]
42    pub client_info: ClientInfo,
43    pub capabilities: ClientCapabilities, // Direct embedding
44    #[serde(rename = "workspaceFolders")]
45    pub workspace_folders: Option<Vec<WorkspaceFolder>>,
46}
47
48#[derive(Serialize, Deserialize, Debug)]
49pub struct ClientInfo {
50    pub name: String,
51    pub version: String,
52}
53
54#[derive(Serialize, Deserialize, Debug)]
55pub struct WorkspaceFolder {
56    pub uri: String,
57    pub name: String,
58}
59
60#[derive(Serialize, Deserialize, Debug)]
61pub struct ClientCapabilities {
62    pub workspace: Option<CapabilitiesWorkspace>, // Changed from HashMap to direct struct
63    #[serde(rename = "textDocument")]
64    pub text_document: Option<CapabilitiesTextDocument>, // Changed from HashMap to direct struct
65}
66
67#[derive(Serialize, Deserialize, Debug)]
68pub struct CapabilitiesWorkspace {
69    #[serde(rename = "workspaceFolders")]
70    pub workspace_folders: bool,
71    #[serde(rename = "didChangeConfiguration")]
72    pub did_change_configuration: DidChangeConfiguration,
73    #[serde(rename = "workspaceEdit")]
74    pub workspace_edit: WorkspaceEdit,
75    pub configuration: bool,
76}
77
78#[derive(Serialize, Deserialize, Debug)]
79pub struct DidChangeConfiguration {
80    #[serde(rename = "dynamicRegistration")]
81    pub dynamic_registration: bool,
82}
83
84#[derive(Serialize, Deserialize, Debug)]
85pub struct WorkspaceEdit {
86    #[serde(rename = "documentChanges")]
87    pub document_changes: bool,
88}
89
90#[derive(Serialize, Deserialize, Debug)]
91pub struct CapabilitiesTextDocument {
92    pub hover: Hover,
93    pub completion: Completion,
94    #[serde(rename = "codeAction")]
95    pub code_action: CodeAction,
96}
97
98#[derive(Serialize, Deserialize, Debug)]
99pub struct Hover {
100    #[serde(rename = "contentFormat")]
101    pub content_format: Vec<String>,
102}
103
104#[derive(Serialize, Deserialize, Debug)]
105pub struct Completion {
106    #[serde(rename = "completionItem")]
107    pub completion_item: CompletionItem,
108}
109
110#[derive(Serialize, Deserialize, Debug)]
111pub struct CompletionItem {
112    #[serde(rename = "snippetSupport")]
113    pub snippet_support: bool,
114}
115
116#[derive(Serialize, Deserialize, Debug)]
117pub struct CodeAction {
118    #[serde(rename = "codeActionLiteralSupport")]
119    pub code_action_literal_support: CodeActionLiteralSupport,
120}
121
122#[derive(Serialize, Deserialize, Debug)]
123pub struct CodeActionLiteralSupport {
124    #[serde(rename = "codeActionKind")]
125    pub code_action_kind: CodeActionKind,
126}
127
128#[derive(Serialize, Deserialize, Debug)]
129pub struct CodeActionKind {
130    #[serde(rename = "valueSet")]
131    pub value_set: Vec<String>,
132}
133
134#[derive(Serialize, Deserialize, Debug)]
135pub struct Location {
136    uri: String,
137    range: Range,
138}
139
140#[derive(Serialize, Deserialize, Debug)]
141pub struct Range {
142    start: Position,
143    end: Position,
144}
145
146#[derive(Serialize, Deserialize, Debug)]
147pub struct Position {
148    line: u32,
149    character: u32,
150}
151
152impl Position {
153    pub fn new(line: u32, character: u32) -> Self {
154        Position { line, character }
155    }
156}
157
158impl RequestMessage {
159    /// Helper function to create a new `initialize` request message.
160    /// id - The ID of the request message.
161    /// process_id - The process ID of the client. (usually `std::process::id()`)
162    /// root_uri - The root URI of the workspace. (e.g. `file://path/to/code`)
163    /// client_name - The name of the client. (e.g. `vim-go`)
164    /// workspace_folders - List of folders that the lsp needs context for.
165    /// TODO: This function is currently a bit opinionated towards textdefintion.
166    /// To have a custom initialize message, the workaround for now is to directly
167    /// create a `RequestMessage` with desired capabilities.
168    pub fn new_initialize(
169        id: u32,
170        process_id: u32,
171        root_uri: String,
172        client_name: String,
173        client_version: String,
174        workspace_folders: Vec<WorkspaceFolder>,
175    ) -> Self {
176        let client_info = ClientInfo {
177            name: client_name,
178            version: client_version,
179        };
180
181        let capabilities = ClientCapabilities {
182            workspace: Some(CapabilitiesWorkspace {
183                workspace_folders: true,
184                did_change_configuration: DidChangeConfiguration {
185                    dynamic_registration: true,
186                },
187                workspace_edit: WorkspaceEdit {
188                    document_changes: true,
189                },
190                configuration: true,
191            }),
192            text_document: Some(CapabilitiesTextDocument {
193                hover: Hover {
194                    content_format: vec!["plaintext".to_string()],
195                },
196                completion: Completion {
197                    completion_item: CompletionItem {
198                        snippet_support: true,
199                    },
200                },
201                code_action: CodeAction {
202                    code_action_literal_support: CodeActionLiteralSupport {
203                        code_action_kind: CodeActionKind {
204                            value_set: vec![
205                                "source.organizeImports".to_string(),
206                                "refactor.rewrite".to_string(),
207                                "refactor.extract".to_string(),
208                            ],
209                        },
210                    },
211                },
212            }),
213        };
214
215        RequestMessage {
216            base_message: BaseMessage {
217                jsonrpc: "2.0".to_string(),
218            },
219            id: serde_json::Value::from(id),
220            method: "initialize".to_string(),
221            params: serde_json::to_value(InitializeParams {
222                process_id,
223                root_uri,
224                client_info,
225                capabilities,
226                workspace_folders: Some(workspace_folders),
227            })
228            .unwrap(),
229        }
230    }
231
232    /// Helper function to create a new `textDocument/definition` request message.
233    /// id - The ID of the request message.
234    /// uri - The URI of the text document. (e.g. `file://path/to/code/main.go`)
235    /// line - The line number of the cursor position.
236    /// character - The the cursor position of the character we want to get the definition of.
237    pub fn new_get_definition(id: u32, uri: String, position: Position) -> Self {
238        RequestMessage {
239            base_message: BaseMessage {
240                jsonrpc: "2.0".to_string(),
241            },
242            id: serde_json::Value::from(id),
243            method: "textDocument/definition".to_string(),
244            params: serde_json::json!({
245                "textDocument": {
246                    "uri": uri
247                },
248                "position": {
249                    "line": position.line,
250                    "character": position.character,
251                }
252            }),
253        }
254    }
255}
256
257impl NotificationMessage {
258    /// Helper function to create a new `initialized` notification message.
259    /// This message is sent by the client to the server once it has finished initializing
260    /// and signals that the client is ready to receive requests.
261    pub fn new_initialized() -> Self {
262        NotificationMessage {
263            base_message: BaseMessage {
264                jsonrpc: "2.0".to_string(),
265            },
266            method: "initialized".to_string(),
267            params: serde_json::Value::Object(serde_json::Map::new()),
268        }
269    }
270}
271
272impl ResponseMessage {
273    pub fn handle_initialize(&self) -> Result<()> {
274        if self.error.is_some() {
275            bail!("Error from LSP server: {:?}", self.error);
276        };
277
278        Ok(())
279    }
280
281    pub fn handle_definition(&self) -> Result<Vec<Location>> {
282        if self.error.is_some() {
283            bail!("Error from LSP server: {:?}", self.error);
284        };
285
286        let location: Result<Location, _> = serde_json::from_value(self.result.clone());
287        let locations: Result<Vec<Location>, _> = serde_json::from_value(self.result.clone());
288
289        match location {
290            Ok(loc) => Ok(vec![loc]),
291            Err(_) => match locations {
292                Ok(locs) => Ok(locs),
293                Err(_) => anyhow::bail!("Failed to parse definition location(s) from response."),
294            },
295        }
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use serde_json::json;
303
304    #[test]
305    fn test_initialize_message() {
306        let process_id = std::process::id();
307        let expected_init_json = json!({
308            "jsonrpc": "2.0",
309            "id": 1,
310            "method": "initialize",
311            "params": {
312                "processId": process_id,
313                "clientInfo": {
314                    "name": "YourLSPClientName", // Customize this
315                    "version": "1.0.0" // Optional: Adjust as necessary
316                },
317                "rootUri": "file://path/to/root",
318                "capabilities": {
319                    "workspace": {
320                        "workspaceFolders": true,
321                        "didChangeConfiguration": {
322                            "dynamicRegistration": true
323                        },
324                        "workspaceEdit": {
325                            "documentChanges": true
326                        },
327                        "configuration": true
328                    },
329                    "textDocument": {
330                        "hover": {
331                            "contentFormat": ["plaintext"]
332                        },
333                        "completion": {
334                            "completionItem": {
335                                "snippetSupport": true // Set to false if your client does not support snippets
336                            }
337                        },
338                        "codeAction": {
339                            "codeActionLiteralSupport": {
340                                "codeActionKind": {
341                                    "valueSet": ["source.organizeImports", "refactor.rewrite", "refactor.extract"]
342                                }
343                            }
344                        }
345                    }
346                },
347                "workspaceFolders": [{
348                    "uri": "file://path/to/workspace",
349                    "name": "file://path/to/workspace" ,
350                }]
351            }
352        });
353
354        let init_params = RequestMessage::new_initialize(
355            1,
356            process_id,
357            "file://path/to/root".to_string(),
358            "YourLSPClientName".to_string(),
359            "1.0.0".to_string(),
360            vec![WorkspaceFolder {
361                uri: "file://path/to/workspace".to_string(),
362                name: "file://path/to/workspace".to_string(),
363            }],
364        );
365
366        // Check that the JSON serialization is correct
367        let init_params_json = serde_json::to_value(init_params).unwrap();
368        assert_eq!(expected_init_json, init_params_json);
369    }
370
371    #[test]
372    fn test_initialized_notification() {
373        let expected_initialized_json = json!({
374            "jsonrpc": "2.0",
375            "method": "initialized",
376            "params": {}
377        });
378
379        let initialized_notification = NotificationMessage::new_initialized();
380        let initialized_notification_json = serde_json::to_value(initialized_notification).unwrap();
381        assert_eq!(expected_initialized_json, initialized_notification_json);
382    }
383
384    #[test]
385    fn test_get_definition() {
386        let expected_get_definition_json = json!({
387            "jsonrpc": "2.0",
388            "id": 1,
389            "method": "textDocument/definition",
390            "params": {
391                "textDocument": {
392                    "uri": "file://path/to/code/main.go"
393                },
394                "position": {
395                    "line": 1,
396                    "character": 2
397                }
398            }
399        });
400
401        let get_definition = RequestMessage::new_get_definition(
402            1,
403            "file://path/to/code/main.go".to_string(),
404            Position {
405                line: 1,
406                character: 2,
407            },
408        );
409
410        let get_definition_json = serde_json::to_value(get_definition).unwrap();
411        assert_eq!(expected_get_definition_json, get_definition_json);
412    }
413}