1use anyhow::{Context, Result, bail};
2use serde::Deserialize;
3use std::fs;
4use std::path::{Path, PathBuf};
5
6#[derive(Debug, Clone, Copy)]
10pub enum Backend {
11 Protoc,
14 Buf,
17}
18
19#[allow(dead_code)]
24#[derive(Debug, Clone)]
25pub struct AppConfig {
26 pub backend: Backend,
28 pub python_exe: String,
31 pub include: Vec<PathBuf>,
34 pub inputs: Vec<String>,
37 pub out: PathBuf,
39 pub generate_mypy: bool,
41 pub generate_mypy_grpc: bool,
43 pub postprocess: PostProcess,
45 pub verify: Option<Verify>,
47}
48
49#[allow(dead_code)]
54#[derive(Debug, Clone)]
55pub struct PostProcess {
56 pub relative_imports: bool,
58 pub fix_pyi: bool,
60 pub create_package: bool,
63 pub exclude_google: bool,
65 pub pyright_header: bool,
67 pub module_suffixes: Vec<String>,
70}
71
72#[allow(dead_code)]
77#[derive(Debug, Clone)]
78pub struct Verify {
79 pub mypy_cmd: Option<Vec<String>>,
82 pub pyright_cmd: Option<Vec<String>>,
85}
86
87#[derive(Deserialize)]
89struct PyProject {
90 tool: Option<ToolSection>,
91}
92
93#[derive(Deserialize)]
94struct ToolSection {
95 #[serde(rename = "python_proto_importer")]
96 python_proto_importer: Option<ImporterRoot>,
97}
98
99#[derive(Deserialize)]
100struct ImporterRoot {
101 #[serde(flatten)]
102 core: ImporterCore,
103 verify: Option<VerifyToml>,
104}
105
106#[allow(dead_code)]
107#[derive(Deserialize)]
108struct ImporterCore {
109 backend: Option<String>,
110 python_exe: Option<String>,
111 include: Option<Vec<String>>, inputs: Option<Vec<String>>, out: Option<String>,
114 mypy: Option<bool>,
115 mypy_grpc: Option<bool>,
116 buf_gen_yaml: Option<String>,
117 postprocess: Option<PostProcessToml>,
118}
119
120#[allow(dead_code)]
121#[derive(Deserialize)]
122struct PostProcessToml {
123 relative_imports: Option<bool>,
124 fix_pyi: Option<bool>,
125 create_package: Option<bool>,
126 exclude_google: Option<bool>,
127 pyright_header: Option<bool>,
128 module_suffixes: Option<Vec<String>>,
129}
130
131#[allow(dead_code)]
132#[derive(Deserialize)]
133struct VerifyToml {
134 mypy_cmd: Option<Vec<String>>,
135 pyright_cmd: Option<Vec<String>>,
136}
137
138impl AppConfig {
139 pub fn load(pyproject_path: Option<&Path>) -> Result<Self> {
171 let path = match pyproject_path {
172 Some(p) => p.to_path_buf(),
173 None => PathBuf::from("pyproject.toml"),
174 };
175 let content = fs::read_to_string(&path)
176 .with_context(|| format!("failed to read {}", path.display()))?;
177 let root: PyProject = toml::from_str(&content).context("failed to parse pyproject.toml")?;
178 let Some(tool) = root.tool else {
179 bail!("[tool.python_proto_importer] not found");
180 };
181 let Some(importer) = tool.python_proto_importer else {
182 bail!("[tool.python_proto_importer] not found");
183 };
184
185 let backend = match importer
186 .core
187 .backend
188 .as_deref()
189 .unwrap_or("protoc")
190 .to_lowercase()
191 .as_str()
192 {
193 "protoc" => Backend::Protoc,
194 "buf" => Backend::Buf,
195 other => bail!("unsupported backend: {}", other),
196 };
197
198 let python_exe = importer
199 .core
200 .python_exe
201 .unwrap_or_else(|| "python3".to_string());
202 let mut include = importer
203 .core
204 .include
205 .unwrap_or_default()
206 .into_iter()
207 .map(PathBuf::from)
208 .collect::<Vec<_>>();
209
210 if include.is_empty() {
212 include.push(PathBuf::from("."));
213 }
214 let inputs = importer.core.inputs.unwrap_or_default();
215 let out = importer
216 .core
217 .out
218 .map(PathBuf::from)
219 .unwrap_or_else(|| PathBuf::from("generated/python"));
220
221 let generate_mypy = importer.core.mypy.unwrap_or(false);
222 let generate_mypy_grpc = importer.core.mypy_grpc.unwrap_or(false);
223
224 let pp = importer.core.postprocess.unwrap_or(PostProcessToml {
225 relative_imports: Some(true),
226 fix_pyi: Some(true),
227 create_package: Some(true),
228 exclude_google: Some(true),
229 pyright_header: Some(false),
230 module_suffixes: None,
231 });
232 let postprocess = PostProcess {
233 relative_imports: pp.relative_imports.unwrap_or(true),
234 fix_pyi: pp.fix_pyi.unwrap_or(true),
235 create_package: pp.create_package.unwrap_or(true),
236 exclude_google: pp.exclude_google.unwrap_or(true),
237 pyright_header: pp.pyright_header.unwrap_or(false),
238 module_suffixes: pp.module_suffixes.unwrap_or_else(|| {
239 vec![
240 "_pb2.py".into(),
241 "_pb2.pyi".into(),
242 "_pb2_grpc.py".into(),
243 "_pb2_grpc.pyi".into(),
244 ]
245 }),
246 };
247
248 let verify = importer.verify.map(|v| Verify {
249 mypy_cmd: v.mypy_cmd,
250 pyright_cmd: v.pyright_cmd,
251 });
252
253 Ok(Self {
254 backend,
255 python_exe,
256 include,
257 inputs,
258 out,
259 generate_mypy,
260 generate_mypy_grpc,
261 postprocess,
262 verify,
263 })
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use std::fs;
271 use tempfile::tempdir;
272
273 #[test]
274 fn load_minimal_config() {
275 let dir = tempdir().unwrap();
276 let config_path = dir.path().join("pyproject.toml");
277 fs::write(
278 &config_path,
279 r#"
280[tool.python_proto_importer]
281inputs = ["proto/**/*.proto"]
282"#,
283 )
284 .unwrap();
285
286 let config = AppConfig::load(Some(&config_path)).unwrap();
287
288 assert!(matches!(config.backend, Backend::Protoc));
289 assert_eq!(config.python_exe, "python3");
290 assert_eq!(config.include, vec![PathBuf::from(".")]);
291 assert_eq!(config.inputs, vec!["proto/**/*.proto"]);
292 assert_eq!(config.out, PathBuf::from("generated/python"));
293 assert!(!config.generate_mypy);
294 assert!(!config.generate_mypy_grpc);
295 assert!(config.postprocess.relative_imports);
296 assert!(config.postprocess.fix_pyi);
297 assert!(config.postprocess.create_package);
298 assert!(config.postprocess.exclude_google);
299 assert!(!config.postprocess.pyright_header);
300 assert_eq!(
301 config.postprocess.module_suffixes,
302 vec!["_pb2.py", "_pb2.pyi", "_pb2_grpc.py", "_pb2_grpc.pyi"]
303 );
304 assert!(config.verify.is_none());
305 }
306
307 #[test]
308 fn load_full_config() {
309 let dir = tempdir().unwrap();
310 let config_path = dir.path().join("pyproject.toml");
311 fs::write(
312 &config_path,
313 r#"
314[tool.python_proto_importer]
315backend = "buf"
316python_exe = "uv"
317include = ["proto", "common"]
318inputs = ["proto/**/*.proto", "common/**/*.proto"]
319out = "src/generated"
320mypy = true
321mypy_grpc = true
322
323[tool.python_proto_importer.postprocess]
324relative_imports = false
325fix_pyi = false
326create_package = false
327exclude_google = false
328pyright_header = true
329module_suffixes = ["_pb2.py", "_grpc.py"]
330
331[tool.python_proto_importer.verify]
332mypy_cmd = ["mypy", "--strict"]
333pyright_cmd = ["pyright", "generated"]
334"#,
335 )
336 .unwrap();
337
338 let config = AppConfig::load(Some(&config_path)).unwrap();
339
340 assert!(matches!(config.backend, Backend::Buf));
341 assert_eq!(config.python_exe, "uv");
342 assert_eq!(
343 config.include,
344 vec![PathBuf::from("proto"), PathBuf::from("common")]
345 );
346 assert_eq!(config.inputs, vec!["proto/**/*.proto", "common/**/*.proto"]);
347 assert_eq!(config.out, PathBuf::from("src/generated"));
348 assert!(config.generate_mypy);
349 assert!(config.generate_mypy_grpc);
350 assert!(!config.postprocess.relative_imports);
351 assert!(!config.postprocess.fix_pyi);
352 assert!(!config.postprocess.create_package);
353 assert!(!config.postprocess.exclude_google);
354 assert!(config.postprocess.pyright_header);
355 assert_eq!(
356 config.postprocess.module_suffixes,
357 vec!["_pb2.py", "_grpc.py"]
358 );
359
360 let verify = config.verify.unwrap();
361 assert_eq!(verify.mypy_cmd.unwrap(), vec!["mypy", "--strict"]);
362 assert_eq!(verify.pyright_cmd.unwrap(), vec!["pyright", "generated"]);
363 }
364
365 #[test]
366 fn load_empty_include_defaults_to_current_dir() {
367 let dir = tempdir().unwrap();
368 let config_path = dir.path().join("pyproject.toml");
369 fs::write(
370 &config_path,
371 r#"
372[tool.python_proto_importer]
373inputs = ["proto/**/*.proto"]
374include = []
375"#,
376 )
377 .unwrap();
378
379 let config = AppConfig::load(Some(&config_path)).unwrap();
380 assert_eq!(config.include, vec![PathBuf::from(".")]);
381 }
382
383 #[test]
384 fn backend_case_insensitive() {
385 let dir = tempdir().unwrap();
386 let config_path = dir.path().join("pyproject.toml");
387 fs::write(
388 &config_path,
389 r#"
390[tool.python_proto_importer]
391backend = "PROTOC"
392inputs = ["proto/**/*.proto"]
393"#,
394 )
395 .unwrap();
396
397 let config = AppConfig::load(Some(&config_path)).unwrap();
398 assert!(matches!(config.backend, Backend::Protoc));
399 }
400
401 #[test]
402 fn unsupported_backend_fails() {
403 let dir = tempdir().unwrap();
404 let config_path = dir.path().join("pyproject.toml");
405 fs::write(
406 &config_path,
407 r#"
408[tool.python_proto_importer]
409backend = "unsupported"
410inputs = ["proto/**/*.proto"]
411"#,
412 )
413 .unwrap();
414
415 let result = AppConfig::load(Some(&config_path));
416 assert!(result.is_err());
417 assert!(
418 result
419 .unwrap_err()
420 .to_string()
421 .contains("unsupported backend")
422 );
423 }
424
425 #[test]
426 fn missing_config_section_fails() {
427 let dir = tempdir().unwrap();
428 let config_path = dir.path().join("pyproject.toml");
429 fs::write(
430 &config_path,
431 r#"
432[tool.other_tool]
433something = "value"
434"#,
435 )
436 .unwrap();
437
438 let result = AppConfig::load(Some(&config_path));
439 assert!(result.is_err());
440 assert!(
441 result
442 .unwrap_err()
443 .to_string()
444 .contains("[tool.python_proto_importer] not found")
445 );
446 }
447
448 #[test]
449 fn missing_file_fails() {
450 let result = AppConfig::load(Some(&PathBuf::from("nonexistent.toml")));
451 assert!(result.is_err());
452 assert!(result.unwrap_err().to_string().contains("failed to read"));
453 }
454
455 #[test]
456 fn invalid_toml_fails() {
457 let dir = tempdir().unwrap();
458 let config_path = dir.path().join("pyproject.toml");
459 fs::write(
460 &config_path,
461 r#"
462[tool.python_proto_importer
463# Missing closing bracket
464inputs = ["proto/**/*.proto"]
465"#,
466 )
467 .unwrap();
468
469 let result = AppConfig::load(Some(&config_path));
470 assert!(result.is_err());
471 assert!(result.unwrap_err().to_string().contains("failed to parse"));
472 }
473
474 #[test]
475 fn load_default_path() {
476 let dir = tempdir().unwrap();
477 let original_dir = std::env::current_dir().unwrap();
478 std::env::set_current_dir(&dir).unwrap();
479
480 let config_path = dir.path().join("pyproject.toml");
481 fs::write(
482 &config_path,
483 r#"
484[tool.python_proto_importer]
485inputs = ["proto/**/*.proto"]
486"#,
487 )
488 .unwrap();
489
490 let config = AppConfig::load(None).unwrap();
491 assert_eq!(config.inputs, vec!["proto/**/*.proto"]);
492
493 std::env::set_current_dir(&original_dir).unwrap();
494 }
495
496 #[test]
497 fn verify_section_optional() {
498 let dir = tempdir().unwrap();
499 let config_path = dir.path().join("pyproject.toml");
500 fs::write(
501 &config_path,
502 r#"
503[tool.python_proto_importer]
504inputs = ["proto/**/*.proto"]
505
506[tool.python_proto_importer.verify]
507mypy_cmd = ["mypy"]
508# pyright_cmd intentionally omitted
509"#,
510 )
511 .unwrap();
512
513 let config = AppConfig::load(Some(&config_path)).unwrap();
514 let verify = config.verify.unwrap();
515 assert_eq!(verify.mypy_cmd.unwrap(), vec!["mypy"]);
516 assert!(verify.pyright_cmd.is_none());
517 }
518}