1use crate::context::{BaseModule, Context};
2use crate::human_size;
3use crate::types::objects::{
4 ArgumentDefinition, AuthType, EndpointDefinition, ParameterType, ServiceDefinition, Type,
5};
6use heck::ToUpperCamelCase;
7use proc_macro2::{Ident, TokenStream};
8use quote::quote;
9
10#[derive(Copy, Clone)]
11enum Style {
12 Async,
13 Sync,
14}
15
16pub fn generate(ctx: &Context, def: &ServiceDefinition) -> TokenStream {
17 let sync_trait = generate_trait(ctx, def, Style::Sync);
18 let async_trait = generate_trait(ctx, def, Style::Async);
19
20 quote! {
21 use conjure_http::endpoint;
22
23 #sync_trait
24 #async_trait
25 }
26}
27
28fn generate_trait(ctx: &Context, def: &ServiceDefinition, style: Style) -> TokenStream {
29 let docs = ctx.docs(def.docs());
30 let service_name = def.service_name().name();
31 let name = trait_name(ctx, def, style);
32 let params = params(ctx, def);
33
34 let use_legacy_error_serialization = if ctx.use_legacy_error_serialization() {
35 quote!(, use_legacy_error_serialization)
36 } else {
37 quote!()
38 };
39
40 let binary_types = def
41 .endpoints()
42 .iter()
43 .flat_map(|e| generate_binary_type(ctx, def, e, style));
44
45 let endpoints = def
46 .endpoints()
47 .iter()
48 .map(|e| generate_trait_endpoint(ctx, def, e, style));
49
50 quote! {
51 #docs
52 #[conjure_http::conjure_endpoints(name = #service_name #use_legacy_error_serialization)]
53 pub trait #name #params {
54 #(#binary_types)*
55
56 #(#endpoints)*
57 }
58 }
59}
60
61fn trait_name(ctx: &Context, def: &ServiceDefinition, style: Style) -> Ident {
62 match style {
63 Style::Async => ctx.type_name(&format!("Async{}", def.service_name().name())),
64 Style::Sync => ctx.type_name(def.service_name().name()),
65 }
66}
67
68fn params(ctx: &Context, def: &ServiceDefinition) -> TokenStream {
69 let mut params = vec![];
70 if service_has_binary_request_body(ctx, def) {
71 params.push(quote! {
72 #[request_body]
73 I
74 });
75 }
76 if service_has_binary_response_body(ctx, def) {
77 params.push(quote! {
78 #[response_writer]
79 O
80 });
81 }
82
83 if params.is_empty() {
84 quote!()
85 } else {
86 quote!(<#(#params),*>)
87 }
88}
89
90fn service_has_binary_request_body(ctx: &Context, def: &ServiceDefinition) -> bool {
91 def.endpoints()
92 .iter()
93 .any(|e| endpoint_has_binary_request_body(ctx, e))
94}
95
96fn endpoint_has_binary_request_body(ctx: &Context, endpoint: &EndpointDefinition) -> bool {
97 endpoint.args().iter().any(|a| match a.param_type() {
98 ParameterType::Body(_) => ctx.is_binary(a.type_()),
99 _ => false,
100 })
101}
102
103fn service_has_binary_response_body(ctx: &Context, def: &ServiceDefinition) -> bool {
104 def.endpoints()
105 .iter()
106 .any(|e| endpoint_has_binary_response_body(ctx, e))
107}
108
109fn endpoint_has_binary_response_body(ctx: &Context, endpoint: &EndpointDefinition) -> bool {
110 match return_type(ctx, endpoint) {
111 ReturnType::Binary | ReturnType::OptionalBinary => true,
112 ReturnType::None | ReturnType::Json(_) => false,
113 }
114}
115
116fn generate_binary_type(
117 ctx: &Context,
118 def: &ServiceDefinition,
119 endpoint: &EndpointDefinition,
120 style: Style,
121) -> Option<TokenStream> {
122 if endpoint_has_binary_response_body(ctx, endpoint) {
123 let docs = format!(
124 "The body type returned by the `{}` method.",
125 ctx.field_name(endpoint.endpoint_name())
126 );
127 let name = binary_type(endpoint);
128 let bounds = match style {
129 Style::Async => {
130 let send = ctx.send_ident(def.service_name());
131 quote!(conjure_http::server::AsyncWriteBody<O> + 'static + #send)
132 }
133 Style::Sync => quote!(conjure_http::server::WriteBody<O> + 'static),
134 };
135 Some(quote! {
136 #[doc = #docs]
137 type #name: #bounds;
138 })
139 } else {
140 None
141 }
142}
143
144fn binary_type(endpoint: &EndpointDefinition) -> TokenStream {
145 format!("{}Body", endpoint.endpoint_name().to_upper_camel_case())
146 .parse()
147 .unwrap()
148}
149
150fn generate_trait_endpoint(
151 ctx: &Context,
152 def: &ServiceDefinition,
153 endpoint: &EndpointDefinition,
154 style: Style,
155) -> TokenStream {
156 let docs = ctx.docs(endpoint.docs());
157 let method = endpoint
158 .http_method()
159 .as_str()
160 .parse::<TokenStream>()
161 .unwrap();
162 let path = &**endpoint.http_path();
163 let endpoint_name = &**endpoint.endpoint_name();
164 let async_ = match style {
165 Style::Async => quote!(async),
166 Style::Sync => quote!(),
167 };
168 let name = ctx.field_name(endpoint.endpoint_name());
169 let produces = match endpoint.returns() {
170 Some(ty) => {
171 let produces = produces(ctx, ty);
172 quote!(, produces = #produces)
173 }
174 None => quote!(),
175 };
176
177 let auth_arg = auth_arg(endpoint);
178 let args = endpoint.args().iter().map(|a| arg(ctx, def, endpoint, a));
179 let request_context_arg = request_context_arg(endpoint);
180
181 let result = ctx.result_ident(def.service_name());
182
183 let ret_ty = rust_return_type(ctx, def, endpoint, &return_type(ctx, endpoint));
184 let ret_ty = quote!(#result<#ret_ty, conjure_http::private::Error>);
185
186 quote! {
188 #docs
189 #[endpoint(method = #method, path = #path, name = #endpoint_name #produces)]
190 #async_ fn #name(&self #auth_arg #(, #args)* #request_context_arg) -> #ret_ty;
191 }
192}
193
194fn produces(ctx: &Context, ty: &Type) -> TokenStream {
195 match ctx.is_optional(ty) {
196 Some(inner) if ctx.is_binary(inner) => {
197 quote!(conjure_http::server::conjure::OptionalBinaryResponseSerializer)
198 }
199 _ if ctx.is_binary(ty) => quote!(conjure_http::server::conjure::BinaryResponseSerializer),
200 _ if ctx.is_iterable(ty) => {
201 quote!(conjure_http::server::conjure::CollectionResponseSerializer)
202 }
203 _ => quote!(conjure_http::server::StdResponseSerializer),
204 }
205}
206
207fn auth_arg(endpoint: &EndpointDefinition) -> TokenStream {
208 match endpoint.auth() {
209 Some(auth) => {
210 let params = match auth {
211 AuthType::Header(_) => quote!(),
212 AuthType::Cookie(cookie) => {
213 let name = &cookie.cookie_name();
214 quote!((cookie_name = #name))
215 }
216 };
217 quote!(, #[auth #params] auth_: conjure_object::BearerToken)
218 }
219 None => quote!(),
220 }
221}
222
223fn arg(
224 ctx: &Context,
225 def: &ServiceDefinition,
226 endpoint: &EndpointDefinition,
227 arg: &ArgumentDefinition,
228) -> TokenStream {
229 let name = ctx.field_name(arg.arg_name());
230
231 let log_as = if name == **arg.arg_name() {
232 quote!()
233 } else {
234 let log_as = &**arg.arg_name();
235 quote!(, log_as = #log_as)
236 };
237
238 let safe = if ctx.is_safe_arg(arg) {
239 quote!(, safe)
240 } else {
241 quote!()
242 };
243
244 let attr = match arg.param_type() {
245 ParameterType::Body(_) => {
246 let deserializer = if ctx.is_optional(arg.type_()).is_some() {
247 let mut decoder =
248 quote!(conjure_http::server::conjure::OptionalRequestDeserializer);
249 let dealiased = ctx.dealiased_type(arg.type_());
250 if dealiased != arg.type_() {
251 let dealiased =
252 ctx.rust_type(BaseModule::Endpoints, def.service_name(), dealiased);
253 decoder =
254 quote!(conjure_http::server::FromRequestDeserializer<#decoder, #dealiased>)
255 }
256 decoder
257 } else if ctx.is_binary(arg.type_()) {
258 quote!(conjure_http::server::conjure::BinaryRequestDeserializer)
259 } else {
260 let param = match server_limit_request_size(endpoint) {
261 Ok(Some(limit)) => quote!(<#limit>),
262 Ok(None) => quote!(),
263 Err(e) => quote!(<compile_error!(#e)>),
264 };
265 quote!(conjure_http::server::StdRequestDeserializer #param)
266 };
267 quote!(#[body(deserializer = #deserializer #log_as #safe)])
268 }
269 ParameterType::Header(header) => {
270 let name = &**header.param_id();
271 let decoder = if ctx.is_optional(arg.type_()).is_some() {
272 optional_decoder(ctx, def, arg.type_())
273 } else {
274 quote!(conjure_http::server::conjure::FromPlainDecoder)
275 };
276 quote!(#[header(name = #name, decoder = #decoder #log_as #safe)])
277 }
278 ParameterType::Path(_) => {
279 let name = &**arg.arg_name();
280 quote! {
281 #[path(
282 name = #name,
283 decoder = conjure_http::server::conjure::FromPlainDecoder
284 #log_as
285 #safe
286 )]
287 }
288 }
289 ParameterType::Query(query) => {
290 let name = &**query.param_id();
291 let decoder = if ctx.is_optional(arg.type_()).is_some() {
292 optional_decoder(ctx, def, arg.type_())
293 } else if ctx.is_iterable(arg.type_()) {
294 quote!(conjure_http::server::conjure::FromPlainSeqDecoder<_>)
295 } else {
296 quote!(conjure_http::server::conjure::FromPlainDecoder)
297 };
298 quote!(#[query(name = #name, decoder = #decoder #log_as #safe)])
299 }
300 };
301
302 let ty = if ctx.is_binary(arg.type_()) {
303 quote!(I)
304 } else {
305 ctx.rust_type(BaseModule::Endpoints, def.service_name(), arg.type_())
306 };
307 quote!(#attr #name: #ty)
308}
309
310fn optional_decoder(ctx: &Context, def: &ServiceDefinition, ty: &Type) -> TokenStream {
311 let mut decoder = quote!(conjure_http::server::conjure::FromPlainOptionDecoder);
312 let dealiased = ctx.dealiased_type(ty);
313 if dealiased != ty {
314 let dealiased = ctx.rust_type(BaseModule::Endpoints, def.service_name(), dealiased);
315 decoder = quote!(conjure_http::server::FromDecoder<#decoder, #dealiased>)
316 }
317 decoder
318}
319
320fn request_context_arg(endpoint: &EndpointDefinition) -> TokenStream {
321 if has_request_context(endpoint) {
322 quote!(, #[context] request_context_: conjure_http::server::RequestContext<'_>)
323 } else {
324 quote!()
325 }
326}
327
328fn return_type<'a>(ctx: &Context, endpoint: &'a EndpointDefinition) -> ReturnType<'a> {
329 match endpoint.returns() {
330 Some(ty) => match ctx.is_optional(ty) {
331 Some(inner) if ctx.is_binary(inner) => ReturnType::OptionalBinary,
332 _ if ctx.is_binary(ty) => ReturnType::Binary,
333 _ => ReturnType::Json(ty),
334 },
335 None => ReturnType::None,
336 }
337}
338
339fn rust_return_type(
340 ctx: &Context,
341 def: &ServiceDefinition,
342 endpoint: &EndpointDefinition,
343 ty: &ReturnType<'_>,
344) -> TokenStream {
345 match ty {
346 ReturnType::None => quote!(()),
347 ReturnType::Json(ty) => ctx.rust_type(BaseModule::Endpoints, def.service_name(), ty),
348 ReturnType::Binary => {
349 let name = binary_type(endpoint);
350 quote!(Self::#name)
351 }
352 ReturnType::OptionalBinary => {
353 let name = binary_type(endpoint);
354 let option = ctx.option_ident(def.service_name());
355 quote!(#option<Self::#name>)
356 }
357 }
358}
359
360enum ReturnType<'a> {
361 None,
362 Json(&'a Type),
363 Binary,
364 OptionalBinary,
365}
366
367fn has_request_context(endpoint: &EndpointDefinition) -> bool {
368 endpoint
369 .tags()
370 .iter()
371 .any(|t| t == "server-request-context")
372}
373
374fn server_limit_request_size(endpoint: &EndpointDefinition) -> Result<Option<usize>, String> {
375 let mut it = endpoint
376 .tags()
377 .iter()
378 .filter_map(|t| t.strip_prefix("server-limit-request-size:"))
379 .map(|s| s.trim());
380
381 let Some(limit) = it.next() else {
382 return Ok(None);
383 };
384
385 if it.next().is_some() {
386 return Err("invalid endpoint definition includes multiple tags with the `server-limit-request-size` prefix".to_string());
387 }
388
389 human_size::parse(limit).map(Some)
390}