conjure_codegen/
servers.rs

1use crate::context::{BaseModule, Context};
2use crate::human_size;
3use crate::types::objects::{
4    ArgumentDefinition, AuthType, EndpointDefinition, ParameterType, ServiceDefinition, Type,
5};
6use heck::ToUpperCamelCase;
7use proc_macro2::{Ident, TokenStream};
8use quote::quote;
9
10#[derive(Copy, Clone)]
11enum Style {
12    Async,
13    Sync,
14}
15
16pub fn generate(ctx: &Context, def: &ServiceDefinition) -> TokenStream {
17    let sync_trait = generate_trait(ctx, def, Style::Sync);
18    let async_trait = generate_trait(ctx, def, Style::Async);
19
20    quote! {
21        use conjure_http::endpoint;
22
23        #sync_trait
24        #async_trait
25    }
26}
27
28fn generate_trait(ctx: &Context, def: &ServiceDefinition, style: Style) -> TokenStream {
29    let docs = ctx.docs(def.docs());
30    let service_name = def.service_name().name();
31    let name = trait_name(ctx, def, style);
32    let params = params(ctx, def);
33
34    let use_legacy_error_serialization = if ctx.use_legacy_error_serialization() {
35        quote!(, use_legacy_error_serialization)
36    } else {
37        quote!()
38    };
39
40    let binary_types = def
41        .endpoints()
42        .iter()
43        .flat_map(|e| generate_binary_type(ctx, def, e, style));
44
45    let endpoints = def
46        .endpoints()
47        .iter()
48        .map(|e| generate_trait_endpoint(ctx, def, e, style));
49
50    quote! {
51        #docs
52        #[conjure_http::conjure_endpoints(name = #service_name #use_legacy_error_serialization)]
53        pub trait #name #params {
54            #(#binary_types)*
55
56            #(#endpoints)*
57        }
58    }
59}
60
61fn trait_name(ctx: &Context, def: &ServiceDefinition, style: Style) -> Ident {
62    match style {
63        Style::Async => ctx.type_name(&format!("Async{}", def.service_name().name())),
64        Style::Sync => ctx.type_name(def.service_name().name()),
65    }
66}
67
68fn params(ctx: &Context, def: &ServiceDefinition) -> TokenStream {
69    let mut params = vec![];
70    if service_has_binary_request_body(ctx, def) {
71        params.push(quote! {
72                #[request_body]
73                I
74        });
75    }
76    if service_has_binary_response_body(ctx, def) {
77        params.push(quote! {
78                #[response_writer]
79                O
80        });
81    }
82
83    if params.is_empty() {
84        quote!()
85    } else {
86        quote!(<#(#params),*>)
87    }
88}
89
90fn service_has_binary_request_body(ctx: &Context, def: &ServiceDefinition) -> bool {
91    def.endpoints()
92        .iter()
93        .any(|e| endpoint_has_binary_request_body(ctx, e))
94}
95
96fn endpoint_has_binary_request_body(ctx: &Context, endpoint: &EndpointDefinition) -> bool {
97    endpoint.args().iter().any(|a| match a.param_type() {
98        ParameterType::Body(_) => ctx.is_binary(a.type_()),
99        _ => false,
100    })
101}
102
103fn service_has_binary_response_body(ctx: &Context, def: &ServiceDefinition) -> bool {
104    def.endpoints()
105        .iter()
106        .any(|e| endpoint_has_binary_response_body(ctx, e))
107}
108
109fn endpoint_has_binary_response_body(ctx: &Context, endpoint: &EndpointDefinition) -> bool {
110    match return_type(ctx, endpoint) {
111        ReturnType::Binary | ReturnType::OptionalBinary => true,
112        ReturnType::None | ReturnType::Json(_) => false,
113    }
114}
115
116fn generate_binary_type(
117    ctx: &Context,
118    def: &ServiceDefinition,
119    endpoint: &EndpointDefinition,
120    style: Style,
121) -> Option<TokenStream> {
122    if endpoint_has_binary_response_body(ctx, endpoint) {
123        let docs = format!(
124            "The body type returned by the `{}` method.",
125            ctx.field_name(endpoint.endpoint_name())
126        );
127        let name = binary_type(endpoint);
128        let bounds = match style {
129            Style::Async => {
130                let send = ctx.send_ident(def.service_name());
131                quote!(conjure_http::server::AsyncWriteBody<O> + 'static + #send)
132            }
133            Style::Sync => quote!(conjure_http::server::WriteBody<O> + 'static),
134        };
135        Some(quote! {
136            #[doc = #docs]
137            type #name: #bounds;
138        })
139    } else {
140        None
141    }
142}
143
144fn binary_type(endpoint: &EndpointDefinition) -> TokenStream {
145    format!("{}Body", endpoint.endpoint_name().to_upper_camel_case())
146        .parse()
147        .unwrap()
148}
149
150fn generate_trait_endpoint(
151    ctx: &Context,
152    def: &ServiceDefinition,
153    endpoint: &EndpointDefinition,
154    style: Style,
155) -> TokenStream {
156    let docs = ctx.docs(endpoint.docs());
157    let method = endpoint
158        .http_method()
159        .as_str()
160        .parse::<TokenStream>()
161        .unwrap();
162    let path = &**endpoint.http_path();
163    let endpoint_name = &**endpoint.endpoint_name();
164    let async_ = match style {
165        Style::Async => quote!(async),
166        Style::Sync => quote!(),
167    };
168    let name = ctx.field_name(endpoint.endpoint_name());
169    let produces = match endpoint.returns() {
170        Some(ty) => {
171            let produces = produces(ctx, ty);
172            quote!(, produces = #produces)
173        }
174        None => quote!(),
175    };
176
177    let auth_arg = auth_arg(endpoint);
178    let args = endpoint.args().iter().map(|a| arg(ctx, def, endpoint, a));
179    let request_context_arg = request_context_arg(endpoint);
180
181    let result = ctx.result_ident(def.service_name());
182
183    let ret_ty = rust_return_type(ctx, def, endpoint, &return_type(ctx, endpoint));
184    let ret_ty = quote!(#result<#ret_ty, conjure_http::private::Error>);
185
186    // ignore deprecation since the endpoint has to be implemented regardless
187    quote! {
188        #docs
189        #[endpoint(method = #method, path = #path, name = #endpoint_name #produces)]
190        #async_ fn #name(&self #auth_arg #(, #args)* #request_context_arg) -> #ret_ty;
191    }
192}
193
194fn produces(ctx: &Context, ty: &Type) -> TokenStream {
195    match ctx.is_optional(ty) {
196        Some(inner) if ctx.is_binary(inner) => {
197            quote!(conjure_http::server::conjure::OptionalBinaryResponseSerializer)
198        }
199        _ if ctx.is_binary(ty) => quote!(conjure_http::server::conjure::BinaryResponseSerializer),
200        _ if ctx.is_iterable(ty) => {
201            quote!(conjure_http::server::conjure::CollectionResponseSerializer)
202        }
203        _ => quote!(conjure_http::server::StdResponseSerializer),
204    }
205}
206
207fn auth_arg(endpoint: &EndpointDefinition) -> TokenStream {
208    match endpoint.auth() {
209        Some(auth) => {
210            let params = match auth {
211                AuthType::Header(_) => quote!(),
212                AuthType::Cookie(cookie) => {
213                    let name = &cookie.cookie_name();
214                    quote!((cookie_name = #name))
215                }
216            };
217            quote!(, #[auth #params] auth_: conjure_object::BearerToken)
218        }
219        None => quote!(),
220    }
221}
222
223fn arg(
224    ctx: &Context,
225    def: &ServiceDefinition,
226    endpoint: &EndpointDefinition,
227    arg: &ArgumentDefinition,
228) -> TokenStream {
229    let name = ctx.field_name(arg.arg_name());
230
231    let log_as = if name == **arg.arg_name() {
232        quote!()
233    } else {
234        let log_as = &**arg.arg_name();
235        quote!(, log_as = #log_as)
236    };
237
238    let safe = if ctx.is_safe_arg(arg) {
239        quote!(, safe)
240    } else {
241        quote!()
242    };
243
244    let attr = match arg.param_type() {
245        ParameterType::Body(_) => {
246            let deserializer = if ctx.is_optional(arg.type_()).is_some() {
247                let mut decoder =
248                    quote!(conjure_http::server::conjure::OptionalRequestDeserializer);
249                let dealiased = ctx.dealiased_type(arg.type_());
250                if dealiased != arg.type_() {
251                    let dealiased =
252                        ctx.rust_type(BaseModule::Endpoints, def.service_name(), dealiased);
253                    decoder =
254                        quote!(conjure_http::server::FromRequestDeserializer<#decoder, #dealiased>)
255                }
256                decoder
257            } else if ctx.is_binary(arg.type_()) {
258                quote!(conjure_http::server::conjure::BinaryRequestDeserializer)
259            } else {
260                let param = match server_limit_request_size(endpoint) {
261                    Ok(Some(limit)) => quote!(<#limit>),
262                    Ok(None) => quote!(),
263                    Err(e) => quote!(<compile_error!(#e)>),
264                };
265                quote!(conjure_http::server::StdRequestDeserializer #param)
266            };
267            quote!(#[body(deserializer = #deserializer #log_as #safe)])
268        }
269        ParameterType::Header(header) => {
270            let name = &**header.param_id();
271            let decoder = if ctx.is_optional(arg.type_()).is_some() {
272                optional_decoder(ctx, def, arg.type_())
273            } else {
274                quote!(conjure_http::server::conjure::FromPlainDecoder)
275            };
276            quote!(#[header(name = #name, decoder = #decoder #log_as #safe)])
277        }
278        ParameterType::Path(_) => {
279            let name = &**arg.arg_name();
280            quote! {
281                #[path(
282                    name = #name,
283                    decoder = conjure_http::server::conjure::FromPlainDecoder
284                    #log_as
285                    #safe
286                )]
287            }
288        }
289        ParameterType::Query(query) => {
290            let name = &**query.param_id();
291            let decoder = if ctx.is_optional(arg.type_()).is_some() {
292                optional_decoder(ctx, def, arg.type_())
293            } else if ctx.is_iterable(arg.type_()) {
294                quote!(conjure_http::server::conjure::FromPlainSeqDecoder<_>)
295            } else {
296                quote!(conjure_http::server::conjure::FromPlainDecoder)
297            };
298            quote!(#[query(name = #name, decoder = #decoder #log_as #safe)])
299        }
300    };
301
302    let ty = if ctx.is_binary(arg.type_()) {
303        quote!(I)
304    } else {
305        ctx.rust_type(BaseModule::Endpoints, def.service_name(), arg.type_())
306    };
307    quote!(#attr #name: #ty)
308}
309
310fn optional_decoder(ctx: &Context, def: &ServiceDefinition, ty: &Type) -> TokenStream {
311    let mut decoder = quote!(conjure_http::server::conjure::FromPlainOptionDecoder);
312    let dealiased = ctx.dealiased_type(ty);
313    if dealiased != ty {
314        let dealiased = ctx.rust_type(BaseModule::Endpoints, def.service_name(), dealiased);
315        decoder = quote!(conjure_http::server::FromDecoder<#decoder, #dealiased>)
316    }
317    decoder
318}
319
320fn request_context_arg(endpoint: &EndpointDefinition) -> TokenStream {
321    if has_request_context(endpoint) {
322        quote!(, #[context] request_context_: conjure_http::server::RequestContext<'_>)
323    } else {
324        quote!()
325    }
326}
327
328fn return_type<'a>(ctx: &Context, endpoint: &'a EndpointDefinition) -> ReturnType<'a> {
329    match endpoint.returns() {
330        Some(ty) => match ctx.is_optional(ty) {
331            Some(inner) if ctx.is_binary(inner) => ReturnType::OptionalBinary,
332            _ if ctx.is_binary(ty) => ReturnType::Binary,
333            _ => ReturnType::Json(ty),
334        },
335        None => ReturnType::None,
336    }
337}
338
339fn rust_return_type(
340    ctx: &Context,
341    def: &ServiceDefinition,
342    endpoint: &EndpointDefinition,
343    ty: &ReturnType<'_>,
344) -> TokenStream {
345    match ty {
346        ReturnType::None => quote!(()),
347        ReturnType::Json(ty) => ctx.rust_type(BaseModule::Endpoints, def.service_name(), ty),
348        ReturnType::Binary => {
349            let name = binary_type(endpoint);
350            quote!(Self::#name)
351        }
352        ReturnType::OptionalBinary => {
353            let name = binary_type(endpoint);
354            let option = ctx.option_ident(def.service_name());
355            quote!(#option<Self::#name>)
356        }
357    }
358}
359
360enum ReturnType<'a> {
361    None,
362    Json(&'a Type),
363    Binary,
364    OptionalBinary,
365}
366
367fn has_request_context(endpoint: &EndpointDefinition) -> bool {
368    endpoint
369        .tags()
370        .iter()
371        .any(|t| t == "server-request-context")
372}
373
374fn server_limit_request_size(endpoint: &EndpointDefinition) -> Result<Option<usize>, String> {
375    let mut it = endpoint
376        .tags()
377        .iter()
378        .filter_map(|t| t.strip_prefix("server-limit-request-size:"))
379        .map(|s| s.trim());
380
381    let Some(limit) = it.next() else {
382        return Ok(None);
383    };
384
385    if it.next().is_some() {
386        return Err("invalid endpoint definition includes multiple tags with the `server-limit-request-size` prefix".to_string());
387    }
388
389    human_size::parse(limit).map(Some)
390}