Skip to main content

dictator_python/
imports.rs

1//! Import ordering checks for Python sources (PEP 8 compliant).
2
3use dictator_decree_abi::{Diagnostic, Diagnostics, Span};
4use memchr::memchr_iter;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum ImportType {
8    Stdlib,     // Python standard library: os, sys, json, etc.
9    ThirdParty, // External packages: requests, django, etc.
10    Local,      // Relative imports: . or ..
11}
12
13impl ImportType {
14    const fn order(self) -> u8 {
15        match self {
16            Self::Stdlib => 0,
17            Self::ThirdParty => 1,
18            Self::Local => 2,
19        }
20    }
21}
22
23pub fn check_import_ordering(source: &str, diags: &mut Diagnostics) {
24    let bytes = source.as_bytes();
25    let mut imports: Vec<(usize, usize, ImportType)> = Vec::new();
26    let mut line_start = 0;
27
28    for nl in memchr_iter(b'\n', bytes) {
29        let line = &source[line_start..nl];
30        let trimmed = line.trim();
31
32        if let Some(import_type) = parse_import_line(trimmed) {
33            imports.push((line_start, nl, import_type));
34        }
35
36        // Stop at first non-import, non-comment, non-blank, non-docstring line
37        if !trimmed.is_empty()
38            && !trimmed.starts_with("import")
39            && !trimmed.starts_with("from")
40            && !trimmed.starts_with('#')
41            && !trimmed.starts_with("\"\"\"")
42            && !trimmed.starts_with("'''")
43            && !trimmed.ends_with("\"\"\"")
44            && !trimmed.ends_with("'''")
45            && !trimmed.starts_with("__future__")
46        {
47            break;
48        }
49
50        line_start = nl + 1;
51    }
52
53    // Check import order
54    if imports.len() > 1 {
55        let mut last_type = ImportType::Stdlib;
56
57        for (start, end, import_type) in &imports {
58            if import_type.order() < last_type.order() {
59                diags.push(Diagnostic {
60                    rule: "python/import-order".to_string(),
61                    message: format!(
62                        "Import order violation: {import_type:?} import after \
63                         {last_type:?} import. Expected: stdlib -> third_party -> local"
64                    ),
65                    enforced: false,
66                    span: Span::new(*start, *end),
67                });
68            }
69
70            last_type = *import_type;
71        }
72    }
73}
74
75fn parse_import_line(line: &str) -> Option<ImportType> {
76    if !line.starts_with("import") && !line.starts_with("from") {
77        return None;
78    }
79
80    // Handle "from X import Y" style
81    if line.starts_with("from") {
82        let from_keyword = "from ";
83        if let Some(pos) = line.find(from_keyword) {
84            let after_from = &line[pos + from_keyword.len()..];
85            let module_name = after_from.split_whitespace().next()?.trim_end_matches(',');
86
87            return Some(classify_module(module_name));
88        }
89    }
90
91    // Handle "import X" style
92    if line.starts_with("import") {
93        let import_keyword = "import ";
94        if let Some(pos) = line.find(import_keyword) {
95            let after_import = &line[pos + import_keyword.len()..];
96            let module_name = after_import
97                .split([',', ';'])
98                .next()?
99                .split_whitespace()
100                .next()?
101                .trim_end_matches(',');
102
103            return Some(classify_module(module_name));
104        }
105    }
106
107    None
108}
109
110#[must_use]
111pub fn classify_module(module_name: &str) -> ImportType {
112    // Local imports start with . or ..
113    if module_name.starts_with('.') {
114        return ImportType::Local;
115    }
116
117    // Get the top-level package name
118    let top_level = module_name.split('.').next().unwrap_or(module_name);
119
120    if is_python_stdlib(top_level) {
121        ImportType::Stdlib
122    } else {
123        ImportType::ThirdParty
124    }
125}
126
127#[allow(clippy::too_many_lines)]
128#[must_use]
129pub fn is_python_stdlib(module: &str) -> bool {
130    matches!(
131        module,
132        "__future__"
133            | "__main__"
134            | "abc"
135            | "argparse"
136            | "array"
137            | "ast"
138            | "asyncio"
139            | "atexit"
140            | "base64"
141            | "bisect"
142            | "builtins"
143            | "bz2"
144            | "calendar"
145            | "cmath"
146            | "cmd"
147            | "code"
148            | "codecs"
149            | "collections"
150            | "concurrent"
151            | "configparser"
152            | "contextlib"
153            | "contextvars"
154            | "copy"
155            | "copyreg"
156            | "csv"
157            | "ctypes"
158            | "curses"
159            | "dataclasses"
160            | "datetime"
161            | "dbm"
162            | "decimal"
163            | "difflib"
164            | "dis"
165            | "distutils"
166            | "doctest"
167            | "email"
168            | "encodings"
169            | "enum"
170            | "errno"
171            | "fcntl"
172            | "filecmp"
173            | "fileinput"
174            | "fnmatch"
175            | "fractions"
176            | "functools"
177            | "gc"
178            | "getopt"
179            | "getpass"
180            | "gettext"
181            | "glob"
182            | "gzip"
183            | "hashlib"
184            | "heapq"
185            | "hmac"
186            | "html"
187            | "http"
188            | "importlib"
189            | "inspect"
190            | "io"
191            | "ipaddress"
192            | "itertools"
193            | "json"
194            | "keyword"
195            | "locale"
196            | "logging"
197            | "lzma"
198            | "marshal"
199            | "math"
200            | "mimetypes"
201            | "mmap"
202            | "multiprocessing"
203            | "numbers"
204            | "operator"
205            | "optparse"
206            | "os"
207            | "pathlib"
208            | "pdb"
209            | "pickle"
210            | "pipes"
211            | "pkgutil"
212            | "platform"
213            | "pprint"
214            | "profile"
215            | "pstats"
216            | "pwd"
217            | "py_compile"
218            | "pydoc"
219            | "queue"
220            | "random"
221            | "re"
222            | "readline"
223            | "reprlib"
224            | "resource"
225            | "runpy"
226            | "sched"
227            | "secrets"
228            | "select"
229            | "selectors"
230            | "shelve"
231            | "shlex"
232            | "shutil"
233            | "signal"
234            | "site"
235            | "smtplib"
236            | "socket"
237            | "sqlite3"
238            | "ssl"
239            | "stat"
240            | "statistics"
241            | "string"
242            | "struct"
243            | "subprocess"
244            | "sys"
245            | "sysconfig"
246            | "syslog"
247            | "tarfile"
248            | "tempfile"
249            | "test"
250            | "textwrap"
251            | "threading"
252            | "time"
253            | "timeit"
254            | "tkinter"
255            | "token"
256            | "tokenize"
257            | "trace"
258            | "traceback"
259            | "tracemalloc"
260            | "tty"
261            | "turtle"
262            | "types"
263            | "typing"
264            | "typing_extensions"
265            | "unittest"
266            | "urllib"
267            | "uuid"
268            | "venv"
269            | "warnings"
270            | "wave"
271            | "weakref"
272            | "webbrowser"
273            | "xml"
274            | "xmlrpc"
275            | "zipfile"
276            | "zlib"
277    )
278}