1#![doc = include_str!("../README.md")]
2#![doc(
3 test(attr(deny(warnings))),
4 html_favicon_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png",
5 html_logo_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png"
6)]
7#![cfg_attr(docsrs, feature(doc_auto_cfg))]
8
9pub use prost_build as prost;
10use prost_build::{Config, Module, Service, ServiceGenerator};
11use regex::Regex;
12use std::collections::HashSet;
13use std::fmt::Write;
14use std::io::{Error, Result};
15use std::path::{Path, PathBuf};
16use std::{env, fs};
17
18#[derive(Default)]
22pub struct TwirpBuilder {
23 config: Config,
24 generator: TwirpServiceGenerator,
25 type_name_domain: Option<String>,
26}
27
28impl TwirpBuilder {
29 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn from_prost(config: Config) -> Self {
36 Self {
37 config,
38 generator: TwirpServiceGenerator::new(),
39 type_name_domain: None,
40 }
41 }
42
43 pub fn with_client(mut self) -> Self {
45 self.generator = self.generator.with_client();
46 self
47 }
48
49 pub fn with_server(mut self) -> Self {
51 self.generator = self.generator.with_server();
52 self
53 }
54
55 pub fn with_grpc(mut self) -> Self {
57 self.generator = self.generator.with_grpc();
58 self
59 }
60
61 pub fn with_axum_request_extractor(
79 mut self,
80 name: impl Into<String>,
81 type_name: impl Into<String>,
82 ) -> Self {
83 self.generator = self.generator.with_axum_request_extractor(name, type_name);
84 self
85 }
86
87 pub fn with_type_name_domain(mut self, domain: impl Into<String>) -> Self {
91 self.type_name_domain = Some(domain.into());
92 self
93 }
94
95 pub fn compile_protos(
97 mut self,
98 protos: &[impl AsRef<Path>],
99 includes: &[impl AsRef<Path>],
100 ) -> Result<()> {
101 let out_dir = PathBuf::from(
102 env::var_os("OUT_DIR").ok_or_else(|| Error::other("OUT_DIR is not set"))?,
103 );
104
105 for proto in protos {
107 println!("cargo:rerun-if-changed={}", proto.as_ref().display());
108 }
109 self.config
110 .enable_type_names()
111 .type_name_domain(
112 ["."],
113 self.type_name_domain
114 .as_deref()
115 .unwrap_or("type.googleapis.com"),
116 )
117 .service_generator(Box::new(self.generator));
118
119 prost_reflect_build::Builder::new()
121 .file_descriptor_set_bytes("self::FILE_DESCRIPTOR_SET_BYTES")
122 .configure(&mut self.config, protos, includes)?;
123
124 let config = self.config.skip_protoc_run();
126 let file_descriptor_set = config.load_fds(protos, includes)?;
127 let modules = file_descriptor_set
128 .file
129 .iter()
130 .map(|fd| Module::from_protobuf_package_name(fd.package()))
131 .collect::<HashSet<_>>();
132
133 config.compile_fds(file_descriptor_set)?;
135
136 let re = Regex::new(r"^(\s*)pub mod \w+ \{\s*$").expect("Failed to compile regex");
141
142 for module in modules {
144 let file_path = Path::new(&out_dir).join(module.to_file_name_or("_"));
145 if !file_path.exists() {
146 continue; }
148 let original_content = fs::read_to_string(&file_path)?;
149
150 let mut modified_content = original_content
152 .lines()
153 .flat_map(|line| {
154 if let Some(captures) = re.captures(line) {
155 let indentation = captures.get(1).map_or("", |m| m.as_str());
156 vec![
157 line.to_string(),
158 format!(" {}{}", indentation, "#[allow(unused_imports)]"),
160 format!(
161 " {}{}",
162 indentation, "use super::FILE_DESCRIPTOR_SET_BYTES;"
163 ),
164 ]
165 } else {
166 vec![line.to_string()]
167 }
168 })
169 .collect::<Vec<_>>();
170
171 modified_content.push("const FILE_DESCRIPTOR_SET_BYTES: &[u8] = include_bytes!(\"file_descriptor_set.bin\");\n".to_string());
172 let file_content = modified_content.join("\n");
173
174 fs::write(&file_path, &file_content)?;
175 }
176
177 Ok(())
178 }
179}
180
181#[derive(Default)]
189struct TwirpServiceGenerator {
190 client: bool,
191 server: bool,
192 grpc: bool,
193 request_extractors: Vec<(String, String)>,
194}
195
196impl TwirpServiceGenerator {
197 pub fn new() -> Self {
198 Self::default()
199 }
200
201 pub fn with_client(mut self) -> Self {
202 self.client = true;
203 self
204 }
205
206 pub fn with_server(mut self) -> Self {
207 self.server = true;
208 self
209 }
210
211 pub fn with_grpc(mut self) -> Self {
212 self.grpc = true;
213 self
214 }
215
216 pub fn with_axum_request_extractor(
217 mut self,
218 name: impl Into<String>,
219 type_name: impl Into<String>,
220 ) -> Self {
221 self.request_extractors
222 .push((name.into(), type_name.into()));
223 self
224 }
225}
226
227impl ServiceGenerator for TwirpServiceGenerator {
228 fn generate(&mut self, service: Service, buf: &mut String) {
229 self.do_generate(service, buf)
230 .expect("failed to generate Twirp service")
231 }
232}
233
234impl TwirpServiceGenerator {
235 fn do_generate(&mut self, service: Service, buf: &mut String) -> std::fmt::Result {
236 if self.client {
237 writeln!(buf)?;
238 for comment in &service.comments.leading {
239 writeln!(buf, "/// {comment}")?;
240 }
241 if service.options.deprecated.unwrap_or(false) {
242 writeln!(buf, "#[deprecated]")?;
243 }
244 writeln!(buf, "#[derive(Clone)]")?;
245 writeln!(
246 buf,
247 "pub struct {}Client<C: ::twurst_client::TwirpHttpService> {{",
248 service.name
249 )?;
250 writeln!(buf, " client: ::twurst_client::TwirpHttpClient<C>")?;
251 writeln!(buf, "}}")?;
252 writeln!(buf)?;
253 writeln!(
254 buf,
255 "impl<C: ::twurst_client::TwirpHttpService> {}Client<C> {{",
256 service.name
257 )?;
258 writeln!(
259 buf,
260 " pub fn new(client: impl Into<::twurst_client::TwirpHttpClient<C>>) -> Self {{"
261 )?;
262 writeln!(buf, " Self {{ client: client.into() }}")?;
263 writeln!(buf, " }}")?;
264 for method in &service.methods {
265 if method.client_streaming || method.server_streaming {
266 continue; }
268 for comment in &method.comments.leading {
269 writeln!(buf, " /// {comment}")?;
270 }
271 if method.options.deprecated.unwrap_or(false) {
272 writeln!(buf, "#[deprecated]")?;
273 }
274 writeln!(
275 buf,
276 " pub async fn {}(&self, request: &{}) -> Result<{}, ::twurst_client::TwirpError> {{",
277 method.name, method.input_type, method.output_type,
278 )?;
279 writeln!(
280 buf,
281 " self.client.call(\"/{}.{}/{}\", request).await",
282 service.package, service.proto_name, method.proto_name,
283 )?;
284 writeln!(buf, " }}")?;
285 }
286 writeln!(buf, "}}")?;
287 }
288
289 if self.server {
290 writeln!(buf)?;
291 for comment in &service.comments.leading {
292 writeln!(buf, "/// {comment}")?;
293 }
294 writeln!(buf, "#[::twurst_server::codegen::trait_variant_make(Send)]")?;
295 writeln!(buf, "pub trait {} {{", service.name)?;
296 for method in &service.methods {
297 if !self.grpc && (method.client_streaming || method.server_streaming) {
298 continue; }
300 for comment in &method.comments.leading {
301 writeln!(buf, " /// {comment}")?;
302 }
303 write!(buf, " async fn {}(&self, request: ", method.name)?;
304 if method.client_streaming {
305 write!(
306 buf,
307 "impl ::twurst_server::codegen::Stream<Item=Result<{},::twurst_client::TwirpError>> + Send + 'static",
308 method.input_type,
309 )?;
310 } else {
311 write!(buf, "{}", method.input_type)?;
312 }
313 for (arg_name, arg_type) in &self.request_extractors {
314 write!(buf, ", {arg_name}: {arg_type}")?;
315 }
316 writeln!(buf, ") -> Result<")?;
317 if method.server_streaming {
318 writeln!(buf, "Box<dyn ::twurst_server::codegen::Stream<Item=Result<{}, ::twurst_server::TwirpError>> + Send>", method.output_type)?;
320 } else {
321 writeln!(buf, "{}", method.output_type)?;
322 }
323 writeln!(buf, ", ::twurst_server::TwirpError>;")?;
324 }
325 writeln!(buf)?;
326 writeln!(
327 buf,
328 " fn into_router<S: Clone + Send + Sync + 'static>(self) -> ::twurst_server::codegen::Router<S> where Self : Sized + Send + Sync + 'static {{"
329 )?;
330 writeln!(
331 buf,
332 " ::twurst_server::codegen::TwirpRouter::new(::std::sync::Arc::new(self))"
333 )?;
334 for method in &service.methods {
335 if method.client_streaming || method.server_streaming {
336 writeln!(
337 buf,
338 " .route_streaming(\"/{}.{}/{}\")",
339 service.package, service.proto_name, method.proto_name,
340 )?;
341 continue;
342 }
343 write!(
344 buf,
345 " .route(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: {}",
346 service.package, service.proto_name, method.proto_name, method.input_type,
347 )?;
348 if self.request_extractors.is_empty() {
349 write!(buf, ", _: ::twurst_server::codegen::RequestParts, _: S")?;
350 } else {
351 write!(
352 buf,
353 ", mut parts: ::twurst_server::codegen::RequestParts, state: S",
354 )?;
355 }
356 write!(buf, "| {{")?;
357 writeln!(buf, " async move {{")?;
358 write!(buf, " service.{}(request", method.name)?;
359 for (_name, type_name) in &self.request_extractors {
360 write!(
361 buf,
362 ", match <{type_name} as ::twurst_server::codegen::FromRequestParts<_>>::from_request_parts(&mut parts, &state).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
363 )?;
364 }
365 writeln!(buf, ").await")?;
366 writeln!(buf, " }}")?;
367 writeln!(buf, " }})")?;
368 }
369 writeln!(buf, " .build()")?;
370 writeln!(buf, " }}")?;
371
372 if self.grpc {
373 writeln!(buf)?;
374 writeln!(
375 buf,
376 " fn into_grpc_router(self) -> ::twurst_server::codegen::Router where Self : Sized + Send + Sync + 'static {{"
377 )?;
378 writeln!(
379 buf,
380 " ::twurst_server::codegen::GrpcRouter::new(::std::sync::Arc::new(self))"
381 )?;
382 for method in &service.methods {
383 let method_name = match (method.client_streaming, method.server_streaming) {
384 (false, false) => "route",
385 (false, true) => "route_server_streaming",
386 (true, false) => "route_client_streaming",
387 (true, true) => "route_streaming",
388 };
389 write!(
390 buf,
391 " .{}(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: ",method_name,
392 service.package, service.proto_name, method.proto_name,
393 )?;
394 if method.client_streaming {
395 write!(
396 buf,
397 "::twurst_server::codegen::GrpcClientStream<{}>",
398 method.input_type,
399 )?;
400 } else {
401 write!(buf, "{}", method.input_type)?;
402 }
403 if self.request_extractors.is_empty() {
404 write!(buf, ", _: ::twurst_server::codegen::RequestParts")?;
405 } else {
406 write!(buf, ", mut parts: ::twurst_server::codegen::RequestParts")?;
407 }
408 write!(buf, "| {{")?;
409 write!(buf, " async move {{")?;
410 if method.server_streaming {
411 write!(buf, "Ok(Box::into_pin(")?;
412 }
413 write!(buf, "service.{}(request", method.name)?;
414 for (_name, type_name) in &self.request_extractors {
415 write!(
416 buf,
417 ", match <{type_name} as ::twurst_server::codegen::FromRequestParts<_>>::from_request_parts(&mut parts, &()).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
418 )?;
419 }
420 write!(buf, ").await")?;
421 if method.server_streaming {
422 write!(buf, "?))")?;
423 }
424 writeln!(buf, "}}")?;
425 writeln!(buf, " }})")?;
426 }
427 writeln!(buf, " .build()")?;
428 writeln!(buf, " }}")?;
429 }
430
431 writeln!(buf, "}}")?;
432 }
433
434 Ok(())
435 }
436}