twurst_build/
lib.rs

1#![doc = include_str!("../README.md")]
2#![doc(test(attr(deny(warnings))))]
3#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4
5pub use prost_build as prost;
6use prost_build::{Config, Module, Service, ServiceGenerator};
7use regex::Regex;
8use std::collections::HashSet;
9use std::fmt::Write;
10use std::io::{Error, Result};
11use std::path::{Path, PathBuf};
12use std::{env, fs};
13
14/// Builds protobuf bindings for Twirp.
15///
16/// Client and server are not enabled by defaults and must be enabled with the [`with_client`](Self::with_client) and [`with_server`](Self::with_server) methods.
17#[derive(Default)]
18pub struct TwirpBuilder {
19    config: Config,
20    generator: TwirpServiceGenerator,
21    type_name_domain: Option<String>,
22}
23
24impl TwirpBuilder {
25    /// Builder with the default prost [`Config`].
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    /// Builder with a custom prost [`Config`].
31    pub fn from_prost(config: Config) -> Self {
32        Self {
33            config,
34            generator: TwirpServiceGenerator::new(),
35            type_name_domain: None,
36        }
37    }
38
39    /// Generates code for the Twirp client.
40    pub fn with_client(mut self) -> Self {
41        self.generator = self.generator.with_client();
42        self
43    }
44
45    /// Generates code for the Twirp server.
46    pub fn with_server(mut self) -> Self {
47        self.generator = self.generator.with_server();
48        self
49    }
50
51    /// Generates code for gRPC alongside Twirp.
52    pub fn with_grpc(mut self) -> Self {
53        self.generator = self.generator.with_grpc();
54        self
55    }
56
57    /// Adds an extra parameter to generated server methods that implements [`axum::FromRequestParts`](https://docs.rs/axum/latest/axum/extract/trait.FromRequestParts.html).
58    ///
59    /// For example
60    /// ```proto
61    /// message Service {
62    ///     rpc Test(TestRequest) returns (TestResponse) {}
63    /// }
64    /// ```
65    /// Compiled with option `.with_axum_request_extractor("headers", "::axum::http::HeaderMap")`
66    /// will generate the following code allowing to extract the request headers:
67    /// ```ignore
68    /// trait Service {
69    ///     async fn test(request: TestRequest, headers: ::axum::http::HeaderMap) -> Result<TestResponse, TwirpError>;
70    /// }
71    /// ```
72    ///
73    /// Note that the parameter type must implement [`axum::FromRequestParts`](https://docs.rs/axum/latest/axum/extract/trait.FromRequestParts.html).
74    pub fn with_axum_request_extractor(
75        mut self,
76        name: impl Into<String>,
77        type_name: impl Into<String>,
78    ) -> Self {
79        self.generator = self.generator.with_axum_request_extractor(name, type_name);
80        self
81    }
82
83    /// Customizes the type name domain.
84    ///
85    /// By default, 'type.googleapis.com' is used.
86    pub fn with_type_name_domain(mut self, domain: impl Into<String>) -> Self {
87        self.type_name_domain = Some(domain.into());
88        self
89    }
90
91    /// Do compile the protos.
92    pub fn compile_protos(
93        mut self,
94        protos: &[impl AsRef<Path>],
95        includes: &[impl AsRef<Path>],
96    ) -> Result<()> {
97        let out_dir = PathBuf::from(
98            env::var_os("OUT_DIR").ok_or_else(|| Error::other("OUT_DIR is not set"))?,
99        );
100
101        // We make sure the script is executed again if a file changed
102        for proto in protos {
103            println!("cargo:rerun-if-changed={}", proto.as_ref().display());
104        }
105        self.config
106            .enable_type_names()
107            .type_name_domain(
108                ["."],
109                self.type_name_domain
110                    .as_deref()
111                    .unwrap_or("type.googleapis.com"),
112            )
113            .service_generator(Box::new(self.generator));
114
115        // We configure with prost reflect
116        prost_reflect_build::Builder::new()
117            .file_descriptor_set_bytes("self::FILE_DESCRIPTOR_SET_BYTES")
118            .configure(&mut self.config, protos, includes)?;
119
120        // We do the build itself while saving the list of modules
121        let config = self.config.skip_protoc_run();
122        let file_descriptor_set = config.load_fds(protos, includes)?;
123        let modules = file_descriptor_set
124            .file
125            .iter()
126            .map(|fd| Module::from_protobuf_package_name(fd.package()))
127            .collect::<HashSet<_>>();
128
129        // We generate the files
130        config.compile_fds(file_descriptor_set)?;
131
132        // TODO(vsiles) consider proper AST parsing in case we need to do something
133        // more robust
134        //
135        // prepare a regex to match `pub mod <module-name> {`
136        let re = Regex::new(r"^(\s*)pub mod \w+ \{\s*$").expect("Failed to compile regex");
137
138        // We add the file descriptor to every file to make reflection work automatically
139        for module in modules {
140            let file_path = Path::new(&out_dir).join(module.to_file_name_or("_"));
141            if !file_path.exists() {
142                continue; // We ignore not built files
143            }
144            let original_content = fs::read_to_string(&file_path)?;
145
146            // scan for nested modules and insert the right FILE_DESCRIPTOR_SET_BYTES definition
147            let mut modified_content = original_content
148                .lines()
149                .flat_map(|line| {
150                    if let Some(captures) = re.captures(line) {
151                        let indentation = captures.get(1).map_or("", |m| m.as_str());
152                        vec![
153                            line.to_string(),
154                            // if there is no nested type, the next line would generate a warning
155                            format!("    {}{}", indentation, "#[allow(unused_imports)]"),
156                            format!(
157                                "    {}{}",
158                                indentation, "use super::FILE_DESCRIPTOR_SET_BYTES;"
159                            ),
160                        ]
161                    } else {
162                        vec![line.to_string()]
163                    }
164                })
165                .collect::<Vec<_>>();
166
167            modified_content.push("const FILE_DESCRIPTOR_SET_BYTES: &[u8] = include_bytes!(\"file_descriptor_set.bin\");\n".to_string());
168            let file_content = modified_content.join("\n");
169
170            fs::write(&file_path, &file_content)?;
171        }
172
173        Ok(())
174    }
175}
176
177/// Low level generator for Twirp related code.
178///
179/// This only useful if you want to customize builds. For common use cases, please use [`TwirpBuilder`].
180///
181/// Should be given to [`Config::service_generator`].
182///
183/// Client and server are not enabled by defaults and must be enabled with the [`with_client`](Self::with_client) and [`with_server`](Self::with_server) methods.
184#[derive(Default)]
185struct TwirpServiceGenerator {
186    client: bool,
187    server: bool,
188    grpc: bool,
189    request_extractors: Vec<(String, String)>,
190}
191
192impl TwirpServiceGenerator {
193    pub fn new() -> Self {
194        Self::default()
195    }
196
197    pub fn with_client(mut self) -> Self {
198        self.client = true;
199        self
200    }
201
202    pub fn with_server(mut self) -> Self {
203        self.server = true;
204        self
205    }
206
207    pub fn with_grpc(mut self) -> Self {
208        self.grpc = true;
209        self
210    }
211
212    pub fn with_axum_request_extractor(
213        mut self,
214        name: impl Into<String>,
215        type_name: impl Into<String>,
216    ) -> Self {
217        self.request_extractors
218            .push((name.into(), type_name.into()));
219        self
220    }
221}
222
223impl ServiceGenerator for TwirpServiceGenerator {
224    fn generate(&mut self, service: Service, buf: &mut String) {
225        self.do_generate(service, buf)
226            .expect("failed to generate Twirp service")
227    }
228}
229
230impl TwirpServiceGenerator {
231    fn do_generate(&mut self, service: Service, buf: &mut String) -> std::fmt::Result {
232        if self.client {
233            writeln!(buf)?;
234            for comment in &service.comments.leading {
235                writeln!(buf, "/// {comment}")?;
236            }
237            if service.options.deprecated.unwrap_or(false) {
238                writeln!(buf, "#[deprecated]")?;
239            }
240            writeln!(buf, "#[derive(Clone)]")?;
241            writeln!(
242                buf,
243                "pub struct {}Client<C: ::twurst_client::TwirpHttpService> {{",
244                service.name
245            )?;
246            writeln!(buf, "    client: ::twurst_client::TwirpHttpClient<C>")?;
247            writeln!(buf, "}}")?;
248            writeln!(buf)?;
249            writeln!(
250                buf,
251                "impl<C: ::twurst_client::TwirpHttpService> {}Client<C> {{",
252                service.name
253            )?;
254            writeln!(
255                buf,
256                "    pub fn new(client: impl Into<::twurst_client::TwirpHttpClient<C>>) -> Self {{"
257            )?;
258            writeln!(buf, "        Self {{ client: client.into() }}")?;
259            writeln!(buf, "    }}")?;
260            for method in &service.methods {
261                if method.client_streaming || method.server_streaming {
262                    continue; // Not supported
263                }
264                for comment in &method.comments.leading {
265                    writeln!(buf, "    /// {comment}")?;
266                }
267                if method.options.deprecated.unwrap_or(false) {
268                    writeln!(buf, "#[deprecated]")?;
269                }
270                writeln!(
271                    buf,
272                    "    pub async fn {}(&self, request: &{}) -> Result<{}, ::twurst_client::TwirpError> {{",
273                    method.name, method.input_type, method.output_type,
274                )?;
275                writeln!(
276                    buf,
277                    "        self.client.call(\"/{}.{}/{}\", request).await",
278                    service.package, service.proto_name, method.proto_name,
279                )?;
280                writeln!(buf, "    }}")?;
281            }
282            writeln!(buf, "}}")?;
283        }
284
285        if self.server {
286            writeln!(buf)?;
287            for comment in &service.comments.leading {
288                writeln!(buf, "/// {comment}")?;
289            }
290            writeln!(buf, "#[::twurst_server::codegen::trait_variant_make(Send)]")?;
291            writeln!(buf, "pub trait {} {{", service.name)?;
292            for method in &service.methods {
293                if !self.grpc && (method.client_streaming || method.server_streaming) {
294                    continue; // No streaming
295                }
296                for comment in &method.comments.leading {
297                    writeln!(buf, "    /// {comment}")?;
298                }
299                write!(buf, "    async fn {}(&self, request: ", method.name)?;
300                if method.client_streaming {
301                    write!(
302                        buf,
303                        "impl ::twurst_server::codegen::Stream<Item=Result<{},::twurst_client::TwirpError>> + Send + 'static",
304                        method.input_type,
305                    )?;
306                } else {
307                    write!(buf, "{}", method.input_type)?;
308                }
309                for (arg_name, arg_type) in &self.request_extractors {
310                    write!(buf, ", {arg_name}: {arg_type}")?;
311                }
312                writeln!(buf, ") -> Result<")?;
313                if method.server_streaming {
314                    // TODO: move back to `impl` when we will be able to use precise capturing to not capture &self
315                    writeln!(buf, "Box<dyn ::twurst_server::codegen::Stream<Item=Result<{}, ::twurst_server::TwirpError>> + Send>", method.output_type)?;
316                } else {
317                    writeln!(buf, "{}", method.output_type)?;
318                }
319                writeln!(buf, ", ::twurst_server::TwirpError>;")?;
320            }
321            writeln!(buf)?;
322            writeln!(
323                buf,
324                "    fn into_router<S: Clone + Send + Sync + 'static>(self) -> ::twurst_server::codegen::Router<S> where Self : Sized + Send + Sync + 'static {{"
325            )?;
326            writeln!(
327                buf,
328                "        ::twurst_server::codegen::TwirpRouter::new(::std::sync::Arc::new(self))"
329            )?;
330            for method in &service.methods {
331                if method.client_streaming || method.server_streaming {
332                    writeln!(
333                        buf,
334                        "            .route_streaming(\"/{}.{}/{}\")",
335                        service.package, service.proto_name, method.proto_name,
336                    )?;
337                    continue;
338                }
339                write!(
340                        buf,
341                        "            .route(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: {}",
342                        service.package, service.proto_name, method.proto_name, method.input_type,
343                    )?;
344                if self.request_extractors.is_empty() {
345                    write!(buf, ", _: ::twurst_server::codegen::RequestParts, _: S")?;
346                } else {
347                    write!(
348                        buf,
349                        ", mut parts: ::twurst_server::codegen::RequestParts, state: S",
350                    )?;
351                }
352                write!(buf, "| {{")?;
353                writeln!(buf, "                async move {{")?;
354                write!(buf, "                    service.{}(request", method.name)?;
355                for (_name, type_name) in &self.request_extractors {
356                    write!(
357                            buf,
358                            ", match <{type_name} as ::twurst_server::codegen::FromRequestParts<_>>::from_request_parts(&mut parts, &state).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
359                        )?;
360                }
361                writeln!(buf, ").await")?;
362                writeln!(buf, "                }}")?;
363                writeln!(buf, "            }})")?;
364            }
365            writeln!(buf, "            .build()")?;
366            writeln!(buf, "    }}")?;
367
368            if self.grpc {
369                writeln!(buf)?;
370                writeln!(
371                    buf,
372                    "    fn into_grpc_router(self) -> ::twurst_server::codegen::Router where Self : Sized + Send + Sync + 'static {{"
373                )?;
374                writeln!(
375                    buf,
376                    "        ::twurst_server::codegen::GrpcRouter::new(::std::sync::Arc::new(self))"
377                )?;
378                for method in &service.methods {
379                    let method_name = match (method.client_streaming, method.server_streaming) {
380                        (false, false) => "route",
381                        (false, true) => "route_server_streaming",
382                        (true, false) => "route_client_streaming",
383                        (true, true) => "route_streaming",
384                    };
385                    write!(
386                        buf,
387                        "            .{}(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: ",method_name,
388                        service.package, service.proto_name, method.proto_name,
389                    )?;
390                    if method.client_streaming {
391                        write!(
392                            buf,
393                            "::twurst_server::codegen::GrpcClientStream<{}>",
394                            method.input_type,
395                        )?;
396                    } else {
397                        write!(buf, "{}", method.input_type)?;
398                    }
399                    if self.request_extractors.is_empty() {
400                        write!(buf, ", _: ::twurst_server::codegen::RequestParts")?;
401                    } else {
402                        write!(buf, ", mut parts: ::twurst_server::codegen::RequestParts")?;
403                    }
404                    write!(buf, "| {{")?;
405                    write!(buf, "                async move {{")?;
406                    if method.server_streaming {
407                        write!(buf, "Ok(Box::into_pin(")?;
408                    }
409                    write!(buf, "service.{}(request", method.name)?;
410                    for (_name, type_name) in &self.request_extractors {
411                        write!(
412                            buf,
413                            ", match <{type_name} as ::twurst_server::codegen::FromRequestParts<_>>::from_request_parts(&mut parts, &()).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
414                        )?;
415                    }
416                    write!(buf, ").await")?;
417                    if method.server_streaming {
418                        write!(buf, "?))")?;
419                    }
420                    writeln!(buf, "}}")?;
421                    writeln!(buf, "            }})")?;
422                }
423                writeln!(buf, "            .build()")?;
424                writeln!(buf, "    }}")?;
425            }
426
427            writeln!(buf, "}}")?;
428        }
429
430        Ok(())
431    }
432}