python_proto_importer/postprocess/
mod.rs1use anyhow::{Context, Result};
36use std::collections::BTreeSet;
37use std::fs;
38use std::path::{Path, PathBuf};
39use walkdir::WalkDir;
40
41pub mod apply;
42pub mod fds;
43pub mod rel_imports;
44
45pub fn add_pyright_header(root: &Path) -> Result<usize> {
74 use std::io::Write;
75 let mut modified = 0usize;
76 for entry in WalkDir::new(root).into_iter().filter_map(Result::ok) {
77 let p = entry.path();
78 if p.is_file() && p.extension().and_then(|s| s.to_str()) == Some("py") {
79 let name = p.file_name().and_then(|s| s.to_str()).unwrap_or("");
80 if !name.ends_with("_pb2.py") && !name.ends_with("_pb2_grpc.py") {
81 continue;
82 }
83 let content = fs::read_to_string(p).with_context(|| format!("read {}", p.display()))?;
84 let header = "# pyright: reportAttributeAccessIssue=false\n# This file is generated by grpcio-tools and may reference grpc.experimental which lacks stubs in types-grpcio.\n";
85 if content.starts_with(header) {
86 continue;
87 }
88 let mut f = fs::OpenOptions::new()
89 .write(true)
90 .truncate(true)
91 .open(p)
92 .with_context(|| format!("open {} for write", p.display()))?;
93 f.write_all(header.as_bytes())?;
94 f.write_all(content.as_bytes())?;
95 modified += 1;
96 }
97 }
98 Ok(modified)
99}
100
101pub fn create_packages(root: &Path) -> Result<usize> {
132 let mut dirs: BTreeSet<PathBuf> = BTreeSet::new();
133 for entry in WalkDir::new(root).into_iter().filter_map(Result::ok) {
134 let path = entry.path();
135 if path.is_dir() {
136 dirs.insert(path.to_path_buf());
137 }
138 }
139
140 let mut created = 0usize;
141 for dir in dirs {
142 let init_py = dir.join("__init__.py");
143 if !init_py.exists() {
144 fs::write(&init_py, b"")
145 .with_context(|| format!("failed to write {}", init_py.display()))?;
146 created += 1;
147 }
148 }
149 Ok(created)
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use std::fs;
156 use tempfile::tempdir;
157
158 #[test]
159 fn add_pyright_header_to_pb2_files() {
160 let dir = tempdir().unwrap();
161
162 let pb2_file = dir.path().join("service_pb2.py");
164 let grpc_file = dir.path().join("service_pb2_grpc.py");
165 let regular_file = dir.path().join("regular.py");
166
167 fs::write(&pb2_file, "# Generated code\nclass Service:\n pass\n").unwrap();
168 fs::write(
169 &grpc_file,
170 "# Generated gRPC code\nclass ServiceServicer:\n pass\n",
171 )
172 .unwrap();
173 fs::write(®ular_file, "# Regular Python file\nprint('hello')\n").unwrap();
174
175 let modified = add_pyright_header(dir.path()).unwrap();
176 assert_eq!(modified, 2); let pb2_content = fs::read_to_string(&pb2_file).unwrap();
180 assert!(pb2_content.starts_with("# pyright: reportAttributeAccessIssue=false"));
181 assert!(pb2_content.contains("# Generated code"));
182
183 let grpc_content = fs::read_to_string(&grpc_file).unwrap();
184 assert!(grpc_content.starts_with("# pyright: reportAttributeAccessIssue=false"));
185
186 let regular_content = fs::read_to_string(®ular_file).unwrap();
188 assert!(!regular_content.contains("pyright"));
189 }
190
191 #[test]
192 fn add_pyright_header_skips_existing() {
193 let dir = tempdir().unwrap();
194 let pb2_file = dir.path().join("service_pb2.py");
195
196 let existing_content = "# pyright: reportAttributeAccessIssue=false\n# This file is generated by grpcio-tools and may reference grpc.experimental which lacks stubs in types-grpcio.\n# Generated code\n";
198 fs::write(&pb2_file, existing_content).unwrap();
199
200 let modified = add_pyright_header(dir.path()).unwrap();
201 assert_eq!(modified, 0); let content = fs::read_to_string(&pb2_file).unwrap();
204 assert_eq!(content, existing_content); }
206
207 #[test]
208 fn add_pyright_header_nested_directories() {
209 let dir = tempdir().unwrap();
210 let nested_dir = dir.path().join("services");
211 fs::create_dir_all(&nested_dir).unwrap();
212
213 let pb2_file = nested_dir.join("api_pb2.py");
214 fs::write(&pb2_file, "# Generated code\n").unwrap();
215
216 let modified = add_pyright_header(dir.path()).unwrap();
217 assert_eq!(modified, 1);
218
219 let content = fs::read_to_string(&pb2_file).unwrap();
220 assert!(content.starts_with("# pyright: reportAttributeAccessIssue=false"));
221 }
222
223 #[test]
224 fn create_packages_all_directories() {
225 let dir = tempdir().unwrap();
226
227 let nested_dirs = ["services", "services/api", "services/auth", "common"];
229
230 for nested in &nested_dirs {
231 fs::create_dir_all(dir.path().join(nested)).unwrap();
232 }
233
234 let created = create_packages(dir.path()).unwrap();
235 assert_eq!(created, 5);
237
238 assert!(dir.path().join("__init__.py").exists());
240 assert!(dir.path().join("services/__init__.py").exists());
241 assert!(dir.path().join("services/api/__init__.py").exists());
242 assert!(dir.path().join("services/auth/__init__.py").exists());
243 assert!(dir.path().join("common/__init__.py").exists());
244 }
245
246 #[test]
247 fn create_packages_skips_existing() {
248 let dir = tempdir().unwrap();
249
250 let nested_dir = dir.path().join("services");
251 fs::create_dir_all(&nested_dir).unwrap();
252
253 fs::write(nested_dir.join("__init__.py"), "# Existing content").unwrap();
255
256 let created = create_packages(dir.path()).unwrap();
257 assert_eq!(created, 1);
259
260 let content = fs::read_to_string(nested_dir.join("__init__.py")).unwrap();
262 assert_eq!(content, "# Existing content");
263
264 let root_content = fs::read_to_string(dir.path().join("__init__.py")).unwrap();
266 assert_eq!(root_content, "");
267 }
268
269 #[test]
270 fn create_packages_empty_directory() {
271 let dir = tempdir().unwrap();
272 let created = create_packages(dir.path()).unwrap();
275 assert_eq!(created, 1);
276 assert!(dir.path().join("__init__.py").exists());
277 }
278
279 #[test]
280 fn add_pyright_header_file_extension_filtering() {
281 let dir = tempdir().unwrap();
282
283 fs::write(dir.path().join("service_pb2.py"), "# Python file").unwrap();
285 fs::write(dir.path().join("service_pb2.pyi"), "# Stub file").unwrap();
286 fs::write(dir.path().join("service_pb2.txt"), "# Text file").unwrap();
287 fs::write(dir.path().join("service.py"), "# Regular Python").unwrap();
288
289 let modified = add_pyright_header(dir.path()).unwrap();
290 assert_eq!(modified, 1);
292
293 let py_content = fs::read_to_string(dir.path().join("service_pb2.py")).unwrap();
294 assert!(py_content.starts_with("# pyright"));
295
296 let pyi_content = fs::read_to_string(dir.path().join("service_pb2.pyi")).unwrap();
297 assert!(!pyi_content.contains("pyright"));
298
299 let regular_content = fs::read_to_string(dir.path().join("service.py")).unwrap();
300 assert!(!regular_content.contains("pyright"));
301 }
302}