kanamaru_build/prost/codegen/
traits.rs

1use std::collections::HashSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::quote;
6use quote::ToTokens;
7use syn::Ident;
8
9use crate::utils::generate_doc_comment;
10use crate::utils::{format_method_name, generate_doc_comments, Method, Service};
11
12pub struct GenerateTraitService<'a, S> {
13    pub service: &'a S,
14    pub emit_package: bool,
15    pub proto_path: &'a str,
16    pub compile_well_known_types: bool,
17    pub service_trait: Ident,
18    pub disable_comments: &'a HashSet<String>,
19    pub use_arc_self: bool,
20    pub generate_default_stubs: bool,
21}
22
23impl<S: Service> GenerateTraitService<'_, S> {
24    pub fn generate_methods(&self) -> TokenStream {
25        let mut stream = TokenStream::new();
26
27        for method in self.service.methods() {
28            let name = quote::format_ident!("{}", method.name());
29
30            let (req_message, res_message) =
31                method.request_response_name(self.proto_path, self.compile_well_known_types);
32            let method_doc = if self.disable_comments.contains(&format_method_name(
33                self.service,
34                method,
35                self.emit_package,
36            )) {
37                TokenStream::new()
38            } else {
39                generate_doc_comments(method.comment())
40            };
41            let self_param = if self.use_arc_self {
42                quote!(self: std::sync::Arc<Self>)
43            } else {
44                quote!(&self)
45            };
46            let not_implemented = quote! {
47                Err(kanamaru::Status::unimplemented("Not implemented"))
48            };
49            let method_tokens: TokenStream = match (
50                method.client_streaming(),
51                method.server_streaming(),
52                self.generate_default_stubs,
53            ) {
54                (true, true, true) => {
55                    let stream =
56                        quote::format_ident!("{}Stream", method.identifier().to_upper_camel_case());
57                    let stream_doc = generate_doc_comment(format!(
58                        " Server streaming response type for the {} method.",
59                        method.identifier()
60                    ));
61                    quote! {
62                        #stream_doc
63                        type #stream: kanamaru::codegen::tokio_stream::Stream<Item = std::result::Result<IpcMessage<#res_message>, kanamaru::Status>> + std::marker::Send + 'static;
64
65                        #method_doc
66                        async fn #name<R: Runtime>(#self_param, request: kanamaru::StreamingRequest<R, #req_message>)
67                            -> std::result::Result<kanamaru::StreamingResponse<#res_message, Self::#stream>, kanamaru::Status> {
68                            #not_implemented
69                        }
70                    }
71                }
72                (true, true, false) => {
73                    let stream =
74                        quote::format_ident!("{}Stream", method.identifier().to_upper_camel_case());
75                    let stream_doc = generate_doc_comment(format!(
76                        " Server streaming response type for the {} method.",
77                        method.identifier()
78                    ));
79                    quote! {
80                        #stream_doc
81                        type #stream: kanamaru::codegen::tokio_stream::Stream<Item = std::result::Result<IpcMessage<#res_message>, kanamaru::Status>> + std::marker::Send + 'static;
82
83                        #method_doc
84                        async fn #name<R: Runtime>(#self_param, request: kanamaru::StreamingRequest<R, #req_message>)
85                            -> std::result::Result<kanamaru::StreamingResponse<#res_message, Self::#stream>, kanamaru::Status>;
86                    }
87                }
88                (true, false, true) => {
89                    quote! {
90                        #method_doc
91                        async fn #name<R: Runtime>(#self_param, request: kanamaru::StreamingRequest<R, #req_message>)
92                            -> std::result::Result<kanamaru::UnaryResponse<#res_message>, kanamaru::Status> {
93                            #not_implemented
94                        }
95                    }
96                }
97                (true, false, false) => {
98                    quote! {
99                        #method_doc
100                        async fn #name<R: Runtime>(#self_param, request: kanamaru::StreamingRequest<R, #req_message>)
101                            -> std::result::Result<kanamaru::UnaryResponse<#res_message>, kanamaru::Status>;
102                    }
103                }
104                (false, true, true) => {
105                    let stream =
106                        quote::format_ident!("{}Stream", method.identifier().to_upper_camel_case());
107                    let stream_doc = generate_doc_comment(format!(
108                        " Server streaming response type for the {} method.",
109                        method.identifier()
110                    ));
111                    quote! {
112                        #stream_doc
113                        type #stream: kanamaru::codegen::tokio_stream::Stream<Item = std::result::Result<IpcMessage<#res_message>, kanamaru::Status>> + std::marker::Send + 'static;
114
115                        #method_doc
116                        async fn #name<R: Runtime>(#self_param, request: kanamaru::UnaryRequest<R, #req_message>)
117                            -> std::result::Result<kanamaru::StreamingResponse<#res_message, Self::#stream>, kanamaru::Status> {
118                            #not_implemented
119                        }
120                    }
121                }
122                (false, true, false) => {
123                    let stream =
124                        quote::format_ident!("{}Stream", method.identifier().to_upper_camel_case());
125                    let stream_doc = generate_doc_comment(format!(
126                        " Server streaming response type for the {} method.",
127                        method.identifier()
128                    ));
129                    quote! {
130                        #stream_doc
131                        type #stream: kanamaru::codegen::tokio_stream::Stream<Item = std::result::Result<IpcMessage<#res_message>, kanamaru::Status>> + std::marker::Send + 'static;
132
133                        #method_doc
134                        async fn #name<R: Runtime>(#self_param, request: kanamaru::UnaryRequest<R, #req_message>)
135                            -> std::result::Result<kanamaru::StreamingResponse<#res_message, Self::#stream>, kanamaru::Status>;
136                    }
137                }
138                (false, false, true) => {
139                    quote! {
140                        #method_doc
141                        async fn #name<R: Runtime>(#self_param, request: kanamaru::UnaryRequest<R, #req_message>)
142                            -> std::result::Result<kanamaru::UnaryResponse<#res_message>, kanamaru::Status>{
143                            #not_implemented
144                        }
145                    }
146                }
147                (false, false, false) => {
148                    quote! {
149                        #method_doc
150                        async fn #name<R: Runtime>(#self_param, request: kanamaru::UnaryRequest<R, #req_message>)
151                            -> std::result::Result<kanamaru::UnaryResponse<#res_message>, kanamaru::Status>;
152                    }
153                }
154            };
155            stream.extend(method_tokens);
156        }
157        stream
158    }
159}
160
161impl<S: Service> ToTokens for GenerateTraitService<'_, S> {
162    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
163        let methods = self.generate_methods();
164        let trait_doc = generate_doc_comment(format!(
165            " Generated trait containing gRPC methods that should be implemented for use with {}Responder.",
166            self.service.name()
167        ));
168        let server_trait = &self.service_trait;
169        let _trait = quote! {
170            #trait_doc
171            #[async_trait]
172            pub trait #server_trait : std::marker::Send + std::marker::Sync + 'static {
173                #methods
174            }
175        };
176        tokens.extend(_trait);
177    }
178}