1use rustc_hash::{FxHashMap, FxHashSet};
6
7const STDLIB: &[&str] = &[
10 "__future__",
11 "abc",
12 "argparse",
13 "array",
14 "ast",
15 "asyncio",
16 "base64",
17 "bisect",
18 "builtins",
19 "bz2",
20 "calendar",
21 "collections",
22 "concurrent",
23 "configparser",
24 "contextlib",
25 "contextvars",
26 "copy",
27 "csv",
28 "ctypes",
29 "dataclasses",
30 "datetime",
31 "decimal",
32 "difflib",
33 "dis",
34 "email",
35 "enum",
36 "errno",
37 "faulthandler",
38 "fcntl",
39 "filecmp",
40 "fileinput",
41 "fnmatch",
42 "fractions",
43 "functools",
44 "gc",
45 "getpass",
46 "gettext",
47 "glob",
48 "graphlib",
49 "gzip",
50 "hashlib",
51 "heapq",
52 "hmac",
53 "html",
54 "http",
55 "imaplib",
56 "importlib",
57 "inspect",
58 "io",
59 "ipaddress",
60 "itertools",
61 "json",
62 "keyword",
63 "logging",
64 "lzma",
65 "math",
66 "mimetypes",
67 "multiprocessing",
68 "numbers",
69 "operator",
70 "os",
71 "pathlib",
72 "pickle",
73 "pkgutil",
74 "platform",
75 "plistlib",
76 "pprint",
77 "profile",
78 "pstats",
79 "queue",
80 "random",
81 "re",
82 "reprlib",
83 "secrets",
84 "select",
85 "selectors",
86 "shelve",
87 "shlex",
88 "shutil",
89 "signal",
90 "site",
91 "smtplib",
92 "socket",
93 "socketserver",
94 "sqlite3",
95 "ssl",
96 "stat",
97 "statistics",
98 "string",
99 "struct",
100 "subprocess",
101 "sys",
102 "sysconfig",
103 "tarfile",
104 "tempfile",
105 "textwrap",
106 "threading",
107 "time",
108 "timeit",
109 "tkinter",
110 "token",
111 "tokenize",
112 "traceback",
113 "tracemalloc",
114 "types",
115 "typing",
116 "unittest",
117 "urllib",
118 "uuid",
119 "venv",
120 "warnings",
121 "wave",
122 "weakref",
123 "webbrowser",
124 "xml",
125 "xmlrpc",
126 "zipfile",
127 "zipimport",
128 "zlib",
129 "zoneinfo",
130];
131
132const ALIASES: &[(&str, &str)] = &[
135 ("cv2", "opencv-python"),
136 ("PIL", "pillow"),
137 ("yaml", "pyyaml"),
138 ("sklearn", "scikit-learn"),
139 ("bs4", "beautifulsoup4"),
140 ("dateutil", "python-dateutil"),
141 ("dotenv", "python-dotenv"),
142 ("jose", "python-jose"),
143 ("attr", "attrs"),
144 ("git", "gitpython"),
145 ("OpenSSL", "pyopenssl"),
146 ("serial", "pyserial"),
147 ("Crypto", "pycryptodome"),
148 ("google", "google-api-python-client"),
149 ("jwt", "pyjwt"),
150 ("MySQLdb", "mysqlclient"),
151 ("psycopg2", "psycopg2-binary"),
152 ("docx", "python-docx"),
153 ("pptx", "python-pptx"),
154 ("markdown", "markdown"),
155];
156
157pub struct Known {
158 stdlib: FxHashSet<&'static str>,
159 alias: FxHashMap<&'static str, &'static str>,
161}
162
163impl Known {
164 pub fn new() -> Self {
165 Known {
166 stdlib: STDLIB.iter().copied().collect(),
167 alias: ALIASES.iter().copied().collect(),
168 }
169 }
170
171 pub fn is_stdlib(&self, top_level: &str) -> bool {
172 self.stdlib.contains(top_level)
173 }
174
175 pub fn dist_for_import(&self, top_level: &str) -> String {
178 if let Some(d) = self.alias.get(top_level) {
179 return normalize_dist(d);
180 }
181 normalize_dist(top_level)
182 }
183}
184
185impl Default for Known {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191pub fn normalize_dist(name: &str) -> String {
193 let mut out = String::with_capacity(name.len());
194 let mut prev_sep = false;
195 for c in name.to_ascii_lowercase().chars() {
196 if c == '-' || c == '_' || c == '.' {
197 if !prev_sep {
198 out.push('-');
199 prev_sep = true;
200 }
201 } else {
202 out.push(c);
203 prev_sep = false;
204 }
205 }
206 out.trim_matches('-').to_string()
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn normalize_is_pep503() {
215 assert_eq!(normalize_dist("Flask_SQLAlchemy"), "flask-sqlalchemy");
216 assert_eq!(normalize_dist("scikit.learn"), "scikit-learn");
217 }
218
219 #[test]
220 fn aliases_and_stdlib() {
221 let k = Known::new();
222 assert!(k.is_stdlib("os"));
223 assert!(!k.is_stdlib("numpy"));
224 assert_eq!(k.dist_for_import("cv2"), "opencv-python");
225 assert_eq!(k.dist_for_import("requests"), "requests");
226 }
227}