py_import_helper/
core.rs

1//! Core import helper functionality
2//!
3//! This module contains the main `ImportHelper` struct and its implementation,
4//! providing the primary API for collecting, categorizing, and formatting Python
5//! imports according to PEP 8 and common Python formatting standards.
6
7use std::collections::{HashMap, HashSet};
8
9use crate::registry::PackageRegistry;
10use crate::types::{AllCategorizedImports, CategorizedImports, ImportSpec};
11use crate::{ImportCategory, ImportSections, ImportStatement, ImportType};
12
13/// Main helper for managing Python imports across the codebase
14#[derive(Debug)]
15pub struct ImportHelper {
16    /// Collected imports organized by category
17    sections: ImportSections,
18    /// Cache for import categorization
19    category_cache: HashMap<String, ImportCategory>,
20    /// The package name for identifying local imports
21    package_name: Option<String>,
22    /// Custom local package prefixes to recognize
23    local_package_prefixes: HashSet<String>,
24    /// Package registry for stdlib and third-party recognition
25    registry: PackageRegistry,
26}
27
28impl ImportHelper {
29    /// Create a new import helper instance
30    #[must_use]
31    pub fn new() -> Self {
32        Self {
33            sections: ImportSections::default(),
34            category_cache: HashMap::new(),
35            package_name: None,
36            local_package_prefixes: HashSet::new(),
37            registry: PackageRegistry::new(),
38        }
39    }
40
41    /// Create a new import helper instance with package name for local import detection
42    #[must_use]
43    pub fn with_package_name(package_name: String) -> Self {
44        let mut helper = Self::new();
45        helper.package_name = Some(package_name.clone());
46        helper.local_package_prefixes.insert(package_name);
47        helper
48    }
49
50    /// Get immutable reference to the package registry
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// use py_import_helper::ImportHelper;
56    ///
57    /// let helper = ImportHelper::new();
58    /// assert!(helper.registry().is_stdlib("typing"));
59    /// ```
60    #[must_use]
61    pub fn registry(&self) -> &PackageRegistry {
62        &self.registry
63    }
64
65    /// Get mutable reference to the package registry
66    ///
67    /// Use this to customize which packages are recognized as stdlib or third-party
68    /// before generating imports.
69    ///
70    /// # Examples
71    ///
72    /// ```
73    /// use py_import_helper::ImportHelper;
74    ///
75    /// let mut helper = ImportHelper::new();
76    /// helper.registry_mut()
77    ///     .add_stdlib_package("my_custom_stdlib")
78    ///     .add_third_party_package("my_company_lib");
79    ///
80    /// helper.add_import_string("import my_custom_stdlib");
81    /// let (_future, stdlib, _third, _local) = helper.get_categorized();
82    /// assert!(stdlib.iter().any(|s| s.contains("my_custom_stdlib")));
83    /// ```
84    pub fn registry_mut(&mut self) -> &mut PackageRegistry {
85        // Clear cache when registry is modified
86        &mut self.registry
87    }
88
89    /// Clear the categorization cache
90    ///
91    /// Call this after modifying the registry to ensure changes take effect.
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// use py_import_helper::ImportHelper;
97    ///
98    /// let mut helper = ImportHelper::new();
99    /// helper.add_import_string("import mypackage");
100    /// helper.registry_mut().add_stdlib_package("mypackage");
101    /// helper.clear_cache();  // Force re-categorization
102    /// ```
103    pub fn clear_cache(&mut self) -> &mut Self {
104        self.category_cache.clear();
105        self
106    }
107
108    /// Add a custom local package prefix to the recognition list
109    pub fn add_local_package_prefix(&mut self, prefix: impl Into<String>) -> &mut Self {
110        let prefix = prefix.into();
111        self.local_package_prefixes.insert(prefix);
112        // Don't add to cache - these are prefixes, not exact matches
113        self
114    }
115
116    /// Add multiple local package prefixes at once
117    pub fn add_local_package_prefixes(&mut self, prefixes: &[impl AsRef<str>]) -> &mut Self {
118        for prefix in prefixes {
119            self.add_local_package_prefix(prefix.as_ref());
120        }
121        self
122    }
123
124    /// Add an import using structured `ImportSpec`
125    pub fn add_import(&mut self, spec: &ImportSpec) {
126        let import_statement = if let Some(items) = &spec.items {
127            format!("from {} import {}", spec.package, items.join(", "))
128        } else {
129            format!("import {}", spec.package)
130        };
131
132        if spec.type_checking {
133            self.add_type_checking_import(&import_statement);
134        } else {
135            self.add_regular_import(&import_statement);
136        }
137    }
138
139    /// Convenience method to add import from string (for backward compatibility)
140    pub fn add_import_string(&mut self, import_statement: &str) {
141        self.add_regular_import(import_statement);
142    }
143
144    /// Add an import statement using string (internal method)
145    fn add_regular_import(&mut self, import_statement: &str) {
146        if let Some(import) = self.parse_import(import_statement) {
147            match (&import.category, &import.import_type) {
148                (ImportCategory::Future, _) => self.sections.future.push(import),
149                (ImportCategory::StandardLibrary, ImportType::Direct) => {
150                    self.sections.standard_library_direct.push(import)
151                }
152                (ImportCategory::StandardLibrary, ImportType::From) => {
153                    self.sections.standard_library_from.push(import)
154                }
155                (ImportCategory::ThirdParty, ImportType::Direct) => {
156                    self.sections.third_party_direct.push(import)
157                }
158                (ImportCategory::ThirdParty, ImportType::From) => {
159                    self.sections.third_party_from.push(import)
160                }
161                (ImportCategory::Local, ImportType::Direct) => {
162                    self.sections.local_direct.push(import)
163                }
164                (ImportCategory::Local, ImportType::From) => self.sections.local_from.push(import),
165            }
166        }
167    }
168
169    /// Add a from import statement programmatically
170    /// Example: `add_from_import("typing", &["Any", "Optional"])`
171    pub fn add_from_import(&mut self, package: &str, items: &[&str]) {
172        let import_statement = if items.len() == 1 {
173            format!("from {} import {}", package, items[0])
174        } else {
175            format!("from {} import {}", package, items.join(", "))
176        };
177        self.add_regular_import(&import_statement);
178    }
179
180    /// Add a from import statement to `TYPE_CHECKING` block programmatically
181    /// Example: `add_type_checking_from_import("httpx", &["Client", "Response"])`
182    pub fn add_type_checking_from_import(&mut self, package: &str, items: &[&str]) {
183        let import_statement = if items.len() == 1 {
184            format!("from {} import {}", package, items[0])
185        } else {
186            format!("from {} import {}", package, items.join(", "))
187        };
188        self.add_type_checking_import(&import_statement);
189    }
190
191    /// Add a direct import statement programmatically
192    /// Example: `add_direct_import("json`")
193    pub fn add_direct_import(&mut self, module: &str) {
194        let import_statement = format!("import {module}");
195        self.add_regular_import(&import_statement);
196    }
197
198    /// Add a direct import statement to `TYPE_CHECKING` block programmatically
199    /// Example: `add_type_checking_direct_import("httpx`")
200    pub fn add_type_checking_direct_import(&mut self, module: &str) {
201        let import_statement = format!("import {module}");
202        self.add_type_checking_import(&import_statement);
203    }
204
205    /// Add an import statement to the `TYPE_CHECKING` block
206    pub fn add_type_checking_import(&mut self, import_statement: &str) {
207        if let Some(import) = self.parse_import(import_statement) {
208            match (&import.category, &import.import_type) {
209                (ImportCategory::Future, _) => self.sections.type_checking_future.push(import),
210                (ImportCategory::StandardLibrary, ImportType::Direct) => self
211                    .sections
212                    .type_checking_standard_library_direct
213                    .push(import),
214                (ImportCategory::StandardLibrary, ImportType::From) => self
215                    .sections
216                    .type_checking_standard_library_from
217                    .push(import),
218                (ImportCategory::ThirdParty, ImportType::Direct) => {
219                    self.sections.type_checking_third_party_direct.push(import);
220                }
221                (ImportCategory::ThirdParty, ImportType::From) => {
222                    self.sections.type_checking_third_party_from.push(import);
223                }
224                (ImportCategory::Local, ImportType::Direct) => {
225                    self.sections.type_checking_local_direct.push(import);
226                }
227                (ImportCategory::Local, ImportType::From) => {
228                    self.sections.type_checking_local_from.push(import);
229                }
230            }
231
232            // Automatically add TYPE_CHECKING to typing import when we have type checking imports
233            self.ensure_type_checking_import_added();
234        }
235    }
236
237    /// Generate all imports (regular + `TYPE_CHECKING`) for templates
238    /// Returns a tuple with 8 vectors:
239    /// (future, stdlib, `third_party`, local, `tc_future`, `tc_stdlib`, `tc_third_party`, `tc_local`)
240    #[must_use]
241    pub fn get_all_categorized(&self) -> AllCategorizedImports {
242        // Get regular imports (now includes future)
243        let (future_imports, stdlib_imports, third_party_imports, local_imports) =
244            self.get_categorized();
245
246        // Get TYPE_CHECKING imports
247        let (tc_future, tc_stdlib, tc_third_party, tc_local) = self.get_type_checking_categorized();
248
249        (
250            future_imports,
251            stdlib_imports,
252            third_party_imports,
253            local_imports,
254            tc_future,
255            tc_stdlib,
256            tc_third_party,
257            tc_local,
258        )
259    }
260
261    /// Generate categorized `TYPE_CHECKING` imports for templates
262    /// Returns (`future_imports`, `stdlib_imports`, `third_party_imports`, `local_imports`)
263    /// Get `TYPE_CHECKING` imports categorized by type
264    ///
265    /// Returns a tuple of (`future_imports`, `stdlib_imports`, `third_party_imports`, `local_imports`)
266    /// for imports that should go in the `TYPE_CHECKING` block.
267    #[must_use]
268    pub fn get_type_checking_categorized(&self) -> CategorizedImports {
269        self.get_type_checking_categorized_impl()
270    }
271
272    #[must_use]
273    pub fn get_type_checking_categorized_impl(&self) -> CategorizedImports {
274        let mut future_imports = Vec::new();
275        let mut stdlib_imports = Vec::new();
276        let mut third_party_imports = Vec::new();
277        let mut local_imports = Vec::new();
278
279        // Future imports
280        if !self.sections.type_checking_future.is_empty() {
281            let future = self.format_imports(&self.sections.type_checking_future);
282            future_imports.extend(future);
283        }
284
285        // Standard library imports - direct first, then from
286        if !self
287            .sections
288            .type_checking_standard_library_direct
289            .is_empty()
290        {
291            let std_direct =
292                self.format_imports(&self.sections.type_checking_standard_library_direct);
293            stdlib_imports.extend(std_direct);
294        }
295        if !self.sections.type_checking_standard_library_from.is_empty() {
296            let std_from = self.format_imports(&self.sections.type_checking_standard_library_from);
297            stdlib_imports.extend(std_from);
298        }
299
300        // Third-party imports - direct first, then from
301        if !self.sections.type_checking_third_party_direct.is_empty() {
302            let third_direct = self.format_imports(&self.sections.type_checking_third_party_direct);
303            third_party_imports.extend(third_direct);
304        }
305        if !self.sections.type_checking_third_party_from.is_empty() {
306            let third_from = self.format_imports(&self.sections.type_checking_third_party_from);
307            third_party_imports.extend(third_from);
308        }
309
310        // Local imports - direct first, then from
311        if !self.sections.type_checking_local_direct.is_empty() {
312            let local_direct = self.format_imports(&self.sections.type_checking_local_direct);
313            local_imports.extend(local_direct);
314        }
315        if !self.sections.type_checking_local_from.is_empty() {
316            let local_from = self.format_imports(&self.sections.type_checking_local_from);
317            local_imports.extend(local_from);
318        }
319
320        // Sort each category alphabetically
321        future_imports.sort();
322        stdlib_imports.sort();
323        third_party_imports.sort();
324        local_imports.sort();
325
326        (
327            future_imports,
328            stdlib_imports,
329            third_party_imports,
330            local_imports,
331        )
332    }
333
334    /// Get the collected imports as categorized tuples
335    /// Returns (`future_imports`, `stdlib_imports`, `third_party_imports`, `local_imports`)
336    #[must_use]
337    pub fn get_categorized(&self) -> CategorizedImports {
338        let mut future_imports = Vec::new();
339        let mut stdlib_imports = Vec::new();
340        let mut third_party_imports = Vec::new();
341        let mut local_imports = Vec::new();
342
343        // Future imports
344        if !self.sections.future.is_empty() {
345            let future = self.format_imports(&self.sections.future);
346            future_imports.extend(future);
347        }
348
349        // Standard library imports - direct first, then from
350        if !self.sections.standard_library_direct.is_empty() {
351            let std_direct_imports = self.format_imports(&self.sections.standard_library_direct);
352            stdlib_imports.extend(std_direct_imports);
353        }
354        if !self.sections.standard_library_from.is_empty() {
355            let std_from_imports = self.format_imports(&self.sections.standard_library_from);
356            stdlib_imports.extend(std_from_imports);
357        }
358
359        // Third-party imports - direct first, then from
360        if !self.sections.third_party_direct.is_empty() {
361            let third_direct_imports = self.format_imports(&self.sections.third_party_direct);
362            third_party_imports.extend(third_direct_imports);
363        }
364        if !self.sections.third_party_from.is_empty() {
365            let third_from_imports = self.format_imports(&self.sections.third_party_from);
366            third_party_imports.extend(third_from_imports);
367        }
368
369        // Local imports - direct first, then from
370        if !self.sections.local_direct.is_empty() {
371            let local_direct_imports = self.format_imports(&self.sections.local_direct);
372            local_imports.extend(local_direct_imports);
373        }
374        if !self.sections.local_from.is_empty() {
375            let local_from_imports = self.format_imports(&self.sections.local_from);
376            local_imports.extend(local_from_imports);
377        }
378
379        // Sort each category alphabetically
380        future_imports.sort();
381        stdlib_imports.sort();
382        third_party_imports.sort();
383        local_imports.sort();
384
385        (
386            future_imports,
387            stdlib_imports,
388            third_party_imports,
389            local_imports,
390        )
391    }
392
393    /// Reset the import sections while preserving configuration
394    /// Useful when reusing the same helper for multiple files
395    pub fn reset(&mut self) -> &mut Self {
396        self.sections = ImportSections::default();
397        self
398    }
399
400    /// Check if any imports have been collected (excluding `TYPE_CHECKING` imports)
401    #[must_use]
402    pub fn is_empty(&self) -> bool {
403        self.sections.future.is_empty()
404            && self.sections.standard_library_direct.is_empty()
405            && self.sections.standard_library_from.is_empty()
406            && self.sections.third_party_direct.is_empty()
407            && self.sections.third_party_from.is_empty()
408            && self.sections.local_direct.is_empty()
409            && self.sections.local_from.is_empty()
410    }
411
412    /// Check if any `TYPE_CHECKING` imports have been collected
413    #[must_use]
414    pub fn is_type_checking_empty(&self) -> bool {
415        self.sections.type_checking_future.is_empty()
416            && self
417                .sections
418                .type_checking_standard_library_direct
419                .is_empty()
420            && self.sections.type_checking_standard_library_from.is_empty()
421            && self.sections.type_checking_third_party_direct.is_empty()
422            && self.sections.type_checking_third_party_from.is_empty()
423            && self.sections.type_checking_local_direct.is_empty()
424            && self.sections.type_checking_local_from.is_empty()
425    }
426
427    /// Count total number of import statements collected (excluding `TYPE_CHECKING` imports)
428    #[must_use]
429    pub fn count(&self) -> usize {
430        self.sections.future.len()
431            + self.sections.standard_library_direct.len()
432            + self.sections.standard_library_from.len()
433            + self.sections.third_party_direct.len()
434            + self.sections.third_party_from.len()
435            + self.sections.local_direct.len()
436            + self.sections.local_from.len()
437    }
438
439    /// Count total number of `TYPE_CHECKING` import statements collected
440    #[must_use]
441    pub fn count_type_checking(&self) -> usize {
442        self.sections.type_checking_future.len()
443            + self.sections.type_checking_standard_library_direct.len()
444            + self.sections.type_checking_standard_library_from.len()
445            + self.sections.type_checking_third_party_direct.len()
446            + self.sections.type_checking_third_party_from.len()
447            + self.sections.type_checking_local_direct.len()
448            + self.sections.type_checking_local_from.len()
449    }
450
451    /// Generate sorted and formatted import statements
452    #[must_use]
453    pub fn get_formatted(&self) -> Vec<String> {
454        let mut result = Vec::new();
455        let mut has_previous_section = false;
456
457        // Future imports
458        if !self.sections.future.is_empty() {
459            let future_imports = self.format_imports(&self.sections.future);
460            result.extend(future_imports);
461            has_previous_section = true;
462        }
463
464        // Standard library imports - direct first, then from
465        let std_has_direct = !self.sections.standard_library_direct.is_empty();
466        let std_has_from = !self.sections.standard_library_from.is_empty();
467
468        if std_has_direct || std_has_from {
469            if has_previous_section {
470                result.push(String::new()); // Empty line between sections
471            }
472
473            // Direct imports first
474            if std_has_direct {
475                let std_direct_imports =
476                    self.format_imports(&self.sections.standard_library_direct);
477                result.extend(std_direct_imports);
478            }
479
480            // From imports after direct imports
481            if std_has_from {
482                let std_from_imports = self.format_imports(&self.sections.standard_library_from);
483                result.extend(std_from_imports);
484            }
485
486            has_previous_section = true;
487        }
488
489        // Third-party imports - direct first, then from
490        let third_has_direct = !self.sections.third_party_direct.is_empty();
491        let third_has_from = !self.sections.third_party_from.is_empty();
492
493        if third_has_direct || third_has_from {
494            if has_previous_section {
495                result.push(String::new()); // Empty line between sections
496            }
497
498            // Direct imports first
499            if third_has_direct {
500                let third_direct_imports = self.format_imports(&self.sections.third_party_direct);
501                result.extend(third_direct_imports);
502            }
503
504            // From imports after direct imports
505            if third_has_from {
506                let third_from_imports = self.format_imports(&self.sections.third_party_from);
507                result.extend(third_from_imports);
508            }
509
510            has_previous_section = true;
511        }
512
513        // Local imports - direct first, then from
514        let local_has_direct = !self.sections.local_direct.is_empty();
515        let local_has_from = !self.sections.local_from.is_empty();
516
517        if local_has_direct || local_has_from {
518            if has_previous_section {
519                result.push(String::new()); // Empty line between sections
520            }
521
522            // Direct imports first
523            if local_has_direct {
524                let local_direct_imports = self.format_imports(&self.sections.local_direct);
525                result.extend(local_direct_imports);
526            }
527
528            // From imports after direct imports
529            if local_has_from {
530                let local_from_imports = self.format_imports(&self.sections.local_from);
531                result.extend(local_from_imports);
532            }
533        }
534
535        result
536    }
537
538    /// Parse an import statement and categorize it
539    fn parse_import(&mut self, import_statement: &str) -> Option<ImportStatement> {
540        let trimmed = import_statement.trim();
541        if trimmed.is_empty() {
542            return None;
543        }
544
545        let category = self.categorize_import(trimmed);
546        let import_type = if trimmed.starts_with("from ") {
547            ImportType::From
548        } else {
549            ImportType::Direct
550        };
551        let package = Self::extract_package(trimmed);
552        let items = Self::extract_items(trimmed);
553        let is_multiline = trimmed.contains('(') || trimmed.contains(')');
554
555        // Reconstruct the statement with sorted items for from imports
556        let statement = if import_type == ImportType::From && !items.is_empty() {
557            format!("from {} import {}", package, items.join(", "))
558        } else {
559            trimmed.to_string()
560        };
561
562        Some(ImportStatement {
563            statement,
564            category,
565            import_type,
566            package,
567            items,
568            is_multiline,
569        })
570    }
571
572    /// Categorize an import statement
573    fn categorize_import(&mut self, import_statement: &str) -> ImportCategory {
574        if import_statement.starts_with("from __future__") {
575            return ImportCategory::Future;
576        }
577
578        let package = Self::extract_package(import_statement);
579
580        // Check cache first
581        if let Some(&cached_category) = self.category_cache.get(&package) {
582            return cached_category;
583        }
584
585        // Determine category with priority order:
586        // 1. Local imports (relative or matching local prefixes)
587        // 2. Standard library (built-in or custom registered)
588        // 3. Third-party (custom registered or default)
589        let category = if self.is_local_import(import_statement) {
590            ImportCategory::Local
591        } else if self.is_standard_library_package(&package) {
592            ImportCategory::StandardLibrary
593        } else if self.is_common_third_party_package(&package) {
594            ImportCategory::ThirdParty
595        } else {
596            // Default to third-party for unknown packages
597            ImportCategory::ThirdParty
598        };
599
600        self.category_cache.insert(package, category);
601        category
602    }
603
604    /// Extract the package name from an import statement
605    fn extract_package(import_statement: &str) -> String {
606        if let Some(from_part) = import_statement.strip_prefix("from ") {
607            if let Some(import_pos) = from_part.find(" import ") {
608                return from_part[..import_pos].trim().to_string();
609            }
610        } else if let Some(import_part) = import_statement.strip_prefix("import ") {
611            // For direct imports, return the full module path
612            return import_part
613                .split_whitespace()
614                .next()
615                .unwrap_or(import_part)
616                .trim()
617                .to_string();
618        }
619
620        import_statement.to_string()
621    }
622
623    /// Extract imported items from an import statement
624    fn extract_items(import_statement: &str) -> Vec<String> {
625        if let Some(from_part) = import_statement.strip_prefix("from ") {
626            if let Some(import_pos) = from_part.find(" import ") {
627                let items_part = &from_part[import_pos + 8..];
628                let cleaned = items_part.replace(['(', ')'], "").replace(',', " ");
629                let mut items: Vec<String> = cleaned
630                    .split_whitespace()
631                    .map(|s| s.trim().to_string())
632                    .filter(|s| !s.is_empty())
633                    .collect();
634
635                // Sort items with ALL_CAPS first, then mixed case alphabetically
636                items.sort_by(|a, b| Self::custom_import_sort(a, b));
637                return items;
638            }
639        } else if let Some(import_part) = import_statement.strip_prefix("import ") {
640            // For direct imports, the "item" is the module itself
641            return vec![import_part.trim().to_string()];
642        }
643        Vec::new()
644    }
645
646    /// Check if this is a local/relative import
647    fn is_local_import(&self, import_statement: &str) -> bool {
648        // Check for relative imports
649        if import_statement.contains("from .")
650            || import_statement.contains("from ..")
651            || import_statement.contains("from ...")
652            || import_statement.contains("from ....")
653        {
654            return true;
655        }
656
657        let package = Self::extract_package(import_statement);
658
659        // Check custom local package prefixes first
660        for prefix in &self.local_package_prefixes {
661            if package.starts_with(prefix.as_str()) {
662                return true;
663            }
664        }
665
666        // Fallback to package_name check for backwards compatibility
667        if let Some(pkg_name) = &self.package_name {
668            if package.starts_with(pkg_name) {
669                return true;
670            }
671        }
672
673        false
674    }
675
676    /// Check if a package is part of Python's standard library
677    fn is_standard_library_package(&self, package: &str) -> bool {
678        // Check against the constant list of standard library modules
679        self.registry.is_stdlib(package)
680    }
681
682    /// Check if a package is a common third-party package
683    fn is_common_third_party_package(&self, package: &str) -> bool {
684        // Check against the constant list of common third-party packages
685        self.registry.is_third_party(package)
686    }
687
688    /// Format a list of imports, merging same-package imports where appropriate
689    #[allow(clippy::unused_self)]
690    fn format_imports(&self, imports: &[ImportStatement]) -> Vec<String> {
691        let mut package_imports: HashMap<String, Vec<&ImportStatement>> = HashMap::new();
692
693        // Group imports by package
694        for import in imports {
695            package_imports
696                .entry(import.package.clone())
697                .or_default()
698                .push(import);
699        }
700
701        let mut result = Vec::new();
702        let mut packages: Vec<_> = package_imports.keys().collect();
703        packages.sort();
704
705        for package in packages {
706            let imports_for_package = package_imports.get(package).unwrap();
707
708            if imports_for_package.len() == 1 {
709                // Single import, use as-is
710                result.push(imports_for_package[0].statement.clone());
711            } else {
712                // Multiple imports from same package, merge if possible
713                result.extend(Self::merge_package_imports(imports_for_package));
714            }
715        }
716
717        result
718    }
719
720    /// Merge multiple imports from the same package
721    fn merge_package_imports(imports: &[&ImportStatement]) -> Vec<String> {
722        let mut all_items = HashSet::new();
723        let package = &imports[0].package;
724
725        // Collect all items being imported from this package
726        for import in imports {
727            all_items.extend(import.items.iter().cloned());
728        }
729
730        if all_items.is_empty() {
731            // Simple "import package" statements
732            return imports.iter().map(|i| i.statement.clone()).collect();
733        }
734
735        let mut sorted_items: Vec<_> = all_items.into_iter().collect();
736        sorted_items.sort_by(|a, b| Self::custom_import_sort(a, b));
737
738        // Format as single line or multi-line based on length
739        if sorted_items.len() <= 3 && sorted_items.iter().map(String::len).sum::<usize>() < 60 {
740            // Single line
741            vec![format!(
742                "from {} import {}",
743                package,
744                sorted_items.join(", ")
745            )]
746        } else {
747            // Multi-line with parentheses
748            let mut result = vec![format!("from {} import (", package)];
749            for item in sorted_items {
750                result.push(format!("    {item},"));
751            }
752            result.push(")".to_string());
753            result
754        }
755    }
756
757    /// Custom sorting for import items: `ALL_CAPS` first (alphabetically), then mixed case (alphabetically)
758    fn custom_import_sort(a: &str, b: &str) -> std::cmp::Ordering {
759        let a_is_all_caps = a.chars().all(|c| c.is_uppercase() || !c.is_alphabetic());
760        let b_is_all_caps = b.chars().all(|c| c.is_uppercase() || !c.is_alphabetic());
761
762        match (a_is_all_caps, b_is_all_caps) {
763            // Both are ALL_CAPS or both are mixed case - sort alphabetically
764            (true, true) | (false, false) => a.cmp(b),
765            // a is ALL_CAPS, b is mixed case - a comes first
766            (true, false) => std::cmp::Ordering::Less,
767            // a is mixed case, b is ALL_CAPS - b comes first
768            (false, true) => std::cmp::Ordering::Greater,
769        }
770    }
771
772    /// Automatically add `TYPE_CHECKING` to typing import when type checking imports are used
773    fn ensure_type_checking_import_added(&mut self) {
774        // Check if we already have a typing import with TYPE_CHECKING
775        let has_type_checking = self.sections.standard_library_from.iter().any(|import| {
776            import.package == "typing" && import.items.contains(&"TYPE_CHECKING".to_string())
777        });
778
779        if !has_type_checking {
780            // Check if we have any typing import that we can modify
781            if let Some(typing_import) = self
782                .sections
783                .standard_library_from
784                .iter_mut()
785                .find(|import| import.package == "typing")
786            {
787                // Add TYPE_CHECKING to existing typing import
788                if !typing_import.items.contains(&"TYPE_CHECKING".to_string()) {
789                    typing_import.items.push("TYPE_CHECKING".to_string());
790                    typing_import
791                        .items
792                        .sort_by(|a, b| Self::custom_import_sort(a, b));
793
794                    // Update the statement string
795                    if typing_import.items.len() == 1 {
796                        typing_import.statement =
797                            format!("from typing import {}", typing_import.items[0]);
798                    } else {
799                        typing_import.statement =
800                            format!("from typing import {}", typing_import.items.join(", "));
801                    }
802                }
803            } else {
804                // No typing import exists, add one with just TYPE_CHECKING
805                self.add_import_string("from typing import TYPE_CHECKING");
806            }
807        }
808    }
809
810    /// Clone configuration without imports (useful for creating multiple helpers with same config)
811    #[must_use]
812    pub fn clone_config(&self) -> Self {
813        Self {
814            sections: ImportSections::default(),
815            category_cache: self.category_cache.clone(),
816            package_name: self.package_name.clone(),
817            local_package_prefixes: self.local_package_prefixes.clone(),
818            registry: self.registry.clone(),
819        }
820    }
821}
822
823/// Convenience functions for common import operations
824impl ImportHelper {
825    /// Create imports for a model file with required type imports
826    pub fn create_model_imports(&mut self, required_types: &[String]) {
827        // Standard model imports
828        self.add_import_string("from pydantic import BaseModel, ConfigDict, Field");
829
830        // Collect all typing imports needed across all types
831        let mut typing_imports = std::collections::HashSet::new();
832        let mut collections_abc_imports = std::collections::HashSet::new();
833        let mut datetime_imports = Vec::new();
834        let mut decimal_imports = Vec::new();
835
836        for type_name in required_types {
837            match type_name.as_str() {
838                "datetime" | "date" | "time" | "timedelta" => {
839                    if !datetime_imports.contains(&type_name.as_str()) {
840                        datetime_imports.push(type_name.as_str());
841                    }
842                }
843                "Decimal" => {
844                    if !decimal_imports.contains(&"Decimal") {
845                        decimal_imports.push("Decimal");
846                    }
847                }
848                "UUID" => {
849                    self.add_import_string("from uuid import UUID");
850                }
851                // For complex types, extract typing imports
852                _ => {
853                    // Check if this type contains typing elements
854                    let extracted_typing = Self::extract_typing_imports_from_type(type_name);
855                    typing_imports.extend(extracted_typing);
856
857                    // Check for collections.abc imports
858                    if type_name.contains("Callable") {
859                        collections_abc_imports.insert("Callable".to_string());
860                    }
861                }
862            }
863        }
864
865        // Add datetime imports if any were found
866        if !datetime_imports.is_empty() {
867            let import_statement = format!("from datetime import {}", datetime_imports.join(", "));
868            self.add_regular_import(&import_statement);
869        }
870
871        // Add decimal imports if any were found
872        if !decimal_imports.is_empty() {
873            self.add_import_string("from decimal import Decimal");
874        }
875
876        // Add typing imports if any were found (only Any, Generic, TypeVar, Protocol)
877        if !typing_imports.is_empty() {
878            let mut sorted_typing: Vec<String> = typing_imports.into_iter().collect();
879            sorted_typing.sort();
880            let import_statement = format!("from typing import {}", sorted_typing.join(", "));
881            self.add_regular_import(&import_statement);
882        }
883
884        // Add collections.abc imports if any were found (e.g., Callable)
885        if !collections_abc_imports.is_empty() {
886            let mut sorted_collections: Vec<String> = collections_abc_imports.into_iter().collect();
887            sorted_collections.sort();
888            let import_statement = format!(
889                "from collections.abc import {}",
890                sorted_collections.join(", ")
891            );
892            self.add_regular_import(&import_statement);
893        }
894    }
895
896    /// Extract typing imports from a complex type string
897    /// This handles types like list[Any], dict[str, Any], etc.
898    /// Only imports what's actually needed for Python 3.13+ (Any, Generic, `TypeVar`, Protocol)
899    fn extract_typing_imports_from_type(type_str: &str) -> std::collections::HashSet<String> {
900        let mut typing_imports = std::collections::HashSet::new();
901
902        // Check for Any type (used in generics and standalone)
903        if type_str.contains("Any") {
904            typing_imports.insert("Any".to_string());
905        }
906
907        // Check for Generic type (used for generic classes)
908        if type_str.contains("Generic") {
909            typing_imports.insert("Generic".to_string());
910        }
911
912        // Check for TypeVar usage
913        if type_str.contains("TypeVar") {
914            typing_imports.insert("TypeVar".to_string());
915        }
916
917        // Check for Protocol type (structural subtyping)
918        if type_str.contains("Protocol") {
919            typing_imports.insert("Protocol".to_string());
920        }
921
922        typing_imports
923    }
924}
925
926impl Default for ImportHelper {
927    fn default() -> Self {
928        Self::new()
929    }
930}
931
932#[cfg(test)]
933mod tests {
934    use super::*;
935
936    #[test]
937    fn test_import_categorization() {
938        let mut helper = ImportHelper::new();
939
940        // Test future imports
941        helper.add_import_string("from __future__ import annotations");
942        assert_eq!(helper.sections.future.len(), 1);
943
944        // Test standard library from import
945        helper.add_import_string("from typing import Optional");
946        assert_eq!(helper.sections.standard_library_from.len(), 1);
947
948        // Test standard library direct import
949        helper.add_import_string("import uuid");
950        assert_eq!(helper.sections.standard_library_direct.len(), 1);
951
952        // Test third party
953        helper.add_import_string("from pydantic import BaseModel");
954        assert_eq!(helper.sections.third_party_from.len(), 1);
955
956        // Test local imports
957        helper.add_import_string("from .models import User");
958        assert_eq!(helper.sections.local_from.len(), 1);
959    }
960
961    #[test]
962    fn test_import_merging() {
963        let mut helper = ImportHelper::new();
964
965        helper.add_import_string("from typing import Optional");
966        helper.add_import_string("from typing import Any");
967        helper.add_import_string("from typing import List");
968
969        let imports = helper.get_formatted();
970
971        // Should merge into a single import
972        assert!(imports.iter().any(|i| i.contains("from typing import")));
973        assert!(imports
974            .iter()
975            .any(|i| i.contains("Any") && i.contains("Optional")));
976    }
977
978    #[test]
979    fn test_alphabetical_sorting_of_import_items() {
980        let mut helper = ImportHelper::new();
981
982        // Test with unsorted items in a from import
983        helper.add_import_string("from typing import Optional, Any, List, Dict");
984        helper.add_import_string("from collections import defaultdict, OrderedDict, Counter");
985        helper.add_import_string("import uuid");
986        helper.add_import_string("import os");
987
988        let imports = helper.get_formatted();
989
990        // Print for debugging
991        for import in &imports {
992            println!("{}", import);
993        }
994
995        // Check that direct imports come before from imports
996        let import_str = imports.join("\n");
997        let os_pos = import_str.find("import os").unwrap();
998        let uuid_pos = import_str.find("import uuid").unwrap();
999        let typing_pos = import_str.find("from typing import").unwrap();
1000
1001        // Direct imports should come before from imports
1002        assert!(
1003            os_pos < typing_pos,
1004            "Direct imports should come before from imports"
1005        );
1006        assert!(
1007            uuid_pos < typing_pos,
1008            "Direct imports should come before from imports"
1009        );
1010
1011        // Check that items within from imports are sorted alphabetically
1012        let typing_import = imports
1013            .iter()
1014            .find(|s| s.contains("from typing import"))
1015            .unwrap();
1016
1017        // Should be: "from typing import Any, Dict, List, Optional"
1018        assert!(
1019            typing_import.contains("Any, Dict, List, Optional")
1020                || typing_import.contains("(\n    Any,\n    Dict,\n    List,\n    Optional,\n)"),
1021            "Import items should be sorted alphabetically, got: {}",
1022            typing_import
1023        );
1024    }
1025
1026    #[test]
1027    fn test_direct_imports_sorted_alphabetically() {
1028        let mut helper = ImportHelper::new();
1029
1030        helper.add_import_string("import uuid");
1031        helper.add_import_string("import os");
1032        helper.add_import_string("import sys");
1033        helper.add_import_string("import json");
1034
1035        let imports = helper.get_formatted();
1036
1037        // Should be sorted: json, os, sys, uuid
1038        let import_lines: Vec<String> = imports
1039            .iter()
1040            .filter(|s| s.starts_with("import "))
1041            .cloned()
1042            .collect();
1043
1044        assert_eq!(import_lines.len(), 4);
1045        assert!(import_lines[0].contains("json"));
1046        assert!(import_lines[1].contains("os"));
1047        assert!(import_lines[2].contains("sys"));
1048        assert!(import_lines[3].contains("uuid"));
1049    }
1050
1051    #[test]
1052    fn test_uppercase_priority_in_import_sorting() {
1053        let mut helper = ImportHelper::new();
1054
1055        // Test with mixed case items - uppercase should come before lowercase for same letter
1056        helper.add_import_string("from example import Ab, AA, Aa, AB");
1057
1058        let imports = helper.get_formatted();
1059
1060        // Print for debugging
1061        for import in &imports {
1062            println!("{}", import);
1063        }
1064
1065        let example_import = imports
1066            .iter()
1067            .find(|s| s.contains("from example import"))
1068            .unwrap();
1069
1070        // Should be: "from example import AA, AB, Aa, Ab"
1071        // (uppercase A's first, then lowercase a's)
1072        assert!(
1073            example_import.contains("AA, AB, Aa, Ab"),
1074            "Import items should be sorted with uppercase priority, got: {}",
1075            example_import
1076        );
1077    }
1078
1079    #[test]
1080    fn test_comprehensive_case_sorting() {
1081        let mut helper = ImportHelper::new();
1082
1083        // Test with multiple letters and mixed cases
1084        helper.add_import_string("from test import zz, ZZ, bb, BB, aa, AA, cc, CC");
1085
1086        let imports = helper.get_formatted();
1087
1088        let test_import = imports
1089            .iter()
1090            .find(|s| s.contains("from test import"))
1091            .unwrap();
1092
1093        // Should be: "from test import AA, BB, CC, ZZ, aa, bb, cc, zz"
1094        // (all uppercase first in alphabetical order, then all lowercase)
1095        assert!(
1096            test_import.contains("AA, BB, CC, ZZ, aa, bb, cc, zz"),
1097            "Import items should be sorted with uppercase priority across all letters, got: {}",
1098            test_import
1099        );
1100    }
1101
1102    #[test]
1103    fn test_type_checking_imports() {
1104        let mut helper = ImportHelper::with_package_name("mypackage".to_string());
1105
1106        // Add regular imports
1107        helper.add_import_string("from __future__ import annotations");
1108        helper.add_import_string("from typing import Any");
1109        helper.add_import_string("from pydantic import BaseModel");
1110
1111        // Add TYPE_CHECKING imports
1112        helper.add_type_checking_import("import httpx");
1113        helper.add_type_checking_import("from typing import TYPE_CHECKING");
1114        helper.add_type_checking_import("from collections.abc import Callable");
1115        helper.add_type_checking_import("from mypackage.models import User");
1116
1117        // Check that regular imports are counted correctly
1118        assert_eq!(helper.count(), 3);
1119        assert_eq!(helper.count_type_checking(), 4);
1120        assert!(!helper.is_type_checking_empty());
1121
1122        // Generate TYPE_CHECKING imports
1123        let (future, stdlib, third_party, local) = helper.get_type_checking_categorized();
1124
1125        // Verify categorization
1126        assert!(
1127            third_party.iter().any(|s| s.contains("import httpx")),
1128            "Should have httpx in third_party"
1129        );
1130        assert!(
1131            stdlib
1132                .iter()
1133                .any(|s| s.contains("from typing import TYPE_CHECKING")),
1134            "Should have TYPE_CHECKING"
1135        );
1136        assert!(
1137            stdlib
1138                .iter()
1139                .any(|s| s.contains("from collections.abc import Callable")),
1140            "Should have Callable"
1141        );
1142        assert!(
1143            local
1144                .iter()
1145                .any(|s| s.contains("from mypackage.models import User")),
1146            "Should have User in local"
1147        );
1148        assert!(
1149            future.is_empty(),
1150            "Should have no future imports in TYPE_CHECKING"
1151        );
1152    }
1153
1154    #[test]
1155    fn test_programmatic_import_builders() {
1156        let mut helper = ImportHelper::new();
1157
1158        // Test add_from_import
1159        helper.add_from_import("typing", &["Any", "Optional"]);
1160        helper.add_from_import("json", &["loads"]);
1161
1162        // Test add_direct_import
1163        helper.add_direct_import("sys");
1164
1165        // Test TYPE_CHECKING builders
1166        helper.add_type_checking_from_import("httpx", &["Client", "Response"]);
1167        helper.add_type_checking_direct_import("logging");
1168
1169        // Verify regular imports
1170        let (_, stdlib, _, _) = helper.get_categorized();
1171        assert!(
1172            stdlib
1173                .iter()
1174                .any(|s| s.contains("Any") && s.contains("Optional")),
1175            "Should have Any and Optional in typing imports"
1176        );
1177        assert!(stdlib.iter().any(|s| s.contains("from json import loads")));
1178        assert!(stdlib.iter().any(|s| s.contains("import sys")));
1179
1180        // Verify TYPE_CHECKING imports
1181        let (_, tc_stdlib, tc_third_party, _) = helper.get_type_checking_categorized();
1182        assert!(tc_third_party
1183            .iter()
1184            .any(|s| s.contains("from httpx import Client, Response")));
1185        assert!(tc_stdlib.iter().any(|s| s.contains("import logging")));
1186    }
1187
1188    #[test]
1189    fn test_type_checking_four_categories() {
1190        let mut helper = ImportHelper::with_package_name("myapp".to_string());
1191
1192        // Add imports in all four categories for TYPE_CHECKING
1193        helper.add_type_checking_import("from __future__ import annotations");
1194        helper.add_type_checking_import("from typing import Protocol");
1195        helper.add_type_checking_import("from httpx import Client");
1196        helper.add_type_checking_import("from myapp.models import User");
1197
1198        let (future, stdlib, third_party, local) = helper.get_type_checking_categorized();
1199
1200        // Verify all four categories
1201        assert_eq!(future.len(), 1, "Should have 1 future import");
1202        assert_eq!(stdlib.len(), 1, "Should have 1 stdlib import");
1203        assert_eq!(third_party.len(), 1, "Should have 1 third-party import");
1204        assert_eq!(local.len(), 1, "Should have 1 local import");
1205
1206        assert!(future[0].contains("from __future__ import annotations"));
1207        assert!(stdlib[0].contains("from typing import Protocol"));
1208        assert!(third_party[0].contains("from httpx import Client"));
1209        assert!(local[0].contains("from myapp.models import User"));
1210    }
1211}