1use super::Backend;
32use crate::fixtures::import_analysis::{
33 adapt_type_for_consumer, can_merge_into, classify_import_statement,
34 find_sorted_insert_position, import_line_sort_key, import_sort_key, parse_import_layout,
35 ImportGroup, ImportKind, ImportLayout,
36};
37use crate::fixtures::string_utils::parameter_has_annotation;
38use crate::fixtures::types::TypeImportSpec;
39use std::collections::{HashMap, HashSet};
40use tower_lsp_server::jsonrpc::Result;
41use tower_lsp_server::ls_types::*;
42use tracing::{info, warn};
43
44const TITLE_PREFIX: &str = "pytest-ls";
48
49const SOURCE_PYTEST_LSP: CodeActionKind = CodeActionKind::new("source.pytest-ls");
51
52const SOURCE_FIX_ALL_PYTEST_LSP: CodeActionKind = CodeActionKind::new("source.fixAll.pytest-ls");
54
55fn kind_requested(only: &Option<Vec<CodeActionKind>>, action_kind: &CodeActionKind) -> bool {
69 let Some(ref kinds) = only else {
70 return true; };
72 let action_str = action_kind.as_str();
73 kinds.iter().any(|k| {
74 let k_str = k.as_str();
75 action_str == k_str || action_str.starts_with(&format!("{}.", k_str))
77 })
78}
79
80fn emit_kind_import_edits(
89 layout: &ImportLayout,
90 new_from_imports: &HashMap<String, Vec<String>>,
91 new_bare_imports: &[String],
92 group: Option<&ImportGroup>,
93 fallback_insert_line: u32,
94 edits: &mut Vec<TextEdit>,
95) {
96 let mut unmerged_from: Vec<(String, Vec<String>)> = Vec::new();
98
99 let mut modules: Vec<&String> = new_from_imports.keys().collect();
100 modules.sort();
101
102 let line_strs = layout.line_strs();
103
104 for module in modules {
105 let new_names = &new_from_imports[module];
106
107 if let Some(fi) = layout.find_matching_from_import(module) {
108 if can_merge_into(fi) {
109 let mut all_names: Vec<String> = fi.name_strings();
113 for n in new_names {
114 if !all_names.iter().any(|existing| existing.trim() == n.trim()) {
115 all_names.push(n.clone());
116 }
117 }
118 all_names.sort_by(|a, b| {
119 import_sort_key(a)
120 .to_lowercase()
121 .cmp(&import_sort_key(b).to_lowercase())
122 });
123 all_names.dedup();
124
125 let merged_line = format!("from {} import {}", module, all_names.join(", "));
126 info!(
127 "Merging import into existing line {}: {}",
128 fi.line, merged_line
129 );
130
131 let end_char = layout.line(fi.end_line).len() as u32;
134 edits.push(TextEdit {
135 range: Range {
136 start: Position {
137 line: fi.line as u32,
138 character: 0,
139 },
140 end: Position {
141 line: fi.end_line as u32,
142 character: end_char,
143 },
144 },
145 new_text: merged_line,
146 });
147 } else {
148 unmerged_from.push((module.clone(), new_names.clone()));
150 }
151 } else {
152 unmerged_from.push((module.clone(), new_names.clone()));
153 }
154 }
155
156 struct NewImport {
162 sort_key: (u8, String),
163 text: String,
164 }
165
166 let mut new_imports: Vec<NewImport> = Vec::new();
167
168 for stmt in new_bare_imports {
170 new_imports.push(NewImport {
171 sort_key: import_line_sort_key(stmt),
172 text: stmt.clone(),
173 });
174 }
175
176 for (module, names) in &unmerged_from {
178 let mut sorted_names = names.clone();
179 sorted_names.sort_by(|a, b| {
180 import_sort_key(a)
181 .to_lowercase()
182 .cmp(&import_sort_key(b).to_lowercase())
183 });
184 let text = format!("from {} import {}", module, sorted_names.join(", "));
185 new_imports.push(NewImport {
186 sort_key: import_line_sort_key(&text),
187 text,
188 });
189 }
190
191 new_imports.sort_by(|a, b| a.sort_key.cmp(&b.sort_key));
194
195 for ni in &new_imports {
196 let insert_line = match group {
197 Some(g) => find_sorted_insert_position(&line_strs, g, &ni.sort_key),
198 None => fallback_insert_line,
199 };
200 info!("Adding new import line at {}: {}", insert_line, ni.text);
201 edits.push(TextEdit {
202 range: Backend::create_point_range(insert_line, 0),
203 new_text: format!("{}\n", ni.text),
204 });
205 }
206}
207
208fn build_import_edits(
218 layout: &ImportLayout,
219 specs: &[&TypeImportSpec],
220 existing_imports: &HashSet<String>,
221) -> Vec<TextEdit> {
222 let groups = &layout.groups;
223
224 let mut stdlib_from: HashMap<String, Vec<String>> = HashMap::new();
226 let mut tp_from: HashMap<String, Vec<String>> = HashMap::new();
227 let mut stdlib_bare: Vec<String> = Vec::new();
228 let mut tp_bare: Vec<String> = Vec::new();
229 let mut seen_names: HashSet<&str> = HashSet::new();
230
231 for spec in specs {
232 if existing_imports.contains(&spec.check_name) {
233 info!("Import '{}' already present, skipping", spec.check_name);
234 continue;
235 }
236 if !seen_names.insert(&spec.check_name) {
237 continue;
238 }
239
240 let kind = classify_import_statement(&spec.import_statement);
241
242 if let Some(rest) = spec.import_statement.strip_prefix("from ") {
243 if let Some((module, name)) = rest.split_once(" import ") {
244 let module = module.trim();
245 let name = name.trim();
246 if !module.is_empty() && !name.is_empty() {
247 match kind {
248 ImportKind::Future | ImportKind::Stdlib => &mut stdlib_from,
256 ImportKind::ThirdParty => &mut tp_from,
257 }
258 .entry(module.to_string())
259 .or_default()
260 .push(name.to_string());
261 continue;
262 }
263 }
264 }
265 match kind {
267 ImportKind::Future | ImportKind::Stdlib => &mut stdlib_bare,
268 ImportKind::ThirdParty => &mut tp_bare,
269 }
270 .push(spec.import_statement.clone());
271 }
272
273 let has_new_stdlib = !stdlib_from.is_empty() || !stdlib_bare.is_empty();
274 let has_new_tp = !tp_from.is_empty() || !tp_bare.is_empty();
275
276 if !has_new_stdlib && !has_new_tp {
277 return vec![];
278 }
279
280 let last_stdlib_group = groups.iter().rev().find(|g| g.kind == ImportKind::Stdlib);
283 let first_tp_group = groups.iter().find(|g| g.kind == ImportKind::ThirdParty);
284 let last_tp_group = groups
285 .iter()
286 .rev()
287 .find(|g| g.kind == ImportKind::ThirdParty);
288 let last_future_group = groups.iter().rev().find(|g| g.kind == ImportKind::Future);
289
290 let will_insert_stdlib =
295 stdlib_from
296 .keys()
297 .any(|m| match layout.find_matching_from_import(m) {
298 None => true,
299 Some(fi) => !can_merge_into(fi),
300 })
301 || !stdlib_bare.is_empty();
302 let will_insert_tp = tp_from
303 .keys()
304 .any(|m| match layout.find_matching_from_import(m) {
305 None => true,
306 Some(fi) => !can_merge_into(fi),
307 })
308 || !tp_bare.is_empty();
309
310 let mut edits: Vec<TextEdit> = Vec::new();
311
312 if has_new_stdlib {
314 let fallback_line = match (last_stdlib_group, first_tp_group) {
315 (Some(sg), _) => (sg.last_line + 1) as u32,
316 (None, Some(tpg)) => tpg.first_line as u32,
317 (None, None) => last_future_group
321 .map(|fg| (fg.last_line + 1) as u32)
322 .unwrap_or(0),
323 };
324
325 if will_insert_stdlib
329 && last_stdlib_group.is_none()
330 && last_future_group.is_some()
331 && first_tp_group.is_none()
332 {
333 edits.push(TextEdit {
334 range: Backend::create_point_range(fallback_line, 0),
335 new_text: "\n".to_string(),
336 });
337 }
338
339 emit_kind_import_edits(
340 layout,
341 &stdlib_from,
342 &stdlib_bare,
343 last_stdlib_group,
344 fallback_line,
345 &mut edits,
346 );
347
348 if will_insert_stdlib && last_stdlib_group.is_none() && first_tp_group.is_some() {
351 edits.push(TextEdit {
352 range: Backend::create_point_range(fallback_line, 0),
353 new_text: "\n".to_string(),
354 });
355 }
356 }
357
358 if has_new_tp {
360 let fallback_line = match (last_tp_group, last_stdlib_group) {
361 (Some(tpg), _) => (tpg.last_line + 1) as u32,
362 (None, Some(sg)) => (sg.last_line + 1) as u32,
363 (None, None) => 0,
364 };
365
366 if will_insert_tp
369 && last_tp_group.is_none()
370 && (last_stdlib_group.is_some() || will_insert_stdlib)
371 {
372 edits.push(TextEdit {
373 range: Backend::create_point_range(fallback_line, 0),
374 new_text: "\n".to_string(),
375 });
376 }
377
378 emit_kind_import_edits(
379 layout,
380 &tp_from,
381 &tp_bare,
382 last_tp_group,
383 fallback_line,
384 &mut edits,
385 );
386 }
387
388 edits
389}
390
391impl Backend {
394 pub async fn handle_code_action(
396 &self,
397 params: CodeActionParams,
398 ) -> Result<Option<CodeActionResponse>> {
399 let uri = params.text_document.uri;
400 let range = params.range;
401 let context = params.context;
402
403 info!(
404 "code_action request: uri={:?}, diagnostics={}, only={:?}",
405 uri,
406 context.diagnostics.len(),
407 context.only
408 );
409
410 let Some(file_path) = self.uri_to_path(&uri) else {
411 info!("Returning None for code_action request: could not resolve URI");
412 return Ok(None);
413 };
414
415 let Some(content) = self.fixture_db.get_file_content(&file_path) else {
418 info!("Returning None: file content not in cache");
419 return Ok(None);
420 };
421 let lines: Vec<&str> = content.lines().collect();
422
423 let existing_imports = self
426 .fixture_db
427 .imports
428 .get(&file_path)
429 .map(|entry| entry.value().clone())
430 .unwrap_or_default();
431
432 let consumer_import_map = self.fixture_db.get_name_to_import_map(&file_path, &content);
437
438 let layout = parse_import_layout(&content);
441
442 let mut actions: Vec<CodeActionOrCommand> = Vec::new();
443
444 if kind_requested(&context.only, &CodeActionKind::QUICKFIX) {
449 let undeclared = self.fixture_db.get_undeclared_fixtures(&file_path);
450 info!("Found {} undeclared fixtures in file", undeclared.len());
451
452 for diagnostic in &context.diagnostics {
453 info!(
454 "Processing diagnostic: code={:?}, range={:?}",
455 diagnostic.code, diagnostic.range
456 );
457
458 let Some(NumberOrString::String(code)) = &diagnostic.code else {
459 continue;
460 };
461 if code != "undeclared-fixture" {
462 continue;
463 }
464
465 let diag_line = Self::lsp_line_to_internal(diagnostic.range.start.line);
466 let diag_char = diagnostic.range.start.character as usize;
467
468 info!(
469 "Looking for undeclared fixture at line={}, char={}",
470 diag_line, diag_char
471 );
472
473 let Some(fixture) = undeclared
474 .iter()
475 .find(|f| f.line == diag_line && f.start_char == diag_char)
476 else {
477 continue;
478 };
479
480 info!("Found matching fixture: {}", fixture.name);
481
482 let fixture_def = self
484 .fixture_db
485 .resolve_fixture_for_file(&file_path, &fixture.name);
486
487 let (type_suffix, return_type_imports) = match &fixture_def {
488 Some(def) => {
489 if let Some(rt) = &def.return_type {
490 let (adapted, remaining) = adapt_type_for_consumer(
491 rt,
492 &def.return_type_imports,
493 &consumer_import_map,
494 );
495 (format!(": {}", adapted), remaining)
496 } else {
497 (String::new(), vec![])
498 }
499 }
500 None => (String::new(), vec![]),
501 };
502
503 let Some(insertion) = self
508 .fixture_db
509 .get_function_param_insertion_info(&file_path, fixture.function_line)
510 else {
511 warn!(
512 "Could not find parameter insertion point for '{}' at {:?}:{}",
513 fixture.name, file_path, fixture.function_line
514 );
515 continue;
516 };
517
518 let insert_line = Self::internal_line_to_lsp(insertion.line);
519 let insert_char = insertion.char_pos as u32;
520
521 let param_text = match &insertion.multiline_indent {
522 Some(indent) => {
523 if insertion.needs_comma {
524 format!(",\n{}{}{}", indent, fixture.name, type_suffix)
527 } else {
528 format!("\n{}{}{},", indent, fixture.name, type_suffix)
531 }
532 }
533 None => {
534 if insertion.needs_comma {
535 format!(", {}{}", fixture.name, type_suffix)
536 } else {
537 format!("{}{}", fixture.name, type_suffix)
538 }
539 }
540 };
541
542 let spec_refs: Vec<&TypeImportSpec> = return_type_imports.iter().collect();
544 let mut all_edits = build_import_edits(&layout, &spec_refs, &existing_imports);
545
546 all_edits.push(TextEdit {
549 range: Self::create_point_range(insert_line, insert_char),
550 new_text: param_text,
551 });
552
553 let edit = WorkspaceEdit {
554 changes: Some(vec![(uri.clone(), all_edits)].into_iter().collect()),
555 document_changes: None,
556 change_annotations: None,
557 };
558
559 let display_type = type_suffix.strip_prefix(": ").unwrap_or("");
561 let title = if !display_type.is_empty() {
562 format!(
563 "{}: Add '{}' fixture parameter ({})",
564 TITLE_PREFIX, fixture.name, display_type
565 )
566 } else {
567 format!("{}: Add '{}' fixture parameter", TITLE_PREFIX, fixture.name)
568 };
569
570 let action = CodeAction {
571 title,
572 kind: Some(CodeActionKind::QUICKFIX),
573 diagnostics: Some(vec![diagnostic.clone()]),
574 edit: Some(edit),
575 command: None,
576 is_preferred: Some(true),
577 disabled: None,
578 data: None,
579 };
580
581 info!("Created code action: {}", action.title);
582 actions.push(CodeActionOrCommand::CodeAction(action));
583 }
584 }
585
586 let want_source = kind_requested(&context.only, &SOURCE_PYTEST_LSP);
591 let want_fix_all = kind_requested(&context.only, &SOURCE_FIX_ALL_PYTEST_LSP);
592
593 let need_fixture_map = want_source || want_fix_all;
594
595 if need_fixture_map {
596 if let Some(ref usages) = self.fixture_db.usages.get(&file_path) {
597 let available = self.fixture_db.get_available_fixtures(&file_path);
598 let fixture_map: std::collections::HashMap<&str, _> = available
599 .iter()
600 .filter_map(|def| def.return_type.as_ref().map(|_rt| (def.name.as_str(), def)))
601 .collect();
602
603 if !fixture_map.is_empty() {
604 if want_source {
610 let cursor_line_internal = Self::lsp_line_to_internal(range.start.line);
611
612 for usage in usages.iter() {
613 if !usage.is_parameter {
617 continue;
618 }
619
620 if usage.line != cursor_line_internal {
621 continue;
622 }
623
624 let cursor_char = range.start.character as usize;
625 if cursor_char < usage.start_char || cursor_char > usage.end_char {
626 continue;
627 }
628
629 if parameter_has_annotation(&lines, usage.line, usage.end_char) {
630 continue;
631 }
632
633 let Some(def) = fixture_map.get(usage.name.as_str()) else {
634 continue;
635 };
636
637 let return_type = def.return_type.as_deref().unwrap();
641
642 let (adapted_type, adapted_imports) = adapt_type_for_consumer(
644 return_type,
645 &def.return_type_imports,
646 &consumer_import_map,
647 );
648
649 info!(
650 "Cursor-based annotation action for '{}': {}",
651 usage.name, adapted_type
652 );
653
654 let spec_refs: Vec<&TypeImportSpec> = adapted_imports.iter().collect();
656 let mut all_edits =
657 build_import_edits(&layout, &spec_refs, &existing_imports);
658
659 let lsp_line = Self::internal_line_to_lsp(usage.line);
660 all_edits.push(TextEdit {
661 range: Self::create_point_range(lsp_line, usage.end_char as u32),
662 new_text: format!(": {}", adapted_type),
663 });
664
665 let ws_edit = WorkspaceEdit {
666 changes: Some(vec![(uri.clone(), all_edits)].into_iter().collect()),
667 document_changes: None,
668 change_annotations: None,
669 };
670
671 let title = format!(
672 "{}: Add type annotation for fixture '{}'",
673 TITLE_PREFIX, usage.name
674 );
675
676 let action = CodeAction {
677 title: title.clone(),
678 kind: Some(SOURCE_PYTEST_LSP),
679 diagnostics: None,
680 edit: Some(ws_edit),
681 command: None,
682 is_preferred: Some(true),
683 disabled: None,
684 data: None,
685 };
686 info!("Created source.pytest-ls action: {}", title);
687 actions.push(CodeActionOrCommand::CodeAction(action));
688 }
689 }
690
691 if want_fix_all {
697 let mut all_adapted_imports: Vec<TypeImportSpec> = Vec::new();
699 let mut annotation_edits: Vec<TextEdit> = Vec::new();
700 let mut annotated_count: usize = 0;
701
702 for usage in usages.iter() {
703 if !usage.is_parameter {
707 continue;
708 }
709
710 if parameter_has_annotation(&lines, usage.line, usage.end_char) {
711 continue;
712 }
713
714 let Some(def) = fixture_map.get(usage.name.as_str()) else {
715 continue;
716 };
717
718 let return_type = def.return_type.as_deref().unwrap();
722
723 let (adapted_type, adapted_imports) = adapt_type_for_consumer(
725 return_type,
726 &def.return_type_imports,
727 &consumer_import_map,
728 );
729
730 all_adapted_imports.extend(adapted_imports);
733
734 let lsp_line = Self::internal_line_to_lsp(usage.line);
736 annotation_edits.push(TextEdit {
737 range: Self::create_point_range(lsp_line, usage.end_char as u32),
738 new_text: format!(": {}", adapted_type),
739 });
740
741 annotated_count += 1;
742 }
743
744 if !annotation_edits.is_empty() {
745 let spec_refs: Vec<&TypeImportSpec> =
746 all_adapted_imports.iter().collect();
747 let mut all_edits =
748 build_import_edits(&layout, &spec_refs, &existing_imports);
749 all_edits.extend(annotation_edits);
750
751 let ws_edit = WorkspaceEdit {
752 changes: Some(vec![(uri.clone(), all_edits)].into_iter().collect()),
753 document_changes: None,
754 change_annotations: None,
755 };
756
757 let title = format!(
758 "{}: Add all fixture type annotations ({} fixture{})",
759 TITLE_PREFIX,
760 annotated_count,
761 if annotated_count == 1 { "" } else { "s" }
762 );
763
764 let action = CodeAction {
765 title: title.clone(),
766 kind: Some(SOURCE_FIX_ALL_PYTEST_LSP),
767 diagnostics: None,
768 edit: Some(ws_edit),
769 command: None,
770 is_preferred: Some(false),
771 disabled: None,
772 data: None,
773 };
774
775 info!("Created source.fixAll.pytest-ls action: {}", title);
776 actions.push(CodeActionOrCommand::CodeAction(action));
777 }
778 }
779 }
780 }
781 }
782
783 if !actions.is_empty() {
786 info!("Returning {} code actions", actions.len());
787 return Ok(Some(actions));
788 }
789
790 info!("Returning None for code_action request");
791 Ok(None)
792 }
793}
794
795#[cfg(test)]
796mod tests {
797 use super::*;
798 use crate::fixtures::import_analysis::parse_import_layout;
799
800 fn layout_from_lines(lines: &[&str]) -> ImportLayout {
804 parse_import_layout(&lines.join("\n"))
805 }
806
807 #[test]
810 fn test_kind_requested_no_filter_accepts_everything() {
811 assert!(kind_requested(&None, &CodeActionKind::QUICKFIX));
812 assert!(kind_requested(&None, &SOURCE_PYTEST_LSP));
813 assert!(kind_requested(&None, &SOURCE_FIX_ALL_PYTEST_LSP));
814 }
815
816 #[test]
817 fn test_kind_requested_exact_match() {
818 let only = Some(vec![CodeActionKind::QUICKFIX]);
819 assert!(kind_requested(&only, &CodeActionKind::QUICKFIX));
820 assert!(!kind_requested(&only, &SOURCE_PYTEST_LSP));
821 }
822
823 #[test]
824 fn test_kind_requested_parent_source_matches_children() {
825 let only = Some(vec![CodeActionKind::SOURCE]);
826 assert!(kind_requested(&only, &SOURCE_PYTEST_LSP));
827 assert!(kind_requested(&only, &SOURCE_FIX_ALL_PYTEST_LSP));
828 assert!(!kind_requested(&only, &CodeActionKind::QUICKFIX));
829 }
830
831 #[test]
832 fn test_kind_requested_parent_source_fix_all_matches_child() {
833 let only = Some(vec![CodeActionKind::SOURCE_FIX_ALL]);
834 assert!(kind_requested(&only, &SOURCE_FIX_ALL_PYTEST_LSP));
835 assert!(!kind_requested(&only, &SOURCE_PYTEST_LSP));
836 }
837
838 #[test]
839 fn test_kind_requested_specific_child_does_not_match_sibling() {
840 let only = Some(vec![SOURCE_PYTEST_LSP]);
841 assert!(kind_requested(&only, &SOURCE_PYTEST_LSP));
842 assert!(!kind_requested(&only, &SOURCE_FIX_ALL_PYTEST_LSP));
843 }
844
845 #[test]
846 fn test_kind_requested_multiple_filters() {
847 let only = Some(vec![
848 CodeActionKind::QUICKFIX,
849 CodeActionKind::SOURCE_FIX_ALL,
850 ]);
851 assert!(kind_requested(&only, &CodeActionKind::QUICKFIX));
852 assert!(kind_requested(&only, &SOURCE_FIX_ALL_PYTEST_LSP));
853 assert!(!kind_requested(&only, &SOURCE_PYTEST_LSP));
854 }
855
856 #[test]
857 fn test_kind_requested_quickfix_only_rejects_source() {
858 let only = Some(vec![CodeActionKind::QUICKFIX]);
859 assert!(!kind_requested(&only, &SOURCE_PYTEST_LSP));
860 assert!(!kind_requested(&only, &SOURCE_FIX_ALL_PYTEST_LSP));
861 }
862
863 #[test]
866 fn test_build_import_edits_merge_into_existing() {
867 let lines = vec![
868 "import pytest",
869 "from typing import Optional",
870 "",
871 "def test(): pass",
872 ];
873 let layout = layout_from_lines(&lines);
874 let spec = TypeImportSpec {
875 check_name: "Any".to_string(),
876 import_statement: "from typing import Any".to_string(),
877 };
878 let existing: HashSet<String> = HashSet::new();
879 let edits = build_import_edits(&layout, &[&spec], &existing);
880
881 assert_eq!(edits.len(), 1);
882 assert_eq!(edits[0].range.start.line, 1);
883 assert_eq!(edits[0].range.start.character, 0);
884 assert_eq!(edits[0].range.end.line, 1);
885 assert_eq!(edits[0].new_text, "from typing import Any, Optional");
886 }
887
888 #[test]
889 fn test_build_import_edits_skips_already_imported() {
890 let lines = vec!["from typing import Any"];
891 let layout = layout_from_lines(&lines);
892 let spec = TypeImportSpec {
893 check_name: "Any".to_string(),
894 import_statement: "from typing import Any".to_string(),
895 };
896 let mut existing: HashSet<String> = HashSet::new();
897 existing.insert("Any".to_string());
898 let edits = build_import_edits(&layout, &[&spec], &existing);
899 assert!(edits.is_empty());
900 }
901
902 #[test]
903 fn test_build_import_edits_merge_multiple_into_existing() {
904 let lines = vec!["from typing import Union", "", "def test(): pass"];
905 let layout = layout_from_lines(&lines);
906 let spec1 = TypeImportSpec {
907 check_name: "Any".to_string(),
908 import_statement: "from typing import Any".to_string(),
909 };
910 let spec2 = TypeImportSpec {
911 check_name: "Optional".to_string(),
912 import_statement: "from typing import Optional".to_string(),
913 };
914 let existing: HashSet<String> = HashSet::new();
915 let edits = build_import_edits(&layout, &[&spec1, &spec2], &existing);
916 assert_eq!(edits.len(), 1);
917 assert_eq!(edits[0].new_text, "from typing import Any, Optional, Union");
918 }
919
920 #[test]
921 fn test_build_import_edits_merge_preserves_alias() {
922 let lines = vec!["from pathlib import Path as P", "", "def test(): pass"];
923 let layout = layout_from_lines(&lines);
924 let spec = TypeImportSpec {
925 check_name: "PurePath".to_string(),
926 import_statement: "from pathlib import PurePath".to_string(),
927 };
928 let existing: HashSet<String> = HashSet::new();
929 let edits = build_import_edits(&layout, &[&spec], &existing);
930 assert_eq!(edits.len(), 1);
931 assert_eq!(edits[0].new_text, "from pathlib import Path as P, PurePath");
932 }
933
934 #[test]
935 fn test_build_import_edits_deduplicates_specs() {
936 let lines = vec!["import pytest", "", "def test(): pass"];
937 let layout = layout_from_lines(&lines);
938 let spec1 = TypeImportSpec {
939 check_name: "Path".to_string(),
940 import_statement: "from pathlib import Path".to_string(),
941 };
942 let spec2 = TypeImportSpec {
943 check_name: "Path".to_string(),
944 import_statement: "from pathlib import Path".to_string(),
945 };
946 let existing: HashSet<String> = HashSet::new();
947 let edits = build_import_edits(&layout, &[&spec1, &spec2], &existing);
948 let import_edits: Vec<_> = edits
949 .iter()
950 .filter(|e| e.new_text.contains("Path"))
951 .collect();
952 assert_eq!(import_edits.len(), 1);
953 assert_eq!(import_edits[0].new_text, "from pathlib import Path\n");
954 }
955
956 #[test]
957 fn test_build_import_edits_merge_into_multi_name_existing() {
958 let lines = vec!["from os import path, othermodule", "", "def test(): pass"];
959 let layout = layout_from_lines(&lines);
960 let spec = TypeImportSpec {
961 check_name: "getcwd".to_string(),
962 import_statement: "from os import getcwd".to_string(),
963 };
964 let existing: HashSet<String> = HashSet::new();
965 let edits = build_import_edits(&layout, &[&spec], &existing);
966 assert_eq!(edits.len(), 1);
967 assert_eq!(
968 edits[0].new_text,
969 "from os import getcwd, othermodule, path"
970 );
971 }
972
973 #[test]
974 fn test_build_import_edits_merge_strips_comment() {
975 let lines = vec![
976 "from typing import Any # needed for X",
977 "",
978 "def test(): pass",
979 ];
980 let layout = layout_from_lines(&lines);
981 let spec = TypeImportSpec {
982 check_name: "Optional".to_string(),
983 import_statement: "from typing import Optional".to_string(),
984 };
985 let existing: HashSet<String> = HashSet::new();
986 let edits = build_import_edits(&layout, &[&spec], &existing);
987 assert_eq!(edits.len(), 1);
988 assert_eq!(edits[0].new_text, "from typing import Any, Optional");
989 assert!(
990 !edits[0].new_text.contains('#'),
991 "merged line must not contain the original comment"
992 );
993 }
994
995 #[test]
996 fn test_build_import_edits_multiline_import_merged() {
997 let lines = vec![
1000 "from typing import (",
1001 " Any,",
1002 " Optional,",
1003 ")",
1004 "",
1005 "def test(): pass",
1006 ];
1007 let layout = layout_from_lines(&lines);
1008 let spec = TypeImportSpec {
1009 check_name: "Union".to_string(),
1010 import_statement: "from typing import Union".to_string(),
1011 };
1012 let existing: HashSet<String> = HashSet::new();
1013 let edits = build_import_edits(&layout, &[&spec], &existing);
1014
1015 assert_eq!(edits.len(), 1);
1017 assert_eq!(edits[0].range.start.line, 0);
1018 assert_eq!(edits[0].range.start.character, 0);
1019 assert_eq!(edits[0].range.end.line, 3);
1020 assert_eq!(edits[0].new_text, "from typing import Any, Optional, Union");
1021 }
1022
1023 #[test]
1026 fn test_stdlib_import_into_existing_stdlib_group() {
1027 let lines = vec![
1028 "import time",
1029 "",
1030 "import pytest",
1031 "from vcc.framework import fixture",
1032 "",
1033 "LOGGING_TIME = 2",
1034 ];
1035 let layout = layout_from_lines(&lines);
1036 let spec = TypeImportSpec {
1037 check_name: "Any".to_string(),
1038 import_statement: "from typing import Any".to_string(),
1039 };
1040 let existing: HashSet<String> = HashSet::new();
1041 let edits = build_import_edits(&layout, &[&spec], &existing);
1042 assert_eq!(edits.len(), 1);
1043 assert_eq!(edits[0].range.start.line, 1);
1044 assert_eq!(edits[0].new_text, "from typing import Any\n");
1045 }
1046
1047 #[test]
1048 fn test_stdlib_import_before_third_party_when_no_stdlib_group() {
1049 let lines = vec![
1050 "import pytest",
1051 "from vcc.framework import fixture",
1052 "",
1053 "def test(): pass",
1054 ];
1055 let layout = layout_from_lines(&lines);
1056 let spec = TypeImportSpec {
1057 check_name: "Any".to_string(),
1058 import_statement: "from typing import Any".to_string(),
1059 };
1060 let existing: HashSet<String> = HashSet::new();
1061 let edits = build_import_edits(&layout, &[&spec], &existing);
1062 assert_eq!(edits.len(), 2);
1063 assert_eq!(edits[0].new_text, "from typing import Any\n");
1064 assert_eq!(edits[0].range.start.line, 0);
1065 assert_eq!(edits[1].new_text, "\n");
1066 assert_eq!(edits[1].range.start.line, 0);
1067 }
1068
1069 #[test]
1070 fn test_third_party_import_after_stdlib_when_no_tp_group() {
1071 let lines = vec!["import os", "import time", "", "def test(): pass"];
1072 let layout = layout_from_lines(&lines);
1073 let spec = TypeImportSpec {
1074 check_name: "FlaskClient".to_string(),
1075 import_statement: "from flask.testing import FlaskClient".to_string(),
1076 };
1077 let existing: HashSet<String> = HashSet::new();
1078 let edits = build_import_edits(&layout, &[&spec], &existing);
1079 assert_eq!(edits.len(), 2);
1080 assert_eq!(edits[0].new_text, "\n");
1081 assert_eq!(edits[0].range.start.line, 2);
1082 assert_eq!(edits[1].new_text, "from flask.testing import FlaskClient\n");
1083 assert_eq!(edits[1].range.start.line, 2);
1084 }
1085
1086 #[test]
1087 fn test_third_party_import_into_existing_tp_group() {
1088 let lines = vec!["import time", "", "import pytest", "", "def test(): pass"];
1089 let layout = layout_from_lines(&lines);
1090 let spec = TypeImportSpec {
1091 check_name: "FlaskClient".to_string(),
1092 import_statement: "from flask.testing import FlaskClient".to_string(),
1093 };
1094 let existing: HashSet<String> = HashSet::new();
1095 let edits = build_import_edits(&layout, &[&spec], &existing);
1096 assert_eq!(edits.len(), 1);
1097 assert_eq!(edits[0].range.start.line, 3);
1098 assert_eq!(edits[0].new_text, "from flask.testing import FlaskClient\n");
1099 }
1100
1101 #[test]
1102 fn test_no_imports_at_all() {
1103 let lines = vec!["def test(): pass"];
1104 let layout = layout_from_lines(&lines);
1105 let spec = TypeImportSpec {
1106 check_name: "Path".to_string(),
1107 import_statement: "from pathlib import Path".to_string(),
1108 };
1109 let existing: HashSet<String> = HashSet::new();
1110 let edits = build_import_edits(&layout, &[&spec], &existing);
1111 assert_eq!(edits.len(), 1);
1112 assert_eq!(edits[0].range.start.line, 0);
1113 assert_eq!(edits[0].new_text, "from pathlib import Path\n");
1114 }
1115
1116 #[test]
1117 fn test_both_stdlib_and_tp_imports_no_existing_groups() {
1118 let lines = vec!["def test(): pass"];
1119 let layout = layout_from_lines(&lines);
1120 let spec_stdlib = TypeImportSpec {
1121 check_name: "Any".to_string(),
1122 import_statement: "from typing import Any".to_string(),
1123 };
1124 let spec_tp = TypeImportSpec {
1125 check_name: "FlaskClient".to_string(),
1126 import_statement: "from flask.testing import FlaskClient".to_string(),
1127 };
1128 let existing: HashSet<String> = HashSet::new();
1129 let edits = build_import_edits(&layout, &[&spec_stdlib, &spec_tp], &existing);
1130 assert_eq!(edits.len(), 3);
1131 assert_eq!(edits[0].new_text, "from typing import Any\n");
1132 assert_eq!(edits[1].new_text, "\n");
1133 assert_eq!(edits[2].new_text, "from flask.testing import FlaskClient\n");
1134 }
1135
1136 #[test]
1137 fn test_bare_stdlib_import_sorted_within_group() {
1138 let lines = vec![
1139 "import os",
1140 "import time",
1141 "",
1142 "import pytest",
1143 "",
1144 "def test(): pass",
1145 ];
1146 let layout = layout_from_lines(&lines);
1147 let spec = TypeImportSpec {
1148 check_name: "pathlib".to_string(),
1149 import_statement: "import pathlib".to_string(),
1150 };
1151 let existing: HashSet<String> = HashSet::new();
1152 let edits = build_import_edits(&layout, &[&spec], &existing);
1153 assert_eq!(edits.len(), 1);
1154 assert_eq!(edits[0].range.start.line, 1);
1155 assert_eq!(edits[0].new_text, "import pathlib\n");
1156 }
1157
1158 #[test]
1159 fn test_from_import_sorts_after_bare_imports_in_group() {
1160 let lines = vec!["import os", "import time", "", "def test(): pass"];
1161 let layout = layout_from_lines(&lines);
1162 let spec = TypeImportSpec {
1163 check_name: "Any".to_string(),
1164 import_statement: "from typing import Any".to_string(),
1165 };
1166 let existing: HashSet<String> = HashSet::new();
1167 let edits = build_import_edits(&layout, &[&spec], &existing);
1168 assert_eq!(edits.len(), 1);
1169 assert_eq!(edits[0].range.start.line, 2);
1170 assert_eq!(edits[0].new_text, "from typing import Any\n");
1171 }
1172
1173 #[test]
1174 fn test_mixed_stdlib_from_imports_grouped() {
1175 let lines = vec!["import time", "", "import pytest", "", "def test(): pass"];
1176 let layout = layout_from_lines(&lines);
1177 let spec1 = TypeImportSpec {
1178 check_name: "Any".to_string(),
1179 import_statement: "from typing import Any".to_string(),
1180 };
1181 let spec2 = TypeImportSpec {
1182 check_name: "Optional".to_string(),
1183 import_statement: "from typing import Optional".to_string(),
1184 };
1185 let existing: HashSet<String> = HashSet::new();
1186 let edits = build_import_edits(&layout, &[&spec1, &spec2], &existing);
1187 assert_eq!(edits.len(), 1);
1188 assert_eq!(edits[0].range.start.line, 1);
1189 assert_eq!(edits[0].new_text, "from typing import Any, Optional\n");
1190 }
1191
1192 #[test]
1193 fn test_tp_from_import_sorted_before_existing() {
1194 let lines = vec![
1195 "import time",
1196 "",
1197 "import pytest",
1198 "from vcc.conxtfw.framework.pytest.fixtures.component import fixture",
1199 "",
1200 "LOGGING_TIME = 2",
1201 ];
1202 let layout = layout_from_lines(&lines);
1203 let spec = TypeImportSpec {
1204 check_name: "conx_canoe".to_string(),
1205 import_statement: "from vcc import conx_canoe".to_string(),
1206 };
1207 let existing: HashSet<String> = HashSet::new();
1208 let edits = build_import_edits(&layout, &[&spec], &existing);
1209 assert_eq!(edits.len(), 1);
1210 assert_eq!(edits[0].range.start.line, 3);
1211 assert_eq!(edits[0].new_text, "from vcc import conx_canoe\n");
1212 }
1213
1214 #[test]
1215 fn test_user_scenario_stdlib_into_correct_group() {
1216 let lines = vec![
1217 "import time",
1218 "",
1219 "import pytest",
1220 "from vcc.conxtfw.framework.pytest.fixtures.component import fixture",
1221 "",
1222 "LOGGING_TIME = 2",
1223 ];
1224 let layout = layout_from_lines(&lines);
1225 let spec = TypeImportSpec {
1226 check_name: "Any".to_string(),
1227 import_statement: "from typing import Any".to_string(),
1228 };
1229 let existing: HashSet<String> = HashSet::new();
1230 let edits = build_import_edits(&layout, &[&spec], &existing);
1231 assert_eq!(edits.len(), 1);
1232 assert_eq!(edits[0].range.start.line, 1);
1233 assert_eq!(edits[0].range.start.character, 0);
1234 assert_eq!(edits[0].new_text, "from typing import Any\n");
1235 }
1236
1237 #[test]
1238 fn test_user_scenario_fix_all_multi_import() {
1239 let lines = vec![
1240 "import time",
1241 "",
1242 "import pytest",
1243 "from vcc.conxtfw.framework.pytest.fixtures.component import fixture",
1244 "",
1245 "LOGGING_TIME = 2",
1246 ];
1247 let layout = layout_from_lines(&lines);
1248 let spec_typing = TypeImportSpec {
1249 check_name: "Any".to_string(),
1250 import_statement: "from typing import Any".to_string(),
1251 };
1252 let spec_pathlib = TypeImportSpec {
1253 check_name: "pathlib".to_string(),
1254 import_statement: "import pathlib".to_string(),
1255 };
1256 let spec_vcc = TypeImportSpec {
1257 check_name: "conx_canoe".to_string(),
1258 import_statement: "from vcc import conx_canoe".to_string(),
1259 };
1260 let existing: HashSet<String> = HashSet::new();
1261 let edits = build_import_edits(
1262 &layout,
1263 &[&spec_typing, &spec_pathlib, &spec_vcc],
1264 &existing,
1265 );
1266 assert_eq!(edits.len(), 3);
1267 let pathlib_edit = edits
1268 .iter()
1269 .find(|e| e.new_text.contains("pathlib"))
1270 .unwrap();
1271 assert_eq!(pathlib_edit.range.start.line, 0);
1272 assert_eq!(pathlib_edit.new_text, "import pathlib\n");
1273 let typing_edit = edits
1274 .iter()
1275 .find(|e| e.new_text.contains("typing"))
1276 .unwrap();
1277 assert_eq!(typing_edit.range.start.line, 1);
1278 assert_eq!(typing_edit.new_text, "from typing import Any\n");
1279 let vcc_edit = edits
1280 .iter()
1281 .find(|e| e.new_text.contains("conx_canoe"))
1282 .unwrap();
1283 assert_eq!(vcc_edit.range.start.line, 3);
1284 assert_eq!(vcc_edit.new_text, "from vcc import conx_canoe\n");
1285 }
1286
1287 #[test]
1288 fn test_future_import_skipped_for_stdlib_insertion() {
1289 let lines = vec![
1292 "from __future__ import annotations",
1293 "",
1294 "import os",
1295 "import time",
1296 "",
1297 "import pytest",
1298 "",
1299 "def test(): pass",
1300 ];
1301 let layout = layout_from_lines(&lines);
1302 let spec = TypeImportSpec {
1303 check_name: "Any".to_string(),
1304 import_statement: "from typing import Any".to_string(),
1305 };
1306 let existing: HashSet<String> = HashSet::new();
1307 let edits = build_import_edits(&layout, &[&spec], &existing);
1308 assert_eq!(edits.len(), 1);
1309 assert_eq!(edits[0].range.start.line, 4);
1310 assert_eq!(edits[0].new_text, "from typing import Any\n");
1311 }
1312
1313 #[test]
1314 fn test_stdlib_not_inserted_before_future_import() {
1315 let lines = vec!["from __future__ import annotations", "", "def test(): pass"];
1321 let layout = layout_from_lines(&lines);
1322 let spec = TypeImportSpec {
1323 check_name: "Any".to_string(),
1324 import_statement: "from typing import Any".to_string(),
1325 };
1326 let existing: HashSet<String> = HashSet::new();
1327 let edits = build_import_edits(&layout, &[&spec], &existing);
1328
1329 assert_eq!(edits.len(), 2);
1331 assert_eq!(edits[0].range.start.line, 1);
1333 assert_eq!(edits[0].new_text, "\n");
1334 let import_edit = edits
1336 .iter()
1337 .find(|e| e.new_text.contains("typing"))
1338 .expect("expected a typing import edit");
1339 assert!(
1340 import_edit.range.start.line > 0,
1341 "stdlib import was inserted at line {}, which is before \
1342 `from __future__ import annotations` at line 0",
1343 import_edit.range.start.line,
1344 );
1345 assert_eq!(import_edit.new_text, "from typing import Any\n");
1346 }
1347
1348 #[test]
1349 fn test_stdlib_not_inserted_before_future_import_no_blank_line() {
1350 let lines = vec!["from __future__ import annotations", "def test(): pass"];
1353 let layout = layout_from_lines(&lines);
1354 let spec = TypeImportSpec {
1355 check_name: "Any".to_string(),
1356 import_statement: "from typing import Any".to_string(),
1357 };
1358 let existing: HashSet<String> = HashSet::new();
1359 let edits = build_import_edits(&layout, &[&spec], &existing);
1360
1361 assert_eq!(edits.len(), 2);
1362 assert_eq!(edits[0].range.start.line, 1);
1363 assert_eq!(edits[0].new_text, "\n");
1364 let import_edit = edits
1365 .iter()
1366 .find(|e| e.new_text.contains("typing"))
1367 .expect("expected a typing import edit");
1368 assert!(
1369 import_edit.range.start.line > 0,
1370 "stdlib import was inserted at line {}, which is before \
1371 `from __future__ import annotations` at line 0",
1372 import_edit.range.start.line,
1373 );
1374 assert_eq!(import_edit.new_text, "from typing import Any\n");
1375 }
1376
1377 #[test]
1378 fn test_different_modules_stdlib_and_tp() {
1379 let lines = vec!["import os", "", "import pytest", "", "def test(): pass"];
1380 let layout = layout_from_lines(&lines);
1381 let spec_stdlib = TypeImportSpec {
1382 check_name: "Any".to_string(),
1383 import_statement: "from typing import Any".to_string(),
1384 };
1385 let spec_tp = TypeImportSpec {
1386 check_name: "FlaskClient".to_string(),
1387 import_statement: "from flask.testing import FlaskClient".to_string(),
1388 };
1389 let existing: HashSet<String> = HashSet::new();
1390 let edits = build_import_edits(&layout, &[&spec_stdlib, &spec_tp], &existing);
1391 assert_eq!(edits.len(), 2);
1392 let stdlib_edit = edits
1393 .iter()
1394 .find(|e| e.new_text.contains("typing"))
1395 .unwrap();
1396 assert_eq!(stdlib_edit.range.start.line, 1);
1397 assert_eq!(stdlib_edit.new_text, "from typing import Any\n");
1398 let tp_edit = edits.iter().find(|e| e.new_text.contains("flask")).unwrap();
1399 assert_eq!(tp_edit.range.start.line, 3);
1400 assert_eq!(tp_edit.new_text, "from flask.testing import FlaskClient\n");
1401 }
1402}