axum_connect_build/
lib.rs

1use std::{
2    cell::RefCell,
3    env,
4    io::{BufWriter, Write},
5    ops::Deref,
6    path::{Path, PathBuf},
7    rc::Rc,
8};
9
10use gen::AxumConnectServiceGenerator;
11
12mod gen;
13
14#[derive(Clone, Debug)]
15pub struct AxumConnectGenSettings {
16    pub includes: Vec<PathBuf>,
17    pub inputs: Vec<PathBuf>,
18    pub protoc_args: Vec<String>,
19    pub protoc_version: Option<String>,
20}
21
22impl Default for AxumConnectGenSettings {
23    fn default() -> Self {
24        Self {
25            includes: Default::default(),
26            inputs: Default::default(),
27            protoc_args: Default::default(),
28            protoc_version: Some("31.1".to_string()),
29        }
30    }
31}
32
33impl AxumConnectGenSettings {
34    pub fn from_directory_recursive<P>(path: P) -> anyhow::Result<Self>
35    where
36        P: Into<PathBuf>,
37    {
38        let path = path.into();
39        let mut settings = Self::default();
40        settings.includes.push(path.clone());
41
42        // Recursively add all files that end in ".proto" to the inputs.
43        let mut dirs = vec![path];
44        while let Some(dir) = dirs.pop() {
45            for entry in std::fs::read_dir(dir)? {
46                let entry = entry?;
47                let path = entry.path();
48                if path.is_dir() {
49                    dirs.push(path.clone());
50                } else if path.extension().map(|ext| ext == "proto").unwrap_or(false) {
51                    settings.inputs.push(path);
52                }
53            }
54        }
55
56        Ok(settings)
57    }
58}
59
60pub fn axum_connect_codegen(settings: AxumConnectGenSettings) -> anyhow::Result<()> {
61    // Fetch protoc
62    if let Some(version) = &settings.protoc_version {
63        let out_dir = env::var("OUT_DIR").unwrap();
64        let protoc_path = protoc_fetcher::protoc(version, Path::new(&out_dir))?;
65        env::set_var("PROTOC", protoc_path);
66    }
67
68    // Instruct cargo to re-run if any of the proto files change
69    for input in &settings.inputs {
70        println!("cargo:rerun-if-changed={}", input.display());
71    }
72
73    let descriptor_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("proto_descriptor.bin");
74
75    let mut conf = prost_build::Config::new();
76
77    // Standard prost configuration
78    conf.compile_well_known_types();
79    conf.file_descriptor_set_path(&descriptor_path);
80    conf.extern_path(".google.protobuf", "::axum_connect::pbjson_types");
81    conf.service_generator(Box::new(AxumConnectServiceGenerator::new()));
82
83    // Arg configuration
84    for arg in settings.protoc_args {
85        conf.protoc_arg(arg);
86    }
87
88    // File configuration
89    conf.compile_protos(&settings.inputs, &settings.includes)
90        .unwrap();
91
92    // Use pbjson to generate the Serde impls, and inline them with the Prost files.
93    let descriptor_set = std::fs::read(descriptor_path)?;
94    let mut output: PathBuf = PathBuf::from(env::var("OUT_DIR").unwrap());
95    output.push("FILENAME");
96
97    // TODO: This is a nasty hack. Get rid of it. Idk how without dumping Prost and pbjson though.
98    let files = Rc::new(RefCell::new(vec![]));
99
100    let files_c = files.clone();
101    let writers = pbjson_build::Builder::new()
102        .register_descriptors(&descriptor_set)?
103        .extern_path(".google.protobuf", "::axum_connect::pbjson_types")
104        .generate(&["."], move |package| {
105            output.set_file_name(format!("{}.rs", package));
106            files_c.deref().borrow_mut().push(output.clone());
107
108            let file = std::fs::OpenOptions::new().append(true).create(true).open(&output)?;
109
110            Ok(BufWriter::new(file))
111        })?;
112
113    for (_, mut writer) in writers {
114        writer.flush()?;
115    }
116
117    // Now second part of the nasty hack, replace a few namespaces with re-exported ones.
118    for file in files.take().into_iter() {
119        let contents = std::fs::read_to_string(&file)?;
120        let contents = contents.replace("pbjson::", "axum_connect::pbjson::");
121        let contents = contents.replace("prost::", "axum_connect::prost::");
122        let contents = contents.replace("serde::", "axum_connect::serde::");
123        std::fs::write(&file, contents)?;
124    }
125
126    Ok(())
127}