tonic_build/
server.rs

1use std::collections::HashSet;
2
3use super::{Attributes, Method, Service};
4use crate::{
5    format_method_name, format_method_path, format_service_name, generate_doc_comment,
6    generate_doc_comments, naive_snake_case,
7};
8use proc_macro2::{Span, TokenStream};
9use quote::quote;
10use syn::{Ident, Lit, LitStr};
11
12#[allow(clippy::too_many_arguments)]
13pub(crate) fn generate_internal<T: Service>(
14    service: &T,
15    emit_package: bool,
16    proto_path: &str,
17    compile_well_known_types: bool,
18    attributes: &Attributes,
19    disable_comments: &HashSet<String>,
20    use_arc_self: bool,
21    generate_default_stubs: bool,
22) -> TokenStream {
23    let methods = generate_methods(
24        service,
25        emit_package,
26        proto_path,
27        compile_well_known_types,
28        use_arc_self,
29        generate_default_stubs,
30    );
31
32    let server_service = quote::format_ident!("{}Server", service.name());
33    let server_trait = quote::format_ident!("{}", service.name());
34    let server_mod = quote::format_ident!("{}_server", naive_snake_case(service.name()));
35    let trait_attributes = attributes.for_trait(service.name());
36    let generated_trait = generate_trait(
37        service,
38        emit_package,
39        proto_path,
40        compile_well_known_types,
41        server_trait.clone(),
42        disable_comments,
43        use_arc_self,
44        generate_default_stubs,
45        trait_attributes,
46    );
47    let package = if emit_package { service.package() } else { "" };
48    // Transport based implementations
49    let service_name = format_service_name(service, emit_package);
50
51    let service_doc = if disable_comments.contains(&service_name) {
52        TokenStream::new()
53    } else {
54        generate_doc_comments(service.comment())
55    };
56
57    let named = generate_named(&server_service, &service_name);
58    let mod_attributes = attributes.for_mod(package);
59    let struct_attributes = attributes.for_struct(&service_name);
60
61    let configure_compression_methods = quote! {
62        /// Enable decompressing requests with the given encoding.
63        #[must_use]
64        pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
65            self.accept_compression_encodings.enable(encoding);
66            self
67        }
68
69        /// Compress responses with the given encoding, if the client supports it.
70        #[must_use]
71        pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
72            self.send_compression_encodings.enable(encoding);
73            self
74        }
75    };
76
77    let configure_max_message_size_methods = quote! {
78        /// Limits the maximum size of a decoded message.
79        ///
80        /// Default: `4MB`
81        #[must_use]
82        pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
83            self.max_decoding_message_size = Some(limit);
84            self
85        }
86
87        /// Limits the maximum size of an encoded message.
88        ///
89        /// Default: `usize::MAX`
90        #[must_use]
91        pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
92            self.max_encoding_message_size = Some(limit);
93            self
94        }
95    };
96
97    quote! {
98        /// Generated server implementations.
99        #(#mod_attributes)*
100        pub mod #server_mod {
101            #![allow(
102                unused_variables,
103                dead_code,
104                missing_docs,
105                clippy::wildcard_imports,
106                // will trigger if compression is disabled
107                clippy::let_unit_value,
108            )]
109            use tonic::codegen::*;
110
111            #generated_trait
112
113            #service_doc
114            #(#struct_attributes)*
115            #[derive(Debug)]
116            pub struct #server_service<T> {
117                inner: Arc<T>,
118                accept_compression_encodings: EnabledCompressionEncodings,
119                send_compression_encodings: EnabledCompressionEncodings,
120                max_decoding_message_size: Option<usize>,
121                max_encoding_message_size: Option<usize>,
122            }
123
124            impl<T> #server_service<T> {
125                pub fn new(inner: T) -> Self {
126                    Self::from_arc(Arc::new(inner))
127                }
128
129                pub fn from_arc(inner: Arc<T>) -> Self {
130                    Self {
131                        inner,
132                        accept_compression_encodings: Default::default(),
133                        send_compression_encodings: Default::default(),
134                        max_decoding_message_size: None,
135                        max_encoding_message_size: None,
136                    }
137                }
138
139                pub fn with_interceptor<F>(inner: T, interceptor: F) -> InterceptedService<Self, F>
140                where
141                    F: tonic::service::Interceptor,
142                {
143                    InterceptedService::new(Self::new(inner), interceptor)
144                }
145
146                #configure_compression_methods
147
148                #configure_max_message_size_methods
149            }
150
151            impl<T, B> tonic::codegen::Service<http::Request<B>> for #server_service<T>
152                where
153                    T: #server_trait,
154                    B: Body + std::marker::Send + 'static,
155                    B::Error: Into<StdError> + std::marker::Send + 'static,
156            {
157                type Response = http::Response<tonic::body::Body>;
158                type Error = std::convert::Infallible;
159                type Future = BoxFuture<Self::Response, Self::Error>;
160
161                fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
162                    Poll::Ready(Ok(()))
163                }
164
165                fn call(&mut self, req: http::Request<B>) -> Self::Future {
166                    match req.uri().path() {
167                        #methods
168
169                        _ => Box::pin(async move {
170                            let mut response = http::Response::new(tonic::body::Body::default());
171                            let headers = response.headers_mut();
172                            headers.insert(tonic::Status::GRPC_STATUS, (tonic::Code::Unimplemented as i32).into());
173                            headers.insert(http::header::CONTENT_TYPE, tonic::metadata::GRPC_CONTENT_TYPE);
174                            Ok(response)
175                        }),
176                    }
177                }
178            }
179
180            impl<T> Clone for #server_service<T> {
181                fn clone(&self) -> Self {
182                    let inner = self.inner.clone();
183                    Self {
184                        inner,
185                        accept_compression_encodings: self.accept_compression_encodings,
186                        send_compression_encodings: self.send_compression_encodings,
187                        max_decoding_message_size: self.max_decoding_message_size,
188                        max_encoding_message_size: self.max_encoding_message_size,
189                    }
190                }
191            }
192
193            #named
194        }
195    }
196}
197
198#[allow(clippy::too_many_arguments)]
199fn generate_trait<T: Service>(
200    service: &T,
201    emit_package: bool,
202    proto_path: &str,
203    compile_well_known_types: bool,
204    server_trait: Ident,
205    disable_comments: &HashSet<String>,
206    use_arc_self: bool,
207    generate_default_stubs: bool,
208    trait_attributes: Vec<syn::Attribute>,
209) -> TokenStream {
210    let methods = generate_trait_methods(
211        service,
212        emit_package,
213        proto_path,
214        compile_well_known_types,
215        disable_comments,
216        use_arc_self,
217        generate_default_stubs,
218    );
219    let trait_doc = generate_doc_comment(format!(
220        " Generated trait containing gRPC methods that should be implemented for use with {}Server.",
221        service.name()
222    ));
223
224    quote! {
225        #trait_doc
226        #(#trait_attributes)*
227        #[async_trait]
228        pub trait #server_trait : std::marker::Send + std::marker::Sync + 'static {
229            #methods
230        }
231    }
232}
233
234fn generate_trait_methods<T: Service>(
235    service: &T,
236    emit_package: bool,
237    proto_path: &str,
238    compile_well_known_types: bool,
239    disable_comments: &HashSet<String>,
240    use_arc_self: bool,
241    generate_default_stubs: bool,
242) -> TokenStream {
243    let mut stream = TokenStream::new();
244
245    for method in service.methods() {
246        let name = quote::format_ident!("{}", method.name());
247
248        let (req_message, res_message) =
249            method.request_response_name(proto_path, compile_well_known_types);
250
251        let method_doc =
252            if disable_comments.contains(&format_method_name(service, method, emit_package)) {
253                TokenStream::new()
254            } else {
255                generate_doc_comments(method.comment())
256            };
257
258        let self_param = if use_arc_self {
259            quote!(self: std::sync::Arc<Self>)
260        } else {
261            quote!(&self)
262        };
263
264        let method = match (
265            method.client_streaming(),
266            method.server_streaming(),
267            generate_default_stubs,
268        ) {
269            (false, false, true) => {
270                quote! {
271                    #method_doc
272                    async fn #name(#self_param, request: tonic::Request<#req_message>)
273                        -> std::result::Result<tonic::Response<#res_message>, tonic::Status> {
274                        Err(tonic::Status::unimplemented("Not yet implemented"))
275                    }
276                }
277            }
278            (false, false, false) => {
279                quote! {
280                    #method_doc
281                    async fn #name(#self_param, request: tonic::Request<#req_message>)
282                        -> std::result::Result<tonic::Response<#res_message>, tonic::Status>;
283                }
284            }
285            (true, false, true) => {
286                quote! {
287                    #method_doc
288                    async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
289                        -> std::result::Result<tonic::Response<#res_message>, tonic::Status> {
290                        Err(tonic::Status::unimplemented("Not yet implemented"))
291                    }
292                }
293            }
294            (true, false, false) => {
295                quote! {
296                    #method_doc
297                    async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
298                        -> std::result::Result<tonic::Response<#res_message>, tonic::Status>;
299                }
300            }
301            (false, true, true) => {
302                quote! {
303                    #method_doc
304                    async fn #name(#self_param, request: tonic::Request<#req_message>)
305                        -> std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status> {
306                        Err(tonic::Status::unimplemented("Not yet implemented"))
307                    }
308                }
309            }
310            (false, true, false) => {
311                let stream = quote::format_ident!("{}Stream", method.identifier());
312                let stream_doc = generate_doc_comment(format!(
313                    " Server streaming response type for the {} method.",
314                    method.identifier()
315                ));
316
317                quote! {
318                    #stream_doc
319                    type #stream: tonic::codegen::tokio_stream::Stream<Item = std::result::Result<#res_message, tonic::Status>> + std::marker::Send + 'static;
320
321                    #method_doc
322                    async fn #name(#self_param, request: tonic::Request<#req_message>)
323                        -> std::result::Result<tonic::Response<Self::#stream>, tonic::Status>;
324                }
325            }
326            (true, true, true) => {
327                quote! {
328                    #method_doc
329                    async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
330                        -> std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status> {
331                        Err(tonic::Status::unimplemented("Not yet implemented"))
332                    }
333                }
334            }
335            (true, true, false) => {
336                let stream = quote::format_ident!("{}Stream", method.identifier());
337                let stream_doc = generate_doc_comment(format!(
338                    " Server streaming response type for the {} method.",
339                    method.identifier()
340                ));
341
342                quote! {
343                    #stream_doc
344                    type #stream: tonic::codegen::tokio_stream::Stream<Item = std::result::Result<#res_message, tonic::Status>> + std::marker::Send + 'static;
345
346                    #method_doc
347                    async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
348                        -> std::result::Result<tonic::Response<Self::#stream>, tonic::Status>;
349                }
350            }
351        };
352
353        stream.extend(method);
354    }
355
356    stream
357}
358
359fn generate_named(server_service: &syn::Ident, service_name: &str) -> TokenStream {
360    let service_name = syn::LitStr::new(service_name, proc_macro2::Span::call_site());
361    let name_doc = generate_doc_comment(" Generated gRPC service name");
362
363    quote! {
364        #name_doc
365        pub const SERVICE_NAME: &str = #service_name;
366
367        impl<T> tonic::server::NamedService for #server_service<T> {
368            const NAME: &'static str = SERVICE_NAME;
369        }
370    }
371}
372
373fn generate_methods<T: Service>(
374    service: &T,
375    emit_package: bool,
376    proto_path: &str,
377    compile_well_known_types: bool,
378    use_arc_self: bool,
379    generate_default_stubs: bool,
380) -> TokenStream {
381    let mut stream = TokenStream::new();
382
383    for method in service.methods() {
384        let path = format_method_path(service, method, emit_package);
385        let method_path = Lit::Str(LitStr::new(&path, Span::call_site()));
386        let ident = quote::format_ident!("{}", method.name());
387        let server_trait = quote::format_ident!("{}", service.name());
388
389        let method_stream = match (method.client_streaming(), method.server_streaming()) {
390            (false, false) => generate_unary(
391                method,
392                proto_path,
393                compile_well_known_types,
394                ident,
395                server_trait,
396                use_arc_self,
397            ),
398
399            (false, true) => generate_server_streaming(
400                method,
401                proto_path,
402                compile_well_known_types,
403                ident.clone(),
404                server_trait,
405                use_arc_self,
406                generate_default_stubs,
407            ),
408            (true, false) => generate_client_streaming(
409                method,
410                proto_path,
411                compile_well_known_types,
412                ident.clone(),
413                server_trait,
414                use_arc_self,
415            ),
416
417            (true, true) => generate_streaming(
418                method,
419                proto_path,
420                compile_well_known_types,
421                ident.clone(),
422                server_trait,
423                use_arc_self,
424                generate_default_stubs,
425            ),
426        };
427
428        let method = quote! {
429            #method_path => {
430                #method_stream
431            }
432        };
433        stream.extend(method);
434    }
435
436    stream
437}
438
439fn generate_unary<T: Method>(
440    method: &T,
441    proto_path: &str,
442    compile_well_known_types: bool,
443    method_ident: Ident,
444    server_trait: Ident,
445    use_arc_self: bool,
446) -> TokenStream {
447    let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
448
449    let service_ident = quote::format_ident!("{}Svc", method.identifier());
450
451    let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
452
453    let inner_arg = if use_arc_self {
454        quote!(inner)
455    } else {
456        quote!(&inner)
457    };
458
459    quote! {
460        #[allow(non_camel_case_types)]
461        struct #service_ident<T: #server_trait >(pub Arc<T>);
462
463        impl<T: #server_trait> tonic::server::UnaryService<#request> for #service_ident<T> {
464            type Response = #response;
465            type Future = BoxFuture<tonic::Response<Self::Response>, tonic::Status>;
466
467            fn call(&mut self, request: tonic::Request<#request>) -> Self::Future {
468                let inner = Arc::clone(&self.0);
469                let fut = async move {
470                    <T as #server_trait>::#method_ident(#inner_arg, request).await
471                };
472                Box::pin(fut)
473            }
474        }
475
476        let accept_compression_encodings = self.accept_compression_encodings;
477        let send_compression_encodings = self.send_compression_encodings;
478        let max_decoding_message_size = self.max_decoding_message_size;
479        let max_encoding_message_size = self.max_encoding_message_size;
480        let inner = self.inner.clone();
481        let fut = async move {
482            let method = #service_ident(inner);
483            let codec = #codec_name::default();
484
485            let mut grpc = tonic::server::Grpc::new(codec)
486                .apply_compression_config(accept_compression_encodings, send_compression_encodings)
487                .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
488
489            let res = grpc.unary(method, req).await;
490            Ok(res)
491        };
492
493        Box::pin(fut)
494    }
495}
496
497fn generate_server_streaming<T: Method>(
498    method: &T,
499    proto_path: &str,
500    compile_well_known_types: bool,
501    method_ident: Ident,
502    server_trait: Ident,
503    use_arc_self: bool,
504    generate_default_stubs: bool,
505) -> TokenStream {
506    let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
507
508    let service_ident = quote::format_ident!("{}Svc", method.identifier());
509
510    let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
511
512    let response_stream = if !generate_default_stubs {
513        let stream = quote::format_ident!("{}Stream", method.identifier());
514        quote!(type ResponseStream = T::#stream)
515    } else {
516        quote!(type ResponseStream = BoxStream<#response>)
517    };
518
519    let inner_arg = if use_arc_self {
520        quote!(inner)
521    } else {
522        quote!(&inner)
523    };
524
525    quote! {
526        #[allow(non_camel_case_types)]
527        struct #service_ident<T: #server_trait >(pub Arc<T>);
528
529        impl<T: #server_trait> tonic::server::ServerStreamingService<#request> for #service_ident<T> {
530            type Response = #response;
531            #response_stream;
532            type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
533
534            fn call(&mut self, request: tonic::Request<#request>) -> Self::Future {
535                let inner = Arc::clone(&self.0);
536                let fut = async move {
537                    <T as #server_trait>::#method_ident(#inner_arg, request).await
538                };
539                Box::pin(fut)
540            }
541        }
542
543        let accept_compression_encodings = self.accept_compression_encodings;
544        let send_compression_encodings = self.send_compression_encodings;
545        let max_decoding_message_size = self.max_decoding_message_size;
546        let max_encoding_message_size = self.max_encoding_message_size;
547        let inner = self.inner.clone();
548        let fut = async move {
549            let method = #service_ident(inner);
550            let codec = #codec_name::default();
551
552            let mut grpc = tonic::server::Grpc::new(codec)
553                .apply_compression_config(accept_compression_encodings, send_compression_encodings)
554                .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
555
556            let res = grpc.server_streaming(method, req).await;
557            Ok(res)
558        };
559
560        Box::pin(fut)
561    }
562}
563
564fn generate_client_streaming<T: Method>(
565    method: &T,
566    proto_path: &str,
567    compile_well_known_types: bool,
568    method_ident: Ident,
569    server_trait: Ident,
570    use_arc_self: bool,
571) -> TokenStream {
572    let service_ident = quote::format_ident!("{}Svc", method.identifier());
573
574    let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
575    let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
576
577    let inner_arg = if use_arc_self {
578        quote!(inner)
579    } else {
580        quote!(&inner)
581    };
582
583    quote! {
584        #[allow(non_camel_case_types)]
585        struct #service_ident<T: #server_trait >(pub Arc<T>);
586
587        impl<T: #server_trait> tonic::server::ClientStreamingService<#request> for #service_ident<T>
588        {
589            type Response = #response;
590            type Future = BoxFuture<tonic::Response<Self::Response>, tonic::Status>;
591
592            fn call(&mut self, request: tonic::Request<tonic::Streaming<#request>>) -> Self::Future {
593                let inner = Arc::clone(&self.0);
594                let fut = async move {
595                    <T as #server_trait>::#method_ident(#inner_arg, request).await
596                };
597                Box::pin(fut)
598            }
599        }
600
601        let accept_compression_encodings = self.accept_compression_encodings;
602        let send_compression_encodings = self.send_compression_encodings;
603        let max_decoding_message_size = self.max_decoding_message_size;
604        let max_encoding_message_size = self.max_encoding_message_size;
605        let inner = self.inner.clone();
606        let fut = async move {
607            let method = #service_ident(inner);
608            let codec = #codec_name::default();
609
610            let mut grpc = tonic::server::Grpc::new(codec)
611                .apply_compression_config(accept_compression_encodings, send_compression_encodings)
612                .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
613
614            let res = grpc.client_streaming(method, req).await;
615            Ok(res)
616        };
617
618        Box::pin(fut)
619    }
620}
621
622fn generate_streaming<T: Method>(
623    method: &T,
624    proto_path: &str,
625    compile_well_known_types: bool,
626    method_ident: Ident,
627    server_trait: Ident,
628    use_arc_self: bool,
629    generate_default_stubs: bool,
630) -> TokenStream {
631    let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
632
633    let service_ident = quote::format_ident!("{}Svc", method.identifier());
634
635    let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
636
637    let response_stream = if !generate_default_stubs {
638        let stream = quote::format_ident!("{}Stream", method.identifier());
639        quote!(type ResponseStream = T::#stream)
640    } else {
641        quote!(type ResponseStream = BoxStream<#response>)
642    };
643
644    let inner_arg = if use_arc_self {
645        quote!(inner)
646    } else {
647        quote!(&inner)
648    };
649
650    quote! {
651        #[allow(non_camel_case_types)]
652        struct #service_ident<T: #server_trait>(pub Arc<T>);
653
654        impl<T: #server_trait> tonic::server::StreamingService<#request> for #service_ident<T>
655        {
656            type Response = #response;
657            #response_stream;
658            type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
659
660            fn call(&mut self, request: tonic::Request<tonic::Streaming<#request>>) -> Self::Future {
661                let inner = Arc::clone(&self.0);
662                let fut = async move {
663                    <T as #server_trait>::#method_ident(#inner_arg, request).await
664                };
665                Box::pin(fut)
666            }
667        }
668
669        let accept_compression_encodings = self.accept_compression_encodings;
670        let send_compression_encodings = self.send_compression_encodings;
671        let max_decoding_message_size = self.max_decoding_message_size;
672        let max_encoding_message_size = self.max_encoding_message_size;
673        let inner = self.inner.clone();
674        let fut = async move {
675            let method = #service_ident(inner);
676            let codec = #codec_name::default();
677
678            let mut grpc = tonic::server::Grpc::new(codec)
679                .apply_compression_config(accept_compression_encodings, send_compression_encodings)
680                .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
681
682            let res = grpc.streaming(method, req).await;
683            Ok(res)
684        };
685
686        Box::pin(fut)
687    }
688}