Skip to main content

oxapi_impl/
responses.rs

1//! Response enum generation.
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote, quote_spanned};
5
6use crate::openapi::{Operation, ParsedSpec, ResponseStatus};
7use crate::types::TypeGenerator;
8use crate::{GeneratedTypeKind, ResponseSuffixes, TypeOverride, TypeOverrides};
9
10/// Generator for response enums.
11pub struct ResponseGenerator<'a> {
12    spec: &'a ParsedSpec,
13    type_gen: &'a TypeGenerator,
14    overrides: &'a TypeOverrides,
15    suffixes: &'a ResponseSuffixes,
16}
17
18impl<'a> ResponseGenerator<'a> {
19    pub fn new(
20        spec: &'a ParsedSpec,
21        type_gen: &'a TypeGenerator,
22        overrides: &'a TypeOverrides,
23        suffixes: &'a ResponseSuffixes,
24    ) -> Self {
25        Self {
26            spec,
27            type_gen,
28            overrides,
29            suffixes,
30        }
31    }
32
33    /// Check if a response matches the given kind (Ok for success codes, Err for error codes).
34    fn response_matches_kind(status: &ResponseStatus, kind: GeneratedTypeKind) -> bool {
35        match kind {
36            GeneratedTypeKind::Ok => status.is_success(),
37            GeneratedTypeKind::Err => status.is_error(),
38            GeneratedTypeKind::Query | GeneratedTypeKind::Path => false,
39        }
40    }
41
42    /// Generate response enums for all operations.
43    pub fn generate_all(&self) -> TokenStream {
44        let enums: Vec<_> = self
45            .spec
46            .operations()
47            .map(|op| self.generate_for_operation(op))
48            .collect();
49
50        quote! {
51            #(#enums)*
52        }
53    }
54
55    /// Generate Ok and Err enums for a single operation.
56    pub fn generate_for_operation(&self, op: &Operation) -> TokenStream {
57        let op_name = op.name();
58
59        let ok_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Ok);
60        let err_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Err);
61
62        quote! {
63            #ok_enum
64            #err_enum
65        }
66    }
67
68    /// Get the enum name, applying any rename override.
69    fn get_enum_name(
70        &self,
71        op: &Operation,
72        default_name: &str,
73        kind: GeneratedTypeKind,
74    ) -> syn::Ident {
75        if let Some(TypeOverride::Rename { name, .. }) =
76            self.overrides.get(op.method, &op.path, kind)
77        {
78            name.clone()
79        } else {
80            format_ident!("{}", default_name)
81        }
82    }
83
84    /// Get a value from a variant override, or return a default.
85    fn get_variant_override<T, F>(
86        &self,
87        op: &Operation,
88        status: u16,
89        kind: GeneratedTypeKind,
90        getter: F,
91        default: T,
92    ) -> T
93    where
94        F: FnOnce(&crate::VariantOverride) -> T,
95    {
96        if let Some(TypeOverride::Rename {
97            variant_overrides, ..
98        }) = self.overrides.get(op.method, &op.path, kind)
99            && let Some(ov) = variant_overrides.get(&status)
100        {
101            return getter(ov);
102        }
103        default
104    }
105
106    /// Get the variant name for a status code, applying any rename override.
107    fn get_variant_name(&self, op: &Operation, status: u16, kind: GeneratedTypeKind) -> syn::Ident {
108        self.get_variant_override(
109            op,
110            status,
111            kind,
112            |ov| ov.name.clone(),
113            format_ident!("Status{}", status),
114        )
115    }
116
117    /// Get the inner type name override for a status code, if specified.
118    fn get_inner_type_name(
119        &self,
120        op: &Operation,
121        status: u16,
122        kind: GeneratedTypeKind,
123    ) -> Option<syn::Ident> {
124        self.get_variant_override(op, status, kind, |ov| ov.inner_type_name.clone(), None)
125    }
126
127    /// Get the variant attributes for a status code.
128    fn get_variant_attrs(
129        &self,
130        op: &Operation,
131        status: u16,
132        kind: GeneratedTypeKind,
133    ) -> Vec<TokenStream> {
134        self.get_variant_override(op, status, kind, |ov| ov.attrs.clone(), Vec::new())
135    }
136
137    /// Validate that all variant overrides reference status codes that exist in the spec.
138    /// Returns compile errors for any invalid status codes.
139    fn validate_variant_overrides(&self, op: &Operation, kind: GeneratedTypeKind) -> TokenStream {
140        let Some(TypeOverride::Rename {
141            variant_overrides, ..
142        }) = self.overrides.get(op.method, &op.path, kind)
143        else {
144            return quote! {};
145        };
146
147        // Collect valid status codes from the spec
148        let valid_codes: std::collections::HashSet<u16> = op
149            .responses
150            .iter()
151            .filter(|r| Self::response_matches_kind(&r.status_code, kind))
152            .filter_map(|r| match &r.status_code {
153                ResponseStatus::Code(code) => Some(*code),
154                _ => None,
155            })
156            .collect();
157
158        // Check each override
159        let errors: Vec<_> = variant_overrides
160            .iter()
161            .filter(|(status, _)| !valid_codes.contains(status))
162            .map(|(status, ov)| {
163                let valid_list: Vec<_> = valid_codes.iter().collect();
164                let kind_name = match kind {
165                    GeneratedTypeKind::Ok => "success",
166                    GeneratedTypeKind::Err => "error",
167                    GeneratedTypeKind::Query => "query",
168                    GeneratedTypeKind::Path => "path",
169                };
170                let msg = format!(
171                    "status code {} does not exist in the OpenAPI spec for {} {} (valid {} codes: {:?})",
172                    status, op.method, op.path, kind_name, valid_list
173                );
174                quote_spanned! { ov.name.span() =>
175                    compile_error!(#msg);
176                }
177            })
178            .collect();
179
180        quote! { #(#errors)* }
181    }
182
183    /// Get attributes for an enum. If attrs contains a derive, returns only the attrs.
184    /// Otherwise prepends the default derives.
185    fn get_enum_attrs(&self, op: &Operation, kind: GeneratedTypeKind) -> Vec<TokenStream> {
186        if let Some(TypeOverride::Rename { attrs, .. }) =
187            self.overrides.get(op.method, &op.path, kind)
188        {
189            // Check if any attr is a derive
190            let has_derive = attrs.iter().any(|a| a.to_string().starts_with("#[derive"));
191            if has_derive {
192                return attrs.clone();
193            } else {
194                // Prepend default derives
195                let mut result = vec![self.suffixes.default_derives.clone()];
196                result.extend(attrs.clone());
197                return result;
198            }
199        }
200        Vec::new()
201    }
202
203    /// Collect variants from responses, separating data collection from code generation.
204    fn collect_variants(
205        &self,
206        op: &Operation,
207        kind: GeneratedTypeKind,
208        enum_name: &syn::Ident,
209        responses: &[&crate::openapi::OperationResponse],
210    ) -> CollectedVariants {
211        let mut errors = Vec::new();
212        let mut inline_definitions = Vec::new();
213
214        let variants = responses
215            .iter()
216            .map(|resp| {
217                let (variant_name, status, is_default) = match &resp.status_code {
218                    ResponseStatus::Code(code) => {
219                        (self.get_variant_name(op, *code, kind), *code, false)
220                    }
221                    ResponseStatus::Default => {
222                        // Only error enums should have Default responses
223                        (format_ident!("Default"), 500, true)
224                    }
225                };
226
227                let variant_attrs = self.get_variant_attrs(op, status, kind);
228                let inner_type_override = self.get_inner_type_name(op, status, kind);
229
230                let body_type = if let Some(schema) = &resp.schema {
231                    // Validate: inner type override only allowed for inline schemas
232                    if let Some(ref inner_name) = inner_type_override
233                        && !TypeGenerator::is_inline_schema(schema)
234                    {
235                        errors.push(quote_spanned! { inner_name.span() =>
236                            compile_error!("inner type name override can only be used with inline schemas, not $ref");
237                        });
238                    }
239
240                    // Use override name or default name hint
241                    let name_hint = inner_type_override
242                        .as_ref()
243                        .map(|i| i.to_string())
244                        .unwrap_or_else(|| format!("{}{}", enum_name, variant_name));
245                    let generated = self
246                        .type_gen
247                        .type_for_schema_with_definitions(schema, &name_hint);
248                    inline_definitions.extend(generated.definitions);
249                    Some(generated.type_ref)
250                } else {
251                    None
252                };
253
254                VariantInfo {
255                    name: variant_name,
256                    body_type,
257                    status,
258                    is_default,
259                    attrs: variant_attrs,
260                }
261            })
262            .collect();
263
264        CollectedVariants {
265            variants,
266            errors,
267            inline_definitions,
268        }
269    }
270
271    /// Generate a response enum (Ok or Err) for an operation.
272    fn generate_response_enum(
273        &self,
274        op: &Operation,
275        op_name: &str,
276        kind: GeneratedTypeKind,
277    ) -> TokenStream {
278        // Skip if replaced
279        if self.overrides.is_replaced(op.method, &op.path, kind) {
280            return quote! {};
281        }
282
283        // Validate variant overrides reference valid status codes
284        let validation_errors = self.validate_variant_overrides(op, kind);
285
286        let suffix = match kind {
287            GeneratedTypeKind::Ok => &self.suffixes.ok_suffix,
288            GeneratedTypeKind::Err => &self.suffixes.err_suffix,
289            GeneratedTypeKind::Query | GeneratedTypeKind::Path => unreachable!(),
290        };
291
292        let enum_name = self.get_enum_name(op, &format!("{}{}", op_name, suffix), kind);
293
294        // Collect relevant responses based on kind
295        let responses: Vec<_> = op
296            .responses
297            .iter()
298            .filter(|r| Self::response_matches_kind(&r.status_code, kind))
299            .collect();
300
301        // Handle empty responses with a default enum
302        if responses.is_empty() {
303            return self.generate_empty_fallback(&enum_name, kind);
304        }
305
306        // Collect variants with their errors and inline definitions
307        let CollectedVariants {
308            variants,
309            errors,
310            inline_definitions,
311        } = self.collect_variants(op, kind, &enum_name, &responses);
312
313        // Generate variant definitions
314        let variant_defs = variants.iter().map(|v| {
315            let name = &v.name;
316            let attrs = &v.attrs;
317            match (&v.body_type, v.is_default) {
318                (Some(ty), true) => quote! { #(#attrs)* #name(::axum::http::StatusCode, #ty) },
319                (Some(ty), false) => quote! { #(#attrs)* #name(#ty) },
320                (None, true) => quote! { #(#attrs)* #name(::axum::http::StatusCode) },
321                (None, false) => quote! { #(#attrs)* #name },
322            }
323        });
324
325        // Generate IntoResponse match arms
326        let into_response_arms = variants.iter().map(|v| {
327            let name = &v.name;
328            let status_code = status_code_ident(v.status);
329
330            match (&v.body_type, v.is_default) {
331                (Some(_), true) => quote! {
332                    Self::#name(status, body) => {
333                        (status, ::axum::Json(body)).into_response()
334                    }
335                },
336                (Some(_), false) => quote! {
337                    Self::#name(body) => {
338                        (#status_code, ::axum::Json(body)).into_response()
339                    }
340                },
341                (None, true) => quote! {
342                    Self::#name(status) => {
343                        status.into_response()
344                    }
345                },
346                (None, false) => quote! {
347                    Self::#name => {
348                        #status_code.into_response()
349                    }
350                },
351            }
352        });
353
354        let enum_attrs = self.get_enum_attrs(op, kind);
355        let default_derives = &self.suffixes.default_derives;
356
357        // If we have override attrs, use them; otherwise use default derives
358        let attrs_tokens = if enum_attrs.is_empty() {
359            quote! { #default_derives }
360        } else {
361            quote! { #(#enum_attrs)* }
362        };
363
364        // Use the enum name's span for the definition so error messages
365        // point to the user's enum definition, not the macro attribute
366        let enum_span = enum_name.span();
367
368        let enum_def = quote_spanned! { enum_span =>
369            #attrs_tokens
370            pub enum #enum_name {
371                #(#variant_defs,)*
372            }
373        };
374
375        quote! {
376            #validation_errors
377
378            #(#errors)*
379
380            #(#inline_definitions)*
381
382            #enum_def
383
384            impl ::axum::response::IntoResponse for #enum_name {
385                fn into_response(self) -> ::axum::response::Response {
386                    use ::axum::response::IntoResponse;
387                    match self {
388                        #(#into_response_arms)*
389                    }
390                }
391            }
392        }
393    }
394
395    /// Generate a fallback enum when no responses are defined in the spec.
396    /// For Ok: generates a default Status200 variant.
397    /// For Err: returns empty (no error type generated when spec has no errors).
398    fn generate_empty_fallback(
399        &self,
400        enum_name: &syn::Ident,
401        kind: GeneratedTypeKind,
402    ) -> TokenStream {
403        let span = enum_name.span();
404        match kind {
405            GeneratedTypeKind::Ok => quote_spanned! { span =>
406                #[derive(Debug)]
407                pub enum #enum_name {
408                    Status200,
409                }
410
411                impl ::axum::response::IntoResponse for #enum_name {
412                    fn into_response(self) -> ::axum::response::Response {
413                        match self {
414                            Self::Status200 => {
415                                ::axum::http::StatusCode::OK.into_response()
416                            }
417                        }
418                    }
419                }
420            },
421            // No error type generated when the spec defines no error responses
422            GeneratedTypeKind::Err => quote! {},
423            GeneratedTypeKind::Query | GeneratedTypeKind::Path => unreachable!(),
424        }
425    }
426}
427
428/// Information about a response variant.
429struct VariantInfo {
430    name: syn::Ident,
431    body_type: Option<TokenStream>,
432    status: u16,
433    is_default: bool,
434    attrs: Vec<TokenStream>,
435}
436
437/// Result of collecting variants from responses.
438struct CollectedVariants {
439    variants: Vec<VariantInfo>,
440    errors: Vec<TokenStream>,
441    inline_definitions: Vec<TokenStream>,
442}
443
444/// Well-known HTTP status codes with their constant names.
445const STATUS_CODE_NAMES: &[(u16, &str)] = &[
446    (200, "OK"),
447    (201, "CREATED"),
448    (202, "ACCEPTED"),
449    (204, "NO_CONTENT"),
450    (301, "MOVED_PERMANENTLY"),
451    (302, "FOUND"),
452    (304, "NOT_MODIFIED"),
453    (400, "BAD_REQUEST"),
454    (401, "UNAUTHORIZED"),
455    (403, "FORBIDDEN"),
456    (404, "NOT_FOUND"),
457    (405, "METHOD_NOT_ALLOWED"),
458    (409, "CONFLICT"),
459    (410, "GONE"),
460    (422, "UNPROCESSABLE_ENTITY"),
461    (429, "TOO_MANY_REQUESTS"),
462    (500, "INTERNAL_SERVER_ERROR"),
463    (501, "NOT_IMPLEMENTED"),
464    (502, "BAD_GATEWAY"),
465    (503, "SERVICE_UNAVAILABLE"),
466    (504, "GATEWAY_TIMEOUT"),
467];
468
469/// Generate a StatusCode expression for a numeric status.
470fn status_code_ident(status: u16) -> TokenStream {
471    // Use well-known constants where possible
472    if let Some((_, name)) = STATUS_CODE_NAMES.iter().find(|(code, _)| *code == status) {
473        let ident = format_ident!("{}", name);
474        quote! { ::axum::http::StatusCode::#ident }
475    } else {
476        quote! { ::axum::http::StatusCode::from_u16(#status).unwrap() }
477    }
478}