connectrpc_build/
lib.rs

1use proc_macro2::TokenStream;
2use prost_build::Module;
3use quote::format_ident;
4use quote::quote;
5use std::collections::BTreeMap;
6use std::env;
7use std::path::{Path, PathBuf};
8use syn::parse_quote;
9
10#[derive(Debug, Clone, Default, Copy)]
11pub struct GeneratorFeatures {
12    reqwest: Option<GeneratorReqwestFeatures>,
13    axum: bool,
14}
15
16#[derive(Debug, Clone, Copy)]
17pub struct GeneratorReqwestFeatures {
18    pub proto: bool,
19    pub json: bool,
20}
21
22impl GeneratorFeatures {
23    pub fn new() -> Self {
24        Self::default()
25    }
26
27    pub fn reqwest(mut self, options: GeneratorReqwestFeatures) -> Self {
28        self.reqwest = Some(options);
29        self
30    }
31
32    pub fn axum(mut self) -> Self {
33        self.axum = true;
34        self
35    }
36
37    pub fn full(mut self) -> Self {
38        self.reqwest = Some(GeneratorReqwestFeatures {
39            proto: true,
40            json: true,
41        });
42        self.axum = true;
43        self
44    }
45}
46
47#[derive(Debug, Clone)]
48pub struct Settings {
49    pub includes: Vec<PathBuf>,
50    pub inputs: Vec<PathBuf>,
51    pub protoc_args: Vec<String>,
52    pub protoc_version: String,
53
54    pub features: GeneratorFeatures,
55    // Map of protobuf package prefixes to Rust module paths for extern types.
56    pub extern_paths: BTreeMap<String, String>,
57}
58
59impl Default for Settings {
60    fn default() -> Self {
61        Self {
62            includes: Vec::new(),
63            inputs: Vec::new(),
64            protoc_args: Vec::new(),
65            protoc_version: "31.1".to_string(),
66
67            features: GeneratorFeatures::default().full(),
68            extern_paths: {
69                let mut m = BTreeMap::new();
70                m.insert(".google.protobuf".to_string(), "::pbjson_types".to_string());
71                m
72            },
73        }
74    }
75}
76
77impl Settings {
78    pub fn from_directory_recursive<P>(path: P) -> anyhow::Result<Self>
79    where
80        P: Into<PathBuf>,
81    {
82        let path = path.into();
83        let mut settings = Self::default();
84        settings.includes.push(path.clone());
85
86        // Recursively add all files that end in ".proto" to the inputs.
87        let mut dirs = vec![path];
88        while let Some(dir) = dirs.pop() {
89            for entry in std::fs::read_dir(dir)? {
90                let entry = entry?;
91                let path = entry.path();
92                if path.is_dir() {
93                    dirs.push(path.clone());
94                } else if path.extension().map(|ext| ext == "proto").unwrap_or(false) {
95                    settings.inputs.push(path);
96                }
97            }
98        }
99
100        Ok(settings)
101    }
102
103    pub fn generate(&self) -> anyhow::Result<BTreeMap<String, String>> {
104        let out_dir = env::var("OUT_DIR").unwrap();
105        let protoc_path = protoc_fetcher::protoc(&self.protoc_version, Path::new(&out_dir))?;
106        unsafe {
107            env::set_var("PROTOC", protoc_path);
108        }
109
110        // Tell cargo to rerun the build if any of the inputs change.
111        for input in &self.inputs {
112            println!("cargo:rerun-if-changed={}", input.display());
113        }
114
115        let descriptor_path =
116            PathBuf::from(env::var("OUT_DIR").unwrap()).join("file_descriptor.bin");
117
118        let mut conf = prost_build::Config::new();
119
120        // Standard prost configuration
121        conf.compile_well_known_types();
122        conf.file_descriptor_set_path(&descriptor_path);
123        conf.default_package_filename("connectrpc.rs");
124        conf.service_generator(Box::new(service_generator(self.features)));
125
126        // Apply extern paths
127        for (proto_prefix, rust_path) in &self.extern_paths {
128            conf.extern_path(proto_prefix, rust_path);
129        }
130
131        // Arg configuration
132        for arg in &self.protoc_args {
133            conf.protoc_arg(arg);
134        }
135
136        // Have to load and generate by hand because I don't care about hacks.
137        let file_descriptor_set = conf.load_fds(&self.inputs, &self.includes)?;
138        let requests = file_descriptor_set
139            .file
140            .into_iter()
141            .map(|descriptor| {
142                (
143                    Module::from_protobuf_package_name(descriptor.package()),
144                    descriptor,
145                )
146            })
147            .collect::<Vec<_>>();
148
149        let mut modules = conf
150            .generate(requests)?
151            .into_iter()
152            .map(|(m, c)| (m.parts().collect::<Vec<_>>().join("."), c))
153            .collect::<BTreeMap<_, _>>();
154
155        let descriptor_set = std::fs::read(&descriptor_path)?;
156
157        let mut builder = pbjson_build::Builder::new();
158        builder.register_descriptors(&descriptor_set)?;
159
160        for (path, rust_path) in &self.extern_paths {
161            builder.extern_path(path, rust_path);
162        }
163        let writers = builder.generate(&["."], move |_package| Ok(Vec::new()))?;
164
165        for (package, output) in writers {
166            modules.entry(package.to_string()).and_modify(|c| {
167                c.push('\n');
168                c.push_str(&String::from_utf8_lossy(&output));
169            });
170        }
171
172        Ok(modules)
173    }
174}
175
176struct Service {
177    /// The name of the server trait, as parsed into a Rust identifier.
178    rpc_trait_name: syn::Ident,
179
180    /// The fully qualified protobuf name of this Service.
181    fqn: String,
182
183    /// The methods that make up this service.
184    methods: Vec<Method>,
185}
186
187struct Method {
188    /// The name of the method, as parsed into a Rust identifier.
189    name: syn::Ident,
190
191    /// The name of the method as it appears in the protobuf definition.
192    proto_name: String,
193
194    /// The input type of this method.
195    input_type: syn::Type,
196
197    /// The output type of this method.
198    output_type: syn::Type,
199}
200
201impl Service {
202    fn from_prost(s: prost_build::Service) -> Self {
203        let fqn = format!("{}.{}", s.package, s.proto_name);
204        let rpc_trait_name = format_ident!("{}", &s.name);
205        let methods = s
206            .methods
207            .into_iter()
208            .map(|m| Method::from_prost(&s.package, &s.proto_name, m))
209            .collect();
210
211        Self {
212            rpc_trait_name,
213            fqn,
214            methods,
215        }
216    }
217}
218
219impl Method {
220    fn from_prost(pkg_name: &str, svc_name: &str, m: prost_build::Method) -> Self {
221        let as_type = |s| -> syn::Type {
222            let Ok(typ) = syn::parse_str::<syn::Type>(s) else {
223                panic!(
224                    "connectrpc-client-build build failed generated invalid Rust while processing {pkg}.{svc}/{name}). this is a bug in connectrpc-client-build, please file a GitHub issue",
225                    pkg = pkg_name,
226                    svc = svc_name,
227                    name = m.proto_name,
228                );
229            };
230            typ
231        };
232
233        let input_type = as_type(&m.input_type);
234        let output_type = as_type(&m.output_type);
235        let name = format_ident!("{}", m.name);
236        let message = m.proto_name;
237
238        Self {
239            name,
240            proto_name: message,
241            input_type,
242            output_type,
243        }
244    }
245}
246
247#[derive(Clone, Copy)]
248pub struct ServiceGenerator {
249    features: GeneratorFeatures,
250}
251
252pub fn service_generator(features: GeneratorFeatures) -> ServiceGenerator {
253    ServiceGenerator { features }
254}
255
256impl prost_build::ServiceGenerator for ServiceGenerator {
257    fn generate(&mut self, service: prost_build::Service, buf: &mut String) {
258        if self.features.reqwest.is_none() && !self.features.axum {
259            return;
260        }
261
262        let service = Service::from_prost(service);
263        let mut use_items: Vec<syn::ItemUse> = vec![];
264        use_items.push(parse_quote! {
265            pub use ::connectrpc;
266        });
267
268        let mut tokens: Vec<syn::Item> = Vec::new();
269
270        let async_service_trait = format_ident!("{}AsyncService", service.rpc_trait_name);
271        tokens.push(generate_async_service_trait(&service, &async_service_trait).into());
272
273        if let Some(reqwest_features) = self.features.reqwest {
274            let mut generates = vec![];
275            if reqwest_features.proto {
276                let codec = format_ident!("Proto");
277                let client = format_ident!("{}Reqwest{}Client", service.rpc_trait_name, codec);
278                generates.push((client, codec));
279            }
280            if reqwest_features.json {
281                let codec = format_ident!("Json");
282                let client = format_ident!("{}Reqwest{}Client", service.rpc_trait_name, codec);
283                generates.push((client, codec));
284            }
285            if !generates.is_empty() {
286                use_items.push(parse_quote! {
287                    use ::connectrpc::client::AsyncUnaryClient;
288                })
289            }
290            for (client, codec) in generates {
291                tokens.push(generate_reqwest_client_struct(&service, &client).into());
292                tokens.push(generate_reqwest_client_impl(&service, &client, &codec).into());
293                tokens.push(
294                    generate_reqwest_client_trait_impl(&service, &async_service_trait, &client)
295                        .into(),
296                );
297            }
298        }
299
300        if self.features.axum {
301            let server_struct = format_ident!("{}AxumServer", service.rpc_trait_name);
302            tokens.push(generate_axum_server_struct(&service, &server_struct).into());
303            tokens.push(generate_axum_server_impl(&service, &server_struct).into());
304        }
305
306        let ast: syn::File = parse_quote! {
307            // Auto-generated by connectrpc-client-build. Do not edit.
308            #(#use_items)*
309
310            #(#tokens)*
311
312        };
313
314        let code = prettyplease::unparse(&ast);
315        buf.push_str(&code);
316    }
317}
318
319fn generate_async_service_trait(service: &Service, service_trait: &syn::Ident) -> syn::ItemTrait {
320    let mut trait_methods: Vec<syn::TraitItemFn> = Vec::with_capacity(service.methods.len());
321
322    for m in &service.methods {
323        let name = &m.name;
324        let input_type = &m.input_type;
325        let output_type = &m.output_type;
326
327        trait_methods.push(parse_quote! {
328                fn #name(
329                    &self,
330                    request: ::connectrpc::UnaryRequest<#input_type>
331                ) -> impl std::future::Future<Output = ::connectrpc::Result<::connectrpc::UnaryResponse<#output_type>>> + Send + '_;
332            });
333    }
334
335    parse_quote! {
336        pub trait #service_trait: Send + Sync {
337            #(#trait_methods)*
338        }
339    }
340}
341
342fn generate_reqwest_client_struct(service: &Service, client: &syn::Ident) -> syn::ItemStruct {
343    let mut client_fields: Vec<TokenStream> = Vec::with_capacity(service.methods.len());
344
345    for m in &service.methods {
346        let name = &m.name;
347
348        client_fields.push(quote! {
349            pub #name: ::connectrpc::ReqwestClient,
350        });
351    }
352
353    parse_quote! {
354        #[derive(Clone)]
355        pub struct #client {
356            #(#client_fields)*
357        }
358    }
359}
360
361fn generate_reqwest_client_impl(
362    service: &Service,
363    struct_name: &syn::Ident,
364    codec_ident: &syn::Ident,
365) -> syn::ItemImpl {
366    let mut client_inits: Vec<TokenStream> = Vec::with_capacity(service.methods.len());
367
368    for m in &service.methods {
369        let name = &m.name;
370
371        client_inits.push(quote! {
372            #name: ::connectrpc::ReqwestClient::new(client.clone(), base_uri.clone(), ::connectrpc::codec::Codec::#codec_ident)?,
373        });
374    }
375
376    parse_quote! {
377        impl #struct_name {
378            pub fn new(client: ::reqwest::Client, base_uri: ::connectrpc::http::Uri) -> ::connectrpc::Result<Self> {
379                Ok(Self {
380                    #(#client_inits)*
381                })
382            }
383        }
384    }
385}
386
387fn generate_reqwest_client_trait_impl(
388    service: &Service,
389    trait_name: &syn::Ident,
390    struct_name: &syn::Ident,
391) -> syn::ItemImpl {
392    let mut client_methods: Vec<syn::ImplItemFn> = Vec::with_capacity(service.methods.len());
393    for m in &service.methods {
394        let name = &m.name;
395        let input_type = &m.input_type;
396        let output_type = &m.output_type;
397        let path = format!("/{}/{}", service.fqn, m.proto_name);
398
399        client_methods.push(parse_quote! {
400            async fn #name(
401                &self,
402                request: ::connectrpc::UnaryRequest<#input_type>
403            ) -> ::connectrpc::Result<::connectrpc::UnaryResponse<#output_type>> {
404                self.#name.call_unary(#path, request).await
405            }
406        });
407    }
408
409    parse_quote! {
410        impl #trait_name for #struct_name {
411            #(#client_methods)*
412        }
413    }
414}
415
416fn generate_axum_server_struct(service: &Service, struct_name: &syn::Ident) -> syn::ItemStruct {
417    let mut handlers = Vec::with_capacity(service.methods.len());
418    for i in 1..=service.methods.len() {
419        handlers.push(format_ident!("H{i}"));
420    }
421
422    let mut handler_constraints = Vec::with_capacity(service.methods.len());
423    let mut fields = Vec::with_capacity(service.methods.len());
424
425    for (i, method) in service.methods.iter().enumerate() {
426        let name = &method.name;
427        let input_type = &method.input_type;
428        let output_type = &method.output_type;
429
430        let handler = &handlers[i];
431        handler_constraints.push(quote! {
432            #handler: ::connectrpc::server::axum::RpcUnaryHandler<#input_type, #output_type, S>,
433        });
434
435        fields.push(quote! {
436            pub #name: #handler,
437        });
438    }
439
440    parse_quote! {
441        pub struct #struct_name<S, #(#handlers,)*>
442        where
443            S: Send + Sync + Clone + 'static,
444            #(#handler_constraints)*
445        {
446            pub state: S,
447            #(#fields)*
448        }
449    }
450}
451
452fn generate_axum_server_impl(service: &Service, struct_name: &syn::Ident) -> syn::ItemImpl {
453    let mut route_inits = Vec::with_capacity(service.methods.len());
454    // let mut route_adds = Vec::with_capacity(service.methods.len());
455
456    let mut handlers = Vec::with_capacity(service.methods.len());
457    for i in 1..=service.methods.len() {
458        handlers.push(format_ident!("H{i}"));
459    }
460
461    let mut handler_constraints = Vec::with_capacity(service.methods.len());
462    for (i, method) in service.methods.iter().enumerate() {
463        let name = &method.name;
464        let path = format!("/{}/{}", service.fqn, method.proto_name);
465        let input_type = &method.input_type;
466        let output_type = &method.output_type;
467
468        route_inits.push(quote! {
469            let #name = self.#name;
470            let cs = common_server.clone();
471            router = router.route(
472                #path,
473                ::axum::routing::any(move |::axum::extract::State(state): ::axum::extract::State<S>, req: ::axum::extract::Request| async move {
474                    #name.call(req, state, cs).await
475                })
476            );
477        });
478
479        let handler = &handlers[i];
480        handler_constraints.push(quote! {
481            #handler: ::connectrpc::server::axum::RpcUnaryHandler<#input_type, #output_type, S>,
482        });
483    }
484
485    parse_quote! {
486        impl<S, #(#handlers,)*> #struct_name<S, #(#handlers,)*>
487        where
488            S: Send + Sync + Clone + 'static,
489            #(#handler_constraints)*
490        {
491            pub fn into_router(
492                self,
493            ) -> ::axum::Router {
494                let mut router = ::axum::Router::new();
495                let common_server = ::connectrpc::server::CommonServer::new();
496                #(#route_inits)*
497                router.with_state(self.state)
498            }
499        }
500    }
501}