1use crate::context::{BaseModule, Context};
2use crate::human_size;
3use crate::types::objects::{
4 ArgumentDefinition, AuthType, EndpointDefinition, ParameterType, ServiceDefinition, Type,
5};
6use heck::ToUpperCamelCase;
7use proc_macro2::{Ident, TokenStream};
8use quote::quote;
9
10#[derive(Copy, Clone)]
11enum Style {
12 Async,
13 Sync,
14 Local,
15}
16
17pub fn generate(ctx: &Context, def: &ServiceDefinition) -> TokenStream {
18 let sync_trait = generate_trait(ctx, def, Style::Sync);
19 let async_trait = generate_trait(ctx, def, Style::Async);
20 let local_trait = generate_trait(ctx, def, Style::Local);
21
22 quote! {
23 use conjure_http::endpoint;
24
25 #sync_trait
26 #async_trait
27 #local_trait
28 }
29}
30
31fn generate_trait(ctx: &Context, def: &ServiceDefinition, style: Style) -> TokenStream {
32 let docs = ctx.docs(def.docs());
33 let service_name = def.service_name().name();
34 let name = trait_name(ctx, def, style);
35 let local = match style {
36 Style::Local => quote!(, local),
37 Style::Async | Style::Sync => quote!(),
38 };
39 let params = params(ctx, def);
40
41 let use_legacy_error_serialization = if ctx.use_legacy_error_serialization() {
42 quote!(, use_legacy_error_serialization)
43 } else {
44 quote!()
45 };
46
47 let binary_types = def
48 .endpoints()
49 .iter()
50 .flat_map(|e| generate_binary_type(ctx, def, e, style));
51
52 let endpoints = def
53 .endpoints()
54 .iter()
55 .map(|e| generate_trait_endpoint(ctx, def, e, style));
56
57 quote! {
58 #docs
59 #[conjure_http::conjure_endpoints(name = #service_name #use_legacy_error_serialization #local)]
60 pub trait #name #params {
61 #(#binary_types)*
62
63 #(#endpoints)*
64 }
65 }
66}
67
68fn trait_name(ctx: &Context, def: &ServiceDefinition, style: Style) -> Ident {
69 match style {
70 Style::Async => ctx.type_name(&format!("Async{}", def.service_name().name())),
71 Style::Local => ctx.type_name(&format!("LocalAsync{}", def.service_name().name())),
72 Style::Sync => ctx.type_name(def.service_name().name()),
73 }
74}
75
76fn params(ctx: &Context, def: &ServiceDefinition) -> TokenStream {
77 let mut params = vec![];
78 if service_has_binary_request_body(ctx, def) {
79 params.push(quote! {
80 #[request_body]
81 I
82 });
83 }
84 if service_has_binary_response_body(ctx, def) {
85 params.push(quote! {
86 #[response_writer]
87 O
88 });
89 }
90
91 if params.is_empty() {
92 quote!()
93 } else {
94 quote!(<#(#params),*>)
95 }
96}
97
98fn service_has_binary_request_body(ctx: &Context, def: &ServiceDefinition) -> bool {
99 def.endpoints()
100 .iter()
101 .any(|e| endpoint_has_binary_request_body(ctx, e))
102}
103
104fn endpoint_has_binary_request_body(ctx: &Context, endpoint: &EndpointDefinition) -> bool {
105 endpoint.args().iter().any(|a| match a.param_type() {
106 ParameterType::Body(_) => ctx.is_binary(a.type_()),
107 _ => false,
108 })
109}
110
111fn service_has_binary_response_body(ctx: &Context, def: &ServiceDefinition) -> bool {
112 def.endpoints()
113 .iter()
114 .any(|e| endpoint_has_binary_response_body(ctx, e))
115}
116
117fn endpoint_has_binary_response_body(ctx: &Context, endpoint: &EndpointDefinition) -> bool {
118 match return_type(ctx, endpoint) {
119 ReturnType::Binary | ReturnType::OptionalBinary => true,
120 ReturnType::None | ReturnType::Json(_) => false,
121 }
122}
123
124fn generate_binary_type(
125 ctx: &Context,
126 def: &ServiceDefinition,
127 endpoint: &EndpointDefinition,
128 style: Style,
129) -> Option<TokenStream> {
130 if endpoint_has_binary_response_body(ctx, endpoint) {
131 let docs = format!(
132 "The body type returned by the `{}` method.",
133 ctx.field_name(endpoint.endpoint_name())
134 );
135 let name = binary_type(endpoint);
136 let bounds = match style {
137 Style::Async => {
138 let send = ctx.send_ident(def.service_name());
139 quote!(conjure_http::server::AsyncWriteBody<O> + 'static + #send)
140 }
141 Style::Local => {
142 quote!(conjure_http::server::LocalAsyncWriteBody<O> + 'static)
143 }
144 Style::Sync => quote!(conjure_http::server::WriteBody<O> + 'static),
145 };
146 Some(quote! {
147 #[doc = #docs]
148 type #name: #bounds;
149 })
150 } else {
151 None
152 }
153}
154
155fn binary_type(endpoint: &EndpointDefinition) -> TokenStream {
156 format!("{}Body", endpoint.endpoint_name().to_upper_camel_case())
157 .parse()
158 .unwrap()
159}
160
161fn generate_trait_endpoint(
162 ctx: &Context,
163 def: &ServiceDefinition,
164 endpoint: &EndpointDefinition,
165 style: Style,
166) -> TokenStream {
167 let docs = ctx.docs(endpoint.docs());
168 let method = endpoint
169 .http_method()
170 .as_str()
171 .parse::<TokenStream>()
172 .unwrap();
173 let path = &**endpoint.http_path();
174 let endpoint_name = &**endpoint.endpoint_name();
175 let async_ = match style {
176 Style::Async | Style::Local => quote!(async),
177 Style::Sync => quote!(),
178 };
179 let name = ctx.field_name(endpoint.endpoint_name());
180 let produces = match endpoint.returns() {
181 Some(ty) => {
182 let produces = produces(ctx, ty);
183 quote!(, produces = #produces)
184 }
185 None => quote!(),
186 };
187
188 let auth_arg = auth_arg(endpoint);
189 let args = endpoint.args().iter().map(|a| arg(ctx, def, endpoint, a));
190 let request_context_arg = request_context_arg(endpoint);
191
192 let result = ctx.result_ident(def.service_name());
193
194 let ret_ty = rust_return_type(ctx, def, endpoint, &return_type(ctx, endpoint));
195 let ret_ty = quote!(#result<#ret_ty, conjure_http::private::Error>);
196
197 quote! {
199 #docs
200 #[endpoint(method = #method, path = #path, name = #endpoint_name #produces)]
201 #async_ fn #name(&self #auth_arg #(, #args)* #request_context_arg) -> #ret_ty;
202 }
203}
204
205fn produces(ctx: &Context, ty: &Type) -> TokenStream {
206 match ctx.is_optional(ty) {
207 Some(inner) if ctx.is_binary(inner) => {
208 quote!(conjure_http::server::conjure::OptionalBinaryResponseSerializer)
209 }
210 _ if ctx.is_binary(ty) => quote!(conjure_http::server::conjure::BinaryResponseSerializer),
211 _ if ctx.is_iterable(ty) => {
212 quote!(conjure_http::server::conjure::CollectionResponseSerializer)
213 }
214 _ => quote!(conjure_http::server::StdResponseSerializer),
215 }
216}
217
218fn auth_arg(endpoint: &EndpointDefinition) -> TokenStream {
219 match endpoint.auth() {
220 Some(auth) => {
221 let params = match auth {
222 AuthType::Header(_) => quote!(),
223 AuthType::Cookie(cookie) => {
224 let name = &cookie.cookie_name();
225 quote!((cookie_name = #name))
226 }
227 };
228 quote!(, #[auth #params] auth_: conjure_object::BearerToken)
229 }
230 None => quote!(),
231 }
232}
233
234fn arg(
235 ctx: &Context,
236 def: &ServiceDefinition,
237 endpoint: &EndpointDefinition,
238 arg: &ArgumentDefinition,
239) -> TokenStream {
240 let name = ctx.field_name(arg.arg_name());
241
242 let log_as = if name == **arg.arg_name() {
243 quote!()
244 } else {
245 let log_as = &**arg.arg_name();
246 quote!(, log_as = #log_as)
247 };
248
249 let safe = if ctx.is_safe_arg(arg) {
250 quote!(, safe)
251 } else {
252 quote!()
253 };
254
255 let attr = match arg.param_type() {
256 ParameterType::Body(_) => {
257 let deserializer = if ctx.is_optional(arg.type_()).is_some() {
258 let mut decoder =
259 quote!(conjure_http::server::conjure::OptionalRequestDeserializer);
260 if ctx.is_aliased(arg.type_()) {
261 let dealiased = ctx.dealiased_type(arg.type_());
262 let dealiased =
263 ctx.rust_type(BaseModule::Endpoints, def.service_name(), dealiased);
264 decoder =
265 quote!(conjure_http::server::FromRequestDeserializer<#decoder, #dealiased>)
266 }
267 decoder
268 } else if ctx.is_binary(arg.type_()) {
269 quote!(conjure_http::server::conjure::BinaryRequestDeserializer)
270 } else {
271 let param = match server_limit_request_size(endpoint) {
272 Ok(Some(limit)) => quote!(<#limit>),
273 Ok(None) => quote!(),
274 Err(e) => quote!(<compile_error!(#e)>),
275 };
276 quote!(conjure_http::server::StdRequestDeserializer #param)
277 };
278 quote!(#[body(deserializer = #deserializer #log_as #safe)])
279 }
280 ParameterType::Header(header) => {
281 let name = &**header.param_id();
282 let decoder = if ctx.is_optional(arg.type_()).is_some() {
283 optional_decoder(ctx, def, arg.type_())
284 } else {
285 quote!(conjure_http::server::conjure::FromPlainDecoder)
286 };
287 quote!(#[header(name = #name, decoder = #decoder #log_as #safe)])
288 }
289 ParameterType::Path(_) => {
290 let name = &**arg.arg_name();
291 quote! {
292 #[path(
293 name = #name,
294 decoder = conjure_http::server::conjure::FromPlainDecoder
295 #log_as
296 #safe
297 )]
298 }
299 }
300 ParameterType::Query(query) => {
301 let name = &**query.param_id();
302 let decoder = if ctx.is_optional(arg.type_()).is_some() {
303 optional_decoder(ctx, def, arg.type_())
304 } else if ctx.is_iterable(arg.type_()) {
305 quote!(conjure_http::server::conjure::FromPlainSeqDecoder<_>)
306 } else {
307 quote!(conjure_http::server::conjure::FromPlainDecoder)
308 };
309 quote!(#[query(name = #name, decoder = #decoder #log_as #safe)])
310 }
311 };
312
313 let ty = if ctx.is_binary(arg.type_()) {
314 quote!(I)
315 } else {
316 ctx.rust_type(BaseModule::Endpoints, def.service_name(), arg.type_())
317 };
318 quote!(#attr #name: #ty)
319}
320
321fn optional_decoder(ctx: &Context, def: &ServiceDefinition, ty: &Type) -> TokenStream {
322 let mut decoder = quote!(conjure_http::server::conjure::FromPlainOptionDecoder);
323 if ctx.is_aliased(ty) {
324 let dealiased = ctx.dealiased_type(ty);
325 let dealiased = ctx.rust_type(BaseModule::Endpoints, def.service_name(), dealiased);
326 decoder = quote!(conjure_http::server::FromDecoder<#decoder, #dealiased>)
327 }
328 decoder
329}
330
331fn request_context_arg(endpoint: &EndpointDefinition) -> TokenStream {
332 if has_request_context(endpoint) {
333 quote!(, #[context] request_context_: conjure_http::server::RequestContext<'_>)
334 } else {
335 quote!()
336 }
337}
338
339fn return_type<'a>(ctx: &Context, endpoint: &'a EndpointDefinition) -> ReturnType<'a> {
340 match endpoint.returns() {
341 Some(ty) => match ctx.is_optional(ty) {
342 Some(inner) if ctx.is_binary(inner) => ReturnType::OptionalBinary,
343 _ if ctx.is_binary(ty) => ReturnType::Binary,
344 _ => ReturnType::Json(ty),
345 },
346 None => ReturnType::None,
347 }
348}
349
350fn rust_return_type(
351 ctx: &Context,
352 def: &ServiceDefinition,
353 endpoint: &EndpointDefinition,
354 ty: &ReturnType<'_>,
355) -> TokenStream {
356 match ty {
357 ReturnType::None => quote!(()),
358 ReturnType::Json(ty) => ctx.rust_type(BaseModule::Endpoints, def.service_name(), ty),
359 ReturnType::Binary => {
360 let name = binary_type(endpoint);
361 quote!(Self::#name)
362 }
363 ReturnType::OptionalBinary => {
364 let name = binary_type(endpoint);
365 let option = ctx.option_ident(def.service_name());
366 quote!(#option<Self::#name>)
367 }
368 }
369}
370
371enum ReturnType<'a> {
372 None,
373 Json(&'a Type),
374 Binary,
375 OptionalBinary,
376}
377
378fn has_request_context(endpoint: &EndpointDefinition) -> bool {
379 endpoint
380 .tags()
381 .iter()
382 .any(|t| t == "server-request-context")
383}
384
385fn server_limit_request_size(endpoint: &EndpointDefinition) -> Result<Option<usize>, String> {
386 let mut it = endpoint
387 .tags()
388 .iter()
389 .filter_map(|t| t.strip_prefix("server-limit-request-size:"))
390 .map(|s| s.trim());
391
392 let Some(limit) = it.next() else {
393 return Ok(None);
394 };
395
396 if it.next().is_some() {
397 return Err("invalid endpoint definition includes multiple tags with the `server-limit-request-size` prefix".to_string());
398 }
399
400 human_size::parse(limit).map(Some)
401}