grpc_build/
lib.rs

1use anyhow::{anyhow, Context, Ok, Result};
2use prost::Message;
3use prost_build::{protoc_from_env, protoc_include_from_env, Module};
4use prost_types::{FileDescriptorProto, FileDescriptorSet};
5use std::{collections::HashMap, path::Path, process::Command};
6
7pub mod base;
8mod builder;
9pub mod tree;
10pub use builder::Builder;
11
12impl Builder {
13    pub fn build(self, in_dir: impl AsRef<Path>) -> Result<(), anyhow::Error> {
14        let out_dir = self.get_out_dir()?;
15        if !self.force && out_dir.exists() {
16            return Err(anyhow!(
17                "the output directory already exists: {}",
18                out_dir.display()
19            ));
20        }
21
22        base::prepare_out_dir(&out_dir).context("failed to prepare out dir")?;
23
24        match self.file_descriptor_set_path.clone() {
25            Some(file_descriptor_path) => {
26                self.compile(in_dir.as_ref(), &out_dir, &file_descriptor_path)
27                    .context("failed to compile the protos")?;
28            }
29            None => {
30                // Create a temporary directory to host the file descriptor set.
31                // The directory gets cleaned when compilation ends.
32                let tmp = tempfile::Builder::new()
33                    .prefix("grpc-build")
34                    .tempdir()
35                    .context("failed to get tempdir")?;
36                let file_descriptor_path = tmp.path().join("grpc-descriptor-set");
37
38                self.compile(in_dir.as_ref(), &out_dir, &file_descriptor_path)
39                    .context("failed to compile the protos")?;
40            }
41        }
42
43        base::refactor(out_dir).context("failed to refactor the protos")?;
44
45        Ok(())
46    }
47
48    fn compile(
49        self,
50        input_dir: &Path,
51        out_dir: &Path,
52        file_descriptor_path: &Path,
53    ) -> Result<(), anyhow::Error> {
54        self.run_protoc(input_dir.as_ref(), file_descriptor_path)
55            .context("failed to run protoc")?;
56
57        let buf = fs_err::read(file_descriptor_path).context("failed to read file descriptors")?;
58        let file_descriptor_set =
59            FileDescriptorSet::decode(&*buf).context("invalid FileDescriptorSet")?;
60
61        self.generate_services(out_dir, file_descriptor_set)
62            .context("failed to generic tonic services")?;
63        Ok(())
64    }
65
66    fn run_protoc(
67        &self,
68        input_dir: &Path,
69        file_descriptor_path: &Path,
70    ) -> Result<(), anyhow::Error> {
71        let protos = crate::base::get_protos(input_dir, self.follow_links).collect::<Vec<_>>();
72
73        if protos.is_empty() {
74            return Err(anyhow!("no .proto files found in {}", input_dir.display()));
75        }
76
77        let compile_includes: &Path = match input_dir.parent() {
78            None => Path::new("."),
79            Some(parent) => parent,
80        };
81
82        let mut cmd = Command::new(protoc_from_env());
83        cmd.arg("--include_imports")
84            .arg("--include_source_info")
85            .arg("--descriptor_set_out")
86            .arg(file_descriptor_path);
87        cmd.arg("--proto_path").arg(compile_includes);
88
89        if let Some(include) = protoc_include_from_env() {
90            cmd.arg("--proto_path").arg(include);
91        }
92
93        for arg in &self.protoc_args {
94            cmd.arg(arg);
95        }
96
97        for proto in &protos {
98            cmd.arg(proto);
99        }
100
101        eprintln!("Running {cmd:?}");
102
103        let output = cmd.output().context(
104            "failed to invoke protoc (hint: https://docs.rs/prost-build/#sourcing-protoc)",
105        )?;
106
107        if !output.status.success() {
108            eprintln!(
109                "---protoc stderr---\n{}\n------",
110                String::from_utf8_lossy(&output.stderr).trim()
111            );
112
113            return Err(anyhow!(
114                "protoc returned a non-zero exit status: {}",
115                output.status,
116            ));
117        }
118
119        Ok(())
120    }
121
122    fn generate_services(
123        mut self,
124        out_dir: &Path,
125        file_descriptor_set: FileDescriptorSet,
126    ) -> Result<(), anyhow::Error> {
127        let service_generator = self.tonic.service_generator();
128        self.prost.service_generator(service_generator);
129
130        let requests = file_descriptor_set
131            .file
132            .into_iter()
133            .map(|descriptor| {
134                // Add our NamedMessage derive
135                for (name, annotation) in derive_named_messages(&descriptor) {
136                    self.prost.type_attribute(&name, annotation);
137                }
138
139                (
140                    Module::from_protobuf_package_name(descriptor.package()),
141                    descriptor,
142                )
143            })
144            .collect::<Vec<_>>();
145
146        let file_names = requests
147            .iter()
148            .map(|(module, _)| {
149                (
150                    module.clone(),
151                    module.to_file_name_or(self.default_module_name.as_deref().unwrap_or("_")),
152                )
153            })
154            .collect::<HashMap<Module, String>>();
155
156        let modules = self.prost.generate(requests)?;
157        for (module, content) in &modules {
158            let file_name = file_names
159                .get(module)
160                .expect("every module should have a filename");
161            let output_path = out_dir.join(file_name);
162
163            let previous_content = fs_err::read(&output_path);
164
165            // only write the file if the contents have changed
166            if previous_content
167                .map(|previous_content| previous_content != content.as_bytes())
168                .unwrap_or(true)
169            {
170                fs_err::write(output_path, content)?;
171            }
172        }
173
174        Ok(())
175    }
176}
177
178/// Build annotations for the top-level messages in a file,
179fn derive_named_messages(
180    descriptor: &FileDescriptorProto,
181) -> impl Iterator<Item = (String, String)> + '_ {
182    let namespace = descriptor.package();
183    descriptor.message_type.iter().map(|message| {
184        let full_name = fully_qualified_name(namespace, message.name());
185        let derive =
186            format!("#[derive(::grpc_build_core::NamedMessage)] #[name = \"{full_name}\"]");
187        (full_name, derive)
188    })
189}
190
191fn fully_qualified_name(namespace: &str, name: &str) -> String {
192    let namespace = namespace.trim_start_matches('.');
193    if namespace.is_empty() {
194        name.into()
195    } else {
196        format!("{namespace}.{name}")
197    }
198}