py_import_helper/
core.rs

1//! Core import helper functionality
2//!
3//! This module contains the main `ImportHelper` struct and its implementation,
4//! providing the primary API for collecting, categorizing, and formatting Python
5//! imports according to PEP 8 and common Python formatting standards.
6
7use std::collections::{HashMap, HashSet};
8
9use crate::registry::PackageRegistry;
10use crate::types::{AllCategorizedImports, CategorizedImports, FormattingConfig, ImportSpec};
11use crate::{ImportCategory, ImportSections, ImportStatement, ImportType};
12
13/// Main helper for managing Python imports across the codebase
14#[derive(Debug)]
15pub struct ImportHelper {
16    /// Collected imports organized by category
17    sections: ImportSections,
18    /// Cache for import categorization
19    category_cache: HashMap<String, ImportCategory>,
20    /// The package name for identifying local imports
21    package_name: Option<String>,
22    /// Custom local package prefixes to recognize
23    local_package_prefixes: HashSet<String>,
24    /// Package registry for stdlib and third-party recognition
25    registry: PackageRegistry,
26    /// Formatting configuration for isort/ruff compliance
27    formatting_config: FormattingConfig,
28}
29
30impl ImportHelper {
31    /// Create a new import helper instance
32    #[must_use]
33    pub fn new() -> Self {
34        Self {
35            sections: ImportSections::default(),
36            category_cache: HashMap::new(),
37            package_name: None,
38            local_package_prefixes: HashSet::new(),
39            registry: PackageRegistry::new(),
40            formatting_config: FormattingConfig::default(),
41        }
42    }
43
44    /// Create a new import helper instance with package name for local import detection
45    #[must_use]
46    pub fn with_package_name(package_name: String) -> Self {
47        let mut helper = Self::new();
48        helper.package_name = Some(package_name.clone());
49        helper.local_package_prefixes.insert(package_name);
50        helper
51    }
52
53    /// Create a new import helper with custom formatting configuration
54    #[must_use]
55    pub fn with_formatting_config(config: FormattingConfig) -> Self {
56        Self {
57            sections: ImportSections::default(),
58            category_cache: HashMap::new(),
59            package_name: None,
60            local_package_prefixes: HashSet::new(),
61            registry: PackageRegistry::new(),
62            formatting_config: config,
63        }
64    }
65
66    /// Create a new import helper with package name and custom formatting
67    #[must_use]
68    pub fn with_package_and_config(package_name: String, config: FormattingConfig) -> Self {
69        let mut helper = Self::with_formatting_config(config);
70        helper.package_name = Some(package_name.clone());
71        helper.local_package_prefixes.insert(package_name);
72        helper
73    }
74
75    /// Get the current formatting configuration
76    #[must_use]
77    pub fn formatting_config(&self) -> &FormattingConfig {
78        &self.formatting_config
79    }
80
81    /// Set a new formatting configuration
82    pub fn set_formatting_config(&mut self, config: FormattingConfig) {
83        self.formatting_config = config;
84    }
85
86    /// Get immutable reference to the package registry
87    ///
88    /// # Examples
89    ///
90    /// ```
91    /// use py_import_helper::ImportHelper;
92    ///
93    /// let helper = ImportHelper::new();
94    /// assert!(helper.registry().is_stdlib("typing"));
95    /// ```
96    #[must_use]
97    pub fn registry(&self) -> &PackageRegistry {
98        &self.registry
99    }
100
101    /// Get mutable reference to the package registry
102    ///
103    /// Use this to customize which packages are recognized as stdlib or third-party
104    /// before generating imports.
105    ///
106    /// # Examples
107    ///
108    /// ```
109    /// use py_import_helper::ImportHelper;
110    ///
111    /// let mut helper = ImportHelper::new();
112    /// helper.registry_mut()
113    ///     .add_stdlib_package("my_custom_stdlib")
114    ///     .add_third_party_package("my_company_lib");
115    ///
116    /// helper.add_import_string("import my_custom_stdlib");
117    /// let (_future, stdlib, _third, _local) = helper.get_categorized();
118    /// assert!(stdlib.iter().any(|s| s.contains("my_custom_stdlib")));
119    /// ```
120    pub fn registry_mut(&mut self) -> &mut PackageRegistry {
121        // Clear cache when registry is modified
122        &mut self.registry
123    }
124
125    /// Clear the categorization cache
126    ///
127    /// Call this after modifying the registry to ensure changes take effect.
128    ///
129    /// # Examples
130    ///
131    /// ```
132    /// use py_import_helper::ImportHelper;
133    ///
134    /// let mut helper = ImportHelper::new();
135    /// helper.add_import_string("import mypackage");
136    /// helper.registry_mut().add_stdlib_package("mypackage");
137    /// helper.clear_cache();  // Force re-categorization
138    /// ```
139    pub fn clear_cache(&mut self) -> &mut Self {
140        self.category_cache.clear();
141        self
142    }
143
144    /// Add a custom local package prefix to the recognition list
145    pub fn add_local_package_prefix(&mut self, prefix: impl Into<String>) -> &mut Self {
146        let prefix = prefix.into();
147        self.local_package_prefixes.insert(prefix);
148        // Don't add to cache - these are prefixes, not exact matches
149        self
150    }
151
152    /// Add multiple local package prefixes at once
153    pub fn add_local_package_prefixes(&mut self, prefixes: &[impl AsRef<str>]) -> &mut Self {
154        for prefix in prefixes {
155            self.add_local_package_prefix(prefix.as_ref());
156        }
157        self
158    }
159
160    /// Add an import using structured `ImportSpec`
161    pub fn add_import(&mut self, spec: &ImportSpec) {
162        let import_statement = if let Some(items) = &spec.items {
163            format!("from {} import {}", spec.package, items.join(", "))
164        } else {
165            format!("import {}", spec.package)
166        };
167
168        if spec.type_checking {
169            self.add_type_checking_import(&import_statement);
170        } else {
171            self.add_regular_import(&import_statement);
172        }
173    }
174
175    /// Convenience method to add import from string (for backward compatibility)
176    pub fn add_import_string(&mut self, import_statement: &str) {
177        self.add_regular_import(import_statement);
178    }
179
180    /// Add an import statement using string (internal method)
181    fn add_regular_import(&mut self, import_statement: &str) {
182        if let Some(import) = self.parse_import(import_statement) {
183            match (&import.category, &import.import_type) {
184                (ImportCategory::Future, _) => self.sections.future.push(import),
185                (ImportCategory::StandardLibrary, ImportType::Direct) => {
186                    self.sections.standard_library_direct.push(import)
187                }
188                (ImportCategory::StandardLibrary, ImportType::From) => {
189                    self.sections.standard_library_from.push(import)
190                }
191                (ImportCategory::ThirdParty, ImportType::Direct) => {
192                    self.sections.third_party_direct.push(import)
193                }
194                (ImportCategory::ThirdParty, ImportType::From) => {
195                    self.sections.third_party_from.push(import)
196                }
197                (ImportCategory::Local, ImportType::Direct) => {
198                    self.sections.local_direct.push(import)
199                }
200                (ImportCategory::Local, ImportType::From) => self.sections.local_from.push(import),
201            }
202        }
203    }
204
205    /// Add a from import statement programmatically
206    /// Example: `add_from_import("typing", &["Any", "Optional"])`
207    pub fn add_from_import(&mut self, package: &str, items: &[&str]) {
208        let import_statement = if items.len() == 1 {
209            format!("from {} import {}", package, items[0])
210        } else {
211            format!("from {} import {}", package, items.join(", "))
212        };
213        self.add_regular_import(&import_statement);
214    }
215
216    /// Add a multiline from import statement programmatically
217    pub fn add_from_import_multiline(&mut self, package: &str, items: &[&str]) {
218        if items.is_empty() {
219            return;
220        }
221
222        if items.len() == 1 {
223            self.add_from_import(package, items);
224            return;
225        }
226
227        let mut import_statement = format!("from {} import (\n", package);
228        for item in items {
229            import_statement.push_str(&format!("    {},\n", item));
230        }
231        import_statement.push(')');
232
233        self.add_regular_import(&import_statement);
234    }
235
236    /// Add a from import statement to `TYPE_CHECKING` block programmatically
237    /// Example: `add_type_checking_from_import("httpx", &["Client", "Response"])`
238    pub fn add_type_checking_from_import(&mut self, package: &str, items: &[&str]) {
239        let import_statement = if items.len() == 1 {
240            format!("from {} import {}", package, items[0])
241        } else {
242            format!("from {} import {}", package, items.join(", "))
243        };
244        self.add_type_checking_import(&import_statement);
245    }
246
247    /// Add a direct import statement programmatically
248    /// Example: `add_direct_import("json`")
249    pub fn add_direct_import(&mut self, module: &str) {
250        let import_statement = format!("import {module}");
251        self.add_regular_import(&import_statement);
252    }
253
254    /// Add a direct import statement to `TYPE_CHECKING` block programmatically
255    /// Example: `add_type_checking_direct_import("httpx`")
256    pub fn add_type_checking_direct_import(&mut self, module: &str) {
257        let import_statement = format!("import {module}");
258        self.add_type_checking_import(&import_statement);
259    }
260
261    /// Add an import statement to the `TYPE_CHECKING` block
262    pub fn add_type_checking_import(&mut self, import_statement: &str) {
263        if let Some(import) = self.parse_import(import_statement) {
264            match (&import.category, &import.import_type) {
265                (ImportCategory::Future, _) => self.sections.type_checking_future.push(import),
266                (ImportCategory::StandardLibrary, ImportType::Direct) => self
267                    .sections
268                    .type_checking_standard_library_direct
269                    .push(import),
270                (ImportCategory::StandardLibrary, ImportType::From) => self
271                    .sections
272                    .type_checking_standard_library_from
273                    .push(import),
274                (ImportCategory::ThirdParty, ImportType::Direct) => {
275                    self.sections.type_checking_third_party_direct.push(import);
276                }
277                (ImportCategory::ThirdParty, ImportType::From) => {
278                    self.sections.type_checking_third_party_from.push(import);
279                }
280                (ImportCategory::Local, ImportType::Direct) => {
281                    self.sections.type_checking_local_direct.push(import);
282                }
283                (ImportCategory::Local, ImportType::From) => {
284                    self.sections.type_checking_local_from.push(import);
285                }
286            }
287
288            // Automatically add TYPE_CHECKING to typing import when we have type checking imports
289            self.ensure_type_checking_import_added();
290        }
291    }
292
293    /// Generate all imports (regular + `TYPE_CHECKING`) for templates
294    /// Returns a tuple with 8 vectors:
295    /// (future, stdlib, `third_party`, local, `tc_future`, `tc_stdlib`, `tc_third_party`, `tc_local`)
296    #[must_use]
297    pub fn get_all_categorized(&self) -> AllCategorizedImports {
298        // Get regular imports (now includes future)
299        let (future_imports, stdlib_imports, third_party_imports, local_imports) =
300            self.get_categorized();
301
302        // Get TYPE_CHECKING imports
303        let (tc_future, tc_stdlib, tc_third_party, tc_local) = self.get_type_checking_categorized();
304
305        (
306            future_imports,
307            stdlib_imports,
308            third_party_imports,
309            local_imports,
310            tc_future,
311            tc_stdlib,
312            tc_third_party,
313            tc_local,
314        )
315    }
316
317    /// Generate categorized `TYPE_CHECKING` imports for templates
318    /// Returns (`future_imports`, `stdlib_imports`, `third_party_imports`, `local_imports`)
319    /// Get `TYPE_CHECKING` imports categorized by type
320    ///
321    /// Returns a tuple of (`future_imports`, `stdlib_imports`, `third_party_imports`, `local_imports`)
322    /// for imports that should go in the `TYPE_CHECKING` block.
323    #[must_use]
324    pub fn get_type_checking_categorized(&self) -> CategorizedImports {
325        self.get_type_checking_categorized_impl()
326    }
327
328    #[must_use]
329    pub fn get_type_checking_categorized_impl(&self) -> CategorizedImports {
330        let mut future_imports = Vec::new();
331        let mut stdlib_imports = Vec::new();
332        let mut third_party_imports = Vec::new();
333        let mut local_imports = Vec::new();
334
335        // Future imports
336        if !self.sections.type_checking_future.is_empty() {
337            let future = self.format_imports(&self.sections.type_checking_future);
338            future_imports.extend(future);
339        }
340
341        // Standard library imports - direct first, then from
342        if !self
343            .sections
344            .type_checking_standard_library_direct
345            .is_empty()
346        {
347            let std_direct =
348                self.format_imports(&self.sections.type_checking_standard_library_direct);
349            stdlib_imports.extend(std_direct);
350        }
351        if !self.sections.type_checking_standard_library_from.is_empty() {
352            let std_from = self.format_imports(&self.sections.type_checking_standard_library_from);
353            stdlib_imports.extend(std_from);
354        }
355
356        // Third-party imports - direct first, then from
357        if !self.sections.type_checking_third_party_direct.is_empty() {
358            let third_direct = self.format_imports(&self.sections.type_checking_third_party_direct);
359            third_party_imports.extend(third_direct);
360        }
361        if !self.sections.type_checking_third_party_from.is_empty() {
362            let third_from = self.format_imports(&self.sections.type_checking_third_party_from);
363            third_party_imports.extend(third_from);
364        }
365
366        // Local imports - direct first, then from
367        if !self.sections.type_checking_local_direct.is_empty() {
368            let local_direct = self.format_imports(&self.sections.type_checking_local_direct);
369            local_imports.extend(local_direct);
370        }
371        if !self.sections.type_checking_local_from.is_empty() {
372            let local_from = self.format_imports(&self.sections.type_checking_local_from);
373            local_imports.extend(local_from);
374        }
375
376        // Sort each category alphabetically
377        future_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
378        stdlib_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
379        third_party_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
380        local_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
381
382        (
383            future_imports,
384            stdlib_imports,
385            third_party_imports,
386            local_imports,
387        )
388    }
389
390    /// Get the collected imports as categorized tuples
391    /// Returns (`future_imports`, `stdlib_imports`, `third_party_imports`, `local_imports`)
392    #[must_use]
393    pub fn get_categorized(&self) -> CategorizedImports {
394        let mut future_imports = Vec::new();
395        let mut stdlib_imports = Vec::new();
396        let mut third_party_imports = Vec::new();
397        let mut local_imports = Vec::new();
398
399        // Future imports
400        if !self.sections.future.is_empty() {
401            let future = self.format_imports(&self.sections.future);
402            future_imports.extend(future);
403        }
404
405        // Standard library imports - direct first, then from
406        if !self.sections.standard_library_direct.is_empty() {
407            let std_direct_imports = self.format_imports(&self.sections.standard_library_direct);
408            stdlib_imports.extend(std_direct_imports);
409        }
410        if !self.sections.standard_library_from.is_empty() {
411            let std_from_imports = self.format_imports(&self.sections.standard_library_from);
412            stdlib_imports.extend(std_from_imports);
413        }
414
415        // Third-party imports - direct first, then from
416        if !self.sections.third_party_direct.is_empty() {
417            let third_direct_imports = self.format_imports(&self.sections.third_party_direct);
418            third_party_imports.extend(third_direct_imports);
419        }
420        if !self.sections.third_party_from.is_empty() {
421            let third_from_imports = self.format_imports(&self.sections.third_party_from);
422            third_party_imports.extend(third_from_imports);
423        }
424
425        // Local imports - direct first, then from
426        if !self.sections.local_direct.is_empty() {
427            let local_direct_imports = self.format_imports(&self.sections.local_direct);
428            local_imports.extend(local_direct_imports);
429        }
430        if !self.sections.local_from.is_empty() {
431            let local_from_imports = self.format_imports(&self.sections.local_from);
432            local_imports.extend(local_from_imports);
433        }
434
435        // Sort each category alphabetically
436        future_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
437        stdlib_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
438        third_party_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
439        local_imports.sort_by(|a, b| Self::sort_import_statements(a, b));
440
441        (
442            future_imports,
443            stdlib_imports,
444            third_party_imports,
445            local_imports,
446        )
447    }
448
449    /// Clear all registered imports while preserving configuration
450    ///
451    /// This method clears both regular and TYPE_CHECKING imports, making the helper
452    /// ready to collect new imports. Configuration like formatting settings and
453    /// local package prefixes are preserved.
454    ///
455    /// This is useful when reusing the same helper for multiple files.
456    ///
457    /// # Examples
458    ///
459    /// ```
460    /// use py_import_helper::ImportHelper;
461    ///
462    /// let mut helper = ImportHelper::with_package_name("mypackage".to_string());
463    /// helper.add_import_string("from typing import Any");
464    /// assert!(!helper.is_empty());
465    ///
466    /// helper.clear();
467    /// assert!(helper.is_empty());
468    /// assert_eq!(helper.count(), 0);
469    /// ```
470    pub fn clear(&mut self) -> &mut Self {
471        self.sections = ImportSections::default();
472        self.category_cache.clear();
473        self
474    }
475
476    /// Reset the import helper to a native state without any configuration
477    ///
478    /// This method resets the helper to the same state as `ImportHelper::new()`,
479    /// clearing all imports, local package prefixes, and resetting formatting
480    /// configuration to defaults. The package name is also cleared.
481    ///
482    /// This is useful when you need a completely fresh helper instance
483    /// without creating a new one.
484    ///
485    /// # Examples
486    ///
487    /// ```
488    /// use py_import_helper::ImportHelper;
489    ///
490    /// let mut helper = ImportHelper::with_package_name("mypackage".to_string());
491    /// helper.add_local_package_prefix("other_package");
492    /// helper.add_import_string("from typing import Any");
493    ///
494    /// // Reset to native state
495    /// helper.reset();
496    ///
497    /// assert!(helper.is_empty());
498    /// assert_eq!(helper.count(), 0);
499    /// // Local package prefixes and package name are cleared
500    /// ```
501    pub fn reset(&mut self) -> &mut Self {
502        self.sections = ImportSections::default();
503        self.category_cache.clear();
504        self.package_name = None;
505        self.local_package_prefixes.clear();
506        self.registry = PackageRegistry::new();
507        self.formatting_config = FormattingConfig::default();
508        self
509    }
510
511    /// Check if any imports have been collected (excluding `TYPE_CHECKING` imports)
512    #[must_use]
513    pub fn is_empty(&self) -> bool {
514        self.sections.future.is_empty()
515            && self.sections.standard_library_direct.is_empty()
516            && self.sections.standard_library_from.is_empty()
517            && self.sections.third_party_direct.is_empty()
518            && self.sections.third_party_from.is_empty()
519            && self.sections.local_direct.is_empty()
520            && self.sections.local_from.is_empty()
521    }
522
523    /// Check if any `TYPE_CHECKING` imports have been collected
524    #[must_use]
525    pub fn is_type_checking_empty(&self) -> bool {
526        self.sections.type_checking_future.is_empty()
527            && self
528                .sections
529                .type_checking_standard_library_direct
530                .is_empty()
531            && self.sections.type_checking_standard_library_from.is_empty()
532            && self.sections.type_checking_third_party_direct.is_empty()
533            && self.sections.type_checking_third_party_from.is_empty()
534            && self.sections.type_checking_local_direct.is_empty()
535            && self.sections.type_checking_local_from.is_empty()
536    }
537
538    /// Count total number of import statements collected (excluding `TYPE_CHECKING` imports)
539    #[must_use]
540    pub fn count(&self) -> usize {
541        self.sections.future.len()
542            + self.sections.standard_library_direct.len()
543            + self.sections.standard_library_from.len()
544            + self.sections.third_party_direct.len()
545            + self.sections.third_party_from.len()
546            + self.sections.local_direct.len()
547            + self.sections.local_from.len()
548    }
549
550    /// Count total number of `TYPE_CHECKING` import statements collected
551    #[must_use]
552    pub fn count_type_checking(&self) -> usize {
553        self.sections.type_checking_future.len()
554            + self.sections.type_checking_standard_library_direct.len()
555            + self.sections.type_checking_standard_library_from.len()
556            + self.sections.type_checking_third_party_direct.len()
557            + self.sections.type_checking_third_party_from.len()
558            + self.sections.type_checking_local_direct.len()
559            + self.sections.type_checking_local_from.len()
560    }
561
562    /// Generate sorted and formatted import statements
563    #[must_use]
564    pub fn get_formatted(&self) -> Vec<String> {
565        let mut result = Vec::new();
566        let mut has_previous_section = false;
567
568        // Future imports
569        if !self.sections.future.is_empty() {
570            let future_imports = self.format_imports(&self.sections.future);
571            result.extend(future_imports);
572            has_previous_section = true;
573        }
574
575        // Standard library imports - direct first, then from
576        let std_has_direct = !self.sections.standard_library_direct.is_empty();
577        let std_has_from = !self.sections.standard_library_from.is_empty();
578
579        if std_has_direct || std_has_from {
580            if has_previous_section {
581                result.push(String::new()); // Empty line between sections
582            }
583
584            // Direct imports first
585            if std_has_direct {
586                let std_direct_imports =
587                    self.format_imports(&self.sections.standard_library_direct);
588                result.extend(std_direct_imports);
589            }
590
591            // From imports after direct imports
592            if std_has_from {
593                let std_from_imports = self.format_imports(&self.sections.standard_library_from);
594                result.extend(std_from_imports);
595            }
596
597            has_previous_section = true;
598        }
599
600        // Third-party imports - direct first, then from
601        let third_has_direct = !self.sections.third_party_direct.is_empty();
602        let third_has_from = !self.sections.third_party_from.is_empty();
603
604        if third_has_direct || third_has_from {
605            if has_previous_section {
606                result.push(String::new()); // Empty line between sections
607            }
608
609            // Direct imports first
610            if third_has_direct {
611                let third_direct_imports = self.format_imports(&self.sections.third_party_direct);
612                result.extend(third_direct_imports);
613            }
614
615            // From imports after direct imports
616            if third_has_from {
617                let third_from_imports = self.format_imports(&self.sections.third_party_from);
618                result.extend(third_from_imports);
619            }
620
621            has_previous_section = true;
622        }
623
624        // Local imports - direct first, then from
625        let local_has_direct = !self.sections.local_direct.is_empty();
626        let local_has_from = !self.sections.local_from.is_empty();
627
628        if local_has_direct || local_has_from {
629            if has_previous_section {
630                result.push(String::new()); // Empty line between sections
631            }
632
633            // Direct imports first
634            if local_has_direct {
635                let local_direct_imports = self.format_imports(&self.sections.local_direct);
636                result.extend(local_direct_imports);
637            }
638
639            // From imports after direct imports
640            if local_has_from {
641                let local_from_imports = self.format_imports(&self.sections.local_from);
642                result.extend(local_from_imports);
643            }
644        }
645
646        result
647    }
648
649    /// Parse an import statement and categorize it
650    fn parse_import(&mut self, import_statement: &str) -> Option<ImportStatement> {
651        let trimmed = import_statement.trim();
652        if trimmed.is_empty() {
653            return None;
654        }
655
656        let category = self.categorize_import(trimmed);
657        let import_type = if trimmed.starts_with("from ") {
658            ImportType::From
659        } else {
660            ImportType::Direct
661        };
662        let package = Self::extract_package(trimmed);
663        let items = Self::extract_items(trimmed);
664        let is_multiline = trimmed.contains('(') || trimmed.contains(')');
665
666        // Reconstruct the statement with sorted items for from imports
667        // Preserve multiline format if present
668        let statement = if is_multiline {
669            trimmed.to_string()
670        } else if import_type == ImportType::From && !items.is_empty() {
671            format!("from {} import {}", package, items.join(", "))
672        } else {
673            trimmed.to_string()
674        };
675
676        Some(ImportStatement {
677            statement,
678            category,
679            import_type,
680            package,
681            items,
682            is_multiline,
683        })
684    }
685
686    /// Categorize an import statement
687    fn categorize_import(&mut self, import_statement: &str) -> ImportCategory {
688        if import_statement.starts_with("from __future__") {
689            return ImportCategory::Future;
690        }
691
692        let package = Self::extract_package(import_statement);
693
694        // Check cache first
695        if let Some(&cached_category) = self.category_cache.get(&package) {
696            return cached_category;
697        }
698
699        // Determine category with priority order:
700        // 1. Local imports (relative or matching local prefixes)
701        // 2. Standard library (built-in or custom registered)
702        // 3. Third-party (custom registered or default)
703        let category = if self.is_local_import(import_statement) {
704            ImportCategory::Local
705        } else if self.is_standard_library_package(&package) {
706            ImportCategory::StandardLibrary
707        } else if self.is_common_third_party_package(&package) {
708            ImportCategory::ThirdParty
709        } else {
710            // Default to third-party for unknown packages
711            ImportCategory::ThirdParty
712        };
713
714        self.category_cache.insert(package, category);
715        category
716    }
717
718    /// Extract the package name from an import statement
719    fn extract_package(import_statement: &str) -> String {
720        if let Some(from_part) = import_statement.strip_prefix("from ") {
721            if let Some(import_pos) = from_part.find(" import ") {
722                return from_part[..import_pos].trim().to_string();
723            }
724        } else if let Some(import_part) = import_statement.strip_prefix("import ") {
725            // For direct imports, return the full module path
726            return import_part
727                .split_whitespace()
728                .next()
729                .unwrap_or(import_part)
730                .trim()
731                .to_string();
732        }
733
734        import_statement.to_string()
735    }
736
737    /// Extract imported items from an import statement
738    fn extract_items(import_statement: &str) -> Vec<String> {
739        if let Some(from_part) = import_statement.strip_prefix("from ") {
740            if let Some(import_pos) = from_part.find(" import ") {
741                let items_part = &from_part[import_pos + 8..];
742                let cleaned = items_part.replace(['(', ')'], "").replace(',', " ");
743                let mut items: Vec<String> = cleaned
744                    .split_whitespace()
745                    .map(|s| s.trim().to_string())
746                    .filter(|s| !s.is_empty())
747                    .collect();
748
749                // Sort items with ALL_CAPS first, then mixed case alphabetically
750                items.sort_by(|a, b| crate::utils::parsing::custom_import_sort(a, b));
751                return items;
752            }
753        } else if let Some(import_part) = import_statement.strip_prefix("import ") {
754            // For direct imports, the "item" is the module itself
755            return vec![import_part.trim().to_string()];
756        }
757        Vec::new()
758    }
759
760    /// Check if this is a local/relative import
761    fn is_local_import(&self, import_statement: &str) -> bool {
762        // Check for relative imports
763        if import_statement.contains("from .")
764            || import_statement.contains("from ..")
765            || import_statement.contains("from ...")
766            || import_statement.contains("from ....")
767        {
768            return true;
769        }
770
771        let package = Self::extract_package(import_statement);
772
773        // Check custom local package prefixes first
774        for prefix in &self.local_package_prefixes {
775            if package.starts_with(prefix.as_str()) {
776                return true;
777            }
778        }
779
780        // Fallback to package_name check for backwards compatibility
781        if let Some(pkg_name) = &self.package_name {
782            if package.starts_with(pkg_name) {
783                return true;
784            }
785        }
786
787        false
788    }
789
790    /// Check if a package is part of Python's standard library
791    fn is_standard_library_package(&self, package: &str) -> bool {
792        // Check against the constant list of standard library modules
793        self.registry.is_stdlib(package)
794    }
795
796    /// Check if a package is a common third-party package
797    fn is_common_third_party_package(&self, package: &str) -> bool {
798        // Check against the constant list of common third-party packages
799        self.registry.is_third_party(package)
800    }
801
802    /// Format a list of imports, merging same-package imports where appropriate
803    fn format_imports(&self, imports: &[ImportStatement]) -> Vec<String> {
804        crate::utils::formatting::format_imports(imports, &self.formatting_config)
805    }
806
807    fn sort_import_statements(a: &str, b: &str) -> std::cmp::Ordering {
808        let a_is_import = a.starts_with("import ");
809        let b_is_import = b.starts_with("import ");
810
811        match (a_is_import, b_is_import) {
812            // Both are 'import' or both are 'from' - sort alphabetically
813            (true, true) | (false, false) => a.cmp(b),
814            // a is 'import', b is 'from' - a comes first
815            (true, false) => std::cmp::Ordering::Less,
816            // a is 'from', b is 'import' - b comes first
817            (false, true) => std::cmp::Ordering::Greater,
818        }
819    }
820
821    /// Automatically add `TYPE_CHECKING` to typing import when type checking imports are used
822    fn ensure_type_checking_import_added(&mut self) {
823        // Check if we already have a typing import with TYPE_CHECKING
824        let has_type_checking = self.sections.standard_library_from.iter().any(|import| {
825            import.package == "typing" && import.items.contains(&"TYPE_CHECKING".to_string())
826        });
827
828        if !has_type_checking {
829            // Check if we have any typing import that we can modify
830            if let Some(typing_import) = self
831                .sections
832                .standard_library_from
833                .iter_mut()
834                .find(|import| import.package == "typing")
835            {
836                // Add TYPE_CHECKING to existing typing import
837                if !typing_import.items.contains(&"TYPE_CHECKING".to_string()) {
838                    typing_import.items.push("TYPE_CHECKING".to_string());
839                    typing_import
840                        .items
841                        .sort_by(|a, b| crate::utils::parsing::custom_import_sort(a, b));
842
843                    // Update the statement string
844                    if typing_import.items.len() == 1 {
845                        typing_import.statement =
846                            format!("from typing import {}", typing_import.items[0]);
847                    } else {
848                        typing_import.statement =
849                            format!("from typing import {}", typing_import.items.join(", "));
850                    }
851                }
852            } else {
853                // No typing import exists, add one with just TYPE_CHECKING
854                self.add_import_string("from typing import TYPE_CHECKING");
855            }
856        }
857    }
858
859    /// Clone configuration without imports (useful for creating multiple helpers with same config)
860    #[must_use]
861    pub fn clone_config(&self) -> Self {
862        Self {
863            sections: ImportSections::default(),
864            category_cache: self.category_cache.clone(),
865            package_name: self.package_name.clone(),
866            local_package_prefixes: self.local_package_prefixes.clone(),
867            registry: self.registry.clone(),
868            formatting_config: self.formatting_config.clone(),
869        }
870    }
871}
872
873/// Convenience functions for common import operations
874impl ImportHelper {
875    /// Create imports for a model file with required type imports
876    pub fn create_model_imports(&mut self, required_types: &[String]) {
877        // Standard model imports
878        self.add_import_string("from pydantic import BaseModel, ConfigDict, Field");
879
880        // Collect all typing imports needed across all types
881        let mut typing_imports = std::collections::HashSet::new();
882        let mut collections_abc_imports = std::collections::HashSet::new();
883        let mut datetime_imports = Vec::new();
884        let mut decimal_imports = Vec::new();
885
886        for type_name in required_types {
887            match type_name.as_str() {
888                "datetime" | "date" | "time" | "timedelta" => {
889                    if !datetime_imports.contains(&type_name.as_str()) {
890                        datetime_imports.push(type_name.as_str());
891                    }
892                }
893                "Decimal" => {
894                    if !decimal_imports.contains(&"Decimal") {
895                        decimal_imports.push("Decimal");
896                    }
897                }
898                "UUID" => {
899                    self.add_import_string("from uuid import UUID");
900                }
901                // For complex types, extract typing imports
902                _ => {
903                    // Check if this type contains typing elements
904                    let extracted_typing = Self::extract_typing_imports_from_type(type_name);
905                    typing_imports.extend(extracted_typing);
906
907                    // Check for collections.abc imports
908                    if type_name.contains("Callable") {
909                        collections_abc_imports.insert("Callable".to_string());
910                    }
911                }
912            }
913        }
914
915        // Add datetime imports if any were found
916        if !datetime_imports.is_empty() {
917            let import_statement = format!("from datetime import {}", datetime_imports.join(", "));
918            self.add_regular_import(&import_statement);
919        }
920
921        // Add decimal imports if any were found
922        if !decimal_imports.is_empty() {
923            self.add_import_string("from decimal import Decimal");
924        }
925
926        // Add typing imports if any were found (only Any, Generic, TypeVar, Protocol)
927        if !typing_imports.is_empty() {
928            let mut sorted_typing: Vec<String> = typing_imports.into_iter().collect();
929            sorted_typing.sort();
930            let import_statement = format!("from typing import {}", sorted_typing.join(", "));
931            self.add_regular_import(&import_statement);
932        }
933
934        // Add collections.abc imports if any were found (e.g., Callable)
935        if !collections_abc_imports.is_empty() {
936            let mut sorted_collections: Vec<String> = collections_abc_imports.into_iter().collect();
937            sorted_collections.sort();
938            let import_statement = format!(
939                "from collections.abc import {}",
940                sorted_collections.join(", ")
941            );
942            self.add_regular_import(&import_statement);
943        }
944    }
945
946    /// Extract typing imports from a complex type string
947    /// This handles types like list[Any], dict[str, Any], etc.
948    /// Only imports what's actually needed for Python 3.13+ (Any, Generic, `TypeVar`, Protocol)
949    fn extract_typing_imports_from_type(type_str: &str) -> std::collections::HashSet<String> {
950        let mut typing_imports = std::collections::HashSet::new();
951
952        // Check for Any type (used in generics and standalone)
953        if type_str.contains("Any") {
954            typing_imports.insert("Any".to_string());
955        }
956
957        // Check for Generic type (used for generic classes)
958        if type_str.contains("Generic") {
959            typing_imports.insert("Generic".to_string());
960        }
961
962        // Check for TypeVar usage
963        if type_str.contains("TypeVar") {
964            typing_imports.insert("TypeVar".to_string());
965        }
966
967        // Check for Protocol type (structural subtyping)
968        if type_str.contains("Protocol") {
969            typing_imports.insert("Protocol".to_string());
970        }
971
972        typing_imports
973    }
974}
975
976impl Default for ImportHelper {
977    fn default() -> Self {
978        Self::new()
979    }
980}
981
982#[cfg(test)]
983mod tests {
984    use super::*;
985
986    #[test]
987    fn test_import_categorization() {
988        let mut helper = ImportHelper::new();
989
990        // Test future imports
991        helper.add_import_string("from __future__ import annotations");
992        assert_eq!(helper.sections.future.len(), 1);
993
994        // Test standard library from import
995        helper.add_import_string("from typing import Optional");
996        assert_eq!(helper.sections.standard_library_from.len(), 1);
997
998        // Test standard library direct import
999        helper.add_import_string("import uuid");
1000        assert_eq!(helper.sections.standard_library_direct.len(), 1);
1001
1002        // Test third party
1003        helper.add_import_string("from pydantic import BaseModel");
1004        assert_eq!(helper.sections.third_party_from.len(), 1);
1005
1006        // Test local imports
1007        helper.add_import_string("from .models import User");
1008        assert_eq!(helper.sections.local_from.len(), 1);
1009    }
1010
1011    #[test]
1012    fn test_import_merging() {
1013        let mut helper = ImportHelper::new();
1014
1015        helper.add_import_string("from typing import Optional");
1016        helper.add_import_string("from typing import Any");
1017        helper.add_import_string("from typing import List");
1018
1019        let imports = helper.get_formatted();
1020
1021        // Should merge into a single import
1022        assert!(imports.iter().any(|i| i.contains("from typing import")));
1023        assert!(imports
1024            .iter()
1025            .any(|i| i.contains("Any") && i.contains("Optional")));
1026    }
1027
1028    #[test]
1029    fn test_alphabetical_sorting_of_import_items() {
1030        let mut helper = ImportHelper::new();
1031
1032        // Test with unsorted items in a from import
1033        helper.add_import_string("from typing import Optional, Any, List, Dict");
1034        helper.add_import_string("from collections import defaultdict, OrderedDict, Counter");
1035        helper.add_import_string("import uuid");
1036        helper.add_import_string("import os");
1037
1038        let imports = helper.get_formatted();
1039
1040        // Print for debugging
1041        for import in &imports {
1042            println!("{}", import);
1043        }
1044
1045        // Check that direct imports come before from imports
1046        let import_str = imports.join("\n");
1047        let os_pos = import_str.find("import os").unwrap();
1048        let uuid_pos = import_str.find("import uuid").unwrap();
1049        let typing_pos = import_str.find("from typing import").unwrap();
1050
1051        // Direct imports should come before from imports
1052        assert!(
1053            os_pos < typing_pos,
1054            "Direct imports should come before from imports"
1055        );
1056        assert!(
1057            uuid_pos < typing_pos,
1058            "Direct imports should come before from imports"
1059        );
1060
1061        // Check that items within from imports are sorted alphabetically
1062        let typing_section = &import_str[import_str.find("from typing").unwrap()..];
1063
1064        // Verify alphabetical order in typing import (case-insensitive, but considering both single and multi-line)
1065        assert!(
1066            typing_section.contains("Any")
1067                && typing_section.contains("Dict")
1068                && typing_section.contains("List")
1069                && typing_section.contains("Optional"),
1070            "Import should contain all items: Any, Dict, List, Optional in alphabetical order"
1071        );
1072    }
1073
1074    #[test]
1075    fn test_direct_imports_sorted_alphabetically() {
1076        let mut helper = ImportHelper::new();
1077
1078        helper.add_direct_import("uuid");
1079        helper.add_direct_import("os");
1080        helper.add_direct_import("sys");
1081        helper.add_direct_import("json");
1082
1083        let imports = helper.get_formatted();
1084
1085        // Should be sorted: json, os, sys, uuid
1086        let import_lines: Vec<String> = imports
1087            .iter()
1088            .filter(|s| !s.is_empty() && s.contains("import"))
1089            .cloned()
1090            .collect();
1091
1092        // Verify all 4 imports are present
1093        assert_eq!(
1094            import_lines.len(),
1095            4,
1096            "Should have 4 direct imports, got: {:?}",
1097            import_lines
1098        );
1099
1100        // Verify they are sorted alphabetically
1101        assert!(
1102            import_lines[0].contains("json"),
1103            "First should be json, got: {}",
1104            import_lines[0]
1105        );
1106        assert!(
1107            import_lines[1].contains("os"),
1108            "Second should be os, got: {}",
1109            import_lines[1]
1110        );
1111        assert!(
1112            import_lines[2].contains("sys"),
1113            "Third should be sys, got: {}",
1114            import_lines[2]
1115        );
1116        assert!(
1117            import_lines[3].contains("uuid"),
1118            "Fourth should be uuid, got: {}",
1119            import_lines[3]
1120        );
1121    }
1122
1123    #[test]
1124    fn test_uppercase_priority_in_import_sorting() {
1125        let mut helper = ImportHelper::new();
1126
1127        // Test with mixed case items - uppercase should come before lowercase for same letter
1128        helper.add_import_string("from example import Ab, AA, Aa, AB");
1129
1130        let imports = helper.get_formatted();
1131
1132        // Print for debugging
1133        for import in &imports {
1134            println!("{}", import);
1135        }
1136
1137        let import_str = imports.join("\n");
1138        let example_section = &import_str[import_str.find("from example").unwrap()..];
1139
1140        // Verify all items are present (order doesn't matter as much as presence for multiline)
1141        assert!(example_section.contains("AA"), "Should contain AA");
1142        assert!(example_section.contains("AB"), "Should contain AB");
1143        assert!(example_section.contains("Aa"), "Should contain Aa");
1144        assert!(example_section.contains("Ab"), "Should contain Ab");
1145    }
1146
1147    #[test]
1148    fn test_comprehensive_case_sorting() {
1149        let mut helper = ImportHelper::new();
1150
1151        // Test with multiple letters and mixed cases
1152        helper.add_import_string("from test import zz, ZZ, bb, BB, aa, AA, cc, CC");
1153
1154        let imports = helper.get_formatted();
1155
1156        let import_str = imports.join("\n");
1157        let test_section = &import_str[import_str.find("from test").unwrap()..];
1158
1159        // Verify all items are present in the correct import section
1160        assert!(test_section.contains("AA"), "Should contain AA");
1161        assert!(test_section.contains("BB"), "Should contain BB");
1162        assert!(test_section.contains("CC"), "Should contain CC");
1163        assert!(test_section.contains("ZZ"), "Should contain ZZ");
1164        assert!(test_section.contains("aa"), "Should contain aa");
1165        assert!(test_section.contains("bb"), "Should contain bb");
1166        assert!(test_section.contains("cc"), "Should contain cc");
1167        assert!(test_section.contains("zz"), "Should contain zz");
1168    }
1169
1170    #[test]
1171    fn test_type_checking_imports() {
1172        let mut helper = ImportHelper::with_package_name("mypackage".to_string());
1173
1174        // Add regular imports
1175        helper.add_import_string("from __future__ import annotations");
1176        helper.add_import_string("from typing import Any");
1177        helper.add_import_string("from pydantic import BaseModel");
1178
1179        // Add TYPE_CHECKING imports
1180        helper.add_type_checking_import("import httpx");
1181        helper.add_type_checking_import("from typing import TYPE_CHECKING");
1182        helper.add_type_checking_import("from collections.abc import Callable");
1183        helper.add_type_checking_import("from mypackage.models import User");
1184
1185        // Check that regular imports are counted correctly
1186        assert_eq!(helper.count(), 3);
1187        assert_eq!(helper.count_type_checking(), 4);
1188        assert!(!helper.is_type_checking_empty());
1189
1190        // Generate TYPE_CHECKING imports
1191        let (future, stdlib, third_party, local) = helper.get_type_checking_categorized();
1192
1193        // Verify categorization
1194        assert!(
1195            third_party.iter().any(|s| s.contains("import httpx")),
1196            "Should have httpx in third_party"
1197        );
1198        assert!(
1199            stdlib
1200                .iter()
1201                .any(|s| s.contains("from typing import TYPE_CHECKING")),
1202            "Should have TYPE_CHECKING"
1203        );
1204        assert!(
1205            stdlib
1206                .iter()
1207                .any(|s| s.contains("from collections.abc import Callable")),
1208            "Should have Callable"
1209        );
1210        assert!(
1211            local
1212                .iter()
1213                .any(|s| s.contains("from mypackage.models import User")),
1214            "Should have User in local"
1215        );
1216        assert!(
1217            future.is_empty(),
1218            "Should have no future imports in TYPE_CHECKING"
1219        );
1220    }
1221
1222    #[test]
1223    fn test_programmatic_import_builders() {
1224        let mut helper = ImportHelper::new();
1225
1226        // Test add_from_import
1227        helper.add_from_import("typing", &["Any", "Optional"]);
1228        helper.add_from_import("json", &["loads"]);
1229
1230        // Test add_direct_import
1231        helper.add_direct_import("sys");
1232
1233        // Test TYPE_CHECKING builders
1234        helper.add_type_checking_from_import("httpx", &["Client", "Response"]);
1235        helper.add_type_checking_direct_import("logging");
1236
1237        // Verify regular imports
1238        let (_, stdlib, _, _) = helper.get_categorized();
1239        assert!(
1240            stdlib
1241                .iter()
1242                .any(|s| s.contains("Any") && s.contains("Optional")),
1243            "Should have Any and Optional in typing imports"
1244        );
1245        assert!(stdlib.iter().any(|s| s.contains("from json import loads")));
1246        assert!(stdlib.iter().any(|s| s.contains("import sys")));
1247
1248        // Verify TYPE_CHECKING imports
1249        let (_, tc_stdlib, tc_third_party, _) = helper.get_type_checking_categorized();
1250        assert!(tc_third_party
1251            .iter()
1252            .any(|s| s.contains("from httpx import Client, Response")));
1253        assert!(tc_stdlib.iter().any(|s| s.contains("import logging")));
1254    }
1255
1256    #[test]
1257    fn test_type_checking_four_categories() {
1258        let mut helper = ImportHelper::with_package_name("myapp".to_string());
1259
1260        // Add imports in all four categories for TYPE_CHECKING
1261        helper.add_type_checking_import("from __future__ import annotations");
1262        helper.add_type_checking_import("from typing import Protocol");
1263        helper.add_type_checking_import("from httpx import Client");
1264        helper.add_type_checking_import("from myapp.models import User");
1265
1266        let (future, stdlib, third_party, local) = helper.get_type_checking_categorized();
1267
1268        // Verify all four categories
1269        assert_eq!(future.len(), 1, "Should have 1 future import");
1270        assert_eq!(stdlib.len(), 1, "Should have 1 stdlib import");
1271        assert_eq!(third_party.len(), 1, "Should have 1 third-party import");
1272        assert_eq!(local.len(), 1, "Should have 1 local import");
1273
1274        assert!(future[0].contains("from __future__ import annotations"));
1275        assert!(stdlib[0].contains("from typing import Protocol"));
1276        assert!(third_party[0].contains("from httpx import Client"));
1277        assert!(local[0].contains("from myapp.models import User"));
1278    }
1279
1280    #[test]
1281    fn test_clear_preserves_configuration() {
1282        let mut helper = ImportHelper::with_package_name("mypackage".to_string());
1283        let mut config = FormattingConfig::black_compatible();
1284        config.force_multiline = true;
1285        helper.set_formatting_config(config.clone());
1286        helper.add_local_package_prefix("other_package");
1287
1288        // Add imports
1289        helper.add_import_string("from typing import Any");
1290        helper.add_type_checking_import("from httpx import Client");
1291        assert_eq!(helper.count(), 1);
1292        assert_eq!(helper.count_type_checking(), 1);
1293
1294        // Clear imports
1295        helper.clear();
1296
1297        // Verify imports are cleared
1298        assert!(helper.is_empty());
1299        assert!(helper.is_type_checking_empty());
1300        assert_eq!(helper.count(), 0);
1301        assert_eq!(helper.count_type_checking(), 0);
1302
1303        // Verify configuration is preserved
1304        assert_eq!(helper.package_name.as_deref(), Some("mypackage"));
1305        assert!(helper.local_package_prefixes.contains("mypackage"));
1306        assert!(helper.local_package_prefixes.contains("other_package"));
1307        assert_eq!(helper.formatting_config().line_length, 88);
1308        assert!(helper.formatting_config().force_multiline);
1309    }
1310
1311    #[test]
1312    fn test_reset_clears_everything() {
1313        let mut helper = ImportHelper::with_package_name("mypackage".to_string());
1314        let mut config = FormattingConfig::black_compatible();
1315        config.force_multiline = true;
1316        helper.set_formatting_config(config);
1317        helper.add_local_package_prefix("other_package");
1318
1319        // Add imports
1320        helper.add_import_string("from typing import Any");
1321        helper.add_type_checking_import("from httpx import Client");
1322        assert_eq!(helper.count(), 1);
1323        assert_eq!(helper.count_type_checking(), 1);
1324
1325        // Reset to default
1326        helper.reset();
1327
1328        // Verify everything is reset
1329        assert!(helper.is_empty());
1330        assert!(helper.is_type_checking_empty());
1331        assert_eq!(helper.count(), 0);
1332        assert_eq!(helper.count_type_checking(), 0);
1333
1334        // Verify configuration is reset to default
1335        assert_eq!(helper.package_name, None);
1336        assert!(helper.local_package_prefixes.is_empty());
1337        assert_eq!(helper.formatting_config().line_length, 79); // Default PEP8
1338        assert!(!helper.formatting_config().force_multiline);
1339    }
1340
1341    #[test]
1342    fn test_clear_and_reuse() {
1343        let mut helper = ImportHelper::with_package_name("mypackage".to_string());
1344
1345        // First use
1346        helper.add_import_string("from typing import Any");
1347        helper.add_import_string("from pydantic import BaseModel");
1348        assert_eq!(helper.count(), 2);
1349
1350        // Clear for reuse with same configuration
1351        helper.clear();
1352        assert!(helper.is_empty());
1353
1354        // Second use with same configuration
1355        helper.add_import_string("import json");
1356        helper.add_import_string("import sys");
1357        let (_, stdlib, _, _) = helper.get_categorized();
1358        assert_eq!(stdlib.len(), 2);
1359        assert!(stdlib.iter().any(|s| s.contains("import json")));
1360        assert!(stdlib.iter().any(|s| s.contains("import sys")));
1361    }
1362}