1use std::collections::{HashMap, HashSet};
8
9use crate::registry::PackageRegistry;
10use crate::types::{AllCategorizedImports, CategorizedImports, ImportSpec};
11use crate::{ImportCategory, ImportSections, ImportStatement, ImportType};
12
13#[derive(Debug)]
15pub struct ImportHelper {
16 sections: ImportSections,
18 category_cache: HashMap<String, ImportCategory>,
20 package_name: Option<String>,
22 local_package_prefixes: HashSet<String>,
24 registry: PackageRegistry,
26}
27
28impl ImportHelper {
29 #[must_use]
31 pub fn new() -> Self {
32 Self {
33 sections: ImportSections::default(),
34 category_cache: HashMap::new(),
35 package_name: None,
36 local_package_prefixes: HashSet::new(),
37 registry: PackageRegistry::new(),
38 }
39 }
40
41 #[must_use]
43 pub fn with_package_name(package_name: String) -> Self {
44 let mut helper = Self::new();
45 helper.package_name = Some(package_name.clone());
46 helper.local_package_prefixes.insert(package_name);
47 helper
48 }
49
50 #[must_use]
61 pub fn registry(&self) -> &PackageRegistry {
62 &self.registry
63 }
64
65 pub fn registry_mut(&mut self) -> &mut PackageRegistry {
85 &mut self.registry
87 }
88
89 pub fn clear_cache(&mut self) -> &mut Self {
104 self.category_cache.clear();
105 self
106 }
107
108 pub fn add_local_package_prefix(&mut self, prefix: impl Into<String>) -> &mut Self {
110 let prefix = prefix.into();
111 self.local_package_prefixes.insert(prefix);
112 self
114 }
115
116 pub fn add_local_package_prefixes(&mut self, prefixes: &[impl AsRef<str>]) -> &mut Self {
118 for prefix in prefixes {
119 self.add_local_package_prefix(prefix.as_ref());
120 }
121 self
122 }
123
124 pub fn add_import(&mut self, spec: &ImportSpec) {
126 let import_statement = if let Some(items) = &spec.items {
127 format!("from {} import {}", spec.package, items.join(", "))
128 } else {
129 format!("import {}", spec.package)
130 };
131
132 if spec.type_checking {
133 self.add_type_checking_import(&import_statement);
134 } else {
135 self.add_regular_import(&import_statement);
136 }
137 }
138
139 pub fn add_import_string(&mut self, import_statement: &str) {
141 self.add_regular_import(import_statement);
142 }
143
144 fn add_regular_import(&mut self, import_statement: &str) {
146 if let Some(import) = self.parse_import(import_statement) {
147 match (&import.category, &import.import_type) {
148 (ImportCategory::Future, _) => self.sections.future.push(import),
149 (ImportCategory::StandardLibrary, ImportType::Direct) => {
150 self.sections.standard_library_direct.push(import)
151 }
152 (ImportCategory::StandardLibrary, ImportType::From) => {
153 self.sections.standard_library_from.push(import)
154 }
155 (ImportCategory::ThirdParty, ImportType::Direct) => {
156 self.sections.third_party_direct.push(import)
157 }
158 (ImportCategory::ThirdParty, ImportType::From) => {
159 self.sections.third_party_from.push(import)
160 }
161 (ImportCategory::Local, ImportType::Direct) => {
162 self.sections.local_direct.push(import)
163 }
164 (ImportCategory::Local, ImportType::From) => self.sections.local_from.push(import),
165 }
166 }
167 }
168
169 pub fn add_from_import(&mut self, package: &str, items: &[&str]) {
172 let import_statement = if items.len() == 1 {
173 format!("from {} import {}", package, items[0])
174 } else {
175 format!("from {} import {}", package, items.join(", "))
176 };
177 self.add_regular_import(&import_statement);
178 }
179
180 pub fn add_type_checking_from_import(&mut self, package: &str, items: &[&str]) {
183 let import_statement = if items.len() == 1 {
184 format!("from {} import {}", package, items[0])
185 } else {
186 format!("from {} import {}", package, items.join(", "))
187 };
188 self.add_type_checking_import(&import_statement);
189 }
190
191 pub fn add_direct_import(&mut self, module: &str) {
194 let import_statement = format!("import {module}");
195 self.add_regular_import(&import_statement);
196 }
197
198 pub fn add_type_checking_direct_import(&mut self, module: &str) {
201 let import_statement = format!("import {module}");
202 self.add_type_checking_import(&import_statement);
203 }
204
205 pub fn add_type_checking_import(&mut self, import_statement: &str) {
207 if let Some(import) = self.parse_import(import_statement) {
208 match (&import.category, &import.import_type) {
209 (ImportCategory::Future, _) => self.sections.type_checking_future.push(import),
210 (ImportCategory::StandardLibrary, ImportType::Direct) => self
211 .sections
212 .type_checking_standard_library_direct
213 .push(import),
214 (ImportCategory::StandardLibrary, ImportType::From) => self
215 .sections
216 .type_checking_standard_library_from
217 .push(import),
218 (ImportCategory::ThirdParty, ImportType::Direct) => {
219 self.sections.type_checking_third_party_direct.push(import);
220 }
221 (ImportCategory::ThirdParty, ImportType::From) => {
222 self.sections.type_checking_third_party_from.push(import);
223 }
224 (ImportCategory::Local, ImportType::Direct) => {
225 self.sections.type_checking_local_direct.push(import);
226 }
227 (ImportCategory::Local, ImportType::From) => {
228 self.sections.type_checking_local_from.push(import);
229 }
230 }
231
232 self.ensure_type_checking_import_added();
234 }
235 }
236
237 #[must_use]
241 pub fn get_all_categorized(&self) -> AllCategorizedImports {
242 let (future_imports, stdlib_imports, third_party_imports, local_imports) =
244 self.get_categorized();
245
246 let (tc_future, tc_stdlib, tc_third_party, tc_local) = self.get_type_checking_categorized();
248
249 (
250 future_imports,
251 stdlib_imports,
252 third_party_imports,
253 local_imports,
254 tc_future,
255 tc_stdlib,
256 tc_third_party,
257 tc_local,
258 )
259 }
260
261 #[must_use]
268 pub fn get_type_checking_categorized(&self) -> CategorizedImports {
269 self.get_type_checking_categorized_impl()
270 }
271
272 #[must_use]
273 pub fn get_type_checking_categorized_impl(&self) -> CategorizedImports {
274 let mut future_imports = Vec::new();
275 let mut stdlib_imports = Vec::new();
276 let mut third_party_imports = Vec::new();
277 let mut local_imports = Vec::new();
278
279 if !self.sections.type_checking_future.is_empty() {
281 let future = self.format_imports(&self.sections.type_checking_future);
282 future_imports.extend(future);
283 }
284
285 if !self
287 .sections
288 .type_checking_standard_library_direct
289 .is_empty()
290 {
291 let std_direct =
292 self.format_imports(&self.sections.type_checking_standard_library_direct);
293 stdlib_imports.extend(std_direct);
294 }
295 if !self.sections.type_checking_standard_library_from.is_empty() {
296 let std_from = self.format_imports(&self.sections.type_checking_standard_library_from);
297 stdlib_imports.extend(std_from);
298 }
299
300 if !self.sections.type_checking_third_party_direct.is_empty() {
302 let third_direct = self.format_imports(&self.sections.type_checking_third_party_direct);
303 third_party_imports.extend(third_direct);
304 }
305 if !self.sections.type_checking_third_party_from.is_empty() {
306 let third_from = self.format_imports(&self.sections.type_checking_third_party_from);
307 third_party_imports.extend(third_from);
308 }
309
310 if !self.sections.type_checking_local_direct.is_empty() {
312 let local_direct = self.format_imports(&self.sections.type_checking_local_direct);
313 local_imports.extend(local_direct);
314 }
315 if !self.sections.type_checking_local_from.is_empty() {
316 let local_from = self.format_imports(&self.sections.type_checking_local_from);
317 local_imports.extend(local_from);
318 }
319
320 future_imports.sort();
322 stdlib_imports.sort();
323 third_party_imports.sort();
324 local_imports.sort();
325
326 (
327 future_imports,
328 stdlib_imports,
329 third_party_imports,
330 local_imports,
331 )
332 }
333
334 #[must_use]
337 pub fn get_categorized(&self) -> CategorizedImports {
338 let mut future_imports = Vec::new();
339 let mut stdlib_imports = Vec::new();
340 let mut third_party_imports = Vec::new();
341 let mut local_imports = Vec::new();
342
343 if !self.sections.future.is_empty() {
345 let future = self.format_imports(&self.sections.future);
346 future_imports.extend(future);
347 }
348
349 if !self.sections.standard_library_direct.is_empty() {
351 let std_direct_imports = self.format_imports(&self.sections.standard_library_direct);
352 stdlib_imports.extend(std_direct_imports);
353 }
354 if !self.sections.standard_library_from.is_empty() {
355 let std_from_imports = self.format_imports(&self.sections.standard_library_from);
356 stdlib_imports.extend(std_from_imports);
357 }
358
359 if !self.sections.third_party_direct.is_empty() {
361 let third_direct_imports = self.format_imports(&self.sections.third_party_direct);
362 third_party_imports.extend(third_direct_imports);
363 }
364 if !self.sections.third_party_from.is_empty() {
365 let third_from_imports = self.format_imports(&self.sections.third_party_from);
366 third_party_imports.extend(third_from_imports);
367 }
368
369 if !self.sections.local_direct.is_empty() {
371 let local_direct_imports = self.format_imports(&self.sections.local_direct);
372 local_imports.extend(local_direct_imports);
373 }
374 if !self.sections.local_from.is_empty() {
375 let local_from_imports = self.format_imports(&self.sections.local_from);
376 local_imports.extend(local_from_imports);
377 }
378
379 future_imports.sort();
381 stdlib_imports.sort();
382 third_party_imports.sort();
383 local_imports.sort();
384
385 (
386 future_imports,
387 stdlib_imports,
388 third_party_imports,
389 local_imports,
390 )
391 }
392
393 pub fn reset(&mut self) -> &mut Self {
396 self.sections = ImportSections::default();
397 self
398 }
399
400 #[must_use]
402 pub fn is_empty(&self) -> bool {
403 self.sections.future.is_empty()
404 && self.sections.standard_library_direct.is_empty()
405 && self.sections.standard_library_from.is_empty()
406 && self.sections.third_party_direct.is_empty()
407 && self.sections.third_party_from.is_empty()
408 && self.sections.local_direct.is_empty()
409 && self.sections.local_from.is_empty()
410 }
411
412 #[must_use]
414 pub fn is_type_checking_empty(&self) -> bool {
415 self.sections.type_checking_future.is_empty()
416 && self
417 .sections
418 .type_checking_standard_library_direct
419 .is_empty()
420 && self.sections.type_checking_standard_library_from.is_empty()
421 && self.sections.type_checking_third_party_direct.is_empty()
422 && self.sections.type_checking_third_party_from.is_empty()
423 && self.sections.type_checking_local_direct.is_empty()
424 && self.sections.type_checking_local_from.is_empty()
425 }
426
427 #[must_use]
429 pub fn count(&self) -> usize {
430 self.sections.future.len()
431 + self.sections.standard_library_direct.len()
432 + self.sections.standard_library_from.len()
433 + self.sections.third_party_direct.len()
434 + self.sections.third_party_from.len()
435 + self.sections.local_direct.len()
436 + self.sections.local_from.len()
437 }
438
439 #[must_use]
441 pub fn count_type_checking(&self) -> usize {
442 self.sections.type_checking_future.len()
443 + self.sections.type_checking_standard_library_direct.len()
444 + self.sections.type_checking_standard_library_from.len()
445 + self.sections.type_checking_third_party_direct.len()
446 + self.sections.type_checking_third_party_from.len()
447 + self.sections.type_checking_local_direct.len()
448 + self.sections.type_checking_local_from.len()
449 }
450
451 #[must_use]
453 pub fn get_formatted(&self) -> Vec<String> {
454 let mut result = Vec::new();
455 let mut has_previous_section = false;
456
457 if !self.sections.future.is_empty() {
459 let future_imports = self.format_imports(&self.sections.future);
460 result.extend(future_imports);
461 has_previous_section = true;
462 }
463
464 let std_has_direct = !self.sections.standard_library_direct.is_empty();
466 let std_has_from = !self.sections.standard_library_from.is_empty();
467
468 if std_has_direct || std_has_from {
469 if has_previous_section {
470 result.push(String::new()); }
472
473 if std_has_direct {
475 let std_direct_imports =
476 self.format_imports(&self.sections.standard_library_direct);
477 result.extend(std_direct_imports);
478 }
479
480 if std_has_from {
482 let std_from_imports = self.format_imports(&self.sections.standard_library_from);
483 result.extend(std_from_imports);
484 }
485
486 has_previous_section = true;
487 }
488
489 let third_has_direct = !self.sections.third_party_direct.is_empty();
491 let third_has_from = !self.sections.third_party_from.is_empty();
492
493 if third_has_direct || third_has_from {
494 if has_previous_section {
495 result.push(String::new()); }
497
498 if third_has_direct {
500 let third_direct_imports = self.format_imports(&self.sections.third_party_direct);
501 result.extend(third_direct_imports);
502 }
503
504 if third_has_from {
506 let third_from_imports = self.format_imports(&self.sections.third_party_from);
507 result.extend(third_from_imports);
508 }
509
510 has_previous_section = true;
511 }
512
513 let local_has_direct = !self.sections.local_direct.is_empty();
515 let local_has_from = !self.sections.local_from.is_empty();
516
517 if local_has_direct || local_has_from {
518 if has_previous_section {
519 result.push(String::new()); }
521
522 if local_has_direct {
524 let local_direct_imports = self.format_imports(&self.sections.local_direct);
525 result.extend(local_direct_imports);
526 }
527
528 if local_has_from {
530 let local_from_imports = self.format_imports(&self.sections.local_from);
531 result.extend(local_from_imports);
532 }
533 }
534
535 result
536 }
537
538 fn parse_import(&mut self, import_statement: &str) -> Option<ImportStatement> {
540 let trimmed = import_statement.trim();
541 if trimmed.is_empty() {
542 return None;
543 }
544
545 let category = self.categorize_import(trimmed);
546 let import_type = if trimmed.starts_with("from ") {
547 ImportType::From
548 } else {
549 ImportType::Direct
550 };
551 let package = Self::extract_package(trimmed);
552 let items = Self::extract_items(trimmed);
553 let is_multiline = trimmed.contains('(') || trimmed.contains(')');
554
555 let statement = if import_type == ImportType::From && !items.is_empty() {
557 format!("from {} import {}", package, items.join(", "))
558 } else {
559 trimmed.to_string()
560 };
561
562 Some(ImportStatement {
563 statement,
564 category,
565 import_type,
566 package,
567 items,
568 is_multiline,
569 })
570 }
571
572 fn categorize_import(&mut self, import_statement: &str) -> ImportCategory {
574 if import_statement.starts_with("from __future__") {
575 return ImportCategory::Future;
576 }
577
578 let package = Self::extract_package(import_statement);
579
580 if let Some(&cached_category) = self.category_cache.get(&package) {
582 return cached_category;
583 }
584
585 let category = if self.is_local_import(import_statement) {
590 ImportCategory::Local
591 } else if self.is_standard_library_package(&package) {
592 ImportCategory::StandardLibrary
593 } else if self.is_common_third_party_package(&package) {
594 ImportCategory::ThirdParty
595 } else {
596 ImportCategory::ThirdParty
598 };
599
600 self.category_cache.insert(package, category);
601 category
602 }
603
604 fn extract_package(import_statement: &str) -> String {
606 if let Some(from_part) = import_statement.strip_prefix("from ") {
607 if let Some(import_pos) = from_part.find(" import ") {
608 return from_part[..import_pos].trim().to_string();
609 }
610 } else if let Some(import_part) = import_statement.strip_prefix("import ") {
611 return import_part
613 .split_whitespace()
614 .next()
615 .unwrap_or(import_part)
616 .trim()
617 .to_string();
618 }
619
620 import_statement.to_string()
621 }
622
623 fn extract_items(import_statement: &str) -> Vec<String> {
625 if let Some(from_part) = import_statement.strip_prefix("from ") {
626 if let Some(import_pos) = from_part.find(" import ") {
627 let items_part = &from_part[import_pos + 8..];
628 let cleaned = items_part.replace(['(', ')'], "").replace(',', " ");
629 let mut items: Vec<String> = cleaned
630 .split_whitespace()
631 .map(|s| s.trim().to_string())
632 .filter(|s| !s.is_empty())
633 .collect();
634
635 items.sort_by(|a, b| Self::custom_import_sort(a, b));
637 return items;
638 }
639 } else if let Some(import_part) = import_statement.strip_prefix("import ") {
640 return vec![import_part.trim().to_string()];
642 }
643 Vec::new()
644 }
645
646 fn is_local_import(&self, import_statement: &str) -> bool {
648 if import_statement.contains("from .")
650 || import_statement.contains("from ..")
651 || import_statement.contains("from ...")
652 || import_statement.contains("from ....")
653 {
654 return true;
655 }
656
657 let package = Self::extract_package(import_statement);
658
659 for prefix in &self.local_package_prefixes {
661 if package.starts_with(prefix.as_str()) {
662 return true;
663 }
664 }
665
666 if let Some(pkg_name) = &self.package_name {
668 if package.starts_with(pkg_name) {
669 return true;
670 }
671 }
672
673 false
674 }
675
676 fn is_standard_library_package(&self, package: &str) -> bool {
678 self.registry.is_stdlib(package)
680 }
681
682 fn is_common_third_party_package(&self, package: &str) -> bool {
684 self.registry.is_third_party(package)
686 }
687
688 #[allow(clippy::unused_self)]
690 fn format_imports(&self, imports: &[ImportStatement]) -> Vec<String> {
691 let mut package_imports: HashMap<String, Vec<&ImportStatement>> = HashMap::new();
692
693 for import in imports {
695 package_imports
696 .entry(import.package.clone())
697 .or_default()
698 .push(import);
699 }
700
701 let mut result = Vec::new();
702 let mut packages: Vec<_> = package_imports.keys().collect();
703 packages.sort();
704
705 for package in packages {
706 let imports_for_package = package_imports.get(package).unwrap();
707
708 if imports_for_package.len() == 1 {
709 result.push(imports_for_package[0].statement.clone());
711 } else {
712 result.extend(Self::merge_package_imports(imports_for_package));
714 }
715 }
716
717 result
718 }
719
720 fn merge_package_imports(imports: &[&ImportStatement]) -> Vec<String> {
722 let mut all_items = HashSet::new();
723 let package = &imports[0].package;
724
725 for import in imports {
727 all_items.extend(import.items.iter().cloned());
728 }
729
730 if all_items.is_empty() {
731 return imports.iter().map(|i| i.statement.clone()).collect();
733 }
734
735 let mut sorted_items: Vec<_> = all_items.into_iter().collect();
736 sorted_items.sort_by(|a, b| Self::custom_import_sort(a, b));
737
738 if sorted_items.len() <= 3 && sorted_items.iter().map(String::len).sum::<usize>() < 60 {
740 vec![format!(
742 "from {} import {}",
743 package,
744 sorted_items.join(", ")
745 )]
746 } else {
747 let mut result = vec![format!("from {} import (", package)];
749 for item in sorted_items {
750 result.push(format!(" {item},"));
751 }
752 result.push(")".to_string());
753 result
754 }
755 }
756
757 fn custom_import_sort(a: &str, b: &str) -> std::cmp::Ordering {
759 let a_is_all_caps = a.chars().all(|c| c.is_uppercase() || !c.is_alphabetic());
760 let b_is_all_caps = b.chars().all(|c| c.is_uppercase() || !c.is_alphabetic());
761
762 match (a_is_all_caps, b_is_all_caps) {
763 (true, true) | (false, false) => a.cmp(b),
765 (true, false) => std::cmp::Ordering::Less,
767 (false, true) => std::cmp::Ordering::Greater,
769 }
770 }
771
772 fn ensure_type_checking_import_added(&mut self) {
774 let has_type_checking = self.sections.standard_library_from.iter().any(|import| {
776 import.package == "typing" && import.items.contains(&"TYPE_CHECKING".to_string())
777 });
778
779 if !has_type_checking {
780 if let Some(typing_import) = self
782 .sections
783 .standard_library_from
784 .iter_mut()
785 .find(|import| import.package == "typing")
786 {
787 if !typing_import.items.contains(&"TYPE_CHECKING".to_string()) {
789 typing_import.items.push("TYPE_CHECKING".to_string());
790 typing_import
791 .items
792 .sort_by(|a, b| Self::custom_import_sort(a, b));
793
794 if typing_import.items.len() == 1 {
796 typing_import.statement =
797 format!("from typing import {}", typing_import.items[0]);
798 } else {
799 typing_import.statement =
800 format!("from typing import {}", typing_import.items.join(", "));
801 }
802 }
803 } else {
804 self.add_import_string("from typing import TYPE_CHECKING");
806 }
807 }
808 }
809
810 #[must_use]
812 pub fn clone_config(&self) -> Self {
813 Self {
814 sections: ImportSections::default(),
815 category_cache: self.category_cache.clone(),
816 package_name: self.package_name.clone(),
817 local_package_prefixes: self.local_package_prefixes.clone(),
818 registry: self.registry.clone(),
819 }
820 }
821}
822
823impl ImportHelper {
825 pub fn create_model_imports(&mut self, required_types: &[String]) {
827 self.add_import_string("from pydantic import BaseModel, ConfigDict, Field");
829
830 let mut typing_imports = std::collections::HashSet::new();
832 let mut collections_abc_imports = std::collections::HashSet::new();
833 let mut datetime_imports = Vec::new();
834 let mut decimal_imports = Vec::new();
835
836 for type_name in required_types {
837 match type_name.as_str() {
838 "datetime" | "date" | "time" | "timedelta" => {
839 if !datetime_imports.contains(&type_name.as_str()) {
840 datetime_imports.push(type_name.as_str());
841 }
842 }
843 "Decimal" => {
844 if !decimal_imports.contains(&"Decimal") {
845 decimal_imports.push("Decimal");
846 }
847 }
848 "UUID" => {
849 self.add_import_string("from uuid import UUID");
850 }
851 _ => {
853 let extracted_typing = Self::extract_typing_imports_from_type(type_name);
855 typing_imports.extend(extracted_typing);
856
857 if type_name.contains("Callable") {
859 collections_abc_imports.insert("Callable".to_string());
860 }
861 }
862 }
863 }
864
865 if !datetime_imports.is_empty() {
867 let import_statement = format!("from datetime import {}", datetime_imports.join(", "));
868 self.add_regular_import(&import_statement);
869 }
870
871 if !decimal_imports.is_empty() {
873 self.add_import_string("from decimal import Decimal");
874 }
875
876 if !typing_imports.is_empty() {
878 let mut sorted_typing: Vec<String> = typing_imports.into_iter().collect();
879 sorted_typing.sort();
880 let import_statement = format!("from typing import {}", sorted_typing.join(", "));
881 self.add_regular_import(&import_statement);
882 }
883
884 if !collections_abc_imports.is_empty() {
886 let mut sorted_collections: Vec<String> = collections_abc_imports.into_iter().collect();
887 sorted_collections.sort();
888 let import_statement = format!(
889 "from collections.abc import {}",
890 sorted_collections.join(", ")
891 );
892 self.add_regular_import(&import_statement);
893 }
894 }
895
896 fn extract_typing_imports_from_type(type_str: &str) -> std::collections::HashSet<String> {
900 let mut typing_imports = std::collections::HashSet::new();
901
902 if type_str.contains("Any") {
904 typing_imports.insert("Any".to_string());
905 }
906
907 if type_str.contains("Generic") {
909 typing_imports.insert("Generic".to_string());
910 }
911
912 if type_str.contains("TypeVar") {
914 typing_imports.insert("TypeVar".to_string());
915 }
916
917 if type_str.contains("Protocol") {
919 typing_imports.insert("Protocol".to_string());
920 }
921
922 typing_imports
923 }
924}
925
926impl Default for ImportHelper {
927 fn default() -> Self {
928 Self::new()
929 }
930}
931
932#[cfg(test)]
933mod tests {
934 use super::*;
935
936 #[test]
937 fn test_import_categorization() {
938 let mut helper = ImportHelper::new();
939
940 helper.add_import_string("from __future__ import annotations");
942 assert_eq!(helper.sections.future.len(), 1);
943
944 helper.add_import_string("from typing import Optional");
946 assert_eq!(helper.sections.standard_library_from.len(), 1);
947
948 helper.add_import_string("import uuid");
950 assert_eq!(helper.sections.standard_library_direct.len(), 1);
951
952 helper.add_import_string("from pydantic import BaseModel");
954 assert_eq!(helper.sections.third_party_from.len(), 1);
955
956 helper.add_import_string("from .models import User");
958 assert_eq!(helper.sections.local_from.len(), 1);
959 }
960
961 #[test]
962 fn test_import_merging() {
963 let mut helper = ImportHelper::new();
964
965 helper.add_import_string("from typing import Optional");
966 helper.add_import_string("from typing import Any");
967 helper.add_import_string("from typing import List");
968
969 let imports = helper.get_formatted();
970
971 assert!(imports.iter().any(|i| i.contains("from typing import")));
973 assert!(imports
974 .iter()
975 .any(|i| i.contains("Any") && i.contains("Optional")));
976 }
977
978 #[test]
979 fn test_alphabetical_sorting_of_import_items() {
980 let mut helper = ImportHelper::new();
981
982 helper.add_import_string("from typing import Optional, Any, List, Dict");
984 helper.add_import_string("from collections import defaultdict, OrderedDict, Counter");
985 helper.add_import_string("import uuid");
986 helper.add_import_string("import os");
987
988 let imports = helper.get_formatted();
989
990 for import in &imports {
992 println!("{}", import);
993 }
994
995 let import_str = imports.join("\n");
997 let os_pos = import_str.find("import os").unwrap();
998 let uuid_pos = import_str.find("import uuid").unwrap();
999 let typing_pos = import_str.find("from typing import").unwrap();
1000
1001 assert!(
1003 os_pos < typing_pos,
1004 "Direct imports should come before from imports"
1005 );
1006 assert!(
1007 uuid_pos < typing_pos,
1008 "Direct imports should come before from imports"
1009 );
1010
1011 let typing_import = imports
1013 .iter()
1014 .find(|s| s.contains("from typing import"))
1015 .unwrap();
1016
1017 assert!(
1019 typing_import.contains("Any, Dict, List, Optional")
1020 || typing_import.contains("(\n Any,\n Dict,\n List,\n Optional,\n)"),
1021 "Import items should be sorted alphabetically, got: {}",
1022 typing_import
1023 );
1024 }
1025
1026 #[test]
1027 fn test_direct_imports_sorted_alphabetically() {
1028 let mut helper = ImportHelper::new();
1029
1030 helper.add_import_string("import uuid");
1031 helper.add_import_string("import os");
1032 helper.add_import_string("import sys");
1033 helper.add_import_string("import json");
1034
1035 let imports = helper.get_formatted();
1036
1037 let import_lines: Vec<String> = imports
1039 .iter()
1040 .filter(|s| s.starts_with("import "))
1041 .cloned()
1042 .collect();
1043
1044 assert_eq!(import_lines.len(), 4);
1045 assert!(import_lines[0].contains("json"));
1046 assert!(import_lines[1].contains("os"));
1047 assert!(import_lines[2].contains("sys"));
1048 assert!(import_lines[3].contains("uuid"));
1049 }
1050
1051 #[test]
1052 fn test_uppercase_priority_in_import_sorting() {
1053 let mut helper = ImportHelper::new();
1054
1055 helper.add_import_string("from example import Ab, AA, Aa, AB");
1057
1058 let imports = helper.get_formatted();
1059
1060 for import in &imports {
1062 println!("{}", import);
1063 }
1064
1065 let example_import = imports
1066 .iter()
1067 .find(|s| s.contains("from example import"))
1068 .unwrap();
1069
1070 assert!(
1073 example_import.contains("AA, AB, Aa, Ab"),
1074 "Import items should be sorted with uppercase priority, got: {}",
1075 example_import
1076 );
1077 }
1078
1079 #[test]
1080 fn test_comprehensive_case_sorting() {
1081 let mut helper = ImportHelper::new();
1082
1083 helper.add_import_string("from test import zz, ZZ, bb, BB, aa, AA, cc, CC");
1085
1086 let imports = helper.get_formatted();
1087
1088 let test_import = imports
1089 .iter()
1090 .find(|s| s.contains("from test import"))
1091 .unwrap();
1092
1093 assert!(
1096 test_import.contains("AA, BB, CC, ZZ, aa, bb, cc, zz"),
1097 "Import items should be sorted with uppercase priority across all letters, got: {}",
1098 test_import
1099 );
1100 }
1101
1102 #[test]
1103 fn test_type_checking_imports() {
1104 let mut helper = ImportHelper::with_package_name("mypackage".to_string());
1105
1106 helper.add_import_string("from __future__ import annotations");
1108 helper.add_import_string("from typing import Any");
1109 helper.add_import_string("from pydantic import BaseModel");
1110
1111 helper.add_type_checking_import("import httpx");
1113 helper.add_type_checking_import("from typing import TYPE_CHECKING");
1114 helper.add_type_checking_import("from collections.abc import Callable");
1115 helper.add_type_checking_import("from mypackage.models import User");
1116
1117 assert_eq!(helper.count(), 3);
1119 assert_eq!(helper.count_type_checking(), 4);
1120 assert!(!helper.is_type_checking_empty());
1121
1122 let (future, stdlib, third_party, local) = helper.get_type_checking_categorized();
1124
1125 assert!(
1127 third_party.iter().any(|s| s.contains("import httpx")),
1128 "Should have httpx in third_party"
1129 );
1130 assert!(
1131 stdlib
1132 .iter()
1133 .any(|s| s.contains("from typing import TYPE_CHECKING")),
1134 "Should have TYPE_CHECKING"
1135 );
1136 assert!(
1137 stdlib
1138 .iter()
1139 .any(|s| s.contains("from collections.abc import Callable")),
1140 "Should have Callable"
1141 );
1142 assert!(
1143 local
1144 .iter()
1145 .any(|s| s.contains("from mypackage.models import User")),
1146 "Should have User in local"
1147 );
1148 assert!(
1149 future.is_empty(),
1150 "Should have no future imports in TYPE_CHECKING"
1151 );
1152 }
1153
1154 #[test]
1155 fn test_programmatic_import_builders() {
1156 let mut helper = ImportHelper::new();
1157
1158 helper.add_from_import("typing", &["Any", "Optional"]);
1160 helper.add_from_import("json", &["loads"]);
1161
1162 helper.add_direct_import("sys");
1164
1165 helper.add_type_checking_from_import("httpx", &["Client", "Response"]);
1167 helper.add_type_checking_direct_import("logging");
1168
1169 let (_, stdlib, _, _) = helper.get_categorized();
1171 assert!(
1172 stdlib
1173 .iter()
1174 .any(|s| s.contains("Any") && s.contains("Optional")),
1175 "Should have Any and Optional in typing imports"
1176 );
1177 assert!(stdlib.iter().any(|s| s.contains("from json import loads")));
1178 assert!(stdlib.iter().any(|s| s.contains("import sys")));
1179
1180 let (_, tc_stdlib, tc_third_party, _) = helper.get_type_checking_categorized();
1182 assert!(tc_third_party
1183 .iter()
1184 .any(|s| s.contains("from httpx import Client, Response")));
1185 assert!(tc_stdlib.iter().any(|s| s.contains("import logging")));
1186 }
1187
1188 #[test]
1189 fn test_type_checking_four_categories() {
1190 let mut helper = ImportHelper::with_package_name("myapp".to_string());
1191
1192 helper.add_type_checking_import("from __future__ import annotations");
1194 helper.add_type_checking_import("from typing import Protocol");
1195 helper.add_type_checking_import("from httpx import Client");
1196 helper.add_type_checking_import("from myapp.models import User");
1197
1198 let (future, stdlib, third_party, local) = helper.get_type_checking_categorized();
1199
1200 assert_eq!(future.len(), 1, "Should have 1 future import");
1202 assert_eq!(stdlib.len(), 1, "Should have 1 stdlib import");
1203 assert_eq!(third_party.len(), 1, "Should have 1 third-party import");
1204 assert_eq!(local.len(), 1, "Should have 1 local import");
1205
1206 assert!(future[0].contains("from __future__ import annotations"));
1207 assert!(stdlib[0].contains("from typing import Protocol"));
1208 assert!(third_party[0].contains("from httpx import Client"));
1209 assert!(local[0].contains("from myapp.models import User"));
1210 }
1211}