py_import_helper/utils/
formatting.rs1use super::parsing::custom_import_sort;
7use crate::types::ImportStatement;
8use std::collections::{HashMap, HashSet};
9
10#[must_use]
12pub fn format_imports(imports: &[ImportStatement]) -> Vec<String> {
13 let mut package_imports: HashMap<String, Vec<&ImportStatement>> = HashMap::new();
14
15 for import in imports {
17 package_imports
18 .entry(import.package.clone())
19 .or_default()
20 .push(import);
21 }
22
23 let mut result = Vec::new();
24 let mut packages: Vec<_> = package_imports.keys().collect();
25 packages.sort();
26
27 for package in packages {
28 let imports_for_package = package_imports
29 .get(package)
30 .expect("BUG: package key must exist in HashMap");
31
32 if imports_for_package.len() == 1 {
33 result.push(imports_for_package[0].statement.clone());
35 } else {
36 result.extend(merge_package_imports(imports_for_package));
38 }
39 }
40
41 result
42}
43
44#[must_use]
46pub fn merge_package_imports(imports: &[&ImportStatement]) -> Vec<String> {
47 let mut all_items = HashSet::new();
48 let package = &imports[0].package;
49
50 for import in imports {
52 all_items.extend(import.items.iter().cloned());
53 }
54
55 if all_items.is_empty() {
56 return imports.iter().map(|i| i.statement.clone()).collect();
58 }
59
60 let mut sorted_items: Vec<_> = all_items.into_iter().collect();
61 sorted_items.sort_by(|a, b| custom_import_sort(a, b));
62
63 if sorted_items.len() <= 3
65 && sorted_items
66 .iter()
67 .map(std::string::String::len)
68 .sum::<usize>()
69 < 60
70 {
71 vec![format!(
73 "from {} import {}",
74 package,
75 sorted_items.join(", ")
76 )]
77 } else {
78 let mut result = vec![format!("from {} import (", package)];
80 for item in sorted_items {
81 result.push(format!(" {item},"));
82 }
83 result.push(")".to_string());
84 result
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use crate::types::{ImportCategory, ImportType};
92
93 #[test]
94 fn test_merge_package_imports() {
95 let import1 = ImportStatement {
96 statement: "from typing import Any".to_string(),
97 category: ImportCategory::StandardLibrary,
98 import_type: ImportType::From,
99 package: "typing".to_string(),
100 items: vec!["Any".to_string()],
101 is_multiline: false,
102 };
103
104 let import2 = ImportStatement {
105 statement: "from typing import Optional".to_string(),
106 category: ImportCategory::StandardLibrary,
107 import_type: ImportType::From,
108 package: "typing".to_string(),
109 items: vec!["Optional".to_string()],
110 is_multiline: false,
111 };
112
113 let merged = merge_package_imports(&[&import1, &import2]);
114 assert_eq!(merged.len(), 1);
115 assert!(merged[0].contains("Any"));
116 assert!(merged[0].contains("Optional"));
117 }
118}