1use proc_macro2::TokenStream;
2use prost_build::Module;
3use quote::format_ident;
4use quote::quote;
5use std::collections::BTreeMap;
6use std::env;
7use std::path::{Path, PathBuf};
8use syn::parse_quote;
9
10#[derive(Debug, Clone, Default, Copy)]
11pub struct GeneratorFeatures {
12 reqwest: Option<GeneratorReqwestFeatures>,
13 axum: bool,
14}
15
16#[derive(Debug, Clone, Copy)]
17pub struct GeneratorReqwestFeatures {
18 pub proto: bool,
19 pub json: bool,
20}
21
22impl GeneratorFeatures {
23 pub fn new() -> Self {
24 Self::default()
25 }
26
27 pub fn reqwest(mut self, options: GeneratorReqwestFeatures) -> Self {
28 self.reqwest = Some(options);
29 self
30 }
31
32 pub fn axum(mut self) -> Self {
33 self.axum = true;
34 self
35 }
36
37 pub fn full(mut self) -> Self {
38 self.reqwest = Some(GeneratorReqwestFeatures {
39 proto: true,
40 json: true,
41 });
42 self.axum = true;
43 self
44 }
45}
46
47#[derive(Debug, Clone)]
48pub struct Settings {
49 pub includes: Vec<PathBuf>,
50 pub inputs: Vec<PathBuf>,
51 pub protoc_args: Vec<String>,
52 pub protoc_version: String,
53
54 pub features: GeneratorFeatures,
55 pub extern_paths: BTreeMap<String, String>,
57}
58
59impl Default for Settings {
60 fn default() -> Self {
61 Self {
62 includes: Vec::new(),
63 inputs: Vec::new(),
64 protoc_args: Vec::new(),
65 protoc_version: "31.1".to_string(),
66
67 features: GeneratorFeatures::default().full(),
68 extern_paths: {
69 let mut m = BTreeMap::new();
70 m.insert(".google.protobuf".to_string(), "::pbjson_types".to_string());
71 m
72 },
73 }
74 }
75}
76
77impl Settings {
78 pub fn from_directory_recursive<P>(path: P) -> anyhow::Result<Self>
79 where
80 P: Into<PathBuf>,
81 {
82 let path = path.into();
83 let mut settings = Self::default();
84 settings.includes.push(path.clone());
85
86 let mut dirs = vec![path];
88 while let Some(dir) = dirs.pop() {
89 for entry in std::fs::read_dir(dir)? {
90 let entry = entry?;
91 let path = entry.path();
92 if path.is_dir() {
93 dirs.push(path.clone());
94 } else if path.extension().map(|ext| ext == "proto").unwrap_or(false) {
95 settings.inputs.push(path);
96 }
97 }
98 }
99
100 Ok(settings)
101 }
102
103 pub fn generate(&self) -> anyhow::Result<BTreeMap<String, String>> {
104 let out_dir = env::var("OUT_DIR").unwrap();
105 let protoc_path = protoc_fetcher::protoc(&self.protoc_version, Path::new(&out_dir))?;
106 unsafe {
107 env::set_var("PROTOC", protoc_path);
108 }
109
110 for input in &self.inputs {
112 println!("cargo:rerun-if-changed={}", input.display());
113 }
114
115 let descriptor_path =
116 PathBuf::from(env::var("OUT_DIR").unwrap()).join("file_descriptor.bin");
117
118 let mut conf = prost_build::Config::new();
119
120 conf.compile_well_known_types();
122 conf.file_descriptor_set_path(&descriptor_path);
123 conf.default_package_filename("connectrpc.rs");
124 conf.service_generator(Box::new(service_generator(self.features)));
125
126 for (proto_prefix, rust_path) in &self.extern_paths {
128 conf.extern_path(proto_prefix, rust_path);
129 }
130
131 for arg in &self.protoc_args {
133 conf.protoc_arg(arg);
134 }
135
136 let file_descriptor_set = conf.load_fds(&self.inputs, &self.includes)?;
138 let requests = file_descriptor_set
139 .file
140 .into_iter()
141 .map(|descriptor| {
142 (
143 Module::from_protobuf_package_name(descriptor.package()),
144 descriptor,
145 )
146 })
147 .collect::<Vec<_>>();
148
149 let mut modules = conf
150 .generate(requests)?
151 .into_iter()
152 .map(|(m, c)| (m.parts().collect::<Vec<_>>().join("."), c))
153 .collect::<BTreeMap<_, _>>();
154
155 let descriptor_set = std::fs::read(&descriptor_path)?;
156
157 let mut builder = pbjson_build::Builder::new();
158 builder.register_descriptors(&descriptor_set)?;
159
160 for (path, rust_path) in &self.extern_paths {
161 builder.extern_path(path, rust_path);
162 }
163 let writers = builder.generate(&["."], move |_package| Ok(Vec::new()))?;
164
165 for (package, output) in writers {
166 modules.entry(package.to_string()).and_modify(|c| {
167 c.push('\n');
168 c.push_str(&String::from_utf8_lossy(&output));
169 });
170 }
171
172 Ok(modules)
173 }
174}
175
176struct Service {
177 rpc_trait_name: syn::Ident,
179
180 fqn: String,
182
183 methods: Vec<Method>,
185}
186
187struct Method {
188 name: syn::Ident,
190
191 proto_name: String,
193
194 input_type: syn::Type,
196
197 output_type: syn::Type,
199}
200
201impl Service {
202 fn from_prost(s: prost_build::Service) -> Self {
203 let fqn = format!("{}.{}", s.package, s.proto_name);
204 let rpc_trait_name = format_ident!("{}", &s.name);
205 let methods = s
206 .methods
207 .into_iter()
208 .map(|m| Method::from_prost(&s.package, &s.proto_name, m))
209 .collect();
210
211 Self {
212 rpc_trait_name,
213 fqn,
214 methods,
215 }
216 }
217}
218
219impl Method {
220 fn from_prost(pkg_name: &str, svc_name: &str, m: prost_build::Method) -> Self {
221 let as_type = |s| -> syn::Type {
222 let Ok(typ) = syn::parse_str::<syn::Type>(s) else {
223 panic!(
224 "connectrpc-client-build build failed generated invalid Rust while processing {pkg}.{svc}/{name}). this is a bug in connectrpc-client-build, please file a GitHub issue",
225 pkg = pkg_name,
226 svc = svc_name,
227 name = m.proto_name,
228 );
229 };
230 typ
231 };
232
233 let input_type = as_type(&m.input_type);
234 let output_type = as_type(&m.output_type);
235 let name = format_ident!("{}", m.name);
236 let message = m.proto_name;
237
238 Self {
239 name,
240 proto_name: message,
241 input_type,
242 output_type,
243 }
244 }
245}
246
247#[derive(Clone, Copy)]
248pub struct ServiceGenerator {
249 features: GeneratorFeatures,
250}
251
252pub fn service_generator(features: GeneratorFeatures) -> ServiceGenerator {
253 ServiceGenerator { features }
254}
255
256impl prost_build::ServiceGenerator for ServiceGenerator {
257 fn generate(&mut self, service: prost_build::Service, buf: &mut String) {
258 if self.features.reqwest.is_none() && !self.features.axum {
259 return;
260 }
261
262 let service = Service::from_prost(service);
263 let mut use_items: Vec<syn::ItemUse> = vec![];
264 use_items.push(parse_quote! {
265 pub use ::connectrpc;
266 });
267
268 let mut tokens: Vec<syn::Item> = Vec::new();
269
270 let async_service_trait = format_ident!("{}AsyncService", service.rpc_trait_name);
271 tokens.push(generate_async_service_trait(&service, &async_service_trait).into());
272
273 if let Some(reqwest_features) = self.features.reqwest {
274 let mut generates = vec![];
275 if reqwest_features.proto {
276 let codec = format_ident!("Proto");
277 let client = format_ident!("{}Reqwest{}Client", service.rpc_trait_name, codec);
278 generates.push((client, codec));
279 }
280 if reqwest_features.json {
281 let codec = format_ident!("Json");
282 let client = format_ident!("{}Reqwest{}Client", service.rpc_trait_name, codec);
283 generates.push((client, codec));
284 }
285 if !generates.is_empty() {
286 use_items.push(parse_quote! {
287 use ::connectrpc::client::AsyncUnaryClient;
288 })
289 }
290 for (client, codec) in generates {
291 tokens.push(generate_reqwest_client_struct(&service, &client).into());
292 tokens.push(generate_reqwest_client_impl(&service, &client, &codec).into());
293 tokens.push(
294 generate_reqwest_client_trait_impl(&service, &async_service_trait, &client)
295 .into(),
296 );
297 }
298 }
299
300 if self.features.axum {
301 let server_struct = format_ident!("{}AxumServer", service.rpc_trait_name);
302 tokens.push(generate_axum_server_struct(&service, &server_struct).into());
303 tokens.push(generate_axum_server_impl(&service, &server_struct).into());
304 }
305
306 let ast: syn::File = parse_quote! {
307 #(#use_items)*
309
310 #(#tokens)*
311
312 };
313
314 let code = prettyplease::unparse(&ast);
315 buf.push_str(&code);
316 }
317}
318
319fn generate_async_service_trait(service: &Service, service_trait: &syn::Ident) -> syn::ItemTrait {
320 let mut trait_methods: Vec<syn::TraitItemFn> = Vec::with_capacity(service.methods.len());
321
322 for m in &service.methods {
323 let name = &m.name;
324 let input_type = &m.input_type;
325 let output_type = &m.output_type;
326
327 trait_methods.push(parse_quote! {
328 fn #name(
329 &self,
330 request: ::connectrpc::UnaryRequest<#input_type>
331 ) -> impl std::future::Future<Output = ::connectrpc::Result<::connectrpc::UnaryResponse<#output_type>>> + Send + '_;
332 });
333 }
334
335 parse_quote! {
336 pub trait #service_trait: Send + Sync {
337 #(#trait_methods)*
338 }
339 }
340}
341
342fn generate_reqwest_client_struct(service: &Service, client: &syn::Ident) -> syn::ItemStruct {
343 let mut client_fields: Vec<TokenStream> = Vec::with_capacity(service.methods.len());
344
345 for m in &service.methods {
346 let name = &m.name;
347
348 client_fields.push(quote! {
349 pub #name: ::connectrpc::ReqwestClient,
350 });
351 }
352
353 parse_quote! {
354 #[derive(Clone)]
355 pub struct #client {
356 #(#client_fields)*
357 }
358 }
359}
360
361fn generate_reqwest_client_impl(
362 service: &Service,
363 struct_name: &syn::Ident,
364 codec_ident: &syn::Ident,
365) -> syn::ItemImpl {
366 let mut client_inits: Vec<TokenStream> = Vec::with_capacity(service.methods.len());
367
368 for m in &service.methods {
369 let name = &m.name;
370
371 client_inits.push(quote! {
372 #name: ::connectrpc::ReqwestClient::new(client.clone(), base_uri.clone(), ::connectrpc::codec::Codec::#codec_ident)?,
373 });
374 }
375
376 parse_quote! {
377 impl #struct_name {
378 pub fn new(client: ::reqwest::Client, base_uri: ::connectrpc::http::Uri) -> ::connectrpc::Result<Self> {
379 Ok(Self {
380 #(#client_inits)*
381 })
382 }
383 }
384 }
385}
386
387fn generate_reqwest_client_trait_impl(
388 service: &Service,
389 trait_name: &syn::Ident,
390 struct_name: &syn::Ident,
391) -> syn::ItemImpl {
392 let mut client_methods: Vec<syn::ImplItemFn> = Vec::with_capacity(service.methods.len());
393 for m in &service.methods {
394 let name = &m.name;
395 let input_type = &m.input_type;
396 let output_type = &m.output_type;
397 let path = format!("/{}/{}", service.fqn, m.proto_name);
398
399 client_methods.push(parse_quote! {
400 async fn #name(
401 &self,
402 request: ::connectrpc::UnaryRequest<#input_type>
403 ) -> ::connectrpc::Result<::connectrpc::UnaryResponse<#output_type>> {
404 self.#name.call_unary(#path, request).await
405 }
406 });
407 }
408
409 parse_quote! {
410 impl #trait_name for #struct_name {
411 #(#client_methods)*
412 }
413 }
414}
415
416fn generate_axum_server_struct(service: &Service, struct_name: &syn::Ident) -> syn::ItemStruct {
417 let mut handlers = Vec::with_capacity(service.methods.len());
418 for i in 1..=service.methods.len() {
419 handlers.push(format_ident!("H{i}"));
420 }
421
422 let mut handler_constraints = Vec::with_capacity(service.methods.len());
423 let mut fields = Vec::with_capacity(service.methods.len());
424
425 for (i, method) in service.methods.iter().enumerate() {
426 let name = &method.name;
427 let input_type = &method.input_type;
428 let output_type = &method.output_type;
429
430 let handler = &handlers[i];
431 handler_constraints.push(quote! {
432 #handler: ::connectrpc::server::axum::RpcUnaryHandler<#input_type, #output_type, S>,
433 });
434
435 fields.push(quote! {
436 pub #name: #handler,
437 });
438 }
439
440 parse_quote! {
441 pub struct #struct_name<S, #(#handlers,)*>
442 where
443 S: Send + Sync + Clone + 'static,
444 #(#handler_constraints)*
445 {
446 pub state: S,
447 #(#fields)*
448 }
449 }
450}
451
452fn generate_axum_server_impl(service: &Service, struct_name: &syn::Ident) -> syn::ItemImpl {
453 let mut route_inits = Vec::with_capacity(service.methods.len());
454 let mut handlers = Vec::with_capacity(service.methods.len());
457 for i in 1..=service.methods.len() {
458 handlers.push(format_ident!("H{i}"));
459 }
460
461 let mut handler_constraints = Vec::with_capacity(service.methods.len());
462 for (i, method) in service.methods.iter().enumerate() {
463 let name = &method.name;
464 let path = format!("/{}/{}", service.fqn, method.proto_name);
465 let input_type = &method.input_type;
466 let output_type = &method.output_type;
467
468 route_inits.push(quote! {
469 let #name = self.#name;
470 let cs = common_server.clone();
471 router = router.route(
472 #path,
473 ::axum::routing::any(move |::axum::extract::State(state): ::axum::extract::State<S>, req: ::axum::extract::Request| async move {
474 #name.call(req, state, cs).await
475 })
476 );
477 });
478
479 let handler = &handlers[i];
480 handler_constraints.push(quote! {
481 #handler: ::connectrpc::server::axum::RpcUnaryHandler<#input_type, #output_type, S>,
482 });
483 }
484
485 parse_quote! {
486 impl<S, #(#handlers,)*> #struct_name<S, #(#handlers,)*>
487 where
488 S: Send + Sync + Clone + 'static,
489 #(#handler_constraints)*
490 {
491 pub fn into_router(
492 self,
493 ) -> ::axum::Router {
494 let mut router = ::axum::Router::new();
495 let common_server = ::connectrpc::server::CommonServer::new();
496 #(#route_inits)*
497 router.with_state(self.state)
498 }
499 }
500 }
501}