axum_connect_build/
lib.rs1use std::{
2 cell::RefCell,
3 env,
4 io::{BufWriter, Write},
5 ops::Deref,
6 path::{Path, PathBuf},
7 rc::Rc,
8};
9
10use gen::AxumConnectServiceGenerator;
11
12mod gen;
13
14#[derive(Clone, Debug)]
15pub struct AxumConnectGenSettings {
16 pub includes: Vec<PathBuf>,
17 pub inputs: Vec<PathBuf>,
18 pub protoc_args: Vec<String>,
19 pub protoc_version: Option<String>,
20}
21
22impl Default for AxumConnectGenSettings {
23 fn default() -> Self {
24 Self {
25 includes: Default::default(),
26 inputs: Default::default(),
27 protoc_args: Default::default(),
28 protoc_version: Some("31.1".to_string()),
29 }
30 }
31}
32
33impl AxumConnectGenSettings {
34 pub fn from_directory_recursive<P>(path: P) -> anyhow::Result<Self>
35 where
36 P: Into<PathBuf>,
37 {
38 let path = path.into();
39 let mut settings = Self::default();
40 settings.includes.push(path.clone());
41
42 let mut dirs = vec![path];
44 while let Some(dir) = dirs.pop() {
45 for entry in std::fs::read_dir(dir)? {
46 let entry = entry?;
47 let path = entry.path();
48 if path.is_dir() {
49 dirs.push(path.clone());
50 } else if path.extension().map(|ext| ext == "proto").unwrap_or(false) {
51 settings.inputs.push(path);
52 }
53 }
54 }
55
56 Ok(settings)
57 }
58}
59
60pub fn axum_connect_codegen(settings: AxumConnectGenSettings) -> anyhow::Result<()> {
61 if let Some(version) = &settings.protoc_version {
63 let out_dir = env::var("OUT_DIR").unwrap();
64 let protoc_path = protoc_fetcher::protoc(version, Path::new(&out_dir))?;
65 env::set_var("PROTOC", protoc_path);
66 }
67
68 for input in &settings.inputs {
70 println!("cargo:rerun-if-changed={}", input.display());
71 }
72
73 let descriptor_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("proto_descriptor.bin");
74
75 let mut conf = prost_build::Config::new();
76
77 conf.compile_well_known_types();
79 conf.file_descriptor_set_path(&descriptor_path);
80 conf.extern_path(".google.protobuf", "::axum_connect::pbjson_types");
81 conf.service_generator(Box::new(AxumConnectServiceGenerator::new()));
82
83 for arg in settings.protoc_args {
85 conf.protoc_arg(arg);
86 }
87
88 conf.compile_protos(&settings.inputs, &settings.includes)
90 .unwrap();
91
92 let descriptor_set = std::fs::read(descriptor_path)?;
94 let mut output: PathBuf = PathBuf::from(env::var("OUT_DIR").unwrap());
95 output.push("FILENAME");
96
97 let files = Rc::new(RefCell::new(vec![]));
99
100 let files_c = files.clone();
101 let writers = pbjson_build::Builder::new()
102 .register_descriptors(&descriptor_set)?
103 .extern_path(".google.protobuf", "::axum_connect::pbjson_types")
104 .generate(&["."], move |package| {
105 output.set_file_name(format!("{}.rs", package));
106 files_c.deref().borrow_mut().push(output.clone());
107
108 let file = std::fs::OpenOptions::new().append(true).create(true).open(&output)?;
109
110 Ok(BufWriter::new(file))
111 })?;
112
113 for (_, mut writer) in writers {
114 writer.flush()?;
115 }
116
117 for file in files.take().into_iter() {
119 let contents = std::fs::read_to_string(&file)?;
120 let contents = contents.replace("pbjson::", "axum_connect::pbjson::");
121 let contents = contents.replace("prost::", "axum_connect::prost::");
122 let contents = contents.replace("serde::", "axum_connect::serde::");
123 std::fs::write(&file, contents)?;
124 }
125
126 Ok(())
127}