Skip to main content

oxapi_impl/
responses.rs

1//! Response enum generation.
2
3use heck::AsSnakeCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote, quote_spanned};
6
7use crate::openapi::{Operation, ParsedSpec, ResponseHeader, ResponseStatus};
8use crate::types::TypeGenerator;
9use crate::{GeneratedTypeKind, ResponseSuffixes, TypeOverride, TypeOverrides};
10
11/// Generator for response enums.
12pub struct ResponseGenerator<'a> {
13    spec: &'a ParsedSpec,
14    type_gen: &'a TypeGenerator,
15    overrides: &'a TypeOverrides,
16    suffixes: &'a ResponseSuffixes,
17}
18
19impl<'a> ResponseGenerator<'a> {
20    pub fn new(
21        spec: &'a ParsedSpec,
22        type_gen: &'a TypeGenerator,
23        overrides: &'a TypeOverrides,
24        suffixes: &'a ResponseSuffixes,
25    ) -> Self {
26        Self {
27            spec,
28            type_gen,
29            overrides,
30            suffixes,
31        }
32    }
33
34    /// Check if a response matches the given kind (Ok for success codes, Err for error codes).
35    fn response_matches_kind(status: &ResponseStatus, kind: GeneratedTypeKind) -> bool {
36        match kind {
37            GeneratedTypeKind::Ok => status.is_success(),
38            GeneratedTypeKind::Err => status.is_error(),
39            GeneratedTypeKind::Query | GeneratedTypeKind::Path => false,
40        }
41    }
42
43    /// Generate response enums for all operations.
44    pub fn generate_all(&self) -> TokenStream {
45        let enums: Vec<_> = self
46            .spec
47            .operations()
48            .map(|op| self.generate_for_operation(op))
49            .collect();
50
51        quote! {
52            #(#enums)*
53        }
54    }
55
56    /// Generate Ok and Err enums for a single operation.
57    pub fn generate_for_operation(&self, op: &Operation) -> TokenStream {
58        let op_name = op.name();
59
60        let ok_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Ok);
61        let err_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Err);
62
63        quote! {
64            #ok_enum
65            #err_enum
66        }
67    }
68
69    /// Get the enum name, applying any rename override.
70    fn get_enum_name(
71        &self,
72        op: &Operation,
73        default_name: &str,
74        kind: GeneratedTypeKind,
75    ) -> syn::Ident {
76        if let Some(TypeOverride::Rename { name, .. }) =
77            self.overrides.get(op.method, &op.path, kind)
78        {
79            name.clone()
80        } else {
81            format_ident!("{}", default_name)
82        }
83    }
84
85    /// Get a value from a variant override, or return a default.
86    fn get_variant_override<T, F>(
87        &self,
88        op: &Operation,
89        status: u16,
90        kind: GeneratedTypeKind,
91        getter: F,
92        default: T,
93    ) -> T
94    where
95        F: FnOnce(&crate::VariantOverride) -> T,
96    {
97        if let Some(TypeOverride::Rename {
98            variant_overrides, ..
99        }) = self.overrides.get(op.method, &op.path, kind)
100            && let Some(ov) = variant_overrides.get(&status)
101        {
102            return getter(ov);
103        }
104        default
105    }
106
107    /// Get the variant name for a status code, applying any rename override.
108    fn get_variant_name(&self, op: &Operation, status: u16, kind: GeneratedTypeKind) -> syn::Ident {
109        self.get_variant_override(
110            op,
111            status,
112            kind,
113            |ov| ov.name.clone(),
114            format_ident!("Status{}", status),
115        )
116    }
117
118    /// Get the inner type name override for a status code, if specified.
119    fn get_inner_type_name(
120        &self,
121        op: &Operation,
122        status: u16,
123        kind: GeneratedTypeKind,
124    ) -> Option<syn::Ident> {
125        self.get_variant_override(op, status, kind, |ov| ov.inner_type_name.clone(), None)
126    }
127
128    /// Get the variant attributes for a status code.
129    fn get_variant_attrs(
130        &self,
131        op: &Operation,
132        status: u16,
133        kind: GeneratedTypeKind,
134    ) -> Vec<TokenStream> {
135        self.get_variant_override(op, status, kind, |ov| ov.attrs.clone(), Vec::new())
136    }
137
138    /// Validate that all variant overrides reference status codes that exist in the spec.
139    /// Returns compile errors for any invalid status codes.
140    fn validate_variant_overrides(&self, op: &Operation, kind: GeneratedTypeKind) -> TokenStream {
141        let Some(TypeOverride::Rename {
142            variant_overrides, ..
143        }) = self.overrides.get(op.method, &op.path, kind)
144        else {
145            return quote! {};
146        };
147
148        // Collect valid status codes from the spec
149        let valid_codes: std::collections::HashSet<u16> = op
150            .responses
151            .iter()
152            .filter(|r| Self::response_matches_kind(&r.status_code, kind))
153            .filter_map(|r| match &r.status_code {
154                ResponseStatus::Code(code) => Some(*code),
155                _ => None,
156            })
157            .collect();
158
159        // Check each override
160        let errors: Vec<_> = variant_overrides
161            .iter()
162            .filter(|(status, _)| !valid_codes.contains(status))
163            .map(|(status, ov)| {
164                let valid_list: Vec<_> = valid_codes.iter().collect();
165                let kind_name = match kind {
166                    GeneratedTypeKind::Ok => "success",
167                    GeneratedTypeKind::Err => "error",
168                    GeneratedTypeKind::Query => "query",
169                    GeneratedTypeKind::Path => "path",
170                };
171                let msg = format!(
172                    "status code {} does not exist in the OpenAPI spec for {} {} (valid {} codes: {:?})",
173                    status, op.method, op.path, kind_name, valid_list
174                );
175                quote_spanned! { ov.name.span() =>
176                    compile_error!(#msg);
177                }
178            })
179            .collect();
180
181        quote! { #(#errors)* }
182    }
183
184    /// Get attributes for an enum. If attrs contains a derive, returns only the attrs.
185    /// Otherwise prepends the default derives.
186    fn get_enum_attrs(&self, op: &Operation, kind: GeneratedTypeKind) -> Vec<TokenStream> {
187        if let Some(TypeOverride::Rename { attrs, .. }) =
188            self.overrides.get(op.method, &op.path, kind)
189        {
190            // Check if any attr is a derive
191            let has_derive = attrs.iter().any(|a| a.to_string().starts_with("#[derive"));
192            if has_derive {
193                return attrs.clone();
194            } else {
195                // Prepend default derives
196                let mut result = vec![self.suffixes.default_derives.clone()];
197                result.extend(attrs.clone());
198                return result;
199            }
200        }
201        Vec::new()
202    }
203
204    /// Collect variants from responses, separating data collection from code generation.
205    fn collect_variants(
206        &self,
207        op: &Operation,
208        kind: GeneratedTypeKind,
209        enum_name: &syn::Ident,
210        responses: &[&crate::openapi::OperationResponse],
211    ) -> CollectedVariants {
212        let mut errors = Vec::new();
213        let mut inline_definitions = Vec::new();
214
215        let variants = responses
216            .iter()
217            .map(|resp| {
218                let (variant_name, status, is_default) = match &resp.status_code {
219                    ResponseStatus::Code(code) => {
220                        (self.get_variant_name(op, *code, kind), *code, false)
221                    }
222                    ResponseStatus::Default => {
223                        // Only error enums should have Default responses
224                        (format_ident!("Default"), 500, true)
225                    }
226                };
227
228                let variant_attrs = self.get_variant_attrs(op, status, kind);
229                let inner_type_override = self.get_inner_type_name(op, status, kind);
230
231                let body_type = if let Some(schema) = &resp.schema {
232                    // Validate: inner type override only allowed for inline schemas
233                    if let Some(ref inner_name) = inner_type_override
234                        && !TypeGenerator::is_inline_schema(schema)
235                    {
236                        errors.push(quote_spanned! { inner_name.span() =>
237                            compile_error!("inner type name override can only be used with inline schemas, not $ref");
238                        });
239                    }
240
241                    // Use override name or default name hint
242                    let name_hint = inner_type_override
243                        .as_ref()
244                        .map(|i| i.to_string())
245                        .unwrap_or_else(|| format!("{}{}", enum_name, variant_name));
246                    let generated = self
247                        .type_gen
248                        .type_for_schema_with_definitions(schema, &name_hint);
249                    inline_definitions.extend(generated.definitions);
250                    Some(generated.type_ref)
251                } else {
252                    None
253                };
254
255                VariantInfo {
256                    name: variant_name,
257                    body_type,
258                    status,
259                    is_default,
260                    attrs: variant_attrs,
261                    headers: resp.headers.clone(),
262                }
263            })
264            .collect();
265
266        CollectedVariants {
267            variants,
268            errors,
269            inline_definitions,
270        }
271    }
272
273    /// Generate a response enum (Ok or Err) for an operation.
274    fn generate_response_enum(
275        &self,
276        op: &Operation,
277        op_name: &str,
278        kind: GeneratedTypeKind,
279    ) -> TokenStream {
280        // Skip if replaced
281        if self.overrides.is_replaced(op.method, &op.path, kind) {
282            return quote! {};
283        }
284
285        // Validate variant overrides reference valid status codes
286        let validation_errors = self.validate_variant_overrides(op, kind);
287
288        let suffix = match kind {
289            GeneratedTypeKind::Ok => &self.suffixes.ok_suffix,
290            GeneratedTypeKind::Err => &self.suffixes.err_suffix,
291            GeneratedTypeKind::Query | GeneratedTypeKind::Path => unreachable!(),
292        };
293
294        let enum_name = self.get_enum_name(op, &format!("{}{}", op_name, suffix), kind);
295
296        // Collect relevant responses based on kind
297        let responses: Vec<_> = op
298            .responses
299            .iter()
300            .filter(|r| Self::response_matches_kind(&r.status_code, kind))
301            .collect();
302
303        // Handle empty responses with a default enum
304        if responses.is_empty() {
305            return self.generate_empty_fallback(&enum_name, kind);
306        }
307
308        // Collect variants with their errors and inline definitions
309        let CollectedVariants {
310            variants,
311            errors,
312            inline_definitions,
313        } = self.collect_variants(op, kind, &enum_name, &responses);
314
315        // Generate per-variant header structs
316        let header_structs: Vec<_> = variants
317            .iter()
318            .filter(|v| !v.headers.is_empty())
319            .map(|v| self.generate_header_struct(&enum_name, v))
320            .collect();
321
322        // Generate variant definitions
323        let variant_defs = variants.iter().map(|v| {
324            let name = &v.name;
325            let attrs = &v.attrs;
326            let has_headers = !v.headers.is_empty();
327
328            match (has_headers, &v.body_type, v.is_default) {
329                // Has headers → struct variant
330                (true, Some(ty), true) => {
331                    let hs = header_struct_ident(&enum_name, &v.name);
332                    quote! { #(#attrs)* #name { headers: #hs, status: ::axum::http::StatusCode, body: #ty } }
333                }
334                (true, Some(ty), false) => {
335                    let hs = header_struct_ident(&enum_name, &v.name);
336                    quote! { #(#attrs)* #name { headers: #hs, body: #ty } }
337                }
338                (true, None, true) => {
339                    let hs = header_struct_ident(&enum_name, &v.name);
340                    quote! { #(#attrs)* #name { headers: #hs, status: ::axum::http::StatusCode } }
341                }
342                (true, None, false) => {
343                    let hs = header_struct_ident(&enum_name, &v.name);
344                    quote! { #(#attrs)* #name { headers: #hs } }
345                }
346                // No headers → tuple or unit variant (unchanged)
347                (false, Some(ty), true) => quote! { #(#attrs)* #name(::axum::http::StatusCode, #ty) },
348                (false, Some(ty), false) => quote! { #(#attrs)* #name(#ty) },
349                (false, None, true) => quote! { #(#attrs)* #name(::axum::http::StatusCode) },
350                (false, None, false) => quote! { #(#attrs)* #name },
351            }
352        });
353
354        // Generate IntoResponse match arms
355        let into_response_arms = variants.iter().map(|v| {
356            let name = &v.name;
357            let status_code = status_code_ident(v.status);
358            let has_headers = !v.headers.is_empty();
359
360            let insert_headers = if has_headers {
361                self.generate_header_insertions(&v.headers)
362            } else {
363                quote! {}
364            };
365
366            match (has_headers, &v.body_type, v.is_default) {
367                // With headers → struct variant destructuring
368                (true, Some(_), true) => quote! {
369                    Self::#name { headers, status, body } => {
370                        let mut response = (status, ::axum::Json(body)).into_response();
371                        #insert_headers
372                        response
373                    }
374                },
375                (true, Some(_), false) => quote! {
376                    Self::#name { headers, body } => {
377                        let mut response = (#status_code, ::axum::Json(body)).into_response();
378                        #insert_headers
379                        response
380                    }
381                },
382                (true, None, true) => quote! {
383                    Self::#name { headers, status } => {
384                        let mut response = status.into_response();
385                        #insert_headers
386                        response
387                    }
388                },
389                (true, None, false) => quote! {
390                    Self::#name { headers } => {
391                        let mut response = #status_code.into_response();
392                        #insert_headers
393                        response
394                    }
395                },
396                // Without headers → unchanged
397                (false, Some(_), true) => quote! {
398                    Self::#name(status, body) => {
399                        (status, ::axum::Json(body)).into_response()
400                    }
401                },
402                (false, Some(_), false) => quote! {
403                    Self::#name(body) => {
404                        (#status_code, ::axum::Json(body)).into_response()
405                    }
406                },
407                (false, None, true) => quote! {
408                    Self::#name(status) => {
409                        status.into_response()
410                    }
411                },
412                (false, None, false) => quote! {
413                    Self::#name => {
414                        #status_code.into_response()
415                    }
416                },
417            }
418        });
419
420        let enum_attrs = self.get_enum_attrs(op, kind);
421        let default_derives = &self.suffixes.default_derives;
422
423        // If we have override attrs, use them; otherwise use default derives
424        let attrs_tokens = if enum_attrs.is_empty() {
425            quote! { #default_derives }
426        } else {
427            quote! { #(#enum_attrs)* }
428        };
429
430        // Use the enum name's span for the definition so error messages
431        // point to the user's enum definition, not the macro attribute
432        let enum_span = enum_name.span();
433
434        let enum_def = quote_spanned! { enum_span =>
435            #attrs_tokens
436            pub enum #enum_name {
437                #(#variant_defs,)*
438            }
439        };
440
441        quote! {
442            #validation_errors
443
444            #(#errors)*
445
446            #(#inline_definitions)*
447
448            #(#header_structs)*
449
450            #enum_def
451
452            impl ::axum::response::IntoResponse for #enum_name {
453                fn into_response(self) -> ::axum::response::Response {
454                    use ::axum::response::IntoResponse;
455                    match self {
456                        #(#into_response_arms)*
457                    }
458                }
459            }
460        }
461    }
462
463    /// Generate a header struct for a variant that has response headers.
464    fn generate_header_struct(&self, enum_name: &syn::Ident, variant: &VariantInfo) -> TokenStream {
465        let struct_name = header_struct_ident(enum_name, &variant.name);
466
467        let fields = variant.headers.iter().map(|h| {
468            let field_name = format_ident!("{}", AsSnakeCase(&h.name).to_string());
469            let inner_type = if let Some(schema) = &h.schema {
470                self.type_gen.type_for_schema(schema, &h.name)
471            } else {
472                quote! { String }
473            };
474
475            let field_type = if h.required {
476                inner_type
477            } else {
478                quote! { Option<#inner_type> }
479            };
480
481            quote! { pub #field_name: #field_type }
482        });
483
484        quote! {
485            #[derive(Debug, Default)]
486            pub struct #struct_name {
487                #(#fields,)*
488            }
489        }
490    }
491
492    /// Generate header insertion code for the IntoResponse impl.
493    fn generate_header_insertions(&self, headers: &[ResponseHeader]) -> TokenStream {
494        let insertions = headers.iter().map(|h| {
495            let field_name = format_ident!("{}", AsSnakeCase(&h.name).to_string());
496            let header_name = h.name.to_lowercase();
497
498            if h.required {
499                quote! {
500                    response.headers_mut().insert(
501                        ::axum::http::HeaderName::from_static(#header_name),
502                        ::axum::http::HeaderValue::from_str(&headers.#field_name.to_string()).unwrap(),
503                    );
504                }
505            } else {
506                quote! {
507                    if let Some(ref v) = headers.#field_name {
508                        response.headers_mut().insert(
509                            ::axum::http::HeaderName::from_static(#header_name),
510                            ::axum::http::HeaderValue::from_str(&v.to_string()).unwrap(),
511                        );
512                    }
513                }
514            }
515        });
516
517        quote! { #(#insertions)* }
518    }
519
520    /// Generate a fallback enum when no responses are defined in the spec.
521    /// For Ok: generates a default Status200 variant.
522    /// For Err: returns empty (no error type generated when spec has no errors).
523    fn generate_empty_fallback(
524        &self,
525        enum_name: &syn::Ident,
526        kind: GeneratedTypeKind,
527    ) -> TokenStream {
528        let span = enum_name.span();
529        match kind {
530            GeneratedTypeKind::Ok => quote_spanned! { span =>
531                #[derive(Debug)]
532                pub enum #enum_name {
533                    Status200,
534                }
535
536                impl ::axum::response::IntoResponse for #enum_name {
537                    fn into_response(self) -> ::axum::response::Response {
538                        match self {
539                            Self::Status200 => {
540                                ::axum::http::StatusCode::OK.into_response()
541                            }
542                        }
543                    }
544                }
545            },
546            // No error type generated when the spec defines no error responses
547            GeneratedTypeKind::Err => quote! {},
548            GeneratedTypeKind::Query | GeneratedTypeKind::Path => unreachable!(),
549        }
550    }
551}
552
553/// Information about a response variant.
554struct VariantInfo {
555    name: syn::Ident,
556    body_type: Option<TokenStream>,
557    status: u16,
558    is_default: bool,
559    attrs: Vec<TokenStream>,
560    headers: Vec<ResponseHeader>,
561}
562
563/// Result of collecting variants from responses.
564struct CollectedVariants {
565    variants: Vec<VariantInfo>,
566    errors: Vec<TokenStream>,
567    inline_definitions: Vec<TokenStream>,
568}
569
570/// Well-known HTTP status codes with their constant names.
571const STATUS_CODE_NAMES: &[(u16, &str)] = &[
572    (200, "OK"),
573    (201, "CREATED"),
574    (202, "ACCEPTED"),
575    (204, "NO_CONTENT"),
576    (301, "MOVED_PERMANENTLY"),
577    (302, "FOUND"),
578    (303, "SEE_OTHER"),
579    (304, "NOT_MODIFIED"),
580    (307, "TEMPORARY_REDIRECT"),
581    (308, "PERMANENT_REDIRECT"),
582    (400, "BAD_REQUEST"),
583    (401, "UNAUTHORIZED"),
584    (403, "FORBIDDEN"),
585    (404, "NOT_FOUND"),
586    (405, "METHOD_NOT_ALLOWED"),
587    (409, "CONFLICT"),
588    (410, "GONE"),
589    (422, "UNPROCESSABLE_ENTITY"),
590    (429, "TOO_MANY_REQUESTS"),
591    (500, "INTERNAL_SERVER_ERROR"),
592    (501, "NOT_IMPLEMENTED"),
593    (502, "BAD_GATEWAY"),
594    (503, "SERVICE_UNAVAILABLE"),
595    (504, "GATEWAY_TIMEOUT"),
596];
597
598/// Generate the header struct identifier: `{EnumName}{VariantName}Headers`.
599fn header_struct_ident(enum_name: &syn::Ident, variant_name: &syn::Ident) -> syn::Ident {
600    format_ident!("{}{}Headers", enum_name, variant_name)
601}
602
603/// Generate a StatusCode expression for a numeric status.
604fn status_code_ident(status: u16) -> TokenStream {
605    // Use well-known constants where possible
606    if let Some((_, name)) = STATUS_CODE_NAMES.iter().find(|(code, _)| *code == status) {
607        let ident = format_ident!("{}", name);
608        quote! { ::axum::http::StatusCode::#ident }
609    } else {
610        quote! { ::axum::http::StatusCode::from_u16(#status).unwrap() }
611    }
612}