1use super::{Method, Service};
2use crate::{generate_doc_comment, generate_doc_comments, naive_snake_case, Builder};
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::{Ident, Lit, LitStr};
6
7pub fn generate<T: Service>(service: &T, config: &Builder) -> TokenStream {
12 let attributes = &config.server_attributes;
13 let methods = generate_methods(service, config, false);
14 let json_methods = generate_methods(service, config, true);
15
16 let server_service = quote::format_ident!("{}Server", service.name());
17 let server_trait = quote::format_ident!("{}Rpc", service.name());
18 let server_mod = quote::format_ident!("{}_server", naive_snake_case(service.name()));
19 let service_name = Lit::Str(LitStr::new(service.name(), Span::call_site()));
20 let supported_methods = generate_supported_methods(service, config);
21 let method_enum = generate_methods_enum(service, config);
22 let generated_trait = generate_trait(service, config, server_trait.clone());
23 let service_doc = generate_doc_comments(service.comment());
24 let mod_attributes = attributes.for_mod(service.package());
25 let struct_attributes = attributes.for_struct(service.identifier());
26
27 quote! {
28 #(#mod_attributes)*
30 pub mod #server_mod {
31 use alloc::vec::Vec;
32
33 #method_enum
34
35 #generated_trait
36
37 #service_doc
38 #(#struct_attributes)*
39 #[derive(Debug)]
40 pub struct #server_service<T: #server_trait> {
41 inner: T,
42 }
43
44 impl<T: #server_trait> #server_service<T> {
45 pub fn new(inner: T) -> Self {
46 Self {
47 inner,
48 }
49 }
50
51 pub async fn dispatch_request(self, path: &str, _data: impl AsRef<[u8]>) -> Result<Vec<u8>, ::prpc::server::Error> {
52 #![allow(clippy::let_unit_value)]
53 match path {
54 #methods
55 _ => anyhow::bail!("Service not found: {path}"),
56 }
57 }
58
59 pub async fn dispatch_json_request(self, path: &str, _data: impl AsRef<[u8]>, _query: bool) -> Result<Vec<u8>, ::prpc::server::Error> {
60 #![allow(clippy::let_unit_value)]
61 match path {
62 #json_methods
63 _ => anyhow::bail!("Service not found: {path}"),
64 }
65 }
66 #supported_methods
67 }
68
69 impl<T: #server_trait> ::prpc::server::NamedService for #server_service<T> {
70 const NAME: &'static str = #service_name;
71 }
72 impl<T: #server_trait> ::prpc::server::Service for #server_service<T> {
73 type Methods = &'static [&'static str];
74 fn methods() -> Self::Methods {
75 Self::supported_methods()
76 }
77 async fn dispatch_request(self, path: &str, data: impl AsRef<[u8]>, json: bool, query: bool) -> Result<Vec<u8>, ::prpc::server::Error> {
78 if json {
79 self.dispatch_json_request(path, data, query).await
80 } else {
81 self.dispatch_request(path, data).await
82 }
83 }
84 }
85 impl<T: #server_trait> From<T> for #server_service<T> {
86 fn from(inner: T) -> Self {
87 Self::new(inner)
88 }
89 }
90 }
91 }
92}
93
94fn generate_trait<T: Service>(service: &T, config: &Builder, server_trait: Ident) -> TokenStream {
95 let methods =
96 generate_trait_methods(service, &config.proto_path, config.compile_well_known_types);
97 let trait_doc = generate_doc_comment(format!(
98 "Generated trait containing RPC methods that should be implemented for use with {}Server.",
99 service.name()
100 ));
101
102 quote! {
103 #trait_doc
104 pub trait #server_trait {
105 #methods
106 }
107 }
108}
109
110fn generate_trait_methods<T: Service>(
111 service: &T,
112 proto_path: &str,
113 compile_well_known_types: bool,
114) -> TokenStream {
115 let mut stream = TokenStream::new();
116
117 for method in service.methods() {
118 let name = quote::format_ident!("{}", method.name());
119
120 let (req_message, res_message) =
121 method.request_response_name(proto_path, compile_well_known_types);
122
123 let method_doc = generate_doc_comments(method.comment());
124
125 let method = match (method.client_streaming(), method.server_streaming()) {
126 (false, false) => {
127 template_quote::quote! {
128 #method_doc
129 async fn #name(self
130 #(if req_message.is_some()) {
131 , request: #req_message
132 }
133 ) -> ::anyhow::Result<#res_message>;
134 }
135 }
136 _ => {
137 panic!("Streaming RPC not supported");
138 }
139 };
140
141 stream.extend(method);
142 }
143
144 stream
145}
146
147fn generate_supported_methods<T: Service>(service: &T, config: &Builder) -> TokenStream {
148 let mut all_methods = TokenStream::new();
149 for method in service.methods() {
150 let path = crate::join_path(
151 config,
152 service.package(),
153 service.identifier(),
154 method.identifier(),
155 );
156
157 let method_path = Lit::Str(LitStr::new(&path, Span::call_site()));
158 all_methods.extend(quote! {
159 #method_path,
160 });
161 }
162
163 quote! {
164 pub fn supported_methods()
165 -> &'static [&'static str] {
166 &[
167 #all_methods
168 ]
169 }
170 }
171}
172
173fn generate_methods_enum<T: Service>(service: &T, config: &Builder) -> TokenStream {
174 let mut paths = vec![];
175 let mut variants = vec![];
176 for method in service.methods() {
177 let path = crate::join_path(
178 config,
179 service.package(),
180 service.identifier(),
181 method.identifier(),
182 );
183
184 let variant = Ident::new(method.identifier(), Span::call_site());
185 variants.push(variant);
186
187 let method_path = Lit::Str(LitStr::new(&path, Span::call_site()));
188 paths.push(method_path);
189 }
190
191 let enum_name = Ident::new(
192 &format!("{}Method", service.identifier()),
193 Span::call_site(),
194 );
195 quote! {
196 pub enum #enum_name {
197 #(#variants,)*
198 }
199
200 impl #enum_name {
201 #[allow(clippy::should_implement_trait)]
202 pub fn from_str(path: &str) -> Option<Self> {
203 match path {
204 #(#paths => Some(Self::#variants),)*
205 _ => None,
206 }
207 }
208 }
209 }
210}
211
212fn generate_methods<T: Service>(service: &T, config: &Builder, json: bool) -> TokenStream {
213 let mut stream = TokenStream::new();
214
215 for method in service.methods() {
216 let path = crate::join_path(
217 config,
218 service.package(),
219 service.identifier(),
220 method.identifier(),
221 );
222 let method_path = Lit::Str(LitStr::new(&path, Span::call_site()));
223 let method_ident = quote::format_ident!("{}", method.name());
224
225 let method_stream = match (method.client_streaming(), method.server_streaming()) {
226 (false, false) => generate_unary(method, config, method_ident, json),
227 _ => {
228 panic!("Streaming RPC not supported");
229 }
230 };
231
232 let method = quote! {
233 #method_path => {
234 #method_stream
235 }
236 };
237 stream.extend(method);
238 }
239
240 stream
241}
242
243fn generate_unary<T: Method>(
244 method: &T,
245 config: &Builder,
246 method_ident: Ident,
247 json: bool,
248) -> TokenStream {
249 let (request, _response) =
250 method.request_response_name(&config.proto_path, config.compile_well_known_types);
251
252 if json {
253 template_quote::quote! {
254 #(if request.is_none()) {
255 let response = self.inner.#method_ident().await?;
256 }
257 #(else) {
258 let data = _data.as_ref();
259 let input: #request = if data.is_empty() {
260 Default::default()
261 } else if _query {
262 ::prpc::serde_qs::from_bytes(data)?
263 } else {
264 ::prpc::serde_json::from_slice(data)?
265 };
266 let response = self.inner.#method_ident(input).await?;
267 }
268 Ok(serde_json::to_vec(&response)?)
269 }
270 } else {
271 template_quote::quote! {
272 #(if request.is_none()) {
273 let response = self.inner.#method_ident().await?;
274 }
275 #(else) {
276 let input: #request = ::prpc::Message::decode(_data.as_ref())?;
277 let response = self.inner.#method_ident(input).await?;
278 }
279 Ok(::prpc::codec::encode_message_to_vec(&response))
280 }
281 }
282}