squawk_server/
lib.rs

1use anyhow::{Context, Result};
2use line_index::LineIndex;
3use log::info;
4use lsp_server::{Connection, Message, Notification, Response};
5use lsp_types::{
6    CodeAction, CodeActionKind, CodeActionOptions, CodeActionOrCommand, CodeActionParams,
7    CodeActionProviderCapability, CodeActionResponse, Command, Diagnostic,
8    DidChangeTextDocumentParams, DidCloseTextDocumentParams, DidOpenTextDocumentParams,
9    GotoDefinitionParams, GotoDefinitionResponse, InitializeParams, Location, OneOf,
10    PublishDiagnosticsParams, SelectionRangeParams, SelectionRangeProviderCapability,
11    ServerCapabilities, TextDocumentSyncCapability, TextDocumentSyncKind, Url,
12    WorkDoneProgressOptions, WorkspaceEdit,
13    notification::{
14        DidChangeTextDocument, DidCloseTextDocument, DidOpenTextDocument, Notification as _,
15        PublishDiagnostics,
16    },
17    request::{CodeActionRequest, GotoDefinition, Request, SelectionRangeRequest},
18};
19use rowan::TextRange;
20use squawk_ide::code_actions::code_actions;
21use squawk_ide::goto_definition::goto_definition;
22use squawk_syntax::{Parse, SourceFile};
23use std::collections::HashMap;
24
25use diagnostic::DIAGNOSTIC_NAME;
26
27use crate::diagnostic::AssociatedDiagnosticData;
28mod diagnostic;
29mod ignore;
30mod lint;
31mod lsp_utils;
32
33struct DocumentState {
34    content: String,
35    version: i32,
36}
37
38pub fn run() -> Result<()> {
39    info!("Starting Squawk LSP server");
40
41    let (connection, io_threads) = Connection::stdio();
42
43    let server_capabilities = serde_json::to_value(&ServerCapabilities {
44        text_document_sync: Some(TextDocumentSyncCapability::Kind(
45            TextDocumentSyncKind::INCREMENTAL,
46        )),
47        code_action_provider: Some(CodeActionProviderCapability::Options(CodeActionOptions {
48            code_action_kinds: Some(vec![CodeActionKind::QUICKFIX]),
49            work_done_progress_options: WorkDoneProgressOptions {
50                work_done_progress: None,
51            },
52            resolve_provider: None,
53        })),
54        selection_range_provider: Some(SelectionRangeProviderCapability::Simple(true)),
55        definition_provider: Some(OneOf::Left(true)),
56        ..Default::default()
57    })
58    .unwrap();
59
60    info!("LSP server initializing connection...");
61    let initialization_params = connection.initialize(server_capabilities)?;
62    info!("LSP server initialized, entering main loop");
63
64    main_loop(connection, initialization_params)?;
65
66    info!("LSP server shutting down");
67
68    io_threads.join()?;
69    Ok(())
70}
71
72fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
73    info!("Server main loop");
74
75    let init_params: InitializeParams = serde_json::from_value(params).unwrap_or_default();
76    info!("Client process ID: {:?}", init_params.process_id);
77    let client_name = init_params.client_info.map(|x| x.name);
78    info!("Client name: {client_name:?}");
79
80    let mut documents: HashMap<Url, DocumentState> = HashMap::new();
81
82    for msg in &connection.receiver {
83        match msg {
84            Message::Request(req) => {
85                info!("Received request: method={}, id={:?}", req.method, req.id);
86
87                if connection.handle_shutdown(&req)? {
88                    info!("Received shutdown request, exiting");
89                    return Ok(());
90                }
91
92                match req.method.as_ref() {
93                    GotoDefinition::METHOD => {
94                        handle_goto_definition(&connection, req, &documents)?;
95                    }
96                    CodeActionRequest::METHOD => {
97                        handle_code_action(&connection, req, &documents)?;
98                    }
99                    SelectionRangeRequest::METHOD => {
100                        handle_selection_range(&connection, req, &documents)?;
101                    }
102                    "squawk/syntaxTree" => {
103                        handle_syntax_tree(&connection, req, &documents)?;
104                    }
105                    "squawk/tokens" => {
106                        handle_tokens(&connection, req, &documents)?;
107                    }
108                    _ => {
109                        info!("Ignoring unhandled request: {}", req.method);
110                    }
111                }
112            }
113            Message::Response(resp) => {
114                info!("Received response: id={:?}", resp.id);
115            }
116            Message::Notification(notif) => {
117                info!("Received notification: method={}", notif.method);
118                match notif.method.as_ref() {
119                    DidOpenTextDocument::METHOD => {
120                        handle_did_open(&connection, notif, &mut documents)?;
121                    }
122                    DidChangeTextDocument::METHOD => {
123                        handle_did_change(&connection, notif, &mut documents)?;
124                    }
125                    DidCloseTextDocument::METHOD => {
126                        handle_did_close(&connection, notif, &mut documents)?;
127                    }
128                    _ => {
129                        info!("Ignoring unhandled notification: {}", notif.method);
130                    }
131                }
132            }
133        }
134    }
135    Ok(())
136}
137
138fn handle_goto_definition(
139    connection: &Connection,
140    req: lsp_server::Request,
141    documents: &HashMap<Url, DocumentState>,
142) -> Result<()> {
143    let params: GotoDefinitionParams = serde_json::from_value(req.params)?;
144    let uri = params.text_document_position_params.text_document.uri;
145    let position = params.text_document_position_params.position;
146
147    let content = documents.get(&uri).map_or("", |doc| &doc.content);
148    let parse: Parse<SourceFile> = SourceFile::parse(content);
149    let file = parse.tree();
150    let line_index = LineIndex::new(content);
151    let offset = lsp_utils::offset(&line_index, position).unwrap();
152
153    let range = goto_definition(file, offset);
154
155    let result = match range {
156        Some(target_range) => {
157            debug_assert!(
158                !target_range.contains(offset),
159                "Our target destination range must not include the source range otherwise go to def won't work in vscode."
160            );
161            GotoDefinitionResponse::Scalar(Location {
162                uri: uri.clone(),
163                range: lsp_utils::range(&line_index, target_range),
164            })
165        }
166        None => GotoDefinitionResponse::Array(vec![]),
167    };
168
169    let resp = Response {
170        id: req.id,
171        result: Some(serde_json::to_value(&result).unwrap()),
172        error: None,
173    };
174
175    connection.sender.send(Message::Response(resp))?;
176    Ok(())
177}
178
179fn handle_selection_range(
180    connection: &Connection,
181    req: lsp_server::Request,
182    documents: &HashMap<Url, DocumentState>,
183) -> Result<()> {
184    let params: SelectionRangeParams = serde_json::from_value(req.params)?;
185    let uri = params.text_document.uri;
186
187    let content = documents.get(&uri).map_or("", |doc| &doc.content);
188    let parse: Parse<SourceFile> = SourceFile::parse(content);
189    let root = parse.syntax_node();
190    let line_index = LineIndex::new(content);
191
192    let mut selection_ranges = vec![];
193
194    for position in params.positions {
195        let Some(offset) = lsp_utils::offset(&line_index, position) else {
196            continue;
197        };
198
199        let mut ranges = Vec::new();
200        {
201            let mut range = TextRange::new(offset, offset);
202            loop {
203                ranges.push(range);
204                let next = squawk_ide::expand_selection::extend_selection(&root, range);
205                if next == range {
206                    break;
207                } else {
208                    range = next
209                }
210            }
211        }
212
213        let mut range = lsp_types::SelectionRange {
214            range: lsp_utils::range(&line_index, *ranges.last().unwrap()),
215            parent: None,
216        };
217        for &r in ranges.iter().rev().skip(1) {
218            range = lsp_types::SelectionRange {
219                range: lsp_utils::range(&line_index, r),
220                parent: Some(Box::new(range)),
221            }
222        }
223        selection_ranges.push(range);
224    }
225
226    let resp = Response {
227        id: req.id,
228        result: Some(serde_json::to_value(&selection_ranges).unwrap()),
229        error: None,
230    };
231
232    connection.sender.send(Message::Response(resp))?;
233    Ok(())
234}
235
236fn handle_code_action(
237    connection: &Connection,
238    req: lsp_server::Request,
239    documents: &HashMap<Url, DocumentState>,
240) -> Result<()> {
241    let params: CodeActionParams = serde_json::from_value(req.params)?;
242    let uri = params.text_document.uri;
243
244    let mut actions: CodeActionResponse = Vec::new();
245
246    let content = documents.get(&uri).map_or("", |doc| &doc.content);
247    let parse: Parse<SourceFile> = SourceFile::parse(content);
248    let file = parse.tree();
249    let line_index = LineIndex::new(content);
250    let offset = lsp_utils::offset(&line_index, params.range.start).unwrap();
251
252    let ide_actions = code_actions(file, offset).unwrap_or_default();
253
254    for action in ide_actions {
255        let lsp_action = lsp_utils::code_action(&line_index, uri.clone(), action);
256        actions.push(CodeActionOrCommand::CodeAction(lsp_action));
257    }
258
259    for mut diagnostic in params
260        .context
261        .diagnostics
262        .into_iter()
263        .filter(|diagnostic| diagnostic.source.as_deref() == Some(DIAGNOSTIC_NAME))
264    {
265        let Some(rule_name) = diagnostic.code.as_ref().map(|x| match x {
266            lsp_types::NumberOrString::String(s) => s.clone(),
267            lsp_types::NumberOrString::Number(n) => n.to_string(),
268        }) else {
269            continue;
270        };
271        let Some(data) = diagnostic.data.take() else {
272            continue;
273        };
274
275        let associated_data: AssociatedDiagnosticData =
276            serde_json::from_value(data).context("deserializing diagnostic data")?;
277
278        if let Some(ignore_line_edit) = associated_data.ignore_line_edit {
279            let disable_line_action = CodeAction {
280                title: format!("Disable {rule_name} for this line"),
281                kind: Some(CodeActionKind::QUICKFIX),
282                diagnostics: Some(vec![diagnostic.clone()]),
283                edit: Some(WorkspaceEdit {
284                    changes: Some({
285                        let mut changes = HashMap::new();
286                        changes.insert(uri.clone(), vec![ignore_line_edit]);
287                        changes
288                    }),
289                    ..Default::default()
290                }),
291                command: None,
292                is_preferred: Some(false),
293                disabled: None,
294                data: None,
295            };
296            actions.push(CodeActionOrCommand::CodeAction(disable_line_action));
297        }
298        if let Some(ignore_file_edit) = associated_data.ignore_file_edit {
299            let disable_file_action = CodeAction {
300                title: format!("Disable {rule_name} for the entire file"),
301                kind: Some(CodeActionKind::QUICKFIX),
302                diagnostics: Some(vec![diagnostic.clone()]),
303                edit: Some(WorkspaceEdit {
304                    changes: Some({
305                        let mut changes = HashMap::new();
306                        changes.insert(uri.clone(), vec![ignore_file_edit]);
307                        changes
308                    }),
309                    ..Default::default()
310                }),
311                command: None,
312                is_preferred: Some(false),
313                disabled: None,
314                data: None,
315            };
316            actions.push(CodeActionOrCommand::CodeAction(disable_file_action));
317        }
318
319        let title = format!("Show documentation for {rule_name}");
320        let documentation_action = CodeAction {
321            title: title.clone(),
322            kind: Some(CodeActionKind::QUICKFIX),
323            diagnostics: Some(vec![diagnostic.clone()]),
324            edit: None,
325            command: Some(Command {
326                title,
327                command: "vscode.open".to_string(),
328                arguments: Some(vec![serde_json::to_value(format!(
329                    "https://squawkhq.com/docs/{rule_name}"
330                ))?]),
331            }),
332            is_preferred: Some(false),
333            disabled: None,
334            data: None,
335        };
336        actions.push(CodeActionOrCommand::CodeAction(documentation_action));
337
338        if !associated_data.title.is_empty() && !associated_data.edits.is_empty() {
339            let fix_action = CodeAction {
340                title: associated_data.title,
341                kind: Some(CodeActionKind::QUICKFIX),
342                diagnostics: Some(vec![diagnostic.clone()]),
343                edit: Some(WorkspaceEdit {
344                    changes: Some({
345                        let mut changes = HashMap::new();
346                        changes.insert(uri.clone(), associated_data.edits);
347                        changes
348                    }),
349                    ..Default::default()
350                }),
351                command: None,
352                is_preferred: Some(true),
353                disabled: None,
354                data: None,
355            };
356            actions.push(CodeActionOrCommand::CodeAction(fix_action));
357        }
358    }
359
360    let result: CodeActionResponse = actions;
361    let resp = Response {
362        id: req.id,
363        result: Some(serde_json::to_value(&result).unwrap()),
364        error: None,
365    };
366
367    connection.sender.send(Message::Response(resp))?;
368    Ok(())
369}
370
371fn publish_diagnostics(
372    connection: &Connection,
373    uri: Url,
374    version: i32,
375    diagnostics: Vec<Diagnostic>,
376) -> Result<()> {
377    let publish_params = PublishDiagnosticsParams {
378        uri,
379        diagnostics,
380        version: Some(version),
381    };
382
383    let notification = Notification {
384        method: PublishDiagnostics::METHOD.to_owned(),
385        params: serde_json::to_value(publish_params)?,
386    };
387
388    connection
389        .sender
390        .send(Message::Notification(notification))?;
391    Ok(())
392}
393
394fn handle_did_open(
395    connection: &Connection,
396    notif: lsp_server::Notification,
397    documents: &mut HashMap<Url, DocumentState>,
398) -> Result<()> {
399    let params: DidOpenTextDocumentParams = serde_json::from_value(notif.params)?;
400    let uri = params.text_document.uri;
401    let content = params.text_document.text;
402    let version = params.text_document.version;
403
404    documents.insert(uri.clone(), DocumentState { content, version });
405
406    let content = documents.get(&uri).map_or("", |doc| &doc.content);
407
408    // TODO: we need a better setup for "run func when input changed"
409    let diagnostics = lint::lint(content);
410    publish_diagnostics(connection, uri, version, diagnostics)?;
411
412    Ok(())
413}
414
415fn handle_did_change(
416    connection: &Connection,
417    notif: lsp_server::Notification,
418    documents: &mut HashMap<Url, DocumentState>,
419) -> Result<()> {
420    let params: DidChangeTextDocumentParams = serde_json::from_value(notif.params)?;
421    let uri = params.text_document.uri;
422    let version = params.text_document.version;
423
424    let Some(doc_state) = documents.get_mut(&uri) else {
425        return Ok(());
426    };
427
428    doc_state.content =
429        lsp_utils::apply_incremental_changes(&doc_state.content, params.content_changes);
430    doc_state.version = version;
431
432    let diagnostics = lint::lint(&doc_state.content);
433    publish_diagnostics(connection, uri, version, diagnostics)?;
434
435    Ok(())
436}
437
438fn handle_did_close(
439    connection: &Connection,
440    notif: lsp_server::Notification,
441    documents: &mut HashMap<Url, DocumentState>,
442) -> Result<()> {
443    let params: DidCloseTextDocumentParams = serde_json::from_value(notif.params)?;
444    let uri = params.text_document.uri;
445
446    documents.remove(&uri);
447
448    let publish_params = PublishDiagnosticsParams {
449        uri,
450        diagnostics: vec![],
451        version: None,
452    };
453
454    let notification = Notification {
455        method: PublishDiagnostics::METHOD.to_owned(),
456        params: serde_json::to_value(publish_params)?,
457    };
458
459    connection
460        .sender
461        .send(Message::Notification(notification))?;
462
463    Ok(())
464}
465
466#[derive(serde::Deserialize)]
467struct SyntaxTreeParams {
468    #[serde(rename = "textDocument")]
469    text_document: lsp_types::TextDocumentIdentifier,
470}
471
472fn handle_syntax_tree(
473    connection: &Connection,
474    req: lsp_server::Request,
475    documents: &HashMap<Url, DocumentState>,
476) -> Result<()> {
477    let params: SyntaxTreeParams = serde_json::from_value(req.params)?;
478    let uri = params.text_document.uri;
479
480    info!("Generating syntax tree for: {uri}");
481
482    let content = documents.get(&uri).map_or("", |doc| &doc.content);
483
484    let parse: Parse<SourceFile> = SourceFile::parse(content);
485    let syntax_tree = format!("{:#?}", parse.syntax_node());
486
487    let resp = Response {
488        id: req.id,
489        result: Some(serde_json::to_value(&syntax_tree).unwrap()),
490        error: None,
491    };
492
493    connection.sender.send(Message::Response(resp))?;
494    Ok(())
495}
496
497#[derive(serde::Deserialize)]
498struct TokensParams {
499    #[serde(rename = "textDocument")]
500    text_document: lsp_types::TextDocumentIdentifier,
501}
502
503fn handle_tokens(
504    connection: &Connection,
505    req: lsp_server::Request,
506    documents: &HashMap<Url, DocumentState>,
507) -> Result<()> {
508    let params: TokensParams = serde_json::from_value(req.params)?;
509    let uri = params.text_document.uri;
510
511    info!("Generating tokens for: {uri}");
512
513    let content = documents.get(&uri).map_or("", |doc| &doc.content);
514
515    let tokens = squawk_lexer::tokenize(content);
516
517    let mut output = Vec::new();
518    let mut char_pos = 0;
519    for token in tokens {
520        let token_start = char_pos;
521        let token_end = token_start + token.len as usize;
522        let token_text = &content[token_start..token_end];
523        output.push(format!(
524            "{:?}@{}..{} {:?}",
525            token.kind, token_start, token_end, token_text
526        ));
527        char_pos = token_end;
528    }
529
530    let tokens_output = output.join("\n");
531
532    let resp = Response {
533        id: req.id,
534        result: Some(serde_json::to_value(&tokens_output).unwrap()),
535        error: None,
536    };
537
538    connection.sender.send(Message::Response(resp))?;
539    Ok(())
540}