dictator_python/
imports.rs

1//! Import ordering checks for Python sources (PEP 8 compliant).
2
3use dictator_decree_abi::{Diagnostic, Diagnostics, Span};
4use memchr::memchr_iter;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum ImportType {
8    Stdlib,     // Python standard library: os, sys, json, etc.
9    ThirdParty, // External packages: requests, django, etc.
10    Local,      // Relative imports: . or ..
11}
12
13impl ImportType {
14    const fn order(self) -> u8 {
15        match self {
16            Self::Stdlib => 0,
17            Self::ThirdParty => 1,
18            Self::Local => 2,
19        }
20    }
21}
22
23pub fn check_import_ordering(source: &str, diags: &mut Diagnostics) {
24    let bytes = source.as_bytes();
25    let mut imports: Vec<(usize, usize, ImportType)> = Vec::new();
26    let mut line_start = 0;
27
28    for nl in memchr_iter(b'\n', bytes) {
29        let line = &source[line_start..nl];
30        let trimmed = line.trim();
31
32        if let Some(import_type) = parse_import_line(trimmed) {
33            imports.push((line_start, nl, import_type));
34        }
35
36        // Stop at first non-import, non-comment, non-blank, non-docstring line
37        if !trimmed.is_empty()
38            && !trimmed.starts_with("import")
39            && !trimmed.starts_with("from")
40            && !trimmed.starts_with('#')
41            && !trimmed.starts_with("\"\"\"")
42            && !trimmed.starts_with("'''")
43            && !trimmed.ends_with("\"\"\"")
44            && !trimmed.ends_with("'''")
45            && !trimmed.starts_with("__future__")
46        {
47            break;
48        }
49
50        line_start = nl + 1;
51    }
52
53    // Check import order
54    if imports.len() > 1 {
55        let mut last_type = ImportType::Stdlib;
56
57        for (start, end, import_type) in &imports {
58            if import_type.order() < last_type.order() {
59                diags.push(Diagnostic {
60                    rule: "python/import-order".to_string(),
61                    message: format!(
62                        "Import order violation: {import_type:?} import after {last_type:?} import. Expected order: stdlib -> third_party -> local"
63                    ),
64                    enforced: false,
65                    span: Span::new(*start, *end),
66                });
67            }
68
69            last_type = *import_type;
70        }
71    }
72}
73
74fn parse_import_line(line: &str) -> Option<ImportType> {
75    if !line.starts_with("import") && !line.starts_with("from") {
76        return None;
77    }
78
79    // Handle "from X import Y" style
80    if line.starts_with("from") {
81        let from_keyword = "from ";
82        if let Some(pos) = line.find(from_keyword) {
83            let after_from = &line[pos + from_keyword.len()..];
84            let module_name = after_from.split_whitespace().next()?.trim_end_matches(',');
85
86            return Some(classify_module(module_name));
87        }
88    }
89
90    // Handle "import X" style
91    if line.starts_with("import") {
92        let import_keyword = "import ";
93        if let Some(pos) = line.find(import_keyword) {
94            let after_import = &line[pos + import_keyword.len()..];
95            let module_name = after_import
96                .split([',', ';'])
97                .next()?
98                .split_whitespace()
99                .next()?
100                .trim_end_matches(',');
101
102            return Some(classify_module(module_name));
103        }
104    }
105
106    None
107}
108
109#[must_use]
110pub fn classify_module(module_name: &str) -> ImportType {
111    // Local imports start with . or ..
112    if module_name.starts_with('.') {
113        return ImportType::Local;
114    }
115
116    // Get the top-level package name
117    let top_level = module_name.split('.').next().unwrap_or(module_name);
118
119    if is_python_stdlib(top_level) {
120        ImportType::Stdlib
121    } else {
122        ImportType::ThirdParty
123    }
124}
125
126#[allow(clippy::too_many_lines)]
127#[must_use]
128pub fn is_python_stdlib(module: &str) -> bool {
129    matches!(
130        module,
131        "__future__"
132            | "__main__"
133            | "abc"
134            | "argparse"
135            | "array"
136            | "ast"
137            | "asyncio"
138            | "atexit"
139            | "base64"
140            | "bisect"
141            | "builtins"
142            | "bz2"
143            | "calendar"
144            | "cmath"
145            | "cmd"
146            | "code"
147            | "codecs"
148            | "collections"
149            | "concurrent"
150            | "configparser"
151            | "contextlib"
152            | "contextvars"
153            | "copy"
154            | "copyreg"
155            | "csv"
156            | "ctypes"
157            | "curses"
158            | "dataclasses"
159            | "datetime"
160            | "dbm"
161            | "decimal"
162            | "difflib"
163            | "dis"
164            | "distutils"
165            | "doctest"
166            | "email"
167            | "encodings"
168            | "enum"
169            | "errno"
170            | "fcntl"
171            | "filecmp"
172            | "fileinput"
173            | "fnmatch"
174            | "fractions"
175            | "functools"
176            | "gc"
177            | "getopt"
178            | "getpass"
179            | "gettext"
180            | "glob"
181            | "gzip"
182            | "hashlib"
183            | "heapq"
184            | "hmac"
185            | "html"
186            | "http"
187            | "importlib"
188            | "inspect"
189            | "io"
190            | "ipaddress"
191            | "itertools"
192            | "json"
193            | "keyword"
194            | "locale"
195            | "logging"
196            | "lzma"
197            | "marshal"
198            | "math"
199            | "mimetypes"
200            | "mmap"
201            | "multiprocessing"
202            | "numbers"
203            | "operator"
204            | "optparse"
205            | "os"
206            | "pathlib"
207            | "pdb"
208            | "pickle"
209            | "pipes"
210            | "pkgutil"
211            | "platform"
212            | "pprint"
213            | "profile"
214            | "pstats"
215            | "pwd"
216            | "py_compile"
217            | "pydoc"
218            | "queue"
219            | "random"
220            | "re"
221            | "readline"
222            | "reprlib"
223            | "resource"
224            | "runpy"
225            | "sched"
226            | "secrets"
227            | "select"
228            | "selectors"
229            | "shelve"
230            | "shlex"
231            | "shutil"
232            | "signal"
233            | "site"
234            | "smtplib"
235            | "socket"
236            | "sqlite3"
237            | "ssl"
238            | "stat"
239            | "statistics"
240            | "string"
241            | "struct"
242            | "subprocess"
243            | "sys"
244            | "sysconfig"
245            | "syslog"
246            | "tarfile"
247            | "tempfile"
248            | "test"
249            | "textwrap"
250            | "threading"
251            | "time"
252            | "timeit"
253            | "tkinter"
254            | "token"
255            | "tokenize"
256            | "trace"
257            | "traceback"
258            | "tracemalloc"
259            | "tty"
260            | "turtle"
261            | "types"
262            | "typing"
263            | "typing_extensions"
264            | "unittest"
265            | "urllib"
266            | "uuid"
267            | "venv"
268            | "warnings"
269            | "wave"
270            | "weakref"
271            | "webbrowser"
272            | "xml"
273            | "xmlrpc"
274            | "zipfile"
275            | "zlib"
276    )
277}