python_proto_importer/postprocess/
mod.rs

1//! Post-processing modules for generated Python protobuf code.
2//!
3//! This module provides utilities to transform and enhance generated Python code
4//! after the initial protoc/buf generation phase. The post-processing steps ensure
5//! the generated code integrates seamlessly with your project structure and
6//! follows Python best practices.
7//!
8//! # Available Post-processors
9//!
10//! - **Import Rewriting** ([`apply`]): Converts absolute imports to relative imports
11//! - **Package Creation** ([`create_packages`]): Automatically creates `__init__.py` files
12//! - **Type Checker Headers** ([`add_pyright_header`]): Adds suppression headers for type checkers
13//! - **FileDescriptorSet Processing** ([`fds`]): Extracts metadata from protoc output
14//! - **Import Analysis** ([`rel_imports`]): Scans and reports import conversion opportunities
15//!
16//! # Post-processing Pipeline
17//!
18//! ```no_run
19//! use python_proto_importer::postprocess::{create_packages, add_pyright_header};
20//! use std::path::Path;
21//!
22//! let output_dir = Path::new("generated");
23//!
24//! // 1. Create __init__.py files for Python package structure
25//! let packages_created = create_packages(output_dir)?;
26//! println!("Created {} __init__.py files", packages_created);
27//!
28//! // 2. Add type checker suppression headers
29//! let headers_added = add_pyright_header(output_dir)?;
30//! println!("Added headers to {} files", headers_added);
31//!
32//! # Ok::<(), anyhow::Error>(())
33//! ```
34
35use anyhow::{Context, Result};
36use std::collections::BTreeSet;
37use std::fs;
38use std::path::{Path, PathBuf};
39use walkdir::WalkDir;
40
41pub mod apply;
42pub mod fds;
43pub mod rel_imports;
44
45/// Add Pyright suppression headers to generated Python protobuf files.
46///
47/// This function adds type checker suppression headers to generated `_pb2.py` and `_pb2_grpc.py`
48/// files. These headers help suppress false positive warnings from type checkers when working
49/// with dynamically generated protobuf code that may reference experimental APIs.
50///
51/// # Arguments
52///
53/// * `root` - Root directory to recursively scan for protobuf Python files
54///
55/// # Returns
56///
57/// Returns the number of files that were modified with headers.
58///
59/// # Behavior
60///
61/// - Only modifies files ending with `_pb2.py` or `_pb2_grpc.py`
62/// - Skips files that already have the suppression header
63/// - Recursively processes all subdirectories
64/// - Preserves existing file content, only prepending the header
65///
66/// # Header Content
67///
68/// Adds the following header to suppress common type checker issues:
69/// ```python
70/// # pyright: reportAttributeAccessIssue=false
71/// # This file is generated by grpcio-tools and may reference grpc.experimental which lacks stubs in types-grpcio.
72/// ```
73pub fn add_pyright_header(root: &Path) -> Result<usize> {
74    use std::io::Write;
75    let mut modified = 0usize;
76    for entry in WalkDir::new(root).into_iter().filter_map(Result::ok) {
77        let p = entry.path();
78        if p.is_file() && p.extension().and_then(|s| s.to_str()) == Some("py") {
79            let name = p.file_name().and_then(|s| s.to_str()).unwrap_or("");
80            if !name.ends_with("_pb2.py") && !name.ends_with("_pb2_grpc.py") {
81                continue;
82            }
83            let content = fs::read_to_string(p).with_context(|| format!("read {}", p.display()))?;
84            let header = "# pyright: reportAttributeAccessIssue=false\n# This file is generated by grpcio-tools and may reference grpc.experimental which lacks stubs in types-grpcio.\n";
85            if content.starts_with(header) {
86                continue;
87            }
88            let mut f = fs::OpenOptions::new()
89                .write(true)
90                .truncate(true)
91                .open(p)
92                .with_context(|| format!("open {} for write", p.display()))?;
93            f.write_all(header.as_bytes())?;
94            f.write_all(content.as_bytes())?;
95            modified += 1;
96        }
97    }
98    Ok(modified)
99}
100
101/// Create `__init__.py` files for all directories in the output tree.
102///
103/// This function ensures that all directories in the generated output have `__init__.py`
104/// files, making them proper Python packages. This is essential for Python's import
105/// system to recognize directories as packages and enables proper module imports.
106///
107/// # Arguments
108///
109/// * `root` - Root directory to recursively process for package creation
110///
111/// # Returns
112///
113/// Returns the number of `__init__.py` files that were created.
114///
115/// # Behavior
116///
117/// - Recursively scans all directories under the root path
118/// - Creates empty `__init__.py` files in directories that don't have them
119/// - Skips directories that already have `__init__.py` files
120/// - Uses `BTreeSet` for consistent ordering of directory processing
121///
122/// # Package Types
123///
124/// This creates regular packages (with `__init__.py`). For namespace packages
125/// (PEP 420), disable package creation in your configuration:
126///
127/// ```toml
128/// [tool.python_proto_importer.postprocess]
129/// create_package = false
130/// ```
131pub fn create_packages(root: &Path) -> Result<usize> {
132    let mut dirs: BTreeSet<PathBuf> = BTreeSet::new();
133    for entry in WalkDir::new(root).into_iter().filter_map(Result::ok) {
134        let path = entry.path();
135        if path.is_dir() {
136            dirs.insert(path.to_path_buf());
137        }
138    }
139
140    let mut created = 0usize;
141    for dir in dirs {
142        let init_py = dir.join("__init__.py");
143        if !init_py.exists() {
144            fs::write(&init_py, b"")
145                .with_context(|| format!("failed to write {}", init_py.display()))?;
146            created += 1;
147        }
148    }
149    Ok(created)
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use std::fs;
156    use tempfile::tempdir;
157
158    #[test]
159    fn add_pyright_header_to_pb2_files() {
160        let dir = tempdir().unwrap();
161
162        // Create test files
163        let pb2_file = dir.path().join("service_pb2.py");
164        let grpc_file = dir.path().join("service_pb2_grpc.py");
165        let regular_file = dir.path().join("regular.py");
166
167        fs::write(&pb2_file, "# Generated code\nclass Service:\n    pass\n").unwrap();
168        fs::write(
169            &grpc_file,
170            "# Generated gRPC code\nclass ServiceServicer:\n    pass\n",
171        )
172        .unwrap();
173        fs::write(&regular_file, "# Regular Python file\nprint('hello')\n").unwrap();
174
175        let modified = add_pyright_header(dir.path()).unwrap();
176        assert_eq!(modified, 2); // Only pb2 and grpc files should be modified
177
178        // Verify headers were added
179        let pb2_content = fs::read_to_string(&pb2_file).unwrap();
180        assert!(pb2_content.starts_with("# pyright: reportAttributeAccessIssue=false"));
181        assert!(pb2_content.contains("# Generated code"));
182
183        let grpc_content = fs::read_to_string(&grpc_file).unwrap();
184        assert!(grpc_content.starts_with("# pyright: reportAttributeAccessIssue=false"));
185
186        // Regular file should be unchanged
187        let regular_content = fs::read_to_string(&regular_file).unwrap();
188        assert!(!regular_content.contains("pyright"));
189    }
190
191    #[test]
192    fn add_pyright_header_skips_existing() {
193        let dir = tempdir().unwrap();
194        let pb2_file = dir.path().join("service_pb2.py");
195
196        // File already has header
197        let existing_content = "# pyright: reportAttributeAccessIssue=false\n# This file is generated by grpcio-tools and may reference grpc.experimental which lacks stubs in types-grpcio.\n# Generated code\n";
198        fs::write(&pb2_file, existing_content).unwrap();
199
200        let modified = add_pyright_header(dir.path()).unwrap();
201        assert_eq!(modified, 0); // Should skip files that already have header
202
203        let content = fs::read_to_string(&pb2_file).unwrap();
204        assert_eq!(content, existing_content); // Content should be unchanged
205    }
206
207    #[test]
208    fn add_pyright_header_nested_directories() {
209        let dir = tempdir().unwrap();
210        let nested_dir = dir.path().join("services");
211        fs::create_dir_all(&nested_dir).unwrap();
212
213        let pb2_file = nested_dir.join("api_pb2.py");
214        fs::write(&pb2_file, "# Generated code\n").unwrap();
215
216        let modified = add_pyright_header(dir.path()).unwrap();
217        assert_eq!(modified, 1);
218
219        let content = fs::read_to_string(&pb2_file).unwrap();
220        assert!(content.starts_with("# pyright: reportAttributeAccessIssue=false"));
221    }
222
223    #[test]
224    fn create_packages_all_directories() {
225        let dir = tempdir().unwrap();
226
227        // Create nested directory structure
228        let nested_dirs = ["services", "services/api", "services/auth", "common"];
229
230        for nested in &nested_dirs {
231            fs::create_dir_all(dir.path().join(nested)).unwrap();
232        }
233
234        let created = create_packages(dir.path()).unwrap();
235        // Should create __init__.py in root + 4 nested directories = 5 total
236        assert_eq!(created, 5);
237
238        // Verify all __init__.py files exist
239        assert!(dir.path().join("__init__.py").exists());
240        assert!(dir.path().join("services/__init__.py").exists());
241        assert!(dir.path().join("services/api/__init__.py").exists());
242        assert!(dir.path().join("services/auth/__init__.py").exists());
243        assert!(dir.path().join("common/__init__.py").exists());
244    }
245
246    #[test]
247    fn create_packages_skips_existing() {
248        let dir = tempdir().unwrap();
249
250        let nested_dir = dir.path().join("services");
251        fs::create_dir_all(&nested_dir).unwrap();
252
253        // Pre-create one __init__.py file
254        fs::write(nested_dir.join("__init__.py"), "# Existing content").unwrap();
255
256        let created = create_packages(dir.path()).unwrap();
257        // Should only create __init__.py in root directory
258        assert_eq!(created, 1);
259
260        // Verify existing file content is preserved
261        let content = fs::read_to_string(nested_dir.join("__init__.py")).unwrap();
262        assert_eq!(content, "# Existing content");
263
264        // New file should be empty
265        let root_content = fs::read_to_string(dir.path().join("__init__.py")).unwrap();
266        assert_eq!(root_content, "");
267    }
268
269    #[test]
270    fn create_packages_empty_directory() {
271        let dir = tempdir().unwrap();
272        // Empty directory should still get __init__.py
273
274        let created = create_packages(dir.path()).unwrap();
275        assert_eq!(created, 1);
276        assert!(dir.path().join("__init__.py").exists());
277    }
278
279    #[test]
280    fn add_pyright_header_file_extension_filtering() {
281        let dir = tempdir().unwrap();
282
283        // Create files with various extensions
284        fs::write(dir.path().join("service_pb2.py"), "# Python file").unwrap();
285        fs::write(dir.path().join("service_pb2.pyi"), "# Stub file").unwrap();
286        fs::write(dir.path().join("service_pb2.txt"), "# Text file").unwrap();
287        fs::write(dir.path().join("service.py"), "# Regular Python").unwrap();
288
289        let modified = add_pyright_header(dir.path()).unwrap();
290        // Only .py files with correct naming should be modified
291        assert_eq!(modified, 1);
292
293        let py_content = fs::read_to_string(dir.path().join("service_pb2.py")).unwrap();
294        assert!(py_content.starts_with("# pyright"));
295
296        let pyi_content = fs::read_to_string(dir.path().join("service_pb2.pyi")).unwrap();
297        assert!(!pyi_content.contains("pyright"));
298
299        let regular_content = fs::read_to_string(dir.path().join("service.py")).unwrap();
300        assert!(!regular_content.contains("pyright"));
301    }
302}