Skip to main content

oxapi_impl/
responses.rs

1//! Response enum generation.
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote, quote_spanned};
5
6use crate::openapi::{Operation, ParsedSpec, ResponseStatus};
7use crate::types::TypeGenerator;
8use crate::{GeneratedTypeKind, ResponseSuffixes, TypeOverride, TypeOverrides};
9
10/// Generator for response enums.
11pub struct ResponseGenerator<'a> {
12    spec: &'a ParsedSpec,
13    type_gen: &'a TypeGenerator,
14    overrides: &'a TypeOverrides,
15    suffixes: &'a ResponseSuffixes,
16}
17
18impl<'a> ResponseGenerator<'a> {
19    pub fn new(
20        spec: &'a ParsedSpec,
21        type_gen: &'a TypeGenerator,
22        overrides: &'a TypeOverrides,
23        suffixes: &'a ResponseSuffixes,
24    ) -> Self {
25        Self {
26            spec,
27            type_gen,
28            overrides,
29            suffixes,
30        }
31    }
32
33    /// Check if a response matches the given kind (Ok for success codes, Err for error codes).
34    fn response_matches_kind(status: &ResponseStatus, kind: GeneratedTypeKind) -> bool {
35        match kind {
36            GeneratedTypeKind::Ok => status.is_success(),
37            GeneratedTypeKind::Err => status.is_error(),
38            GeneratedTypeKind::Query => false,
39        }
40    }
41
42    /// Generate response enums for all operations.
43    pub fn generate_all(&self) -> TokenStream {
44        let enums: Vec<_> = self
45            .spec
46            .operations()
47            .map(|op| self.generate_for_operation(op))
48            .collect();
49
50        quote! {
51            #(#enums)*
52        }
53    }
54
55    /// Generate Ok and Err enums for a single operation.
56    pub fn generate_for_operation(&self, op: &Operation) -> TokenStream {
57        let op_name = op.name();
58
59        let ok_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Ok);
60        let err_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Err);
61
62        quote! {
63            #ok_enum
64            #err_enum
65        }
66    }
67
68    /// Get the enum name, applying any rename override.
69    fn get_enum_name(
70        &self,
71        op: &Operation,
72        default_name: &str,
73        kind: GeneratedTypeKind,
74    ) -> syn::Ident {
75        if let Some(TypeOverride::Rename { name, .. }) =
76            self.overrides.get(op.method, &op.path, kind)
77        {
78            name.clone()
79        } else {
80            format_ident!("{}", default_name)
81        }
82    }
83
84    /// Get a value from a variant override, or return a default.
85    fn get_variant_override<T, F>(
86        &self,
87        op: &Operation,
88        status: u16,
89        kind: GeneratedTypeKind,
90        getter: F,
91        default: T,
92    ) -> T
93    where
94        F: FnOnce(&crate::VariantOverride) -> T,
95    {
96        if let Some(TypeOverride::Rename {
97            variant_overrides, ..
98        }) = self.overrides.get(op.method, &op.path, kind)
99            && let Some(ov) = variant_overrides.get(&status)
100        {
101            return getter(ov);
102        }
103        default
104    }
105
106    /// Get the variant name for a status code, applying any rename override.
107    fn get_variant_name(&self, op: &Operation, status: u16, kind: GeneratedTypeKind) -> syn::Ident {
108        self.get_variant_override(
109            op,
110            status,
111            kind,
112            |ov| ov.name.clone(),
113            format_ident!("Status{}", status),
114        )
115    }
116
117    /// Get the inner type name override for a status code, if specified.
118    fn get_inner_type_name(
119        &self,
120        op: &Operation,
121        status: u16,
122        kind: GeneratedTypeKind,
123    ) -> Option<syn::Ident> {
124        self.get_variant_override(op, status, kind, |ov| ov.inner_type_name.clone(), None)
125    }
126
127    /// Get the variant attributes for a status code.
128    fn get_variant_attrs(
129        &self,
130        op: &Operation,
131        status: u16,
132        kind: GeneratedTypeKind,
133    ) -> Vec<TokenStream> {
134        self.get_variant_override(op, status, kind, |ov| ov.attrs.clone(), Vec::new())
135    }
136
137    /// Validate that all variant overrides reference status codes that exist in the spec.
138    /// Returns compile errors for any invalid status codes.
139    fn validate_variant_overrides(&self, op: &Operation, kind: GeneratedTypeKind) -> TokenStream {
140        let Some(TypeOverride::Rename {
141            variant_overrides, ..
142        }) = self.overrides.get(op.method, &op.path, kind)
143        else {
144            return quote! {};
145        };
146
147        // Collect valid status codes from the spec
148        let valid_codes: std::collections::HashSet<u16> = op
149            .responses
150            .iter()
151            .filter(|r| Self::response_matches_kind(&r.status_code, kind))
152            .filter_map(|r| match &r.status_code {
153                ResponseStatus::Code(code) => Some(*code),
154                _ => None,
155            })
156            .collect();
157
158        // Check each override
159        let errors: Vec<_> = variant_overrides
160            .iter()
161            .filter(|(status, _)| !valid_codes.contains(status))
162            .map(|(status, ov)| {
163                let valid_list: Vec<_> = valid_codes.iter().collect();
164                let kind_name = match kind {
165                    GeneratedTypeKind::Ok => "success",
166                    GeneratedTypeKind::Err => "error",
167                    GeneratedTypeKind::Query => "query",
168                };
169                let msg = format!(
170                    "status code {} does not exist in the OpenAPI spec for {} {} (valid {} codes: {:?})",
171                    status, op.method, op.path, kind_name, valid_list
172                );
173                quote_spanned! { ov.name.span() =>
174                    compile_error!(#msg);
175                }
176            })
177            .collect();
178
179        quote! { #(#errors)* }
180    }
181
182    /// Get attributes for an enum. If attrs contains a derive, returns only the attrs.
183    /// Otherwise prepends the default derives.
184    fn get_enum_attrs(&self, op: &Operation, kind: GeneratedTypeKind) -> Vec<TokenStream> {
185        if let Some(TypeOverride::Rename { attrs, .. }) =
186            self.overrides.get(op.method, &op.path, kind)
187        {
188            // Check if any attr is a derive
189            let has_derive = attrs.iter().any(|a| a.to_string().starts_with("#[derive"));
190            if has_derive {
191                return attrs.clone();
192            } else {
193                // Prepend default derives
194                let mut result = vec![self.suffixes.default_derives.clone()];
195                result.extend(attrs.clone());
196                return result;
197            }
198        }
199        Vec::new()
200    }
201
202    /// Generate a response enum (Ok or Err) for an operation.
203    fn generate_response_enum(
204        &self,
205        op: &Operation,
206        op_name: &str,
207        kind: GeneratedTypeKind,
208    ) -> TokenStream {
209        // Skip if replaced
210        if self.overrides.is_replaced(op.method, &op.path, kind) {
211            return quote! {};
212        }
213
214        // Validate variant overrides reference valid status codes
215        let validation_errors = self.validate_variant_overrides(op, kind);
216
217        let suffix = match kind {
218            GeneratedTypeKind::Ok => &self.suffixes.ok_suffix,
219            GeneratedTypeKind::Err => &self.suffixes.err_suffix,
220            GeneratedTypeKind::Query => unreachable!(),
221        };
222
223        let enum_name = self.get_enum_name(op, &format!("{}{}", op_name, suffix), kind);
224
225        // Collect relevant responses based on kind
226        let responses: Vec<_> = op
227            .responses
228            .iter()
229            .filter(|r| Self::response_matches_kind(&r.status_code, kind))
230            .collect();
231
232        // Handle empty responses with a default enum
233        if responses.is_empty() {
234            return self.generate_empty_fallback(&enum_name, kind);
235        }
236
237        let mut errors = Vec::new();
238        let mut inline_definitions = Vec::new();
239
240        // Process each response into variant info
241        let variants: Vec<VariantInfo> = responses
242            .iter()
243            .map(|resp| {
244                let (variant_name, status, is_default) = match &resp.status_code {
245                    ResponseStatus::Code(code) => {
246                        (self.get_variant_name(op, *code, kind), *code, false)
247                    }
248                    ResponseStatus::Default => {
249                        // Only error enums should have Default responses
250                        (format_ident!("Default"), 500, true)
251                    }
252                };
253
254                let variant_attrs = self.get_variant_attrs(op, status, kind);
255                let inner_type_override = self.get_inner_type_name(op, status, kind);
256
257                let body_type = if let Some(schema) = &resp.schema {
258                    // Validate: inner type override only allowed for inline schemas
259                    if let Some(ref inner_name) = inner_type_override
260                        && !TypeGenerator::is_inline_schema(schema)
261                    {
262                        errors.push(quote_spanned! { inner_name.span() =>
263                            compile_error!("inner type name override can only be used with inline schemas, not $ref");
264                        });
265                    }
266
267                    // Use override name or default name hint
268                    let name_hint = inner_type_override
269                        .as_ref()
270                        .map(|i| i.to_string())
271                        .unwrap_or_else(|| format!("{}{}", enum_name, variant_name));
272                    let generated = self
273                        .type_gen
274                        .type_for_schema_with_definitions(schema, &name_hint);
275                    inline_definitions.extend(generated.definitions);
276                    Some(generated.type_ref)
277                } else {
278                    None
279                };
280
281                VariantInfo {
282                    name: variant_name,
283                    body_type,
284                    status,
285                    is_default,
286                    attrs: variant_attrs,
287                }
288            })
289            .collect();
290
291        // Generate variant definitions
292        let variant_defs = variants.iter().map(|v| {
293            let name = &v.name;
294            let attrs = &v.attrs;
295            match (&v.body_type, v.is_default) {
296                (Some(ty), true) => quote! { #(#attrs)* #name(::axum::http::StatusCode, #ty) },
297                (Some(ty), false) => quote! { #(#attrs)* #name(#ty) },
298                (None, true) => quote! { #(#attrs)* #name(::axum::http::StatusCode) },
299                (None, false) => quote! { #(#attrs)* #name },
300            }
301        });
302
303        // Generate IntoResponse match arms
304        let into_response_arms = variants.iter().map(|v| {
305            let name = &v.name;
306            let status_code = status_code_ident(v.status);
307
308            match (&v.body_type, v.is_default) {
309                (Some(_), true) => quote! {
310                    Self::#name(status, body) => {
311                        (status, ::axum::Json(body)).into_response()
312                    }
313                },
314                (Some(_), false) => quote! {
315                    Self::#name(body) => {
316                        (#status_code, ::axum::Json(body)).into_response()
317                    }
318                },
319                (None, true) => quote! {
320                    Self::#name(status) => {
321                        status.into_response()
322                    }
323                },
324                (None, false) => quote! {
325                    Self::#name => {
326                        #status_code.into_response()
327                    }
328                },
329            }
330        });
331
332        let enum_attrs = self.get_enum_attrs(op, kind);
333        let default_derives = &self.suffixes.default_derives;
334
335        // If we have override attrs, use them; otherwise use default derives
336        let attrs_tokens = if enum_attrs.is_empty() {
337            quote! { #default_derives }
338        } else {
339            quote! { #(#enum_attrs)* }
340        };
341
342        // Use the enum name's span for the definition so error messages
343        // point to the user's enum definition, not the macro attribute
344        let enum_span = enum_name.span();
345
346        let enum_def = quote_spanned! { enum_span =>
347            #attrs_tokens
348            pub enum #enum_name {
349                #(#variant_defs,)*
350            }
351        };
352
353        quote! {
354            #validation_errors
355
356            #(#errors)*
357
358            #(#inline_definitions)*
359
360            #enum_def
361
362            impl ::axum::response::IntoResponse for #enum_name {
363                fn into_response(self) -> ::axum::response::Response {
364                    use ::axum::response::IntoResponse;
365                    match self {
366                        #(#into_response_arms)*
367                    }
368                }
369            }
370        }
371    }
372
373    /// Generate a fallback enum when no responses are defined in the spec.
374    /// For Ok: generates a default Status200 variant.
375    /// For Err: returns empty (no error type generated when spec has no errors).
376    fn generate_empty_fallback(
377        &self,
378        enum_name: &syn::Ident,
379        kind: GeneratedTypeKind,
380    ) -> TokenStream {
381        let span = enum_name.span();
382        match kind {
383            GeneratedTypeKind::Ok => quote_spanned! { span =>
384                #[derive(Debug)]
385                pub enum #enum_name {
386                    Status200,
387                }
388
389                impl ::axum::response::IntoResponse for #enum_name {
390                    fn into_response(self) -> ::axum::response::Response {
391                        match self {
392                            Self::Status200 => {
393                                ::axum::http::StatusCode::OK.into_response()
394                            }
395                        }
396                    }
397                }
398            },
399            // No error type generated when the spec defines no error responses
400            GeneratedTypeKind::Err => quote! {},
401            GeneratedTypeKind::Query => unreachable!(),
402        }
403    }
404}
405
406/// Information about a response variant.
407struct VariantInfo {
408    name: syn::Ident,
409    body_type: Option<TokenStream>,
410    status: u16,
411    is_default: bool,
412    attrs: Vec<TokenStream>,
413}
414
415/// Well-known HTTP status codes with their constant names.
416const STATUS_CODE_NAMES: &[(u16, &str)] = &[
417    (200, "OK"),
418    (201, "CREATED"),
419    (202, "ACCEPTED"),
420    (204, "NO_CONTENT"),
421    (301, "MOVED_PERMANENTLY"),
422    (302, "FOUND"),
423    (304, "NOT_MODIFIED"),
424    (400, "BAD_REQUEST"),
425    (401, "UNAUTHORIZED"),
426    (403, "FORBIDDEN"),
427    (404, "NOT_FOUND"),
428    (405, "METHOD_NOT_ALLOWED"),
429    (409, "CONFLICT"),
430    (410, "GONE"),
431    (422, "UNPROCESSABLE_ENTITY"),
432    (429, "TOO_MANY_REQUESTS"),
433    (500, "INTERNAL_SERVER_ERROR"),
434    (501, "NOT_IMPLEMENTED"),
435    (502, "BAD_GATEWAY"),
436    (503, "SERVICE_UNAVAILABLE"),
437    (504, "GATEWAY_TIMEOUT"),
438];
439
440/// Generate a StatusCode expression for a numeric status.
441fn status_code_ident(status: u16) -> TokenStream {
442    // Use well-known constants where possible
443    if let Some((_, name)) = STATUS_CODE_NAMES.iter().find(|(code, _)| *code == status) {
444        let ident = format_ident!("{}", name);
445        quote! { ::axum::http::StatusCode::#ident }
446    } else {
447        quote! { ::axum::http::StatusCode::from_u16(#status).unwrap() }
448    }
449}