Skip to main content

conjure_codegen/
servers.rs

1use crate::context::{BaseModule, Context};
2use crate::human_size;
3use crate::types::objects::{
4    ArgumentDefinition, AuthType, EndpointDefinition, ParameterType, ServiceDefinition, Type,
5};
6use heck::ToUpperCamelCase;
7use proc_macro2::{Ident, TokenStream};
8use quote::quote;
9
10#[derive(Copy, Clone)]
11enum Style {
12    Async,
13    Sync,
14    Local,
15}
16
17pub fn generate(ctx: &Context, def: &ServiceDefinition) -> TokenStream {
18    let sync_trait = generate_trait(ctx, def, Style::Sync);
19    let async_trait = generate_trait(ctx, def, Style::Async);
20    let local_trait = generate_trait(ctx, def, Style::Local);
21
22    quote! {
23        use conjure_http::endpoint;
24
25        #sync_trait
26        #async_trait
27        #local_trait
28    }
29}
30
31fn generate_trait(ctx: &Context, def: &ServiceDefinition, style: Style) -> TokenStream {
32    let docs = ctx.docs(def.docs());
33    let service_name = def.service_name().name();
34    let name = trait_name(ctx, def, style);
35    let local = match style {
36        Style::Local => quote!(, local),
37        Style::Async | Style::Sync => quote!(),
38    };
39    let params = params(ctx, def);
40
41    let use_legacy_error_serialization = if ctx.use_legacy_error_serialization() {
42        quote!(, use_legacy_error_serialization)
43    } else {
44        quote!()
45    };
46
47    let binary_types = def
48        .endpoints()
49        .iter()
50        .flat_map(|e| generate_binary_type(ctx, def, e, style));
51
52    let endpoints = def
53        .endpoints()
54        .iter()
55        .map(|e| generate_trait_endpoint(ctx, def, e, style));
56
57    quote! {
58        #docs
59        #[conjure_http::conjure_endpoints(name = #service_name #use_legacy_error_serialization #local)]
60        pub trait #name #params {
61            #(#binary_types)*
62
63            #(#endpoints)*
64        }
65    }
66}
67
68fn trait_name(ctx: &Context, def: &ServiceDefinition, style: Style) -> Ident {
69    match style {
70        Style::Async => ctx.type_name(&format!("Async{}", def.service_name().name())),
71        Style::Local => ctx.type_name(&format!("LocalAsync{}", def.service_name().name())),
72        Style::Sync => ctx.type_name(def.service_name().name()),
73    }
74}
75
76fn params(ctx: &Context, def: &ServiceDefinition) -> TokenStream {
77    let mut params = vec![];
78    if service_has_binary_request_body(ctx, def) {
79        params.push(quote! {
80                #[request_body]
81                I
82        });
83    }
84    if service_has_binary_response_body(ctx, def) {
85        params.push(quote! {
86                #[response_writer]
87                O
88        });
89    }
90
91    if params.is_empty() {
92        quote!()
93    } else {
94        quote!(<#(#params),*>)
95    }
96}
97
98fn service_has_binary_request_body(ctx: &Context, def: &ServiceDefinition) -> bool {
99    def.endpoints()
100        .iter()
101        .any(|e| endpoint_has_binary_request_body(ctx, e))
102}
103
104fn endpoint_has_binary_request_body(ctx: &Context, endpoint: &EndpointDefinition) -> bool {
105    endpoint.args().iter().any(|a| match a.param_type() {
106        ParameterType::Body(_) => ctx.is_binary(a.type_()),
107        _ => false,
108    })
109}
110
111fn service_has_binary_response_body(ctx: &Context, def: &ServiceDefinition) -> bool {
112    def.endpoints()
113        .iter()
114        .any(|e| endpoint_has_binary_response_body(ctx, e))
115}
116
117fn endpoint_has_binary_response_body(ctx: &Context, endpoint: &EndpointDefinition) -> bool {
118    match return_type(ctx, endpoint) {
119        ReturnType::Binary | ReturnType::OptionalBinary => true,
120        ReturnType::None | ReturnType::Json(_) => false,
121    }
122}
123
124fn generate_binary_type(
125    ctx: &Context,
126    def: &ServiceDefinition,
127    endpoint: &EndpointDefinition,
128    style: Style,
129) -> Option<TokenStream> {
130    if endpoint_has_binary_response_body(ctx, endpoint) {
131        let docs = format!(
132            "The body type returned by the `{}` method.",
133            ctx.field_name(endpoint.endpoint_name())
134        );
135        let name = binary_type(endpoint);
136        let bounds = match style {
137            Style::Async => {
138                let send = ctx.send_ident(def.service_name());
139                quote!(conjure_http::server::AsyncWriteBody<O> + 'static + #send)
140            }
141            Style::Local => {
142                quote!(conjure_http::server::LocalAsyncWriteBody<O> + 'static)
143            }
144            Style::Sync => quote!(conjure_http::server::WriteBody<O> + 'static),
145        };
146        Some(quote! {
147            #[doc = #docs]
148            type #name: #bounds;
149        })
150    } else {
151        None
152    }
153}
154
155fn binary_type(endpoint: &EndpointDefinition) -> TokenStream {
156    format!("{}Body", endpoint.endpoint_name().to_upper_camel_case())
157        .parse()
158        .unwrap()
159}
160
161fn generate_trait_endpoint(
162    ctx: &Context,
163    def: &ServiceDefinition,
164    endpoint: &EndpointDefinition,
165    style: Style,
166) -> TokenStream {
167    let docs = ctx.docs(endpoint.docs());
168    let method = endpoint
169        .http_method()
170        .as_str()
171        .parse::<TokenStream>()
172        .unwrap();
173    let path = &**endpoint.http_path();
174    let endpoint_name = &**endpoint.endpoint_name();
175    let async_ = match style {
176        Style::Async | Style::Local => quote!(async),
177        Style::Sync => quote!(),
178    };
179    let name = ctx.field_name(endpoint.endpoint_name());
180    let produces = match endpoint.returns() {
181        Some(ty) => {
182            let produces = produces(ctx, ty);
183            quote!(, produces = #produces)
184        }
185        None => quote!(),
186    };
187
188    let auth_arg = auth_arg(endpoint);
189    let args = endpoint.args().iter().map(|a| arg(ctx, def, endpoint, a));
190    let request_context_arg = request_context_arg(endpoint);
191
192    let result = ctx.result_ident(def.service_name());
193
194    let ret_ty = rust_return_type(ctx, def, endpoint, &return_type(ctx, endpoint));
195    let ret_ty = quote!(#result<#ret_ty, conjure_http::private::Error>);
196
197    // ignore deprecation since the endpoint has to be implemented regardless
198    quote! {
199        #docs
200        #[endpoint(method = #method, path = #path, name = #endpoint_name #produces)]
201        #async_ fn #name(&self #auth_arg #(, #args)* #request_context_arg) -> #ret_ty;
202    }
203}
204
205fn produces(ctx: &Context, ty: &Type) -> TokenStream {
206    match ctx.is_optional(ty) {
207        Some(inner) if ctx.is_binary(inner) => {
208            quote!(conjure_http::server::conjure::OptionalBinaryResponseSerializer)
209        }
210        _ if ctx.is_binary(ty) => quote!(conjure_http::server::conjure::BinaryResponseSerializer),
211        _ if ctx.is_iterable(ty) => {
212            quote!(conjure_http::server::conjure::CollectionResponseSerializer)
213        }
214        _ => quote!(conjure_http::server::StdResponseSerializer),
215    }
216}
217
218fn auth_arg(endpoint: &EndpointDefinition) -> TokenStream {
219    match endpoint.auth() {
220        Some(auth) => {
221            let params = match auth {
222                AuthType::Header(_) => quote!(),
223                AuthType::Cookie(cookie) => {
224                    let name = &cookie.cookie_name();
225                    quote!((cookie_name = #name))
226                }
227            };
228            quote!(, #[auth #params] auth_: conjure_object::BearerToken)
229        }
230        None => quote!(),
231    }
232}
233
234fn arg(
235    ctx: &Context,
236    def: &ServiceDefinition,
237    endpoint: &EndpointDefinition,
238    arg: &ArgumentDefinition,
239) -> TokenStream {
240    let name = ctx.field_name(arg.arg_name());
241
242    let log_as = if name == **arg.arg_name() {
243        quote!()
244    } else {
245        let log_as = &**arg.arg_name();
246        quote!(, log_as = #log_as)
247    };
248
249    let safe = if ctx.is_safe_arg(arg) {
250        quote!(, safe)
251    } else {
252        quote!()
253    };
254
255    let attr = match arg.param_type() {
256        ParameterType::Body(_) => {
257            let deserializer = if ctx.is_optional(arg.type_()).is_some() {
258                let mut decoder =
259                    quote!(conjure_http::server::conjure::OptionalRequestDeserializer);
260                if ctx.is_aliased(arg.type_()) {
261                    let dealiased = ctx.dealiased_type(arg.type_());
262                    let dealiased =
263                        ctx.rust_type(BaseModule::Endpoints, def.service_name(), dealiased);
264                    decoder =
265                        quote!(conjure_http::server::FromRequestDeserializer<#decoder, #dealiased>)
266                }
267                decoder
268            } else if ctx.is_binary(arg.type_()) {
269                quote!(conjure_http::server::conjure::BinaryRequestDeserializer)
270            } else {
271                let param = match server_limit_request_size(endpoint) {
272                    Ok(Some(limit)) => quote!(<#limit>),
273                    Ok(None) => quote!(),
274                    Err(e) => quote!(<compile_error!(#e)>),
275                };
276                quote!(conjure_http::server::StdRequestDeserializer #param)
277            };
278            quote!(#[body(deserializer = #deserializer #log_as #safe)])
279        }
280        ParameterType::Header(header) => {
281            let name = &**header.param_id();
282            let decoder = if ctx.is_optional(arg.type_()).is_some() {
283                optional_decoder(ctx, def, arg.type_())
284            } else {
285                quote!(conjure_http::server::conjure::FromPlainDecoder)
286            };
287            quote!(#[header(name = #name, decoder = #decoder #log_as #safe)])
288        }
289        ParameterType::Path(_) => {
290            let name = &**arg.arg_name();
291            quote! {
292                #[path(
293                    name = #name,
294                    decoder = conjure_http::server::conjure::FromPlainDecoder
295                    #log_as
296                    #safe
297                )]
298            }
299        }
300        ParameterType::Query(query) => {
301            let name = &**query.param_id();
302            let decoder = if ctx.is_optional(arg.type_()).is_some() {
303                optional_decoder(ctx, def, arg.type_())
304            } else if ctx.is_iterable(arg.type_()) {
305                quote!(conjure_http::server::conjure::FromPlainSeqDecoder<_>)
306            } else {
307                quote!(conjure_http::server::conjure::FromPlainDecoder)
308            };
309            quote!(#[query(name = #name, decoder = #decoder #log_as #safe)])
310        }
311    };
312
313    let ty = if ctx.is_binary(arg.type_()) {
314        quote!(I)
315    } else {
316        ctx.rust_type(BaseModule::Endpoints, def.service_name(), arg.type_())
317    };
318    quote!(#attr #name: #ty)
319}
320
321fn optional_decoder(ctx: &Context, def: &ServiceDefinition, ty: &Type) -> TokenStream {
322    let mut decoder = quote!(conjure_http::server::conjure::FromPlainOptionDecoder);
323    if ctx.is_aliased(ty) {
324        let dealiased = ctx.dealiased_type(ty);
325        let dealiased = ctx.rust_type(BaseModule::Endpoints, def.service_name(), dealiased);
326        decoder = quote!(conjure_http::server::FromDecoder<#decoder, #dealiased>)
327    }
328    decoder
329}
330
331fn request_context_arg(endpoint: &EndpointDefinition) -> TokenStream {
332    if has_request_context(endpoint) {
333        quote!(, #[context] request_context_: conjure_http::server::RequestContext<'_>)
334    } else {
335        quote!()
336    }
337}
338
339fn return_type<'a>(ctx: &Context, endpoint: &'a EndpointDefinition) -> ReturnType<'a> {
340    match endpoint.returns() {
341        Some(ty) => match ctx.is_optional(ty) {
342            Some(inner) if ctx.is_binary(inner) => ReturnType::OptionalBinary,
343            _ if ctx.is_binary(ty) => ReturnType::Binary,
344            _ => ReturnType::Json(ty),
345        },
346        None => ReturnType::None,
347    }
348}
349
350fn rust_return_type(
351    ctx: &Context,
352    def: &ServiceDefinition,
353    endpoint: &EndpointDefinition,
354    ty: &ReturnType<'_>,
355) -> TokenStream {
356    match ty {
357        ReturnType::None => quote!(()),
358        ReturnType::Json(ty) => ctx.rust_type(BaseModule::Endpoints, def.service_name(), ty),
359        ReturnType::Binary => {
360            let name = binary_type(endpoint);
361            quote!(Self::#name)
362        }
363        ReturnType::OptionalBinary => {
364            let name = binary_type(endpoint);
365            let option = ctx.option_ident(def.service_name());
366            quote!(#option<Self::#name>)
367        }
368    }
369}
370
371enum ReturnType<'a> {
372    None,
373    Json(&'a Type),
374    Binary,
375    OptionalBinary,
376}
377
378fn has_request_context(endpoint: &EndpointDefinition) -> bool {
379    endpoint
380        .tags()
381        .iter()
382        .any(|t| t == "server-request-context")
383}
384
385fn server_limit_request_size(endpoint: &EndpointDefinition) -> Result<Option<usize>, String> {
386    let mut it = endpoint
387        .tags()
388        .iter()
389        .filter_map(|t| t.strip_prefix("server-limit-request-size:"))
390        .map(|s| s.trim());
391
392    let Some(limit) = it.next() else {
393        return Ok(None);
394    };
395
396    if it.next().is_some() {
397        return Err("invalid endpoint definition includes multiple tags with the `server-limit-request-size` prefix".to_string());
398    }
399
400    human_size::parse(limit).map(Some)
401}