python_proto_importer/postprocess/
rel_imports.rs1use anyhow::Result;
2use regex::Regex;
3use std::fs;
4use std::path::Path;
5
6#[allow(dead_code)]
9pub fn rewrite_file_for_relative_imports(path: &Path) -> Result<usize> {
10 let content = fs::read_to_string(path)?;
11 let import_re = Regex::new(r"(?m)^import\s+([A-Za-z0-9_\.]+_pb2(?:_grpc)?)\b").unwrap();
12 let from_re =
13 Regex::new(r"(?m)^from\s+([A-Za-z0-9_\.]+)\s+import\s+([A-Za-z0-9_]+_pb2(?:_grpc)?)\b")
14 .unwrap();
15
16 let mut hits = 0usize;
17 hits += import_re.find_iter(&content).count();
18 hits += from_re.find_iter(&content).count();
19
20 Ok(hits)
22}
23
24#[allow(dead_code)]
26pub fn scan_and_report(root: &Path) -> Result<(usize, usize)> {
27 let mut files = 0usize;
28 let mut lines = 0usize;
29 for entry in walkdir::WalkDir::new(root)
30 .into_iter()
31 .filter_map(Result::ok)
32 {
33 let p = entry.path();
34 if p.is_file() && p.extension().and_then(|e| e.to_str()) == Some("py") {
35 files += 1;
36 lines += rewrite_file_for_relative_imports(p)?;
37 }
38 }
39 Ok((files, lines))
40}
41
42#[cfg(test)]
43mod tests {
44 use super::*;
45 use std::fs;
46 use tempfile::tempdir;
47
48 #[test]
49 fn rewrite_file_for_relative_imports_basic() {
50 let dir = tempdir().unwrap();
51 let file_path = dir.path().join("test.py");
52
53 let content = r#"
54import service_pb2
55from api.v1 import user_pb2
56import other_module
57from package import regular_module
58import grpc_service_pb2_grpc
59"#;
60 fs::write(&file_path, content).unwrap();
61
62 let hits = rewrite_file_for_relative_imports(&file_path).unwrap();
63 assert_eq!(hits, 3);
65 }
66
67 #[test]
68 fn rewrite_file_for_relative_imports_no_matches() {
69 let dir = tempdir().unwrap();
70 let file_path = dir.path().join("test.py");
71
72 let content = r#"
73import os
74from typing import List
75import requests
76from dataclasses import dataclass
77"#;
78 fs::write(&file_path, content).unwrap();
79
80 let hits = rewrite_file_for_relative_imports(&file_path).unwrap();
81 assert_eq!(hits, 0);
82 }
83
84 #[test]
85 fn rewrite_file_for_relative_imports_complex_patterns() {
86 let dir = tempdir().unwrap();
87 let file_path = dir.path().join("test.py");
88
89 let content = r#"
90# Import statements
91import api.v1.service_pb2
92import api.v2.user_pb2_grpc
93from package.subpackage import module_pb2
94from api import payment_pb2_grpc
95from . import local_pb2 # Should not match (already relative)
96
97# Mixed content
98def function():
99 pass
100
101import another_service_pb2
102"#;
103 fs::write(&file_path, content).unwrap();
104
105 let hits = rewrite_file_for_relative_imports(&file_path).unwrap();
106 assert_eq!(hits, 6);
109 }
110
111 #[test]
112 fn rewrite_file_for_relative_imports_multiline() {
113 let dir = tempdir().unwrap();
114 let file_path = dir.path().join("test.py");
115
116 let content = "import service_pb2\nfrom api import user_pb2\nimport normal_module";
117 fs::write(&file_path, content).unwrap();
118
119 let hits = rewrite_file_for_relative_imports(&file_path).unwrap();
120 assert_eq!(hits, 2); }
122
123 #[test]
124 fn scan_and_report_basic() {
125 let dir = tempdir().unwrap();
126
127 let file1 = dir.path().join("service.py");
129 let file2 = dir.path().join("api.py");
130 let file3 = dir.path().join("utils.txt"); fs::write(&file1, "import service_pb2\nfrom api import user_pb2").unwrap();
133 fs::write(&file2, "import payment_pb2_grpc").unwrap();
134 fs::write(&file3, "import service_pb2").unwrap(); let (files, lines) = scan_and_report(dir.path()).unwrap();
137 assert_eq!(files, 2); assert_eq!(lines, 3); }
140
141 #[test]
142 fn scan_and_report_nested_directories() {
143 let dir = tempdir().unwrap();
144
145 let nested_dir = dir.path().join("services");
147 fs::create_dir_all(&nested_dir).unwrap();
148
149 let file1 = dir.path().join("main.py");
150 let file2 = nested_dir.join("api.py");
151
152 fs::write(&file1, "import main_service_pb2").unwrap();
153 fs::write(&file2, "from proto import api_pb2\nimport grpc_pb2_grpc").unwrap();
154
155 let (files, lines) = scan_and_report(dir.path()).unwrap();
156 assert_eq!(files, 2);
157 assert_eq!(lines, 3); }
159
160 #[test]
161 fn scan_and_report_empty_directory() {
162 let dir = tempdir().unwrap();
163
164 let (files, lines) = scan_and_report(dir.path()).unwrap();
165 assert_eq!(files, 0);
166 assert_eq!(lines, 0);
167 }
168
169 #[test]
170 fn scan_and_report_no_proto_imports() {
171 let dir = tempdir().unwrap();
172
173 let file = dir.path().join("normal.py");
174 fs::write(&file, "import os\nfrom typing import List").unwrap();
175
176 let (files, lines) = scan_and_report(dir.path()).unwrap();
177 assert_eq!(files, 1); assert_eq!(lines, 0); }
180
181 #[test]
182 fn rewrite_file_nonexistent_file() {
183 let nonexistent = std::path::Path::new("/nonexistent/file.py");
184 let result = rewrite_file_for_relative_imports(nonexistent);
185 assert!(result.is_err());
186 }
187}