1use std::collections::{HashMap, HashSet};
8
9use crate::registry::PackageRegistry;
10use crate::types::{AllCategorizedImports, CategorizedImports, FormattingConfig, ImportSpec};
11use crate::{ImportCategory, ImportSections, ImportStatement, ImportType};
12
13#[derive(Debug)]
15pub struct ImportHelper {
16 sections: ImportSections,
18 category_cache: HashMap<String, ImportCategory>,
20 package_name: Option<String>,
22 local_package_prefixes: HashSet<String>,
24 registry: PackageRegistry,
26 formatting_config: FormattingConfig,
28}
29
30impl ImportHelper {
31 #[must_use]
33 pub fn new() -> Self {
34 Self {
35 sections: ImportSections::default(),
36 category_cache: HashMap::new(),
37 package_name: None,
38 local_package_prefixes: HashSet::new(),
39 registry: PackageRegistry::new(),
40 formatting_config: FormattingConfig::default(),
41 }
42 }
43
44 #[must_use]
46 pub fn with_package_name(package_name: String) -> Self {
47 let mut helper = Self::new();
48 helper.package_name = Some(package_name.clone());
49 helper.local_package_prefixes.insert(package_name);
50 helper
51 }
52
53 #[must_use]
55 pub fn with_formatting_config(config: FormattingConfig) -> Self {
56 Self {
57 sections: ImportSections::default(),
58 category_cache: HashMap::new(),
59 package_name: None,
60 local_package_prefixes: HashSet::new(),
61 registry: PackageRegistry::new(),
62 formatting_config: config,
63 }
64 }
65
66 #[must_use]
68 pub fn with_package_and_config(package_name: String, config: FormattingConfig) -> Self {
69 let mut helper = Self::with_formatting_config(config);
70 helper.package_name = Some(package_name.clone());
71 helper.local_package_prefixes.insert(package_name);
72 helper
73 }
74
75 #[must_use]
77 pub fn formatting_config(&self) -> &FormattingConfig {
78 &self.formatting_config
79 }
80
81 pub fn set_formatting_config(&mut self, config: FormattingConfig) {
83 self.formatting_config = config;
84 }
85
86 #[must_use]
97 pub fn registry(&self) -> &PackageRegistry {
98 &self.registry
99 }
100
101 pub fn registry_mut(&mut self) -> &mut PackageRegistry {
121 &mut self.registry
123 }
124
125 pub fn clear_cache(&mut self) -> &mut Self {
140 self.category_cache.clear();
141 self
142 }
143
144 pub fn add_local_package_prefix(&mut self, prefix: impl Into<String>) -> &mut Self {
146 let prefix = prefix.into();
147 self.local_package_prefixes.insert(prefix);
148 self
150 }
151
152 pub fn add_local_package_prefixes(&mut self, prefixes: &[impl AsRef<str>]) -> &mut Self {
154 for prefix in prefixes {
155 self.add_local_package_prefix(prefix.as_ref());
156 }
157 self
158 }
159
160 pub fn add_import(&mut self, spec: &ImportSpec) {
162 let import_statement = if let Some(items) = &spec.items {
163 format!("from {} import {}", spec.package, items.join(", "))
164 } else {
165 format!("import {}", spec.package)
166 };
167
168 if spec.type_checking {
169 self.add_type_checking_import(&import_statement);
170 } else {
171 self.add_regular_import(&import_statement);
172 }
173 }
174
175 pub fn add_import_string(&mut self, import_statement: &str) {
177 self.add_regular_import(import_statement);
178 }
179
180 fn add_regular_import(&mut self, import_statement: &str) {
182 if let Some(import) = self.parse_import(import_statement) {
183 match (&import.category, &import.import_type) {
184 (ImportCategory::Future, _) => self.sections.future.push(import),
185 (ImportCategory::StandardLibrary, ImportType::Direct) => {
186 self.sections.standard_library_direct.push(import)
187 }
188 (ImportCategory::StandardLibrary, ImportType::From) => {
189 self.sections.standard_library_from.push(import)
190 }
191 (ImportCategory::ThirdParty, ImportType::Direct) => {
192 self.sections.third_party_direct.push(import)
193 }
194 (ImportCategory::ThirdParty, ImportType::From) => {
195 self.sections.third_party_from.push(import)
196 }
197 (ImportCategory::Local, ImportType::Direct) => {
198 self.sections.local_direct.push(import)
199 }
200 (ImportCategory::Local, ImportType::From) => self.sections.local_from.push(import),
201 }
202 }
203 }
204
205 pub fn add_from_import(&mut self, package: &str, items: &[&str]) {
208 let import_statement = if items.len() == 1 {
209 format!("from {} import {}", package, items[0])
210 } else {
211 format!("from {} import {}", package, items.join(", "))
212 };
213 self.add_regular_import(&import_statement);
214 }
215
216 pub fn add_from_import_multiline(&mut self, package: &str, items: &[&str]) {
218 if items.is_empty() {
219 return;
220 }
221
222 if items.len() == 1 {
223 self.add_from_import(package, items);
224 return;
225 }
226
227 let mut import_statement = format!("from {} import (\n", package);
228 for item in items {
229 import_statement.push_str(&format!(" {},\n", item));
230 }
231 import_statement.push(')');
232
233 self.add_regular_import(&import_statement);
234 }
235
236 pub fn add_type_checking_from_import(&mut self, package: &str, items: &[&str]) {
239 let import_statement = if items.len() == 1 {
240 format!("from {} import {}", package, items[0])
241 } else {
242 format!("from {} import {}", package, items.join(", "))
243 };
244 self.add_type_checking_import(&import_statement);
245 }
246
247 pub fn add_direct_import(&mut self, module: &str) {
250 let import_statement = format!("import {module}");
251 self.add_regular_import(&import_statement);
252 }
253
254 pub fn add_type_checking_direct_import(&mut self, module: &str) {
257 let import_statement = format!("import {module}");
258 self.add_type_checking_import(&import_statement);
259 }
260
261 pub fn add_type_checking_import(&mut self, import_statement: &str) {
263 if let Some(import) = self.parse_import(import_statement) {
264 match (&import.category, &import.import_type) {
265 (ImportCategory::Future, _) => self.sections.type_checking_future.push(import),
266 (ImportCategory::StandardLibrary, ImportType::Direct) => self
267 .sections
268 .type_checking_standard_library_direct
269 .push(import),
270 (ImportCategory::StandardLibrary, ImportType::From) => self
271 .sections
272 .type_checking_standard_library_from
273 .push(import),
274 (ImportCategory::ThirdParty, ImportType::Direct) => {
275 self.sections.type_checking_third_party_direct.push(import);
276 }
277 (ImportCategory::ThirdParty, ImportType::From) => {
278 self.sections.type_checking_third_party_from.push(import);
279 }
280 (ImportCategory::Local, ImportType::Direct) => {
281 self.sections.type_checking_local_direct.push(import);
282 }
283 (ImportCategory::Local, ImportType::From) => {
284 self.sections.type_checking_local_from.push(import);
285 }
286 }
287
288 self.ensure_type_checking_import_added();
290 }
291 }
292
293 #[must_use]
297 pub fn get_all_categorized(&self) -> AllCategorizedImports {
298 let (future_imports, stdlib_imports, third_party_imports, local_imports) =
300 self.get_categorized();
301
302 let (tc_future, tc_stdlib, tc_third_party, tc_local) = self.get_type_checking_categorized();
304
305 (
306 future_imports,
307 stdlib_imports,
308 third_party_imports,
309 local_imports,
310 tc_future,
311 tc_stdlib,
312 tc_third_party,
313 tc_local,
314 )
315 }
316
317 #[must_use]
324 pub fn get_type_checking_categorized(&self) -> CategorizedImports {
325 self.get_type_checking_categorized_impl()
326 }
327
328 #[must_use]
329 pub fn get_type_checking_categorized_impl(&self) -> CategorizedImports {
330 let mut future_imports = Vec::new();
331 let mut stdlib_imports = Vec::new();
332 let mut third_party_imports = Vec::new();
333 let mut local_imports = Vec::new();
334
335 if !self.sections.type_checking_future.is_empty() {
337 let future = self.format_imports(&self.sections.type_checking_future);
338 future_imports.extend(future);
339 }
340
341 if !self
343 .sections
344 .type_checking_standard_library_direct
345 .is_empty()
346 {
347 let std_direct =
348 self.format_imports(&self.sections.type_checking_standard_library_direct);
349 stdlib_imports.extend(std_direct);
350 }
351 if !self.sections.type_checking_standard_library_from.is_empty() {
352 let std_from = self.format_imports(&self.sections.type_checking_standard_library_from);
353 stdlib_imports.extend(std_from);
354 }
355
356 if !self.sections.type_checking_third_party_direct.is_empty() {
358 let third_direct = self.format_imports(&self.sections.type_checking_third_party_direct);
359 third_party_imports.extend(third_direct);
360 }
361 if !self.sections.type_checking_third_party_from.is_empty() {
362 let third_from = self.format_imports(&self.sections.type_checking_third_party_from);
363 third_party_imports.extend(third_from);
364 }
365
366 if !self.sections.type_checking_local_direct.is_empty() {
368 let local_direct = self.format_imports(&self.sections.type_checking_local_direct);
369 local_imports.extend(local_direct);
370 }
371 if !self.sections.type_checking_local_from.is_empty() {
372 let local_from = self.format_imports(&self.sections.type_checking_local_from);
373 local_imports.extend(local_from);
374 }
375
376 future_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
378 stdlib_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
379 third_party_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
380 local_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
381
382 (
383 future_imports,
384 stdlib_imports,
385 third_party_imports,
386 local_imports,
387 )
388 }
389
390 #[must_use]
393 pub fn get_categorized(&self) -> CategorizedImports {
394 let mut future_imports = Vec::new();
395 let mut stdlib_imports = Vec::new();
396 let mut third_party_imports = Vec::new();
397 let mut local_imports = Vec::new();
398
399 if !self.sections.future.is_empty() {
401 let future = self.format_imports(&self.sections.future);
402 future_imports.extend(future);
403 }
404
405 if !self.sections.standard_library_direct.is_empty() {
407 let std_direct_imports = self.format_imports(&self.sections.standard_library_direct);
408 stdlib_imports.extend(std_direct_imports);
409 }
410 if !self.sections.standard_library_from.is_empty() {
411 let std_from_imports = self.format_imports(&self.sections.standard_library_from);
412 stdlib_imports.extend(std_from_imports);
413 }
414
415 if !self.sections.third_party_direct.is_empty() {
417 let third_direct_imports = self.format_imports(&self.sections.third_party_direct);
418 third_party_imports.extend(third_direct_imports);
419 }
420 if !self.sections.third_party_from.is_empty() {
421 let third_from_imports = self.format_imports(&self.sections.third_party_from);
422 third_party_imports.extend(third_from_imports);
423 }
424
425 if !self.sections.local_direct.is_empty() {
427 let local_direct_imports = self.format_imports(&self.sections.local_direct);
428 local_imports.extend(local_direct_imports);
429 }
430 if !self.sections.local_from.is_empty() {
431 let local_from_imports = self.format_imports(&self.sections.local_from);
432 local_imports.extend(local_from_imports);
433 }
434
435 future_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
437 stdlib_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
438 third_party_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
439 local_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
440
441 (
442 future_imports,
443 stdlib_imports,
444 third_party_imports,
445 local_imports,
446 )
447 }
448
449 pub fn clear(&mut self) -> &mut Self {
471 self.sections = ImportSections::default();
472 self.category_cache.clear();
473 self
474 }
475
476 pub fn reset(&mut self) -> &mut Self {
502 self.sections = ImportSections::default();
503 self.category_cache.clear();
504 self.package_name = None;
505 self.local_package_prefixes.clear();
506 self.registry = PackageRegistry::new();
507 self.formatting_config = FormattingConfig::default();
508 self
509 }
510
511 #[must_use]
513 pub fn is_empty(&self) -> bool {
514 self.sections.future.is_empty()
515 && self.sections.standard_library_direct.is_empty()
516 && self.sections.standard_library_from.is_empty()
517 && self.sections.third_party_direct.is_empty()
518 && self.sections.third_party_from.is_empty()
519 && self.sections.local_direct.is_empty()
520 && self.sections.local_from.is_empty()
521 }
522
523 #[must_use]
525 pub fn is_type_checking_empty(&self) -> bool {
526 self.sections.type_checking_future.is_empty()
527 && self
528 .sections
529 .type_checking_standard_library_direct
530 .is_empty()
531 && self.sections.type_checking_standard_library_from.is_empty()
532 && self.sections.type_checking_third_party_direct.is_empty()
533 && self.sections.type_checking_third_party_from.is_empty()
534 && self.sections.type_checking_local_direct.is_empty()
535 && self.sections.type_checking_local_from.is_empty()
536 }
537
538 #[must_use]
540 pub fn count(&self) -> usize {
541 self.sections.future.len()
542 + self.sections.standard_library_direct.len()
543 + self.sections.standard_library_from.len()
544 + self.sections.third_party_direct.len()
545 + self.sections.third_party_from.len()
546 + self.sections.local_direct.len()
547 + self.sections.local_from.len()
548 }
549
550 #[must_use]
552 pub fn count_type_checking(&self) -> usize {
553 self.sections.type_checking_future.len()
554 + self.sections.type_checking_standard_library_direct.len()
555 + self.sections.type_checking_standard_library_from.len()
556 + self.sections.type_checking_third_party_direct.len()
557 + self.sections.type_checking_third_party_from.len()
558 + self.sections.type_checking_local_direct.len()
559 + self.sections.type_checking_local_from.len()
560 }
561
562 #[must_use]
564 pub fn get_formatted(&self) -> Vec<String> {
565 let mut result = Vec::new();
566 let mut has_previous_section = false;
567
568 if !self.sections.future.is_empty() {
570 let future_imports = self.format_imports(&self.sections.future);
571 result.extend(future_imports);
572 has_previous_section = true;
573 }
574
575 let std_has_direct = !self.sections.standard_library_direct.is_empty();
577 let std_has_from = !self.sections.standard_library_from.is_empty();
578
579 if std_has_direct || std_has_from {
580 if has_previous_section {
581 result.push(String::new()); }
583
584 if std_has_direct {
586 let std_direct_imports =
587 self.format_imports(&self.sections.standard_library_direct);
588 result.extend(std_direct_imports);
589 }
590
591 if std_has_from {
593 let std_from_imports = self.format_imports(&self.sections.standard_library_from);
594 result.extend(std_from_imports);
595 }
596
597 has_previous_section = true;
598 }
599
600 let third_has_direct = !self.sections.third_party_direct.is_empty();
602 let third_has_from = !self.sections.third_party_from.is_empty();
603
604 if third_has_direct || third_has_from {
605 if has_previous_section {
606 result.push(String::new()); }
608
609 if third_has_direct {
611 let third_direct_imports = self.format_imports(&self.sections.third_party_direct);
612 result.extend(third_direct_imports);
613 }
614
615 if third_has_from {
617 let third_from_imports = self.format_imports(&self.sections.third_party_from);
618 result.extend(third_from_imports);
619 }
620
621 has_previous_section = true;
622 }
623
624 let local_has_direct = !self.sections.local_direct.is_empty();
626 let local_has_from = !self.sections.local_from.is_empty();
627
628 if local_has_direct || local_has_from {
629 if has_previous_section {
630 result.push(String::new()); }
632
633 if local_has_direct {
635 let local_direct_imports = self.format_imports(&self.sections.local_direct);
636 result.extend(local_direct_imports);
637 }
638
639 if local_has_from {
641 let local_from_imports = self.format_imports(&self.sections.local_from);
642 result.extend(local_from_imports);
643 }
644 }
645
646 result
647 }
648
649 fn parse_import(&mut self, import_statement: &str) -> Option<ImportStatement> {
651 let trimmed = import_statement.trim();
652 if trimmed.is_empty() {
653 return None;
654 }
655
656 let category = self.categorize_import(trimmed);
657 let import_type = if trimmed.starts_with("from ") {
658 ImportType::From
659 } else {
660 ImportType::Direct
661 };
662 let package = Self::extract_package(trimmed);
663 let items = Self::extract_items(trimmed);
664 let is_multiline = trimmed.contains('(') || trimmed.contains(')');
665
666 let statement = if is_multiline {
669 trimmed.to_string()
670 } else if import_type == ImportType::From && !items.is_empty() {
671 format!("from {} import {}", package, items.join(", "))
672 } else {
673 trimmed.to_string()
674 };
675
676 Some(ImportStatement {
677 statement,
678 category,
679 import_type,
680 package,
681 items,
682 is_multiline,
683 })
684 }
685
686 fn categorize_import(&mut self, import_statement: &str) -> ImportCategory {
688 if import_statement.starts_with("from __future__") {
689 return ImportCategory::Future;
690 }
691
692 let package = Self::extract_package(import_statement);
693
694 if let Some(&cached_category) = self.category_cache.get(&package) {
696 return cached_category;
697 }
698
699 let category = if self.is_local_import(import_statement) {
704 ImportCategory::Local
705 } else if self.is_standard_library_package(&package) {
706 ImportCategory::StandardLibrary
707 } else if self.is_common_third_party_package(&package) {
708 ImportCategory::ThirdParty
709 } else {
710 ImportCategory::ThirdParty
712 };
713
714 self.category_cache.insert(package, category);
715 category
716 }
717
718 fn extract_package(import_statement: &str) -> String {
720 if let Some(from_part) = import_statement.strip_prefix("from ") {
721 if let Some(import_pos) = from_part.find(" import ") {
722 return from_part[..import_pos].trim().to_string();
723 }
724 } else if let Some(import_part) = import_statement.strip_prefix("import ") {
725 return import_part
727 .split_whitespace()
728 .next()
729 .unwrap_or(import_part)
730 .trim()
731 .to_string();
732 }
733
734 import_statement.to_string()
735 }
736
737 fn extract_items(import_statement: &str) -> Vec<String> {
739 if let Some(from_part) = import_statement.strip_prefix("from ") {
740 if let Some(import_pos) = from_part.find(" import ") {
741 let items_part = &from_part[import_pos + 8..];
742 let cleaned = items_part.replace(['(', ')'], "").replace(',', " ");
743 let mut items: Vec<String> = cleaned
744 .split_whitespace()
745 .map(|s| s.trim().to_string())
746 .filter(|s| !s.is_empty())
747 .collect();
748
749 items.sort_by(|a, b| crate::utils::parsing::custom_import_sort(a, b));
751 return items;
752 }
753 } else if let Some(import_part) = import_statement.strip_prefix("import ") {
754 return vec![import_part.trim().to_string()];
756 }
757 Vec::new()
758 }
759
760 fn is_local_import(&self, import_statement: &str) -> bool {
762 if import_statement.contains("from .")
764 || import_statement.contains("from ..")
765 || import_statement.contains("from ...")
766 || import_statement.contains("from ....")
767 {
768 return true;
769 }
770
771 let package = Self::extract_package(import_statement);
772
773 for prefix in &self.local_package_prefixes {
775 if package.starts_with(prefix.as_str()) {
776 return true;
777 }
778 }
779
780 if let Some(pkg_name) = &self.package_name {
782 if package.starts_with(pkg_name) {
783 return true;
784 }
785 }
786
787 false
788 }
789
790 fn is_standard_library_package(&self, package: &str) -> bool {
792 self.registry.is_stdlib(package)
794 }
795
796 fn is_common_third_party_package(&self, package: &str) -> bool {
798 self.registry.is_third_party(package)
800 }
801
802 fn format_imports(&self, imports: &[ImportStatement]) -> Vec<String> {
804 crate::utils::formatting::format_imports(imports, &self.formatting_config)
805 }
806
807 fn sort_import_statements(a: &str, b: &str) -> std::cmp::Ordering {
808 let a_is_import = a.starts_with("import ");
809 let b_is_import = b.starts_with("import ");
810
811 match (a_is_import, b_is_import) {
812 (true, true) | (false, false) => a.cmp(b),
814 (true, false) => std::cmp::Ordering::Less,
816 (false, true) => std::cmp::Ordering::Greater,
818 }
819 }
820
821 fn ensure_type_checking_import_added(&mut self) {
823 let has_type_checking = self.sections.standard_library_from.iter().any(|import| {
825 import.package == "typing" && import.items.contains(&"TYPE_CHECKING".to_string())
826 });
827
828 if !has_type_checking {
829 if let Some(typing_import) = self
831 .sections
832 .standard_library_from
833 .iter_mut()
834 .find(|import| import.package == "typing")
835 {
836 if !typing_import.items.contains(&"TYPE_CHECKING".to_string()) {
838 typing_import.items.push("TYPE_CHECKING".to_string());
839 typing_import
840 .items
841 .sort_by(|a, b| crate::utils::parsing::custom_import_sort(a, b));
842
843 if typing_import.items.len() == 1 {
845 typing_import.statement =
846 format!("from typing import {}", typing_import.items[0]);
847 } else {
848 typing_import.statement =
849 format!("from typing import {}", typing_import.items.join(", "));
850 }
851 }
852 } else {
853 self.add_import_string("from typing import TYPE_CHECKING");
855 }
856 }
857 }
858
859 #[must_use]
861 pub fn clone_config(&self) -> Self {
862 Self {
863 sections: ImportSections::default(),
864 category_cache: self.category_cache.clone(),
865 package_name: self.package_name.clone(),
866 local_package_prefixes: self.local_package_prefixes.clone(),
867 registry: self.registry.clone(),
868 formatting_config: self.formatting_config.clone(),
869 }
870 }
871}
872
873impl ImportHelper {
875 pub fn create_model_imports(&mut self, required_types: &[String]) {
877 self.add_import_string("from pydantic import BaseModel, ConfigDict, Field");
879
880 let mut typing_imports = std::collections::HashSet::new();
882 let mut collections_abc_imports = std::collections::HashSet::new();
883 let mut datetime_imports = Vec::new();
884 let mut decimal_imports = Vec::new();
885
886 for type_name in required_types {
887 match type_name.as_str() {
888 "datetime" | "date" | "time" | "timedelta" => {
889 if !datetime_imports.contains(&type_name.as_str()) {
890 datetime_imports.push(type_name.as_str());
891 }
892 }
893 "Decimal" => {
894 if !decimal_imports.contains(&"Decimal") {
895 decimal_imports.push("Decimal");
896 }
897 }
898 "UUID" => {
899 self.add_import_string("from uuid import UUID");
900 }
901 _ => {
903 let extracted_typing = Self::extract_typing_imports_from_type(type_name);
905 typing_imports.extend(extracted_typing);
906
907 if type_name.contains("Callable") {
909 collections_abc_imports.insert("Callable".to_string());
910 }
911 }
912 }
913 }
914
915 if !datetime_imports.is_empty() {
917 let import_statement = format!("from datetime import {}", datetime_imports.join(", "));
918 self.add_regular_import(&import_statement);
919 }
920
921 if !decimal_imports.is_empty() {
923 self.add_import_string("from decimal import Decimal");
924 }
925
926 if !typing_imports.is_empty() {
928 let mut sorted_typing: Vec<String> = typing_imports.into_iter().collect();
929 sorted_typing.sort();
930 let import_statement = format!("from typing import {}", sorted_typing.join(", "));
931 self.add_regular_import(&import_statement);
932 }
933
934 if !collections_abc_imports.is_empty() {
936 let mut sorted_collections: Vec<String> = collections_abc_imports.into_iter().collect();
937 sorted_collections.sort();
938 let import_statement = format!(
939 "from collections.abc import {}",
940 sorted_collections.join(", ")
941 );
942 self.add_regular_import(&import_statement);
943 }
944 }
945
946 fn extract_typing_imports_from_type(type_str: &str) -> std::collections::HashSet<String> {
950 let mut typing_imports = std::collections::HashSet::new();
951
952 if type_str.contains("Any") {
954 typing_imports.insert("Any".to_string());
955 }
956
957 if type_str.contains("Generic") {
959 typing_imports.insert("Generic".to_string());
960 }
961
962 if type_str.contains("TypeVar") {
964 typing_imports.insert("TypeVar".to_string());
965 }
966
967 if type_str.contains("Protocol") {
969 typing_imports.insert("Protocol".to_string());
970 }
971
972 typing_imports
973 }
974}
975
976impl Default for ImportHelper {
977 fn default() -> Self {
978 Self::new()
979 }
980}
981
982#[cfg(test)]
983mod tests {
984 use super::*;
985
986 #[test]
987 fn test_import_categorization() {
988 let mut helper = ImportHelper::new();
989
990 helper.add_import_string("from __future__ import annotations");
992 assert_eq!(helper.sections.future.len(), 1);
993
994 helper.add_import_string("from typing import Optional");
996 assert_eq!(helper.sections.standard_library_from.len(), 1);
997
998 helper.add_import_string("import uuid");
1000 assert_eq!(helper.sections.standard_library_direct.len(), 1);
1001
1002 helper.add_import_string("from pydantic import BaseModel");
1004 assert_eq!(helper.sections.third_party_from.len(), 1);
1005
1006 helper.add_import_string("from .models import User");
1008 assert_eq!(helper.sections.local_from.len(), 1);
1009 }
1010
1011 #[test]
1012 fn test_import_merging() {
1013 let mut helper = ImportHelper::new();
1014
1015 helper.add_import_string("from typing import Optional");
1016 helper.add_import_string("from typing import Any");
1017 helper.add_import_string("from typing import List");
1018
1019 let imports = helper.get_formatted();
1020
1021 assert!(imports.iter().any(|i| i.contains("from typing import")));
1023 assert!(imports
1024 .iter()
1025 .any(|i| i.contains("Any") && i.contains("Optional")));
1026 }
1027
1028 #[test]
1029 fn test_alphabetical_sorting_of_import_items() {
1030 let mut helper = ImportHelper::new();
1031
1032 helper.add_import_string("from typing import Optional, Any, List, Dict");
1034 helper.add_import_string("from collections import defaultdict, OrderedDict, Counter");
1035 helper.add_import_string("import uuid");
1036 helper.add_import_string("import os");
1037
1038 let imports = helper.get_formatted();
1039
1040 for import in &imports {
1042 println!("{}", import);
1043 }
1044
1045 let import_str = imports.join("\n");
1047 let os_pos = import_str.find("import os").unwrap();
1048 let uuid_pos = import_str.find("import uuid").unwrap();
1049 let typing_pos = import_str.find("from typing import").unwrap();
1050
1051 assert!(
1053 os_pos < typing_pos,
1054 "Direct imports should come before from imports"
1055 );
1056 assert!(
1057 uuid_pos < typing_pos,
1058 "Direct imports should come before from imports"
1059 );
1060
1061 let typing_section = &import_str[import_str.find("from typing").unwrap()..];
1063
1064 assert!(
1066 typing_section.contains("Any")
1067 && typing_section.contains("Dict")
1068 && typing_section.contains("List")
1069 && typing_section.contains("Optional"),
1070 "Import should contain all items: Any, Dict, List, Optional in alphabetical order"
1071 );
1072 }
1073
1074 #[test]
1075 fn test_direct_imports_sorted_alphabetically() {
1076 let mut helper = ImportHelper::new();
1077
1078 helper.add_direct_import("uuid");
1079 helper.add_direct_import("os");
1080 helper.add_direct_import("sys");
1081 helper.add_direct_import("json");
1082
1083 let imports = helper.get_formatted();
1084
1085 let import_lines: Vec<String> = imports
1087 .iter()
1088 .filter(|s| !s.is_empty() && s.contains("import"))
1089 .cloned()
1090 .collect();
1091
1092 assert_eq!(
1094 import_lines.len(),
1095 4,
1096 "Should have 4 direct imports, got: {:?}",
1097 import_lines
1098 );
1099
1100 assert!(
1102 import_lines[0].contains("json"),
1103 "First should be json, got: {}",
1104 import_lines[0]
1105 );
1106 assert!(
1107 import_lines[1].contains("os"),
1108 "Second should be os, got: {}",
1109 import_lines[1]
1110 );
1111 assert!(
1112 import_lines[2].contains("sys"),
1113 "Third should be sys, got: {}",
1114 import_lines[2]
1115 );
1116 assert!(
1117 import_lines[3].contains("uuid"),
1118 "Fourth should be uuid, got: {}",
1119 import_lines[3]
1120 );
1121 }
1122
1123 #[test]
1124 fn test_uppercase_priority_in_import_sorting() {
1125 let mut helper = ImportHelper::new();
1126
1127 helper.add_import_string("from example import Ab, AA, Aa, AB");
1129
1130 let imports = helper.get_formatted();
1131
1132 for import in &imports {
1134 println!("{}", import);
1135 }
1136
1137 let import_str = imports.join("\n");
1138 let example_section = &import_str[import_str.find("from example").unwrap()..];
1139
1140 assert!(example_section.contains("AA"), "Should contain AA");
1142 assert!(example_section.contains("AB"), "Should contain AB");
1143 assert!(example_section.contains("Aa"), "Should contain Aa");
1144 assert!(example_section.contains("Ab"), "Should contain Ab");
1145 }
1146
1147 #[test]
1148 fn test_comprehensive_case_sorting() {
1149 let mut helper = ImportHelper::new();
1150
1151 helper.add_import_string("from test import zz, ZZ, bb, BB, aa, AA, cc, CC");
1153
1154 let imports = helper.get_formatted();
1155
1156 let import_str = imports.join("\n");
1157 let test_section = &import_str[import_str.find("from test").unwrap()..];
1158
1159 assert!(test_section.contains("AA"), "Should contain AA");
1161 assert!(test_section.contains("BB"), "Should contain BB");
1162 assert!(test_section.contains("CC"), "Should contain CC");
1163 assert!(test_section.contains("ZZ"), "Should contain ZZ");
1164 assert!(test_section.contains("aa"), "Should contain aa");
1165 assert!(test_section.contains("bb"), "Should contain bb");
1166 assert!(test_section.contains("cc"), "Should contain cc");
1167 assert!(test_section.contains("zz"), "Should contain zz");
1168 }
1169
1170 #[test]
1171 fn test_type_checking_imports() {
1172 let mut helper = ImportHelper::with_package_name("mypackage".to_string());
1173
1174 helper.add_import_string("from __future__ import annotations");
1176 helper.add_import_string("from typing import Any");
1177 helper.add_import_string("from pydantic import BaseModel");
1178
1179 helper.add_type_checking_import("import httpx");
1181 helper.add_type_checking_import("from typing import TYPE_CHECKING");
1182 helper.add_type_checking_import("from collections.abc import Callable");
1183 helper.add_type_checking_import("from mypackage.models import User");
1184
1185 assert_eq!(helper.count(), 3);
1187 assert_eq!(helper.count_type_checking(), 4);
1188 assert!(!helper.is_type_checking_empty());
1189
1190 let (future, stdlib, third_party, local) = helper.get_type_checking_categorized();
1192
1193 assert!(
1195 third_party.iter().any(|s| s.contains("import httpx")),
1196 "Should have httpx in third_party"
1197 );
1198 assert!(
1199 stdlib
1200 .iter()
1201 .any(|s| s.contains("from typing import TYPE_CHECKING")),
1202 "Should have TYPE_CHECKING"
1203 );
1204 assert!(
1205 stdlib
1206 .iter()
1207 .any(|s| s.contains("from collections.abc import Callable")),
1208 "Should have Callable"
1209 );
1210 assert!(
1211 local
1212 .iter()
1213 .any(|s| s.contains("from mypackage.models import User")),
1214 "Should have User in local"
1215 );
1216 assert!(
1217 future.is_empty(),
1218 "Should have no future imports in TYPE_CHECKING"
1219 );
1220 }
1221
1222 #[test]
1223 fn test_programmatic_import_builders() {
1224 let mut helper = ImportHelper::new();
1225
1226 helper.add_from_import("typing", &["Any", "Optional"]);
1228 helper.add_from_import("json", &["loads"]);
1229
1230 helper.add_direct_import("sys");
1232
1233 helper.add_type_checking_from_import("httpx", &["Client", "Response"]);
1235 helper.add_type_checking_direct_import("logging");
1236
1237 let (_, stdlib, _, _) = helper.get_categorized();
1239 assert!(
1240 stdlib
1241 .iter()
1242 .any(|s| s.contains("Any") && s.contains("Optional")),
1243 "Should have Any and Optional in typing imports"
1244 );
1245 assert!(stdlib.iter().any(|s| s.contains("from json import loads")));
1246 assert!(stdlib.iter().any(|s| s.contains("import sys")));
1247
1248 let (_, tc_stdlib, tc_third_party, _) = helper.get_type_checking_categorized();
1250 assert!(tc_third_party
1251 .iter()
1252 .any(|s| s.contains("from httpx import Client, Response")));
1253 assert!(tc_stdlib.iter().any(|s| s.contains("import logging")));
1254 }
1255
1256 #[test]
1257 fn test_type_checking_four_categories() {
1258 let mut helper = ImportHelper::with_package_name("myapp".to_string());
1259
1260 helper.add_type_checking_import("from __future__ import annotations");
1262 helper.add_type_checking_import("from typing import Protocol");
1263 helper.add_type_checking_import("from httpx import Client");
1264 helper.add_type_checking_import("from myapp.models import User");
1265
1266 let (future, stdlib, third_party, local) = helper.get_type_checking_categorized();
1267
1268 assert_eq!(future.len(), 1, "Should have 1 future import");
1270 assert_eq!(stdlib.len(), 1, "Should have 1 stdlib import");
1271 assert_eq!(third_party.len(), 1, "Should have 1 third-party import");
1272 assert_eq!(local.len(), 1, "Should have 1 local import");
1273
1274 assert!(future[0].contains("from __future__ import annotations"));
1275 assert!(stdlib[0].contains("from typing import Protocol"));
1276 assert!(third_party[0].contains("from httpx import Client"));
1277 assert!(local[0].contains("from myapp.models import User"));
1278 }
1279
1280 #[test]
1281 fn test_clear_preserves_configuration() {
1282 let mut helper = ImportHelper::with_package_name("mypackage".to_string());
1283 let mut config = FormattingConfig::black_compatible();
1284 config.force_multiline = true;
1285 helper.set_formatting_config(config.clone());
1286 helper.add_local_package_prefix("other_package");
1287
1288 helper.add_import_string("from typing import Any");
1290 helper.add_type_checking_import("from httpx import Client");
1291 assert_eq!(helper.count(), 1);
1292 assert_eq!(helper.count_type_checking(), 1);
1293
1294 helper.clear();
1296
1297 assert!(helper.is_empty());
1299 assert!(helper.is_type_checking_empty());
1300 assert_eq!(helper.count(), 0);
1301 assert_eq!(helper.count_type_checking(), 0);
1302
1303 assert_eq!(helper.package_name.as_deref(), Some("mypackage"));
1305 assert!(helper.local_package_prefixes.contains("mypackage"));
1306 assert!(helper.local_package_prefixes.contains("other_package"));
1307 assert_eq!(helper.formatting_config().line_length, 88);
1308 assert!(helper.formatting_config().force_multiline);
1309 }
1310
1311 #[test]
1312 fn test_reset_clears_everything() {
1313 let mut helper = ImportHelper::with_package_name("mypackage".to_string());
1314 let mut config = FormattingConfig::black_compatible();
1315 config.force_multiline = true;
1316 helper.set_formatting_config(config);
1317 helper.add_local_package_prefix("other_package");
1318
1319 helper.add_import_string("from typing import Any");
1321 helper.add_type_checking_import("from httpx import Client");
1322 assert_eq!(helper.count(), 1);
1323 assert_eq!(helper.count_type_checking(), 1);
1324
1325 helper.reset();
1327
1328 assert!(helper.is_empty());
1330 assert!(helper.is_type_checking_empty());
1331 assert_eq!(helper.count(), 0);
1332 assert_eq!(helper.count_type_checking(), 0);
1333
1334 assert_eq!(helper.package_name, None);
1336 assert!(helper.local_package_prefixes.is_empty());
1337 assert_eq!(helper.formatting_config().line_length, 79); assert!(!helper.formatting_config().force_multiline);
1339 }
1340
1341 #[test]
1342 fn test_clear_and_reuse() {
1343 let mut helper = ImportHelper::with_package_name("mypackage".to_string());
1344
1345 helper.add_import_string("from typing import Any");
1347 helper.add_import_string("from pydantic import BaseModel");
1348 assert_eq!(helper.count(), 2);
1349
1350 helper.clear();
1352 assert!(helper.is_empty());
1353
1354 helper.add_import_string("import json");
1356 helper.add_import_string("import sys");
1357 let (_, stdlib, _, _) = helper.get_categorized();
1358 assert_eq!(stdlib.len(), 2);
1359 assert!(stdlib.iter().any(|s| s.contains("import json")));
1360 assert!(stdlib.iter().any(|s| s.contains("import sys")));
1361 }
1362}