python_proto_importer/
config.rs

1use anyhow::{Context, Result, bail};
2use serde::Deserialize;
3use std::fs;
4use std::path::{Path, PathBuf};
5
6/// Code generation backend selection.
7///
8/// Determines which tool will be used to generate Python code from proto files.
9#[derive(Debug, Clone, Copy)]
10pub enum Backend {
11    /// Use the standard protoc compiler for code generation.
12    /// This is the currently supported and default backend.
13    Protoc,
14    /// Use buf generate for code generation.
15    /// This backend is planned for future versions but not yet implemented.
16    Buf,
17}
18
19/// Main application configuration parsed from pyproject.toml.
20///
21/// Contains all settings needed to run the proto-to-Python code generation
22/// pipeline, including backend selection, file paths, and processing options.
23#[allow(dead_code)]
24#[derive(Debug, Clone)]
25pub struct AppConfig {
26    /// Backend to use for code generation (protoc or buf).
27    pub backend: Backend,
28    /// Python executable to use for generation and verification.
29    /// Can be "python3", "python", "uv", or a custom path.
30    pub python_exe: String,
31    /// Proto import paths (passed as --proto_path to protoc).
32    /// These directories are searched for proto files and their dependencies.
33    pub include: Vec<PathBuf>,
34    /// Glob patterns for proto files to compile.
35    /// Only files matching these patterns will be processed.
36    pub inputs: Vec<String>,
37    /// Output directory for generated Python files.
38    pub out: PathBuf,
39    /// Whether to generate mypy type stubs (.pyi files) using mypy-protobuf.
40    pub generate_mypy: bool,
41    /// Whether to generate gRPC mypy stubs (_grpc.pyi files) using mypy-grpc.
42    pub generate_mypy_grpc: bool,
43    /// Post-processing configuration options.
44    pub postprocess: PostProcess,
45    /// Optional verification configuration (type checking commands).
46    pub verify: Option<Verify>,
47}
48
49/// Post-processing configuration options.
50///
51/// Controls how generated files are transformed after initial generation,
52/// including import rewriting, package structure creation, and header addition.
53#[allow(dead_code)]
54#[derive(Debug, Clone)]
55pub struct PostProcess {
56    /// Convert absolute imports to relative imports within generated files.
57    pub relative_imports: bool,
58    /// Fix type annotations in .pyi files (reserved for future use).
59    pub fix_pyi: bool,
60    /// Create __init__.py files in all directories to make packages importable.
61    /// Set to false for namespace packages (PEP 420).
62    pub create_package: bool,
63    /// Exclude google.protobuf imports from relative import conversion.
64    pub exclude_google: bool,
65    /// Add Pyright suppression headers to generated _pb2.py and _pb2_grpc.py files.
66    pub pyright_header: bool,
67    /// File suffixes to process during post-processing.
68    /// Default includes _pb2.py, _pb2.pyi, _pb2_grpc.py, _pb2_grpc.pyi.
69    pub module_suffixes: Vec<String>,
70}
71
72/// Verification configuration for optional type checking.
73///
74/// Specifies commands to run for validating generated code quality,
75/// typically mypy and/or pyright type checkers.
76#[allow(dead_code)]
77#[derive(Debug, Clone)]
78pub struct Verify {
79    /// Command to run mypy type checking. If None, mypy verification is skipped.
80    /// Example: ["mypy", "--strict", "generated"]
81    pub mypy_cmd: Option<Vec<String>>,
82    /// Command to run pyright type checking. If None, pyright verification is skipped.
83    /// Example: ["pyright", "generated/**/*.pyi"]
84    pub pyright_cmd: Option<Vec<String>>,
85}
86
87// --- Raw TOML structures ---
88#[derive(Deserialize)]
89struct PyProject {
90    tool: Option<ToolSection>,
91}
92
93#[derive(Deserialize)]
94struct ToolSection {
95    #[serde(rename = "python_proto_importer")]
96    python_proto_importer: Option<ImporterRoot>,
97}
98
99#[derive(Deserialize)]
100struct ImporterRoot {
101    #[serde(flatten)]
102    core: ImporterCore,
103    verify: Option<VerifyToml>,
104}
105
106#[allow(dead_code)]
107#[derive(Deserialize)]
108struct ImporterCore {
109    backend: Option<String>,
110    python_exe: Option<String>,
111    include: Option<Vec<String>>, // paths/globs
112    inputs: Option<Vec<String>>,  // globs
113    out: Option<String>,
114    mypy: Option<bool>,
115    mypy_grpc: Option<bool>,
116    buf_gen_yaml: Option<String>,
117    postprocess: Option<PostProcessToml>,
118}
119
120#[allow(dead_code)]
121#[derive(Deserialize)]
122struct PostProcessToml {
123    relative_imports: Option<bool>,
124    fix_pyi: Option<bool>,
125    create_package: Option<bool>,
126    exclude_google: Option<bool>,
127    pyright_header: Option<bool>,
128    module_suffixes: Option<Vec<String>>,
129}
130
131#[allow(dead_code)]
132#[derive(Deserialize)]
133struct VerifyToml {
134    mypy_cmd: Option<Vec<String>>,
135    pyright_cmd: Option<Vec<String>>,
136}
137
138impl AppConfig {
139    /// Load configuration from a pyproject.toml file.
140    ///
141    /// Parses the TOML configuration file and validates the settings,
142    /// applying defaults where values are not specified.
143    ///
144    /// # Arguments
145    ///
146    /// * `pyproject_path` - Optional path to the pyproject.toml file.
147    ///   If None, looks for "pyproject.toml" in the current directory.
148    ///
149    /// # Returns
150    ///
151    /// Returns the parsed and validated configuration, or an error if:
152    /// - The file cannot be read
153    /// - The TOML is malformed
154    /// - Required configuration sections are missing
155    /// - Configuration values are invalid
156    ///
157    /// # Example
158    ///
159    /// ```no_run
160    /// use python_proto_importer::config::AppConfig;
161    /// use std::path::Path;
162    ///
163    /// // Load from default location
164    /// let config = AppConfig::load(None)?;
165    ///
166    /// // Load from custom path
167    /// let config = AppConfig::load(Some(Path::new("custom.toml")))?;
168    /// # Ok::<(), anyhow::Error>(())
169    /// ```
170    pub fn load(pyproject_path: Option<&Path>) -> Result<Self> {
171        let path = match pyproject_path {
172            Some(p) => p.to_path_buf(),
173            None => PathBuf::from("pyproject.toml"),
174        };
175        let content = fs::read_to_string(&path)
176            .with_context(|| format!("failed to read {}", path.display()))?;
177        let root: PyProject = toml::from_str(&content).context("failed to parse pyproject.toml")?;
178        let Some(tool) = root.tool else {
179            bail!("[tool.python_proto_importer] not found");
180        };
181        let Some(importer) = tool.python_proto_importer else {
182            bail!("[tool.python_proto_importer] not found");
183        };
184
185        let backend = match importer
186            .core
187            .backend
188            .as_deref()
189            .unwrap_or("protoc")
190            .to_lowercase()
191            .as_str()
192        {
193            "protoc" => Backend::Protoc,
194            "buf" => Backend::Buf,
195            other => bail!("unsupported backend: {}", other),
196        };
197
198        let python_exe = importer
199            .core
200            .python_exe
201            .unwrap_or_else(|| "python3".to_string());
202        let mut include = importer
203            .core
204            .include
205            .unwrap_or_default()
206            .into_iter()
207            .map(PathBuf::from)
208            .collect::<Vec<_>>();
209
210        // If include is empty, use current directory as default
211        if include.is_empty() {
212            include.push(PathBuf::from("."));
213        }
214        let inputs = importer.core.inputs.unwrap_or_default();
215        let out = importer
216            .core
217            .out
218            .map(PathBuf::from)
219            .unwrap_or_else(|| PathBuf::from("generated/python"));
220
221        let generate_mypy = importer.core.mypy.unwrap_or(false);
222        let generate_mypy_grpc = importer.core.mypy_grpc.unwrap_or(false);
223
224        let pp = importer.core.postprocess.unwrap_or(PostProcessToml {
225            relative_imports: Some(true),
226            fix_pyi: Some(true),
227            create_package: Some(true),
228            exclude_google: Some(true),
229            pyright_header: Some(false),
230            module_suffixes: None,
231        });
232        let postprocess = PostProcess {
233            relative_imports: pp.relative_imports.unwrap_or(true),
234            fix_pyi: pp.fix_pyi.unwrap_or(true),
235            create_package: pp.create_package.unwrap_or(true),
236            exclude_google: pp.exclude_google.unwrap_or(true),
237            pyright_header: pp.pyright_header.unwrap_or(false),
238            module_suffixes: pp.module_suffixes.unwrap_or_else(|| {
239                vec![
240                    "_pb2.py".into(),
241                    "_pb2.pyi".into(),
242                    "_pb2_grpc.py".into(),
243                    "_pb2_grpc.pyi".into(),
244                ]
245            }),
246        };
247
248        let verify = importer.verify.map(|v| Verify {
249            mypy_cmd: v.mypy_cmd,
250            pyright_cmd: v.pyright_cmd,
251        });
252
253        Ok(Self {
254            backend,
255            python_exe,
256            include,
257            inputs,
258            out,
259            generate_mypy,
260            generate_mypy_grpc,
261            postprocess,
262            verify,
263        })
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use std::fs;
271    use tempfile::tempdir;
272
273    #[test]
274    fn load_minimal_config() {
275        let dir = tempdir().unwrap();
276        let config_path = dir.path().join("pyproject.toml");
277        fs::write(
278            &config_path,
279            r#"
280[tool.python_proto_importer]
281inputs = ["proto/**/*.proto"]
282"#,
283        )
284        .unwrap();
285
286        let config = AppConfig::load(Some(&config_path)).unwrap();
287
288        assert!(matches!(config.backend, Backend::Protoc));
289        assert_eq!(config.python_exe, "python3");
290        assert_eq!(config.include, vec![PathBuf::from(".")]);
291        assert_eq!(config.inputs, vec!["proto/**/*.proto"]);
292        assert_eq!(config.out, PathBuf::from("generated/python"));
293        assert!(!config.generate_mypy);
294        assert!(!config.generate_mypy_grpc);
295        assert!(config.postprocess.relative_imports);
296        assert!(config.postprocess.fix_pyi);
297        assert!(config.postprocess.create_package);
298        assert!(config.postprocess.exclude_google);
299        assert!(!config.postprocess.pyright_header);
300        assert_eq!(
301            config.postprocess.module_suffixes,
302            vec!["_pb2.py", "_pb2.pyi", "_pb2_grpc.py", "_pb2_grpc.pyi"]
303        );
304        assert!(config.verify.is_none());
305    }
306
307    #[test]
308    fn load_full_config() {
309        let dir = tempdir().unwrap();
310        let config_path = dir.path().join("pyproject.toml");
311        fs::write(
312            &config_path,
313            r#"
314[tool.python_proto_importer]
315backend = "buf"
316python_exe = "uv"
317include = ["proto", "common"]
318inputs = ["proto/**/*.proto", "common/**/*.proto"]
319out = "src/generated"
320mypy = true
321mypy_grpc = true
322
323[tool.python_proto_importer.postprocess]
324relative_imports = false
325fix_pyi = false
326create_package = false
327exclude_google = false
328pyright_header = true
329module_suffixes = ["_pb2.py", "_grpc.py"]
330
331[tool.python_proto_importer.verify]
332mypy_cmd = ["mypy", "--strict"]
333pyright_cmd = ["pyright", "generated"]
334"#,
335        )
336        .unwrap();
337
338        let config = AppConfig::load(Some(&config_path)).unwrap();
339
340        assert!(matches!(config.backend, Backend::Buf));
341        assert_eq!(config.python_exe, "uv");
342        assert_eq!(
343            config.include,
344            vec![PathBuf::from("proto"), PathBuf::from("common")]
345        );
346        assert_eq!(config.inputs, vec!["proto/**/*.proto", "common/**/*.proto"]);
347        assert_eq!(config.out, PathBuf::from("src/generated"));
348        assert!(config.generate_mypy);
349        assert!(config.generate_mypy_grpc);
350        assert!(!config.postprocess.relative_imports);
351        assert!(!config.postprocess.fix_pyi);
352        assert!(!config.postprocess.create_package);
353        assert!(!config.postprocess.exclude_google);
354        assert!(config.postprocess.pyright_header);
355        assert_eq!(
356            config.postprocess.module_suffixes,
357            vec!["_pb2.py", "_grpc.py"]
358        );
359
360        let verify = config.verify.unwrap();
361        assert_eq!(verify.mypy_cmd.unwrap(), vec!["mypy", "--strict"]);
362        assert_eq!(verify.pyright_cmd.unwrap(), vec!["pyright", "generated"]);
363    }
364
365    #[test]
366    fn load_empty_include_defaults_to_current_dir() {
367        let dir = tempdir().unwrap();
368        let config_path = dir.path().join("pyproject.toml");
369        fs::write(
370            &config_path,
371            r#"
372[tool.python_proto_importer]
373inputs = ["proto/**/*.proto"]
374include = []
375"#,
376        )
377        .unwrap();
378
379        let config = AppConfig::load(Some(&config_path)).unwrap();
380        assert_eq!(config.include, vec![PathBuf::from(".")]);
381    }
382
383    #[test]
384    fn backend_case_insensitive() {
385        let dir = tempdir().unwrap();
386        let config_path = dir.path().join("pyproject.toml");
387        fs::write(
388            &config_path,
389            r#"
390[tool.python_proto_importer]
391backend = "PROTOC"
392inputs = ["proto/**/*.proto"]
393"#,
394        )
395        .unwrap();
396
397        let config = AppConfig::load(Some(&config_path)).unwrap();
398        assert!(matches!(config.backend, Backend::Protoc));
399    }
400
401    #[test]
402    fn unsupported_backend_fails() {
403        let dir = tempdir().unwrap();
404        let config_path = dir.path().join("pyproject.toml");
405        fs::write(
406            &config_path,
407            r#"
408[tool.python_proto_importer]
409backend = "unsupported"
410inputs = ["proto/**/*.proto"]
411"#,
412        )
413        .unwrap();
414
415        let result = AppConfig::load(Some(&config_path));
416        assert!(result.is_err());
417        assert!(
418            result
419                .unwrap_err()
420                .to_string()
421                .contains("unsupported backend")
422        );
423    }
424
425    #[test]
426    fn missing_config_section_fails() {
427        let dir = tempdir().unwrap();
428        let config_path = dir.path().join("pyproject.toml");
429        fs::write(
430            &config_path,
431            r#"
432[tool.other_tool]
433something = "value"
434"#,
435        )
436        .unwrap();
437
438        let result = AppConfig::load(Some(&config_path));
439        assert!(result.is_err());
440        assert!(
441            result
442                .unwrap_err()
443                .to_string()
444                .contains("[tool.python_proto_importer] not found")
445        );
446    }
447
448    #[test]
449    fn missing_file_fails() {
450        let result = AppConfig::load(Some(&PathBuf::from("nonexistent.toml")));
451        assert!(result.is_err());
452        assert!(result.unwrap_err().to_string().contains("failed to read"));
453    }
454
455    #[test]
456    fn invalid_toml_fails() {
457        let dir = tempdir().unwrap();
458        let config_path = dir.path().join("pyproject.toml");
459        fs::write(
460            &config_path,
461            r#"
462[tool.python_proto_importer
463# Missing closing bracket
464inputs = ["proto/**/*.proto"]
465"#,
466        )
467        .unwrap();
468
469        let result = AppConfig::load(Some(&config_path));
470        assert!(result.is_err());
471        assert!(result.unwrap_err().to_string().contains("failed to parse"));
472    }
473
474    #[test]
475    fn load_default_path() {
476        let dir = tempdir().unwrap();
477        let original_dir = std::env::current_dir().unwrap();
478        std::env::set_current_dir(&dir).unwrap();
479
480        let config_path = dir.path().join("pyproject.toml");
481        fs::write(
482            &config_path,
483            r#"
484[tool.python_proto_importer]
485inputs = ["proto/**/*.proto"]
486"#,
487        )
488        .unwrap();
489
490        let config = AppConfig::load(None).unwrap();
491        assert_eq!(config.inputs, vec!["proto/**/*.proto"]);
492
493        std::env::set_current_dir(&original_dir).unwrap();
494    }
495
496    #[test]
497    fn verify_section_optional() {
498        let dir = tempdir().unwrap();
499        let config_path = dir.path().join("pyproject.toml");
500        fs::write(
501            &config_path,
502            r#"
503[tool.python_proto_importer]
504inputs = ["proto/**/*.proto"]
505
506[tool.python_proto_importer.verify]
507mypy_cmd = ["mypy"]
508# pyright_cmd intentionally omitted
509"#,
510        )
511        .unwrap();
512
513        let config = AppConfig::load(Some(&config_path)).unwrap();
514        let verify = config.verify.unwrap();
515        assert_eq!(verify.mypy_cmd.unwrap(), vec!["mypy"]);
516        assert!(verify.pyright_cmd.is_none());
517    }
518}