Skip to main content

squawk_server/
lib.rs

1use ::line_index::LineIndex;
2use anyhow::{Context, Result};
3use etcetera::BaseStrategy;
4use log::info;
5use lsp_server::{Connection, Message, Notification, Response};
6use lsp_types::{
7    CodeAction, CodeActionKind, CodeActionOptions, CodeActionOrCommand, CodeActionParams,
8    CodeActionProviderCapability, CodeActionResponse, Command, CompletionOptions, CompletionParams,
9    CompletionResponse, Diagnostic, DidChangeTextDocumentParams, DidCloseTextDocumentParams,
10    DidOpenTextDocumentParams, DocumentSymbol, DocumentSymbolParams, GotoDefinitionParams,
11    GotoDefinitionResponse, Hover, HoverContents, HoverParams, HoverProviderCapability,
12    InitializeParams, InlayHint, InlayHintKind, InlayHintLabel, InlayHintLabelPart,
13    InlayHintParams, LanguageString, Location, MarkedString, OneOf, PublishDiagnosticsParams,
14    ReferenceParams, SelectionRangeParams, SelectionRangeProviderCapability, ServerCapabilities,
15    SymbolKind, TextDocumentSyncCapability, TextDocumentSyncKind, Url, WorkDoneProgressOptions,
16    WorkspaceEdit,
17    notification::{
18        DidChangeTextDocument, DidCloseTextDocument, DidOpenTextDocument, Notification as _,
19        PublishDiagnostics,
20    },
21    request::{
22        CodeActionRequest, Completion, DocumentSymbolRequest, GotoDefinition, HoverRequest,
23        InlayHintRequest, References, Request, SelectionRangeRequest,
24    },
25};
26use rowan::TextRange;
27use salsa::Setter;
28use squawk_ide::completion::completion;
29use squawk_ide::document_symbols::{DocumentSymbolKind, document_symbols};
30use squawk_ide::find_references::find_references;
31use squawk_ide::goto_definition::goto_definition;
32use squawk_ide::hover::hover;
33use squawk_ide::inlay_hints::inlay_hints;
34use squawk_ide::{builtins::BUILTINS_SQL, code_actions::code_actions};
35use std::{collections::HashMap, fs, sync::OnceLock};
36
37use diagnostic::DIAGNOSTIC_NAME;
38
39use crate::db::{Database, File, line_index, parse};
40use crate::diagnostic::AssociatedDiagnosticData;
41mod db;
42mod diagnostic;
43mod ignore;
44mod lint;
45mod lsp_utils;
46
47fn builtins_url() -> Option<Url> {
48    // TODO: once we get salsa setup, we can migrate this over
49    static BUILTINS_URL: OnceLock<Option<Url>> = OnceLock::new();
50    BUILTINS_URL
51        .get_or_init(|| {
52            let strategy = etcetera::base_strategy::choose_base_strategy().ok()?;
53            let config_dir = strategy.config_dir();
54            let cache_dir = config_dir.join("squawk/stubs");
55            let path = cache_dir.join("builtins.sql");
56            fs::create_dir_all(cache_dir).ok()?;
57            fs::write(&path, BUILTINS_SQL).ok()?;
58            Url::from_file_path(&path).ok()
59        })
60        .clone()
61}
62
63struct DocumentState {
64    content: String,
65    #[allow(dead_code)]
66    version: i32,
67}
68
69trait FileSystem {
70    fn db(&self) -> &Database;
71    fn file(&self, uri: &Url) -> Option<File>;
72    fn set(&mut self, uri: Url, state: DocumentState);
73    fn remove(&mut self, uri: &Url);
74}
75
76struct FileDatabase {
77    pub db: Database,
78    files: HashMap<Url, File>,
79}
80
81impl FileDatabase {
82    fn new() -> Self {
83        Self {
84            db: Database::default(),
85            files: HashMap::new(),
86        }
87    }
88}
89
90impl FileSystem for FileDatabase {
91    fn db(&self) -> &Database {
92        return &self.db;
93    }
94
95    fn file(&self, uri: &Url) -> Option<File> {
96        self.files.get(uri).copied()
97    }
98
99    fn set(&mut self, uri: Url, state: DocumentState) {
100        if let Some(file) = self.files.get(&uri).copied() {
101            file.set_content(&mut self.db).to(state.content);
102            file.set_version(&mut self.db).to(state.version);
103        } else {
104            let file = File::new(&self.db, state.content, state.version);
105            self.files.insert(uri, file);
106        }
107    }
108
109    fn remove(&mut self, uri: &Url) {
110        self.files.remove(uri);
111    }
112}
113
114pub fn run() -> Result<()> {
115    info!("Starting Squawk LSP server");
116
117    let (connection, io_threads) = Connection::stdio();
118
119    let server_capabilities = serde_json::to_value(&ServerCapabilities {
120        text_document_sync: Some(TextDocumentSyncCapability::Kind(
121            TextDocumentSyncKind::INCREMENTAL,
122        )),
123        code_action_provider: Some(CodeActionProviderCapability::Options(CodeActionOptions {
124            code_action_kinds: Some(vec![
125                CodeActionKind::QUICKFIX,
126                CodeActionKind::REFACTOR_REWRITE,
127            ]),
128            work_done_progress_options: WorkDoneProgressOptions {
129                work_done_progress: None,
130            },
131            resolve_provider: None,
132        })),
133        selection_range_provider: Some(SelectionRangeProviderCapability::Simple(true)),
134        references_provider: Some(OneOf::Left(true)),
135        definition_provider: Some(OneOf::Left(true)),
136        hover_provider: Some(HoverProviderCapability::Simple(true)),
137        inlay_hint_provider: Some(OneOf::Left(true)),
138        document_symbol_provider: Some(OneOf::Left(true)),
139        completion_provider: Some(CompletionOptions {
140            resolve_provider: Some(false),
141            trigger_characters: Some(vec![".".to_owned()]),
142            all_commit_characters: None,
143            work_done_progress_options: WorkDoneProgressOptions {
144                work_done_progress: None,
145            },
146            completion_item: None,
147        }),
148        ..Default::default()
149    })
150    .unwrap();
151
152    info!("LSP server initializing connection...");
153    let initialization_params = connection.initialize(server_capabilities)?;
154    info!("LSP server initialized, entering main loop");
155
156    main_loop(connection, initialization_params)?;
157
158    info!("LSP server shutting down");
159
160    io_threads.join()?;
161    Ok(())
162}
163
164fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
165    info!("Server main loop");
166
167    let init_params: InitializeParams = serde_json::from_value(params).unwrap_or_default();
168    info!("Client process ID: {:?}", init_params.process_id);
169    let client_name = init_params.client_info.map(|x| x.name);
170    info!("Client name: {client_name:?}");
171
172    let mut file_system = FileDatabase::new();
173
174    for msg in &connection.receiver {
175        match msg {
176            Message::Request(req) => {
177                info!("Received request: method={}, id={:?}", req.method, req.id);
178
179                if connection.handle_shutdown(&req)? {
180                    info!("Received shutdown request, exiting");
181                    return Ok(());
182                }
183
184                match req.method.as_ref() {
185                    GotoDefinition::METHOD => {
186                        handle_goto_definition(&connection, req, &file_system)?;
187                    }
188                    HoverRequest::METHOD => {
189                        handle_hover(&connection, req, &file_system)?;
190                    }
191                    CodeActionRequest::METHOD => {
192                        handle_code_action(&connection, req, &file_system)?;
193                    }
194                    SelectionRangeRequest::METHOD => {
195                        handle_selection_range(&connection, req, &file_system)?;
196                    }
197                    InlayHintRequest::METHOD => {
198                        handle_inlay_hints(&connection, req, &file_system)?;
199                    }
200                    DocumentSymbolRequest::METHOD => {
201                        handle_document_symbol(&connection, req, &file_system)?;
202                    }
203                    Completion::METHOD => {
204                        handle_completion(&connection, req, &file_system)?;
205                    }
206                    "squawk/syntaxTree" => {
207                        handle_syntax_tree(&connection, req, &file_system)?;
208                    }
209                    "squawk/tokens" => {
210                        handle_tokens(&connection, req, &file_system)?;
211                    }
212                    References::METHOD => {
213                        handle_references(&connection, req, &file_system)?;
214                    }
215                    _ => {
216                        info!("Ignoring unhandled request: {}", req.method);
217                    }
218                }
219            }
220            Message::Response(resp) => {
221                info!("Received response: id={:?}", resp.id);
222            }
223            Message::Notification(notif) => {
224                info!("Received notification: method={}", notif.method);
225                match notif.method.as_ref() {
226                    DidOpenTextDocument::METHOD => {
227                        handle_did_open(&connection, notif, &mut file_system)?;
228                    }
229                    DidChangeTextDocument::METHOD => {
230                        handle_did_change(&connection, notif, &mut file_system)?;
231                    }
232                    DidCloseTextDocument::METHOD => {
233                        handle_did_close(&connection, notif, &mut file_system)?;
234                    }
235                    _ => {
236                        info!("Ignoring unhandled notification: {}", notif.method);
237                    }
238                }
239            }
240        }
241    }
242    Ok(())
243}
244
245fn handle_goto_definition(
246    connection: &Connection,
247    req: lsp_server::Request,
248    file_system: &impl FileSystem,
249) -> Result<()> {
250    let params: GotoDefinitionParams = serde_json::from_value(req.params)?;
251    let uri = params.text_document_position_params.text_document.uri;
252    let position = params.text_document_position_params.position;
253
254    let db = file_system.db();
255    let file = file_system.file(&uri).unwrap();
256    let parse = parse(db, file);
257    let line_index = line_index(db, file);
258    let offset = lsp_utils::offset(&line_index, position).unwrap();
259
260    let ranges = goto_definition(&parse.tree(), offset)
261        .into_iter()
262        .filter_map(|location| {
263            debug_assert!(
264                !location.range.contains(offset),
265                "Our target destination range must not include the source range otherwise go to def won't work in vscode."
266            );
267
268            let uri = match location.file {
269                squawk_ide::goto_definition::FileId::Current => uri.clone(),
270                squawk_ide::goto_definition::FileId::Builtins => builtins_url()?,
271            };
272
273            let line_index = match location.file {
274                squawk_ide::goto_definition::FileId::Current => &line_index,
275                squawk_ide::goto_definition::FileId::Builtins => &LineIndex::new(BUILTINS_SQL),
276            };
277            let range = lsp_utils::range(line_index, location.range);
278
279            Some(Location {
280                uri,
281                range,
282            })
283        })
284        .collect();
285
286    let result = GotoDefinitionResponse::Array(ranges);
287    let resp = Response {
288        id: req.id,
289        result: Some(serde_json::to_value(&result).unwrap()),
290        error: None,
291    };
292
293    connection.sender.send(Message::Response(resp))?;
294    Ok(())
295}
296
297fn handle_hover(
298    connection: &Connection,
299    req: lsp_server::Request,
300    file_system: &impl FileSystem,
301) -> Result<()> {
302    let params: HoverParams = serde_json::from_value(req.params)?;
303    let uri = params.text_document_position_params.text_document.uri;
304    let position = params.text_document_position_params.position;
305
306    let db = file_system.db();
307    let file = file_system.file(&uri).unwrap();
308    let parse = parse(db, file);
309    let line_index = line_index(db, file);
310    let offset = lsp_utils::offset(&line_index, position).unwrap();
311
312    let type_info = hover(&parse.tree(), offset);
313
314    let result = type_info.map(|type_str| Hover {
315        contents: HoverContents::Scalar(MarkedString::LanguageString(LanguageString {
316            language: "sql".to_string(),
317            value: type_str,
318        })),
319        range: None,
320    });
321
322    let resp = Response {
323        id: req.id,
324        result: Some(serde_json::to_value(&result).unwrap()),
325        error: None,
326    };
327
328    connection.sender.send(Message::Response(resp))?;
329    Ok(())
330}
331
332fn handle_inlay_hints(
333    connection: &Connection,
334    req: lsp_server::Request,
335    file_system: &impl FileSystem,
336) -> Result<()> {
337    let params: InlayHintParams = serde_json::from_value(req.params)?;
338    let uri = params.text_document.uri;
339
340    let db = file_system.db();
341    let file = file_system.file(&uri).unwrap();
342    let parse = parse(db, file);
343    let line_index = line_index(db, file);
344
345    // TODO: move this to a tracked function
346    let hints = inlay_hints(&parse.tree());
347
348    let lsp_hints: Vec<InlayHint> = hints
349        .into_iter()
350        .flat_map(|hint| {
351            let line_col = line_index.line_col(hint.position);
352            let position = lsp_types::Position::new(line_col.line, line_col.col);
353
354            let uri = match hint.file {
355                Some(squawk_ide::goto_definition::FileId::Current) | None => uri.clone(),
356                Some(squawk_ide::goto_definition::FileId::Builtins) => builtins_url()?,
357            };
358
359            let line_index = match hint.file {
360                Some(squawk_ide::goto_definition::FileId::Current) | None => &line_index,
361                Some(squawk_ide::goto_definition::FileId::Builtins) => {
362                    &LineIndex::new(BUILTINS_SQL)
363                }
364            };
365
366            let kind: InlayHintKind = match hint.kind {
367                squawk_ide::inlay_hints::InlayHintKind::Type => InlayHintKind::TYPE,
368                squawk_ide::inlay_hints::InlayHintKind::Parameter => InlayHintKind::PARAMETER,
369            };
370
371            let label = if let Some(target_range) = hint.target {
372                InlayHintLabel::LabelParts(vec![InlayHintLabelPart {
373                    value: hint.label,
374                    location: Some(Location {
375                        uri: uri.clone(),
376                        range: lsp_utils::range(line_index, target_range),
377                    }),
378                    tooltip: None,
379                    command: None,
380                }])
381            } else {
382                InlayHintLabel::String(hint.label)
383            };
384
385            Some(InlayHint {
386                position,
387                label,
388                kind: Some(kind),
389                text_edits: None,
390                tooltip: None,
391                padding_left: None,
392                padding_right: None,
393                data: None,
394            })
395        })
396        .collect();
397
398    let resp = Response {
399        id: req.id,
400        result: Some(serde_json::to_value(&lsp_hints).unwrap()),
401        error: None,
402    };
403
404    connection.sender.send(Message::Response(resp))?;
405    Ok(())
406}
407
408fn handle_document_symbol(
409    connection: &Connection,
410    req: lsp_server::Request,
411    file_system: &impl FileSystem,
412) -> Result<()> {
413    let params: DocumentSymbolParams = serde_json::from_value(req.params)?;
414    let uri = params.text_document.uri;
415
416    let db = file_system.db();
417    let file = file_system.file(&uri).unwrap();
418    let parse = parse(db, file);
419    let line_index = line_index(db, file);
420
421    let symbols = document_symbols(&parse.tree());
422
423    fn convert_symbol(
424        sym: squawk_ide::document_symbols::DocumentSymbol,
425        line_index: &LineIndex,
426    ) -> DocumentSymbol {
427        let range = lsp_utils::range(line_index, sym.full_range);
428        let selection_range = lsp_utils::range(line_index, sym.focus_range);
429
430        let children = sym
431            .children
432            .into_iter()
433            .map(|child| convert_symbol(child, line_index))
434            .collect::<Vec<_>>();
435
436        let children = (!children.is_empty()).then_some(children);
437
438        DocumentSymbol {
439            name: sym.name,
440            detail: sym.detail,
441            kind: match sym.kind {
442                DocumentSymbolKind::Schema => SymbolKind::NAMESPACE,
443                DocumentSymbolKind::Table => SymbolKind::STRUCT,
444                DocumentSymbolKind::View => SymbolKind::STRUCT,
445                DocumentSymbolKind::MaterializedView => SymbolKind::STRUCT,
446                DocumentSymbolKind::Function => SymbolKind::FUNCTION,
447                DocumentSymbolKind::Aggregate => SymbolKind::FUNCTION,
448                DocumentSymbolKind::Procedure => SymbolKind::FUNCTION,
449                DocumentSymbolKind::Type => SymbolKind::CLASS,
450                DocumentSymbolKind::Enum => SymbolKind::ENUM,
451                DocumentSymbolKind::Index => SymbolKind::KEY,
452                DocumentSymbolKind::Domain => SymbolKind::CLASS,
453                DocumentSymbolKind::Sequence => SymbolKind::CONSTANT,
454                DocumentSymbolKind::Trigger => SymbolKind::EVENT,
455                DocumentSymbolKind::Tablespace => SymbolKind::NAMESPACE,
456                DocumentSymbolKind::Database => SymbolKind::MODULE,
457                DocumentSymbolKind::Server => SymbolKind::OBJECT,
458                DocumentSymbolKind::Extension => SymbolKind::PACKAGE,
459                DocumentSymbolKind::Column => SymbolKind::FIELD,
460                DocumentSymbolKind::Variant => SymbolKind::ENUM_MEMBER,
461                DocumentSymbolKind::Cursor => SymbolKind::VARIABLE,
462                DocumentSymbolKind::PreparedStatement => SymbolKind::VARIABLE,
463                DocumentSymbolKind::Channel => SymbolKind::EVENT,
464                DocumentSymbolKind::EventTrigger => SymbolKind::EVENT,
465                DocumentSymbolKind::Role => SymbolKind::CLASS,
466                DocumentSymbolKind::Policy => SymbolKind::VARIABLE,
467            },
468            tags: None,
469            range,
470            selection_range,
471            children,
472            #[allow(deprecated)]
473            deprecated: None,
474        }
475    }
476
477    let lsp_symbols: Vec<DocumentSymbol> = symbols
478        .into_iter()
479        .map(|sym| convert_symbol(sym, &line_index))
480        .collect();
481
482    let resp = Response {
483        id: req.id,
484        result: Some(serde_json::to_value(&lsp_symbols).unwrap()),
485        error: None,
486    };
487
488    connection.sender.send(Message::Response(resp))?;
489    Ok(())
490}
491
492fn handle_selection_range(
493    connection: &Connection,
494    req: lsp_server::Request,
495    file_system: &impl FileSystem,
496) -> Result<()> {
497    let params: SelectionRangeParams = serde_json::from_value(req.params)?;
498    let uri = params.text_document.uri;
499
500    let db = file_system.db();
501    let file = file_system.file(&uri).unwrap();
502    let parse = parse(db, file);
503    let root = parse.syntax_node();
504    let line_index = line_index(db, file);
505
506    let mut selection_ranges = vec![];
507
508    for position in params.positions {
509        let Some(offset) = lsp_utils::offset(&line_index, position) else {
510            continue;
511        };
512
513        let mut ranges = Vec::new();
514        {
515            let mut range = TextRange::new(offset, offset);
516            loop {
517                ranges.push(range);
518                let next = squawk_ide::expand_selection::extend_selection(&root, range);
519                if next == range {
520                    break;
521                } else {
522                    range = next
523                }
524            }
525        }
526
527        let mut range = lsp_types::SelectionRange {
528            range: lsp_utils::range(&line_index, *ranges.last().unwrap()),
529            parent: None,
530        };
531        for &r in ranges.iter().rev().skip(1) {
532            range = lsp_types::SelectionRange {
533                range: lsp_utils::range(&line_index, r),
534                parent: Some(Box::new(range)),
535            }
536        }
537        selection_ranges.push(range);
538    }
539
540    let resp = Response {
541        id: req.id,
542        result: Some(serde_json::to_value(&selection_ranges).unwrap()),
543        error: None,
544    };
545
546    connection.sender.send(Message::Response(resp))?;
547    Ok(())
548}
549
550fn handle_references(
551    connection: &Connection,
552    req: lsp_server::Request,
553    file_system: &impl FileSystem,
554) -> Result<()> {
555    let params: ReferenceParams = serde_json::from_value(req.params)?;
556    let uri = params.text_document_position.text_document.uri;
557    let position = params.text_document_position.position;
558
559    let db = file_system.db();
560    let file = file_system.file(&uri).unwrap();
561    let parse = parse(db, file);
562    let line_index = line_index(db, file);
563    let offset = lsp_utils::offset(&line_index, position).unwrap();
564
565    let refs = find_references(&parse.tree(), offset);
566    let include_declaration = params.context.include_declaration;
567
568    let locations: Vec<Location> = refs
569        .into_iter()
570        .filter(|loc| include_declaration || !loc.range.contains(offset))
571        .filter_map(|loc| {
572            let uri = match loc.file {
573                squawk_ide::goto_definition::FileId::Current => uri.clone(),
574                squawk_ide::goto_definition::FileId::Builtins => builtins_url()?,
575            };
576            let line_index = match loc.file {
577                squawk_ide::goto_definition::FileId::Current => &line_index,
578                squawk_ide::goto_definition::FileId::Builtins => &LineIndex::new(BUILTINS_SQL),
579            };
580            Some(Location {
581                uri,
582                range: lsp_utils::range(line_index, loc.range),
583            })
584        })
585        .collect();
586
587    let resp = Response {
588        id: req.id,
589        result: Some(serde_json::to_value(&locations).unwrap()),
590        error: None,
591    };
592
593    connection.sender.send(Message::Response(resp))?;
594    Ok(())
595}
596
597fn handle_completion(
598    connection: &Connection,
599    req: lsp_server::Request,
600    file_system: &impl FileSystem,
601) -> Result<()> {
602    let params: CompletionParams = serde_json::from_value(req.params)?;
603    let uri = params.text_document_position.text_document.uri;
604    let position = params.text_document_position.position;
605
606    let db = file_system.db();
607    let file = file_system.file(&uri).unwrap();
608    let parse = parse(db, file);
609    let line_index = line_index(db, file);
610
611    let Some(offset) = lsp_utils::offset(&line_index, position) else {
612        let resp = Response {
613            id: req.id,
614            result: Some(serde_json::to_value(CompletionResponse::Array(vec![])).unwrap()),
615            error: None,
616        };
617        connection.sender.send(Message::Response(resp))?;
618        return Ok(());
619    };
620
621    // TODO: move this to a tracked function
622    let completion_items = completion(&parse.tree(), offset)
623        .into_iter()
624        .map(lsp_utils::completion_item)
625        .collect();
626
627    let result = CompletionResponse::Array(completion_items);
628
629    let resp = Response {
630        id: req.id,
631        result: Some(serde_json::to_value(&result).unwrap()),
632        error: None,
633    };
634
635    connection.sender.send(Message::Response(resp))?;
636    Ok(())
637}
638
639fn handle_code_action(
640    connection: &Connection,
641    req: lsp_server::Request,
642    file_system: &impl FileSystem,
643) -> Result<()> {
644    let params: CodeActionParams = serde_json::from_value(req.params)?;
645    let uri = params.text_document.uri;
646
647    let mut actions: CodeActionResponse = Vec::new();
648
649    let db = file_system.db();
650    let file = file_system.file(&uri).unwrap();
651    let parse = parse(db, file);
652    let line_index = line_index(db, file);
653    let offset = lsp_utils::offset(&line_index, params.range.start).unwrap();
654
655    // TODO: move this to a tracked function
656    let ide_actions = code_actions(parse.tree(), offset).unwrap_or_default();
657
658    for action in ide_actions {
659        let lsp_action = lsp_utils::code_action(&line_index, uri.clone(), action);
660        actions.push(CodeActionOrCommand::CodeAction(lsp_action));
661    }
662
663    for mut diagnostic in params
664        .context
665        .diagnostics
666        .into_iter()
667        .filter(|diagnostic| diagnostic.source.as_deref() == Some(DIAGNOSTIC_NAME))
668    {
669        let Some(rule_name) = diagnostic.code.as_ref().map(|x| match x {
670            lsp_types::NumberOrString::String(s) => s.clone(),
671            lsp_types::NumberOrString::Number(n) => n.to_string(),
672        }) else {
673            continue;
674        };
675        let Some(data) = diagnostic.data.take() else {
676            continue;
677        };
678
679        let associated_data: AssociatedDiagnosticData =
680            serde_json::from_value(data).context("deserializing diagnostic data")?;
681
682        if let Some(ignore_line_edit) = associated_data.ignore_line_edit {
683            let disable_line_action = CodeAction {
684                title: format!("Disable {rule_name} for this line"),
685                kind: Some(CodeActionKind::QUICKFIX),
686                diagnostics: Some(vec![diagnostic.clone()]),
687                edit: Some(WorkspaceEdit {
688                    changes: Some({
689                        let mut changes = HashMap::new();
690                        changes.insert(uri.clone(), vec![ignore_line_edit]);
691                        changes
692                    }),
693                    ..Default::default()
694                }),
695                command: None,
696                is_preferred: Some(false),
697                disabled: None,
698                data: None,
699            };
700            actions.push(CodeActionOrCommand::CodeAction(disable_line_action));
701        }
702        if let Some(ignore_file_edit) = associated_data.ignore_file_edit {
703            let disable_file_action = CodeAction {
704                title: format!("Disable {rule_name} for the entire file"),
705                kind: Some(CodeActionKind::QUICKFIX),
706                diagnostics: Some(vec![diagnostic.clone()]),
707                edit: Some(WorkspaceEdit {
708                    changes: Some({
709                        let mut changes = HashMap::new();
710                        changes.insert(uri.clone(), vec![ignore_file_edit]);
711                        changes
712                    }),
713                    ..Default::default()
714                }),
715                command: None,
716                is_preferred: Some(false),
717                disabled: None,
718                data: None,
719            };
720            actions.push(CodeActionOrCommand::CodeAction(disable_file_action));
721        }
722
723        let title = format!("Show documentation for {rule_name}");
724        let documentation_action = CodeAction {
725            title: title.clone(),
726            kind: Some(CodeActionKind::QUICKFIX),
727            diagnostics: Some(vec![diagnostic.clone()]),
728            edit: None,
729            command: Some(Command {
730                title,
731                command: "vscode.open".to_string(),
732                arguments: Some(vec![serde_json::to_value(format!(
733                    "https://squawkhq.com/docs/{rule_name}"
734                ))?]),
735            }),
736            is_preferred: Some(false),
737            disabled: None,
738            data: None,
739        };
740        actions.push(CodeActionOrCommand::CodeAction(documentation_action));
741
742        if !associated_data.title.is_empty() && !associated_data.edits.is_empty() {
743            let fix_action = CodeAction {
744                title: associated_data.title,
745                kind: Some(CodeActionKind::QUICKFIX),
746                diagnostics: Some(vec![diagnostic.clone()]),
747                edit: Some(WorkspaceEdit {
748                    changes: Some({
749                        let mut changes = HashMap::new();
750                        changes.insert(uri.clone(), associated_data.edits);
751                        changes
752                    }),
753                    ..Default::default()
754                }),
755                command: None,
756                is_preferred: Some(true),
757                disabled: None,
758                data: None,
759            };
760            actions.push(CodeActionOrCommand::CodeAction(fix_action));
761        }
762    }
763
764    let result: CodeActionResponse = actions;
765    let resp = Response {
766        id: req.id,
767        result: Some(serde_json::to_value(&result).unwrap()),
768        error: None,
769    };
770
771    connection.sender.send(Message::Response(resp))?;
772    Ok(())
773}
774
775fn publish_diagnostics(
776    connection: &Connection,
777    uri: Url,
778    version: i32,
779    diagnostics: Vec<Diagnostic>,
780) -> Result<()> {
781    let publish_params = PublishDiagnosticsParams {
782        uri,
783        diagnostics,
784        version: Some(version),
785    };
786
787    let notification = Notification {
788        method: PublishDiagnostics::METHOD.to_owned(),
789        params: serde_json::to_value(publish_params)?,
790    };
791
792    connection
793        .sender
794        .send(Message::Notification(notification))?;
795    Ok(())
796}
797
798fn handle_did_open(
799    connection: &Connection,
800    notif: lsp_server::Notification,
801    file_system: &mut impl FileSystem,
802) -> Result<()> {
803    let params: DidOpenTextDocumentParams = serde_json::from_value(notif.params)?;
804    let uri = params.text_document.uri;
805    let content = params.text_document.text;
806    let version = params.text_document.version;
807
808    // TODO: move this to a tracked function
809    let diagnostics = lint::lint(&content);
810
811    file_system.set(uri.clone(), DocumentState { content, version });
812
813    // TODO: we need a better setup for "run func when input changed"
814    publish_diagnostics(connection, uri, version, diagnostics)?;
815
816    Ok(())
817}
818
819fn handle_did_change(
820    connection: &Connection,
821    notif: lsp_server::Notification,
822    file_system: &mut impl FileSystem,
823) -> Result<()> {
824    let params: DidChangeTextDocumentParams = serde_json::from_value(notif.params)?;
825    let uri = params.text_document.uri;
826    let version = params.text_document.version;
827
828    let db = file_system.db();
829    let file = file_system.file(&uri).unwrap();
830    let content = file.content(db);
831
832    let updated_content = lsp_utils::apply_incremental_changes(content, params.content_changes);
833
834    // TODO: move this to a tracked function
835    let diagnostics = lint::lint(&updated_content);
836    publish_diagnostics(connection, uri.clone(), version, diagnostics)?;
837
838    file_system.set(
839        uri,
840        DocumentState {
841            content: updated_content,
842            version,
843        },
844    );
845
846    Ok(())
847}
848
849fn handle_did_close(
850    connection: &Connection,
851    notif: lsp_server::Notification,
852    file_system: &mut impl FileSystem,
853) -> Result<()> {
854    let params: DidCloseTextDocumentParams = serde_json::from_value(notif.params)?;
855    let uri = params.text_document.uri;
856
857    file_system.remove(&uri);
858
859    let publish_params = PublishDiagnosticsParams {
860        uri,
861        diagnostics: vec![],
862        version: None,
863    };
864
865    let notification = Notification {
866        method: PublishDiagnostics::METHOD.to_owned(),
867        params: serde_json::to_value(publish_params)?,
868    };
869
870    connection
871        .sender
872        .send(Message::Notification(notification))?;
873
874    Ok(())
875}
876
877#[derive(serde::Deserialize)]
878struct SyntaxTreeParams {
879    #[serde(rename = "textDocument")]
880    text_document: lsp_types::TextDocumentIdentifier,
881}
882
883fn handle_syntax_tree(
884    connection: &Connection,
885    req: lsp_server::Request,
886    file_system: &impl FileSystem,
887) -> Result<()> {
888    let params: SyntaxTreeParams = serde_json::from_value(req.params)?;
889    let uri = params.text_document.uri;
890
891    info!("Generating syntax tree for: {uri}");
892
893    let db = file_system.db();
894    let file = file_system.file(&uri).unwrap();
895    let parse = parse(db, file);
896    let syntax_tree = format!("{:#?}", parse.syntax_node());
897
898    let resp = Response {
899        id: req.id,
900        result: Some(serde_json::to_value(&syntax_tree).unwrap()),
901        error: None,
902    };
903
904    connection.sender.send(Message::Response(resp))?;
905    Ok(())
906}
907
908#[derive(serde::Deserialize)]
909struct TokensParams {
910    #[serde(rename = "textDocument")]
911    text_document: lsp_types::TextDocumentIdentifier,
912}
913
914fn handle_tokens(
915    connection: &Connection,
916    req: lsp_server::Request,
917    file_system: &impl FileSystem,
918) -> Result<()> {
919    let params: TokensParams = serde_json::from_value(req.params)?;
920    let uri = params.text_document.uri;
921
922    info!("Generating tokens for: {uri}");
923
924    let db = file_system.db();
925    let file = file_system.file(&uri).unwrap();
926    let content = file.content(db);
927
928    // TODO: move this to a tracked function
929    let tokens = squawk_lexer::tokenize(content);
930
931    let mut output = Vec::new();
932    let mut char_pos = 0;
933    for token in tokens {
934        let token_start = char_pos;
935        let token_end = token_start + token.len as usize;
936        let token_text = &content[token_start..token_end];
937        output.push(format!(
938            "{:?}@{}..{} {:?}",
939            token.kind, token_start, token_end, token_text
940        ));
941        char_pos = token_end;
942    }
943
944    let tokens_output = output.join("\n");
945
946    let resp = Response {
947        id: req.id,
948        result: Some(serde_json::to_value(&tokens_output).unwrap()),
949        error: None,
950    };
951
952    connection.sender.send(Message::Response(resp))?;
953    Ok(())
954}