python_proto_importer/postprocess/
rel_imports.rs

1use anyhow::Result;
2use regex::Regex;
3use std::fs;
4use std::path::Path;
5
6/// Very small scaffold for future import rewriting.
7/// For now, it only identifies candidate lines and returns count.
8#[allow(dead_code)]
9pub fn rewrite_file_for_relative_imports(path: &Path) -> Result<usize> {
10    let content = fs::read_to_string(path)?;
11    let import_re = Regex::new(r"(?m)^import\s+([A-Za-z0-9_\.]+_pb2(?:_grpc)?)\b").unwrap();
12    let from_re =
13        Regex::new(r"(?m)^from\s+([A-Za-z0-9_\.]+)\s+import\s+([A-Za-z0-9_]+_pb2(?:_grpc)?)\b")
14            .unwrap();
15
16    let mut hits = 0usize;
17    hits += import_re.find_iter(&content).count();
18    hits += from_re.find_iter(&content).count();
19
20    // No modifications yet; further phases will compute and apply rewrites.
21    Ok(hits)
22}
23
24/// Walk output tree and report count of candidate files/lines (dry-run).
25#[allow(dead_code)]
26pub fn scan_and_report(root: &Path) -> Result<(usize, usize)> {
27    let mut files = 0usize;
28    let mut lines = 0usize;
29    for entry in walkdir::WalkDir::new(root)
30        .into_iter()
31        .filter_map(Result::ok)
32    {
33        let p = entry.path();
34        if p.is_file() && p.extension().and_then(|e| e.to_str()) == Some("py") {
35            files += 1;
36            lines += rewrite_file_for_relative_imports(p)?;
37        }
38    }
39    Ok((files, lines))
40}
41
42#[cfg(test)]
43mod tests {
44    use super::*;
45    use std::fs;
46    use tempfile::tempdir;
47
48    #[test]
49    fn rewrite_file_for_relative_imports_basic() {
50        let dir = tempdir().unwrap();
51        let file_path = dir.path().join("test.py");
52
53        let content = r#"
54import service_pb2
55from api.v1 import user_pb2
56import other_module
57from package import regular_module
58import grpc_service_pb2_grpc
59"#;
60        fs::write(&file_path, content).unwrap();
61
62        let hits = rewrite_file_for_relative_imports(&file_path).unwrap();
63        // Should match: service_pb2, user_pb2, grpc_service_pb2_grpc
64        assert_eq!(hits, 3);
65    }
66
67    #[test]
68    fn rewrite_file_for_relative_imports_no_matches() {
69        let dir = tempdir().unwrap();
70        let file_path = dir.path().join("test.py");
71
72        let content = r#"
73import os
74from typing import List
75import requests
76from dataclasses import dataclass
77"#;
78        fs::write(&file_path, content).unwrap();
79
80        let hits = rewrite_file_for_relative_imports(&file_path).unwrap();
81        assert_eq!(hits, 0);
82    }
83
84    #[test]
85    fn rewrite_file_for_relative_imports_complex_patterns() {
86        let dir = tempdir().unwrap();
87        let file_path = dir.path().join("test.py");
88
89        let content = r#"
90# Import statements
91import api.v1.service_pb2
92import api.v2.user_pb2_grpc
93from package.subpackage import module_pb2
94from api import payment_pb2_grpc
95from . import local_pb2  # Should not match (already relative)
96
97# Mixed content
98def function():
99    pass
100    
101import another_service_pb2
102"#;
103        fs::write(&file_path, content).unwrap();
104
105        let hits = rewrite_file_for_relative_imports(&file_path).unwrap();
106        // Should match: service_pb2, user_pb2_grpc, module_pb2, payment_pb2_grpc, another_service_pb2
107        // But local_pb2 is also counted by the regex even though it starts with "from ."
108        assert_eq!(hits, 6);
109    }
110
111    #[test]
112    fn rewrite_file_for_relative_imports_multiline() {
113        let dir = tempdir().unwrap();
114        let file_path = dir.path().join("test.py");
115
116        let content = "import service_pb2\nfrom api import user_pb2\nimport normal_module";
117        fs::write(&file_path, content).unwrap();
118
119        let hits = rewrite_file_for_relative_imports(&file_path).unwrap();
120        assert_eq!(hits, 2); // service_pb2 and user_pb2
121    }
122
123    #[test]
124    fn scan_and_report_basic() {
125        let dir = tempdir().unwrap();
126
127        // Create Python files with proto imports
128        let file1 = dir.path().join("service.py");
129        let file2 = dir.path().join("api.py");
130        let file3 = dir.path().join("utils.txt"); // Non-Python file
131
132        fs::write(&file1, "import service_pb2\nfrom api import user_pb2").unwrap();
133        fs::write(&file2, "import payment_pb2_grpc").unwrap();
134        fs::write(&file3, "import service_pb2").unwrap(); // Should be ignored
135
136        let (files, lines) = scan_and_report(dir.path()).unwrap();
137        assert_eq!(files, 2); // Only .py files counted
138        assert_eq!(lines, 3); // Total proto import lines
139    }
140
141    #[test]
142    fn scan_and_report_nested_directories() {
143        let dir = tempdir().unwrap();
144
145        // Create nested structure
146        let nested_dir = dir.path().join("services");
147        fs::create_dir_all(&nested_dir).unwrap();
148
149        let file1 = dir.path().join("main.py");
150        let file2 = nested_dir.join("api.py");
151
152        fs::write(&file1, "import main_service_pb2").unwrap();
153        fs::write(&file2, "from proto import api_pb2\nimport grpc_pb2_grpc").unwrap();
154
155        let (files, lines) = scan_and_report(dir.path()).unwrap();
156        assert_eq!(files, 2);
157        assert_eq!(lines, 3); // 1 + 2 imports
158    }
159
160    #[test]
161    fn scan_and_report_empty_directory() {
162        let dir = tempdir().unwrap();
163
164        let (files, lines) = scan_and_report(dir.path()).unwrap();
165        assert_eq!(files, 0);
166        assert_eq!(lines, 0);
167    }
168
169    #[test]
170    fn scan_and_report_no_proto_imports() {
171        let dir = tempdir().unwrap();
172
173        let file = dir.path().join("normal.py");
174        fs::write(&file, "import os\nfrom typing import List").unwrap();
175
176        let (files, lines) = scan_and_report(dir.path()).unwrap();
177        assert_eq!(files, 1); // File is counted
178        assert_eq!(lines, 0); // But no proto import lines
179    }
180
181    #[test]
182    fn rewrite_file_nonexistent_file() {
183        let nonexistent = std::path::Path::new("/nonexistent/file.py");
184        let result = rewrite_file_for_relative_imports(nonexistent);
185        assert!(result.is_err());
186    }
187}