twurst_build/
lib.rs

1#![doc = include_str!("../README.md")]
2#![doc(
3    test(attr(deny(warnings))),
4    html_favicon_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png",
5    html_logo_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png"
6)]
7#![cfg_attr(docsrs, feature(doc_auto_cfg))]
8
9use self::proto_path_map::ProtoPathMap;
10pub use prost_build as prost;
11use prost_build::{Config, Module, Service, ServiceGenerator};
12use regex::Regex;
13use std::collections::HashSet;
14use std::fmt::Write;
15use std::io::{Error, Result};
16use std::path::{Path, PathBuf};
17use std::{env, fs};
18
19mod proto_path_map;
20
21/// Builds protobuf bindings for Twirp.
22///
23/// Client and server are not enabled by defaults and must be enabled with the [`with_client`](Self::with_client) and [`with_server`](Self::with_server) methods.
24#[derive(Default)]
25pub struct TwirpBuilder {
26    config: Config,
27    generator: TwirpServiceGenerator,
28    type_name_domain: Option<String>,
29}
30
31impl TwirpBuilder {
32    /// Builder with the default prost [`Config`].
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Builder with a custom prost [`Config`].
38    pub fn from_prost(config: Config) -> Self {
39        Self {
40            config,
41            generator: TwirpServiceGenerator::new(),
42            type_name_domain: None,
43        }
44    }
45
46    /// Generates code for the Twirp client.
47    pub fn with_client(mut self) -> Self {
48        self.generator = self.generator.with_client();
49        self
50    }
51
52    /// Generates code for the Twirp server.
53    pub fn with_server(mut self) -> Self {
54        self.generator = self.generator.with_server();
55        self
56    }
57
58    /// Generates code for gRPC alongside Twirp.
59    pub fn with_grpc(mut self) -> Self {
60        self.generator = self.generator.with_grpc();
61        self
62    }
63
64    #[deprecated(
65        since = "0.3.1",
66        note = "replaced with with_default_axum_request_extractor"
67    )]
68    pub fn with_axum_request_extractor(
69        self,
70        name: impl Into<String>,
71        type_name: impl Into<String>,
72    ) -> Self {
73        self.with_default_axum_request_extractor(name, type_name)
74    }
75
76    /// Adds an extra parameter to generated server methods that implements [`axum::FromRequestParts`](https://docs.rs/axum/latest/axum/extract/trait.FromRequestParts.html).
77    ///
78    /// For example
79    /// ```proto
80    /// message Service {
81    ///     rpc Test(TestRequest) returns (TestResponse) {}
82    /// }
83    /// ```
84    /// Compiled with option `.with_default_axum_request_extractor("headers", "::axum::http::HeaderMap")`
85    /// will generate the following code (in every service) allowing to extract the request headers:
86    /// ```ignore
87    /// trait Service {
88    ///     async fn test(request: TestRequest, headers: ::axum::http::HeaderMap) -> Result<TestResponse, TwirpError>;
89    /// }
90    /// ```
91    ///
92    /// Note that the parameter type must implement [`axum::FromRequestParts`](https://docs.rs/axum/latest/axum/extract/trait.FromRequestParts.html).
93    ///
94    /// There is a companion method to this: [`TwirpBuilder::with_service_specific_axum_request_extractor`], which adds request extractors per service,
95    /// rather than for all services given to the build.
96    pub fn with_default_axum_request_extractor(
97        mut self,
98        name: impl Into<String>,
99        type_name: impl Into<String>,
100    ) -> Self {
101        self.generator = self
102            .generator
103            .with_default_axum_request_extractor(name, type_name);
104        self
105    }
106
107    /// Adds an extra parameter to a service's server methods that implements [`axum::FromRequestParts`](https://docs.rs/axum/latest/axum/extract/trait.FromRequestParts.html).
108    ///
109    /// For example, given:
110    /// ```proto
111    /// message ServiceA {
112    ///     rpc Test(TestRequest) returns (TestResponse) {}
113    /// }
114    /// ```
115    ///
116    /// And:
117    ///
118    /// ```proto
119    /// message ServiceB {
120    ///     rpc Test(TestRequest) returns (TestResponse) {}
121    /// }
122    /// ```
123    ///
124    /// When compiled with option `.with_service_specific_axum_request_extractor("headers", "::axum::http::HeaderMap", "ServiceA")`
125    /// will generate the following code extract the request headers in just implementors of `ServiceA`:
126    /// ```ignore
127    /// trait ServiceA {
128    ///     async fn test(request: TestRequest, headers: ::axum::http::HeaderMap) -> Result<TestResponse, TwirpError>;
129    /// }
130    ///
131    /// trait ServiceB {
132    ///     async fn test(request: TestRequest) -> Result<TestResponse, TwirpError>;
133    /// }
134    /// ```
135    ///
136    /// Note that the parameter type must implement [`axum::FromRequestParts`](https://docs.rs/axum/latest/axum/extract/trait.FromRequestParts.html).
137    ///
138    /// Service specific request extractors will overwrite any that are set by: [`TwirpBuilder::with_default_axum_request_extractor`]. They are NOT additive, but you can
139    /// add any default extractors also as service specific ones, for example:
140    /// ```ignore
141    /// let builder = TwirpBuilder::new()
142    ///     .with_server()
143    ///     .with_default_axum_request_extractor(
144    ///         "auth_header",
145    ///         "my_crate::AuthorizationHeader",
146    ///     )
147    ///     .with_service_specific_axum_request_extractor(
148    ///         "auth_header",
149    ///         "my_crate::AuthorizationHeader",
150    ///         "MyService"
151    ///     );
152    ///     .with_service_specific_axum_request_extractor(
153    ///         "request_context",
154    ///         "my_crate::RequestContext",
155    ///         "MyService"
156    ///     );
157    /// ```
158    /// Will generate traits for `MyService` which extract both `auth_header` and
159    /// `request_context`, whilst all others will just have `auth_header`.
160    ///
161    /// The service should be specified by Proto path. For example:
162    ///
163    /// ```ignore
164    /// let mut builder = TwirpBuilder::new().with_server();
165    ///
166    /// // Match any Service called `MyService`
167    /// builder.with_service_specific_axum_request_extractor(
168    ///     "auth_header",
169    ///     "my_crate::AuthorizationHeader",
170    ///     "MyService"
171    /// );
172    ///
173    /// // Match any Service called `MyService` in the package `MyPackage`
174    /// builder.with_service_specific_axum_request_extractor(
175    ///     "auth_header",
176    ///     "my_crate::AuthorizationHeader",
177    ///     ".MyPackage.MyService"
178    /// );
179    ///
180    /// // Match all Services in the package `MyPackage`
181    /// builder.with_service_specific_axum_request_extractor(
182    ///     "auth_header",
183    ///     "my_crate::AuthorizationHeader",
184    ///     ".MyPackage"
185    /// );
186    ///
187    /// // Match _any_ Service.
188    /// //
189    /// // NOTE: This will override the defaults for ALL services. This is useful if you want all
190    /// // services to have an extractor with a subset having additional ones, however it means you cannot
191    /// // have disjoint sets of extractors across services.
192    /// builder.with_service_specific_axum_request_extractor(
193    ///     "auth_header",
194    ///     "my_crate::AuthorizationHeader",
195    ///     "."
196    /// );
197    pub fn with_service_specific_axum_request_extractor(
198        mut self,
199        name: impl Into<String>,
200        type_name: impl Into<String>,
201        service_path: impl Into<String>,
202    ) -> Self {
203        self.generator = self.generator.with_service_specific_axum_request_extractor(
204            name,
205            type_name,
206            service_path,
207        );
208        self
209    }
210
211    /// Customizes the type name domain.
212    ///
213    /// By default, 'type.googleapis.com' is used.
214    pub fn with_type_name_domain(mut self, domain: impl Into<String>) -> Self {
215        self.type_name_domain = Some(domain.into());
216        self
217    }
218
219    /// Do compile the protos.
220    pub fn compile_protos(
221        mut self,
222        protos: &[impl AsRef<Path>],
223        includes: &[impl AsRef<Path>],
224    ) -> Result<()> {
225        let out_dir = PathBuf::from(
226            env::var_os("OUT_DIR").ok_or_else(|| Error::other("OUT_DIR is not set"))?,
227        );
228
229        // We make sure the script is executed again if a file changed
230        for proto in protos {
231            println!("cargo:rerun-if-changed={}", proto.as_ref().display());
232        }
233
234        self.config
235            .enable_type_names()
236            .type_name_domain(
237                ["."],
238                self.type_name_domain
239                    .as_deref()
240                    .unwrap_or("type.googleapis.com"),
241            )
242            .service_generator(Box::new(self.generator));
243
244        // We configure with prost reflect
245        prost_reflect_build::Builder::new()
246            .file_descriptor_set_bytes("self::FILE_DESCRIPTOR_SET_BYTES")
247            .configure(&mut self.config, protos, includes)?;
248
249        // We do the build itself while saving the list of modules
250        let config = self.config.skip_protoc_run();
251        let file_descriptor_set = config.load_fds(protos, includes)?;
252        let modules = file_descriptor_set
253            .file
254            .iter()
255            .map(|fd| Module::from_protobuf_package_name(fd.package()))
256            .collect::<HashSet<_>>();
257
258        // We generate the files
259        config.compile_fds(file_descriptor_set)?;
260
261        // TODO(vsiles) consider proper AST parsing in case we need to do something
262        // more robust
263        //
264        // prepare a regex to match `pub mod <module-name> {`
265        let re = Regex::new(r"^(\s*)pub mod \w+ \{\s*$").expect("Failed to compile regex");
266
267        // We add the file descriptor to every file to make reflection work automatically
268        for module in modules {
269            let file_path = Path::new(&out_dir).join(module.to_file_name_or("_"));
270            if !file_path.exists() {
271                continue; // We ignore not built files
272            }
273            let original_content = fs::read_to_string(&file_path)?;
274
275            // scan for nested modules and insert the right FILE_DESCRIPTOR_SET_BYTES definition
276            let mut modified_content = original_content
277                .lines()
278                .flat_map(|line| {
279                    if let Some(captures) = re.captures(line) {
280                        let indentation = captures.get(1).map_or("", |m| m.as_str());
281                        vec![
282                            line.to_string(),
283                            // if there is no nested type, the next line would generate a warning
284                            format!("    {}{}", indentation, "#[allow(unused_imports)]"),
285                            format!(
286                                "    {}{}",
287                                indentation, "use super::FILE_DESCRIPTOR_SET_BYTES;"
288                            ),
289                        ]
290                    } else {
291                        vec![line.to_string()]
292                    }
293                })
294                .collect::<Vec<_>>();
295
296            modified_content.push("const FILE_DESCRIPTOR_SET_BYTES: &[u8] = include_bytes!(\"file_descriptor_set.bin\");\n".to_string());
297            let file_content = modified_content.join("\n");
298
299            fs::write(&file_path, &file_content)?;
300        }
301
302        Ok(())
303    }
304}
305
306/// Low level generator for Twirp related code.
307///
308/// This only useful if you want to customize builds. For common use cases, please use [`TwirpBuilder`].
309///
310/// Should be given to [`Config::service_generator`].
311///
312/// Client and server are not enabled by defaults and must be enabled with the [`with_client`](Self::with_client) and [`with_server`](Self::with_server) methods.
313#[derive(Default)]
314struct TwirpServiceGenerator {
315    client: bool,
316    server: bool,
317    grpc: bool,
318    // stores the default extractors as (argument_name, extractor_type)
319    default_request_extractors: Vec<(String, String)>,
320    // stores an extractor for a proto path as (argument_name, extractor_type)
321    matched_request_extractors: ProtoPathMap<(String, String)>,
322}
323
324impl TwirpServiceGenerator {
325    pub fn new() -> Self {
326        Self::default()
327    }
328
329    pub fn with_client(mut self) -> Self {
330        self.client = true;
331        self
332    }
333
334    pub fn with_server(mut self) -> Self {
335        self.server = true;
336        self
337    }
338
339    pub fn with_grpc(mut self) -> Self {
340        self.grpc = true;
341        self
342    }
343
344    pub fn with_default_axum_request_extractor(
345        mut self,
346        name: impl Into<String>,
347        type_name: impl Into<String>,
348    ) -> Self {
349        self.default_request_extractors
350            .push((name.into(), type_name.into()));
351        self
352    }
353
354    // This will override any and all default extractors, but only for the services which match service_proto_path.
355    pub fn with_service_specific_axum_request_extractor(
356        mut self,
357        name: impl Into<String>,
358        type_name: impl Into<String>,
359        service_proto_path: impl Into<String>,
360    ) -> Self {
361        self.matched_request_extractors
362            .insert(service_proto_path.into(), (name.into(), type_name.into()));
363        self
364    }
365}
366
367impl ServiceGenerator for TwirpServiceGenerator {
368    fn generate(&mut self, service: Service, buf: &mut String) {
369        self.do_generate(service, buf)
370            .expect("failed to generate Twirp service")
371    }
372}
373
374impl TwirpServiceGenerator {
375    fn do_generate(&mut self, service: Service, buf: &mut String) -> std::fmt::Result {
376        let mut service_matches = self
377            .matched_request_extractors
378            .service_matches(&service)
379            .peekable();
380
381        let extractors: Vec<_> = if service_matches.peek().is_some() {
382            service_matches.collect()
383        } else {
384            self.default_request_extractors.iter().collect()
385        };
386
387        if self.client {
388            writeln!(buf)?;
389            for comment in &service.comments.leading {
390                writeln!(buf, "/// {comment}")?;
391            }
392            if service.options.deprecated.unwrap_or(false) {
393                writeln!(buf, "#[deprecated]")?;
394            }
395            writeln!(buf, "#[derive(Clone)]")?;
396            writeln!(
397                buf,
398                "pub struct {}Client<C: ::twurst_client::TwirpHttpService> {{",
399                service.name
400            )?;
401            writeln!(buf, "    client: ::twurst_client::TwirpHttpClient<C>")?;
402            writeln!(buf, "}}")?;
403            writeln!(buf)?;
404            writeln!(
405                buf,
406                "impl<C: ::twurst_client::TwirpHttpService> {}Client<C> {{",
407                service.name
408            )?;
409            writeln!(
410                buf,
411                "    pub fn new(client: impl Into<::twurst_client::TwirpHttpClient<C>>) -> Self {{"
412            )?;
413            writeln!(buf, "        Self {{ client: client.into() }}")?;
414            writeln!(buf, "    }}")?;
415            for method in &service.methods {
416                if method.client_streaming || method.server_streaming {
417                    continue; // Not supported
418                }
419                for comment in &method.comments.leading {
420                    writeln!(buf, "    /// {comment}")?;
421                }
422                if method.options.deprecated.unwrap_or(false) {
423                    writeln!(buf, "#[deprecated]")?;
424                }
425                writeln!(
426                    buf,
427                    "    pub async fn {}(&self, request: &{}) -> Result<{}, ::twurst_client::TwirpError> {{",
428                    method.name, method.input_type, method.output_type,
429                )?;
430                writeln!(
431                    buf,
432                    "        self.client.call(\"/{}.{}/{}\", request).await",
433                    service.package, service.proto_name, method.proto_name,
434                )?;
435                writeln!(buf, "    }}")?;
436            }
437            writeln!(buf, "}}")?;
438        }
439
440        if self.server {
441            writeln!(buf)?;
442            for comment in &service.comments.leading {
443                writeln!(buf, "/// {comment}")?;
444            }
445            writeln!(buf, "#[::twurst_server::codegen::trait_variant_make(Send)]")?;
446            writeln!(buf, "pub trait {} {{", service.name)?;
447            for method in &service.methods {
448                if !self.grpc && (method.client_streaming || method.server_streaming) {
449                    continue; // No streaming
450                }
451                for comment in &method.comments.leading {
452                    writeln!(buf, "    /// {comment}")?;
453                }
454                write!(buf, "    async fn {}(&self, request: ", method.name)?;
455                if method.client_streaming {
456                    write!(
457                        buf,
458                        "impl ::twurst_server::codegen::Stream<Item=Result<{},::twurst_client::TwirpError>> + Send + 'static",
459                        method.input_type,
460                    )?;
461                } else {
462                    write!(buf, "{}", method.input_type)?;
463                }
464                for (arg_name, arg_type) in &extractors {
465                    write!(buf, ", {arg_name}: {arg_type}")?;
466                }
467                writeln!(buf, ") -> Result<")?;
468                if method.server_streaming {
469                    // TODO: move back to `impl` when we will be able to use precise capturing to not capture &self
470                    writeln!(
471                        buf,
472                        "Box<dyn ::twurst_server::codegen::Stream<Item=Result<{}, ::twurst_server::TwirpError>> + Send>",
473                        method.output_type
474                    )?;
475                } else {
476                    writeln!(buf, "{}", method.output_type)?;
477                }
478                writeln!(buf, ", ::twurst_server::TwirpError>;")?;
479            }
480            writeln!(buf)?;
481            writeln!(
482                buf,
483                "    fn into_router<S: Clone + Send + Sync + 'static>(self) -> ::twurst_server::codegen::Router<S> where Self : Sized + Send + Sync + 'static {{"
484            )?;
485            writeln!(
486                buf,
487                "        ::twurst_server::codegen::TwirpRouter::new(::std::sync::Arc::new(self))"
488            )?;
489            for method in &service.methods {
490                if method.client_streaming || method.server_streaming {
491                    writeln!(
492                        buf,
493                        "            .route_streaming(\"/{}.{}/{}\")",
494                        service.package, service.proto_name, method.proto_name,
495                    )?;
496                    continue;
497                }
498                write!(
499                    buf,
500                    "            .route(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: {}",
501                    service.package, service.proto_name, method.proto_name, method.input_type,
502                )?;
503                if extractors.is_empty() {
504                    write!(buf, ", _: ::twurst_server::codegen::RequestParts, _: S")?;
505                } else {
506                    write!(
507                        buf,
508                        ", mut parts: ::twurst_server::codegen::RequestParts, state: S",
509                    )?;
510                }
511                write!(buf, "| {{")?;
512                writeln!(buf, "                async move {{")?;
513                write!(buf, "                    service.{}(request", method.name)?;
514                for (_name, type_name) in &extractors {
515                    write!(
516                        buf,
517                        ", match <{type_name} as ::twurst_server::codegen::FromRequestParts<_>>::from_request_parts(&mut parts, &state).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
518                    )?;
519                }
520                writeln!(buf, ").await")?;
521                writeln!(buf, "                }}")?;
522                writeln!(buf, "            }})")?;
523            }
524            writeln!(buf, "            .build()")?;
525            writeln!(buf, "    }}")?;
526
527            if self.grpc {
528                writeln!(buf)?;
529                writeln!(
530                    buf,
531                    "    fn into_grpc_router(self) -> ::twurst_server::codegen::Router where Self : Sized + Send + Sync + 'static {{"
532                )?;
533                writeln!(
534                    buf,
535                    "        ::twurst_server::codegen::GrpcRouter::new(::std::sync::Arc::new(self))"
536                )?;
537                for method in &service.methods {
538                    let method_name = match (method.client_streaming, method.server_streaming) {
539                        (false, false) => "route",
540                        (false, true) => "route_server_streaming",
541                        (true, false) => "route_client_streaming",
542                        (true, true) => "route_streaming",
543                    };
544                    write!(
545                        buf,
546                        "            .{}(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: ",
547                        method_name, service.package, service.proto_name, method.proto_name,
548                    )?;
549                    if method.client_streaming {
550                        write!(
551                            buf,
552                            "::twurst_server::codegen::GrpcClientStream<{}>",
553                            method.input_type,
554                        )?;
555                    } else {
556                        write!(buf, "{}", method.input_type)?;
557                    }
558                    if extractors.is_empty() {
559                        write!(buf, ", _: ::twurst_server::codegen::RequestParts")?;
560                    } else {
561                        write!(buf, ", mut parts: ::twurst_server::codegen::RequestParts")?;
562                    }
563                    write!(buf, "| {{")?;
564                    write!(buf, "                async move {{")?;
565                    if method.server_streaming {
566                        write!(buf, "Ok(Box::into_pin(")?;
567                    }
568                    write!(buf, "service.{}(request", method.name)?;
569                    for (_name, type_name) in &extractors {
570                        write!(
571                            buf,
572                            ", match <{type_name} as ::twurst_server::codegen::FromRequestParts<_>>::from_request_parts(&mut parts, &()).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
573                        )?;
574                    }
575                    write!(buf, ").await")?;
576                    if method.server_streaming {
577                        write!(buf, "?))")?;
578                    }
579                    writeln!(buf, "}}")?;
580                    writeln!(buf, "            }})")?;
581                }
582                writeln!(buf, "            .build()")?;
583                writeln!(buf, "    }}")?;
584            }
585
586            writeln!(buf, "}}")?;
587        }
588
589        Ok(())
590    }
591}