1mod utils;
2
3use dashmap::DashMap;
4use serde_json::Value;
5use tower_lsp_server::jsonrpc::Result;
6use tower_lsp_server::ls_types::notification::{DidChangeWatchedFiles, Notification};
7use tower_lsp_server::ls_types::*;
8use tower_lsp_server::{Client, LanguageServer};
9
10use ast_grep_config::{CombinedScan, RuleCollection, Severity};
11use ast_grep_core::{
12 tree_sitter::{LanguageExt, StrDoc},
13 AstGrep, Doc,
14};
15
16use std::collections::{BTreeMap, HashMap};
17use std::path::PathBuf;
18use std::str::FromStr;
19use std::sync::{Arc, RwLock};
20
21use utils::{convert_match_to_diagnostic, diagnostic_to_code_action, Fixes, RewriteData};
22
23pub use tower_lsp_server::{LspService, Server};
24
25pub trait LSPLang: LanguageExt + Eq + Send + Sync + FromStr + 'static {}
26impl<T> LSPLang for T where T: LanguageExt + Eq + Send + Sync + FromStr + 'static {}
27
28type Notes = BTreeMap<(u32, u32, u32, u32), Arc<String>>;
29
30struct VersionedAst<D: Doc> {
31 version: i32,
32 root: AstGrep<D>,
33 notes: Notes,
34 fixes: Fixes,
35}
36
37pub struct Backend<L: LSPLang> {
38 client: Client,
39 map: DashMap<String, VersionedAst<StrDoc<L>>>,
40 base: PathBuf,
41 rules: Arc<RwLock<RuleCollection<L>>>,
42 interner: DashMap<String, Arc<String>>,
44 rule_finder: Box<dyn Fn() -> anyhow::Result<RuleCollection<L>> + Send + Sync>,
46 capabilities: Arc<RwLock<ClientCapabilities>>,
48}
49
50const FALLBACK_CODE_ACTION_PROVIDER: Option<CodeActionProviderCapability> =
51 Some(CodeActionProviderCapability::Simple(true));
52
53const APPLY_ALL_FIXES: &str = "ast-grep.applyAllFixes";
54const QUICKFIX_AST_GREP: &str = "quickfix.ast-grep";
55const FIX_ALL_AST_GREP: &str = "source.fixAll.ast-grep";
56
57fn code_action_provider(
58 client_capability: &ClientCapabilities,
59) -> Option<CodeActionProviderCapability> {
60 let is_literal_supported = client_capability
61 .text_document
62 .as_ref()?
63 .code_action
64 .as_ref()?
65 .code_action_literal_support
66 .is_some();
67 if !is_literal_supported {
68 return None;
69 }
70 Some(CodeActionProviderCapability::Options(CodeActionOptions {
71 code_action_kinds: Some(vec![
72 CodeActionKind::new(QUICKFIX_AST_GREP),
73 CodeActionKind::new(FIX_ALL_AST_GREP),
74 ]),
75 work_done_progress_options: Default::default(),
76 resolve_provider: Some(true),
77 }))
78}
79
80impl<L: LSPLang> LanguageServer for Backend<L> {
81 async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
82 let code_action_provider = code_action_provider(¶ms.capabilities);
83 if let Ok(mut cap) = self.capabilities.write() {
84 *cap = params.capabilities;
85 }
86 Ok(InitializeResult {
87 server_info: Some(ServerInfo {
88 name: "ast-grep language server".to_string(),
89 version: None,
90 }),
91 capabilities: ServerCapabilities {
92 text_document_sync: Some(TextDocumentSyncCapability::Kind(TextDocumentSyncKind::FULL)),
94 code_action_provider: code_action_provider.or(FALLBACK_CODE_ACTION_PROVIDER),
95 execute_command_provider: Some(ExecuteCommandOptions {
96 commands: vec![APPLY_ALL_FIXES.to_string()],
97 work_done_progress_options: Default::default(),
98 }),
99 hover_provider: Some(HoverProviderCapability::Simple(true)),
100 ..ServerCapabilities::default()
101 },
102 offset_encoding: None,
103 })
104 }
105
106 async fn initialized(&self, _: InitializedParams) {
107 self
108 .client
109 .log_message(MessageType::INFO, "server initialized!")
110 .await;
111 if let Err(e) = self.reload_rules().await {
112 self
113 .client
114 .show_message(MessageType::ERROR, format!("Failed to load rules: {e}"))
115 .await;
116 }
117
118 if let Err(e) = self.register_file_watchers().await {
120 self
121 .client
122 .log_message(
123 MessageType::ERROR,
124 format!("Failed to register file watchers: {e:?}"),
125 )
126 .await;
127 }
128 }
129
130 async fn shutdown(&self) -> Result<()> {
131 Ok(())
132 }
133
134 async fn did_change_workspace_folders(&self, _: DidChangeWorkspaceFoldersParams) {
135 self
136 .client
137 .log_message(MessageType::INFO, "workspace folders changed!")
138 .await;
139 }
140
141 async fn did_change_configuration(&self, _: DidChangeConfigurationParams) {
142 self
143 .client
144 .log_message(MessageType::INFO, "configuration changed!")
145 .await;
146 }
147
148 async fn did_change_watched_files(&self, _params: DidChangeWatchedFilesParams) {
149 self
151 .client
152 .log_message(
153 MessageType::INFO,
154 "Configuration files changed, reloading rules...",
155 )
156 .await;
157
158 if let Err(e) = self.reload_rules().await {
159 self
160 .client
161 .show_message(MessageType::ERROR, format!("Failed to reload rules: {e}"))
162 .await;
163 } else {
164 self
165 .client
166 .log_message(MessageType::INFO, "Rules reloaded successfully")
167 .await;
168 }
169 }
170 async fn did_open(&self, params: DidOpenTextDocumentParams) {
171 self
172 .client
173 .log_message(MessageType::INFO, "file opened!")
174 .await;
175 self.on_open(params).await;
176 }
177
178 async fn did_change(&self, params: DidChangeTextDocumentParams) {
179 self.on_change(params).await;
180 }
181
182 async fn did_save(&self, _: DidSaveTextDocumentParams) {
183 self
184 .client
185 .log_message(MessageType::INFO, "file saved!")
186 .await;
187 }
188
189 async fn did_close(&self, params: DidCloseTextDocumentParams) {
190 self.on_close(params).await;
191 self
192 .client
193 .log_message(MessageType::INFO, "file closed!")
194 .await;
195 }
196
197 async fn code_action(&self, params: CodeActionParams) -> Result<Option<CodeActionResponse>> {
198 Ok(self.on_code_action(params).await)
199 }
200
201 async fn execute_command(&self, params: ExecuteCommandParams) -> Result<Option<Value>> {
202 Ok(self.on_execute_command(params).await)
203 }
204
205 async fn hover(&self, params: HoverParams) -> Result<Option<Hover>> {
206 self
207 .client
208 .log_message(MessageType::LOG, "Get Hover Notes")
209 .await;
210 Ok(self.do_hover(params.text_document_position_params))
211 }
212}
213
214fn pos_tuple_to_range((line, character, end_line, end_character): (u32, u32, u32, u32)) -> Range {
215 Range {
216 start: Position { line, character },
217 end: Position {
218 line: end_line,
219 character: end_character,
220 },
221 }
222}
223
224impl<L: LSPLang> Backend<L> {
225 pub fn new<F>(client: Client, base: PathBuf, rule_finder: F) -> Self
226 where
227 F: Fn() -> anyhow::Result<RuleCollection<L>> + Send + Sync + 'static,
228 {
229 Self {
230 client,
231 rules: Arc::new(RwLock::new(RuleCollection::default())),
232 base,
233 map: DashMap::new(),
234 interner: DashMap::new(),
235 rule_finder: Box::new(rule_finder),
236 capabilities: Arc::new(RwLock::new(ClientCapabilities::default())),
237 }
238 }
239
240 fn uri_to_relative_path(&self, uri: &Uri) -> Option<PathBuf> {
242 let absolute_path = uri.to_file_path()?;
243 if let Ok(relative_path) = absolute_path.strip_prefix(&self.base) {
244 Some(relative_path.to_path_buf())
245 } else {
246 Some(absolute_path.to_path_buf())
247 }
248 }
249
250 fn do_hover(&self, pos_params: TextDocumentPositionParams) -> Option<Hover> {
251 let uri = pos_params.text_document.uri;
252 let Position {
253 line,
254 character: column,
255 } = pos_params.position;
256 let ast = self.map.get(uri.as_str())?;
257 let query = (line, column, line, column);
258 let (pos, markdown) = ast.notes.range(..=query).next_back()?;
260 if pos.0 > line || pos.2 < line {
262 return None;
263 }
264 if pos.0 == line && pos.1 > column || pos.2 == line && pos.3 < column {
265 return None;
266 }
267 Some(Hover {
268 contents: HoverContents::Markup(MarkupContent {
269 kind: MarkupKind::Markdown,
270 value: markdown.to_string(),
271 }),
272 range: Some(pos_tuple_to_range(*pos)),
273 })
274 }
275
276 fn get_diagnostics(
277 &self,
278 uri: &Uri,
279 root: &AstGrep<StrDoc<L>>,
280 ) -> Option<(Vec<Diagnostic>, Fixes)> {
281 let path = self.uri_to_relative_path(uri)?;
282
283 let rules = self.rules.read().ok()?;
284 let mut diagnostics = vec![];
285 let mut fixes = Fixes::new();
286 let injections = root.get_injections(|lang| L::from_str(lang).ok());
287 let docs = std::iter::once(root).chain(injections.iter());
288 let doc_and_rules = docs.filter_map(|injected| {
289 let rule_refs = rules.get_rule_from_lang(&path, injected.lang().clone());
290 if rule_refs.is_empty() {
291 None
292 } else {
293 Some((injected, rule_refs))
294 }
295 });
296 for (injected, rule_refs) in doc_and_rules {
298 let unused_suppression_rule =
299 CombinedScan::unused_config(Severity::Hint, injected.lang().clone());
300 let mut scan = CombinedScan::new(rule_refs);
301 scan.set_unused_suppression_rule(&unused_suppression_rule);
302 let all_matches = scan.scan(injected, false).matches;
304 let rule_and_matches = all_matches
305 .into_iter()
306 .flat_map(|(rule, ms)| ms.into_iter().map(move |m| (rule, m)));
307 diagnostics.extend(rule_and_matches.map(|(rule, m)| {
308 let diagnostic = convert_match_to_diagnostic(uri, &m, rule);
309 let rewrite_data = RewriteData::from_node_match(&m, rule);
310 if let Some(r) = rewrite_data {
311 fixes.insert((diagnostic.range, rule.id.clone()), r);
312 }
313 diagnostic
314 }));
315 }
316 Some((diagnostics, fixes))
317 }
318
319 fn build_notes(&self, diagnostics: &[Diagnostic]) -> Notes {
320 let mut notes = BTreeMap::new();
321 for diagnostic in diagnostics {
322 let Some(NumberOrString::String(id)) = &diagnostic.code else {
323 continue;
324 };
325 let Ok(rules) = self.rules.read() else {
326 continue;
327 };
328 let Some(note) = rules.get_rule(id).and_then(|r| r.note.clone()) else {
329 continue;
330 };
331 let start = diagnostic.range.start;
332 let end = diagnostic.range.end;
333 let atom = self
334 .interner
335 .entry(id.clone())
336 .or_insert_with(|| Arc::new(note.clone()))
337 .clone();
338 notes.insert((start.line, start.character, end.line, end.character), atom);
339 }
340 notes
341 }
342
343 async fn publish_diagnostics(
344 &self,
345 uri: Uri,
346 version: i32,
347 diagnostics: Vec<Diagnostic>,
348 ) -> Option<()> {
349 self
350 .client
351 .publish_diagnostics(uri, diagnostics, Some(version))
352 .await;
353 Some(())
354 }
355
356 async fn get_path_of_first_workspace(&self) -> Option<std::path::PathBuf> {
357 let client_support_workspace = {
359 let cap = self.capabilities.read().ok()?;
360 cap
361 .workspace
362 .as_ref()
363 .and_then(|w| w.workspace_folders)
364 .unwrap_or(false)
365 };
366 if !client_support_workspace {
367 return None;
368 }
369 let folders = self.client.workspace_folders().await.ok()??;
370 let folder = folders.first()?;
371 folder.uri.to_file_path().map(PathBuf::from)
372 }
373
374 async fn should_skip_file_outside_workspace(&self, text_doc: &TextDocumentItem) -> Option<()> {
376 let workspace_root = self
378 .get_path_of_first_workspace()
379 .await
380 .unwrap_or_else(|| self.base.clone());
381 let doc_file_path = text_doc.uri.to_file_path()?;
382 if doc_file_path.starts_with(workspace_root) {
383 None
384 } else {
385 Some(())
386 }
387 }
388
389 fn compute_diagnostics(
390 &self,
391 uri: &Uri,
392 versioned: &mut VersionedAst<StrDoc<L>>,
393 ) -> Vec<Diagnostic> {
394 let (diagnostics, fixes) = self
395 .get_diagnostics(uri, &versioned.root)
396 .unwrap_or_default();
397 versioned.notes = self.build_notes(&diagnostics);
398 versioned.fixes = fixes;
399 diagnostics
400 }
401
402 async fn on_open(&self, params: DidOpenTextDocumentParams) -> Option<()> {
403 let text_doc = params.text_document;
404 if self
405 .should_skip_file_outside_workspace(&text_doc)
406 .await
407 .is_some()
408 {
409 return None;
410 }
411 let uri = text_doc.uri.as_str().to_owned();
412 self
413 .client
414 .log_message(MessageType::LOG, "Parsing doc.")
415 .await;
416 let (versioned, diagnostics) =
417 self.get_versioned_ast(text_doc.version, &text_doc.uri, &text_doc.text)?;
418 self
419 .client
420 .log_message(MessageType::LOG, "Publishing init diagnostics.")
421 .await;
422 self
423 .publish_diagnostics(text_doc.uri, versioned.version, diagnostics)
424 .await;
425 self.map.insert(uri, versioned); Some(())
427 }
428 fn get_versioned_ast(
429 &self,
430 version: i32,
431 uri: &Uri,
432 text: &str,
433 ) -> Option<(VersionedAst<StrDoc<L>>, Vec<Diagnostic>)> {
434 let lang = Self::infer_lang_from_uri(uri)?;
435 let root = AstGrep::new(text, lang);
436 let mut versioned = VersionedAst {
437 version,
438 root,
439 notes: BTreeMap::new(),
440 fixes: Fixes::new(),
441 };
442 let diagnostics = self.compute_diagnostics(uri, &mut versioned);
443 Some((versioned, diagnostics))
444 }
445
446 async fn on_change(&self, params: DidChangeTextDocumentParams) -> Option<()> {
447 let text_doc = params.text_document;
448 let uri = text_doc.uri.as_str();
449 self
450 .client
451 .log_message(MessageType::LOG, "Parsing changed doc.")
452 .await;
453 let (diagnostics, version) = {
454 let mut versioned = self.map.get_mut(uri)?;
455 if versioned.version > text_doc.version {
457 return None;
458 }
459 let change = ¶ms.content_changes.first()?;
460 let text = &change.text;
461 let (new_version, diagnostics) =
462 self.get_versioned_ast(text_doc.version, &text_doc.uri, text)?;
463 *versioned = new_version;
464 (diagnostics, versioned.version)
465 };
466 self
467 .client
468 .log_message(MessageType::LOG, "Publishing diagnostics.")
469 .await;
470 self
471 .publish_diagnostics(text_doc.uri, version, diagnostics)
472 .await;
473 Some(())
474 }
475 async fn on_close(&self, params: DidCloseTextDocumentParams) {
476 self.map.remove(params.text_document.uri.as_str());
477 }
478
479 fn compute_all_fixes(
480 &self,
481 text_document: TextDocumentIdentifier,
482 ) -> std::result::Result<HashMap<Uri, Vec<TextEdit>>, LspError>
483 where
484 L: ast_grep_core::Language + std::cmp::Eq,
485 {
486 let uri = text_document.uri;
487 let versioned = self
488 .map
489 .get(uri.as_str())
490 .ok_or(LspError::UnsupportedFileType)?;
491 let (_diagnostics, fixes) = self
492 .get_diagnostics(&uri, &versioned.root)
493 .ok_or(LspError::NoActionableFix)?;
494
495 let mut entries: Vec<_> = fixes.iter().collect();
496 entries.sort_by(|((range_a, _), _), ((range_b, _), _)| {
497 range_a
498 .start
499 .cmp(&range_b.start)
500 .then(range_a.end.cmp(&range_b.end))
501 });
502
503 let mut last = Position {
504 line: 0,
505 character: 0,
506 };
507 let edits: Vec<TextEdit> = entries
510 .into_iter()
511 .filter_map(|((range, _id), rewrite_data)| {
512 if range.start < last {
513 return None;
514 }
515 let first_fix = rewrite_data.fixers.first()?;
516 let fixed = first_fix.fixed.to_string();
517 let range = first_fix.range.as_ref().unwrap_or(range);
519 let edit = TextEdit::new(*range, fixed);
520 last = range.end;
521 Some(edit)
522 })
523 .collect();
524 if edits.is_empty() {
525 return Err(LspError::NoActionableFix);
526 }
527 let mut changes = HashMap::new();
528 changes.insert(uri, edits);
529 Ok(changes)
530 }
531
532 async fn on_code_action(&self, params: CodeActionParams) -> Option<CodeActionResponse> {
533 if let Some(kinds) = params.context.only.as_ref() {
534 if kinds.contains(&CodeActionKind::SOURCE_FIX_ALL) {
535 return self.fix_all_code_action(params.text_document);
536 }
537 }
538 self.quickfix_code_action(params)
539 }
540
541 fn fix_all_code_action(
542 &self,
543 text_document: TextDocumentIdentifier,
544 ) -> Option<CodeActionResponse> {
545 let fixed = self.compute_all_fixes(text_document).ok()?;
546 let edit = WorkspaceEdit::new(fixed);
547 let code_action = CodeAction {
548 title: "Fix by ast-grep".into(),
549 command: None,
550 diagnostics: None,
551 edit: Some(edit),
552 kind: Some(CodeActionKind::new(FIX_ALL_AST_GREP)),
553 is_preferred: None,
554 data: None,
555 disabled: None,
556 };
557 Some(vec![CodeActionOrCommand::CodeAction(code_action)])
558 }
559
560 fn quickfix_code_action(&self, params: CodeActionParams) -> Option<CodeActionResponse> {
561 if params.context.diagnostics.is_empty() {
562 return None;
563 }
564 let text_doc = params.text_document;
565
566 let document = self.map.get(text_doc.uri.as_str())?;
567 let fixes_cache = &document.fixes;
568
569 let response = params
570 .context
571 .diagnostics
572 .into_iter()
573 .filter(|d| {
574 d.source
575 .as_ref()
576 .map(|s| s.contains("ast-grep"))
577 .unwrap_or(false)
578 })
579 .filter_map(|d| diagnostic_to_code_action(&text_doc, d, fixes_cache))
580 .flatten()
581 .map(CodeActionOrCommand::from)
582 .collect();
583 Some(response)
584 }
585
586 fn infer_lang_from_uri(uri: &Uri) -> Option<L> {
588 let path = uri.to_file_path()?;
589 L::from_path(path)
590 }
591
592 async fn on_execute_command(&self, params: ExecuteCommandParams) -> Option<Value> {
593 let ExecuteCommandParams {
594 arguments,
595 command,
596 work_done_progress_params: _,
597 } = params;
598
599 match command.as_ref() {
600 APPLY_ALL_FIXES => {
601 self.on_apply_all_fix(command, arguments).await?;
602 None
603 }
604 _ => {
605 self
606 .client
607 .log_message(MessageType::LOG, format!("Unrecognized command: {command}"))
608 .await;
609 None
610 }
611 }
612 }
613
614 async fn on_apply_all_fix_impl(
615 &self,
616 first: Value,
617 ) -> std::result::Result<WorkspaceEdit, LspError> {
618 let text_doc: TextDocumentItem =
619 serde_json::from_value(first).map_err(LspError::JSONDecodeError)?;
620 let uri = text_doc.uri;
621 let changes = self.compute_all_fixes(TextDocumentIdentifier::new(uri))?;
623 let workspace_edit = WorkspaceEdit {
624 changes: Some(changes),
625 document_changes: None,
626 change_annotations: None,
627 };
628 Ok(workspace_edit)
629 }
630
631 async fn on_apply_all_fix(&self, command: String, arguments: Vec<Value>) -> Option<()> {
632 self
633 .client
634 .log_message(
635 MessageType::INFO,
636 format!("Running ExecuteCommand {command}"),
637 )
638 .await;
639 let first = arguments.first()?.clone();
640 let workspace_edit = match self.on_apply_all_fix_impl(first).await {
641 Ok(workspace_edit) => workspace_edit,
642 Err(error) => {
643 self.report_error(error).await;
644 return None;
645 }
646 };
647 self.client.apply_edit(workspace_edit).await.ok()?;
648 None
649 }
650
651 async fn report_error(&self, error: LspError) {
652 match error {
653 LspError::JSONDecodeError(e) => {
654 self
655 .client
656 .log_message(
657 MessageType::ERROR,
658 format!("JSON deserialization error: {e}"),
659 )
660 .await;
661 }
662 LspError::UnsupportedFileType => {
663 self
664 .client
665 .log_message(MessageType::ERROR, "Unsupported file type")
666 .await;
667 }
668 LspError::NoActionableFix => {
669 self
670 .client
671 .log_message(MessageType::LOG, "No actionable fix")
672 .await;
673 }
674 }
675 }
676
677 async fn register_file_watchers(
679 &self,
680 ) -> std::result::Result<(), tower_lsp_server::jsonrpc::Error> {
681 let yml_watcher = FileSystemWatcher {
682 glob_pattern: GlobPattern::String("**/*.{yml,yaml}".to_string()),
683 kind: Some(WatchKind::Create | WatchKind::Change | WatchKind::Delete),
684 };
685 let registration = Registration {
686 id: "ast-grep-config-watcher".to_string(),
687 method: DidChangeWatchedFiles::METHOD.to_string(),
688 register_options: Some(
689 serde_json::to_value(DidChangeWatchedFilesRegistrationOptions {
690 watchers: vec![yml_watcher],
691 })
692 .map_err(|e| tower_lsp_server::jsonrpc::Error::invalid_params(e.to_string()))?,
693 ),
694 };
695
696 self.client.register_capability(vec![registration]).await
697 }
698
699 async fn reload_rules(&self) -> anyhow::Result<()> {
701 self
702 .client
703 .log_message(MessageType::INFO, "Starting rule reload...")
704 .await;
705
706 match (self.rule_finder)() {
707 Ok(new_rules) => {
708 {
710 let mut rules = self
711 .rules
712 .write()
713 .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
714 *rules = new_rules;
715 }
716
717 self
718 .client
719 .log_message(
720 MessageType::INFO,
721 "Rules reloaded successfully using CLI logic",
722 )
723 .await;
724 }
725 Err(error) => {
726 self
728 .client
729 .show_message(MessageType::ERROR, format!("Failed to load rules: {error}"))
730 .await;
731 self
733 .client
734 .log_message(MessageType::ERROR, format!("Failed to load rules: {error}"))
735 .await;
736 }
737 }
738
739 self.interner.clear();
741
742 self.republish_all_diagnostics().await;
744
745 Ok(())
746 }
747
748 async fn republish_all_diagnostics(&self) {
750 for mut entry in self.map.iter_mut() {
752 let (uri_str, versioned) = entry.pair_mut();
753 let Ok(uri) = uri_str.parse::<Uri>() else {
754 continue;
755 };
756 let diagnostics = self.compute_diagnostics(&uri, versioned);
758 self
759 .client
760 .publish_diagnostics(uri, diagnostics, Some(versioned.version))
761 .await;
762 }
763 }
764}
765
766enum LspError {
767 JSONDecodeError(serde_json::Error),
768 UnsupportedFileType,
769 NoActionableFix,
770}