Skip to main content

oxapi_impl/
responses.rs

1//! Response enum generation.
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote, quote_spanned};
5
6use crate::openapi::{Operation, ParsedSpec, ResponseStatus};
7use crate::types::TypeGenerator;
8use crate::{GeneratedTypeKind, ResponseSuffixes, TypeOverride, TypeOverrides};
9
10/// Generator for response enums.
11pub struct ResponseGenerator<'a> {
12    spec: &'a ParsedSpec,
13    type_gen: &'a TypeGenerator,
14    overrides: &'a TypeOverrides,
15    suffixes: &'a ResponseSuffixes,
16}
17
18impl<'a> ResponseGenerator<'a> {
19    pub fn new(
20        spec: &'a ParsedSpec,
21        type_gen: &'a TypeGenerator,
22        overrides: &'a TypeOverrides,
23        suffixes: &'a ResponseSuffixes,
24    ) -> Self {
25        Self {
26            spec,
27            type_gen,
28            overrides,
29            suffixes,
30        }
31    }
32
33    /// Generate response enums for all operations.
34    pub fn generate_all(&self) -> TokenStream {
35        let enums: Vec<_> = self
36            .spec
37            .operations()
38            .map(|op| self.generate_for_operation(op))
39            .collect();
40
41        quote! {
42            #(#enums)*
43        }
44    }
45
46    /// Generate Ok and Err enums for a single operation.
47    pub fn generate_for_operation(&self, op: &Operation) -> TokenStream {
48        let op_name = op.name();
49
50        let ok_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Ok);
51        let err_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Err);
52
53        quote! {
54            #ok_enum
55            #err_enum
56        }
57    }
58
59    /// Get the enum name, applying any rename override.
60    fn get_enum_name(
61        &self,
62        op: &Operation,
63        default_name: &str,
64        kind: GeneratedTypeKind,
65    ) -> syn::Ident {
66        if let Some(TypeOverride::Rename { name, .. }) =
67            self.overrides.get(op.method, &op.path, kind)
68        {
69            name.clone()
70        } else {
71            format_ident!("{}", default_name)
72        }
73    }
74
75    /// Get the variant name for a status code, applying any rename override.
76    fn get_variant_name(&self, op: &Operation, status: u16, kind: GeneratedTypeKind) -> syn::Ident {
77        if let Some(TypeOverride::Rename {
78            variant_overrides, ..
79        }) = self.overrides.get(op.method, &op.path, kind)
80            && let Some(ov) = variant_overrides.get(&status)
81        {
82            return ov.name.clone();
83        }
84        format_ident!("Status{}", status)
85    }
86
87    /// Get the inner type name override for a status code, if specified.
88    fn get_inner_type_name(
89        &self,
90        op: &Operation,
91        status: u16,
92        kind: GeneratedTypeKind,
93    ) -> Option<syn::Ident> {
94        if let Some(TypeOverride::Rename {
95            variant_overrides, ..
96        }) = self.overrides.get(op.method, &op.path, kind)
97            && let Some(ov) = variant_overrides.get(&status)
98        {
99            return ov.inner_type_name.clone();
100        }
101        None
102    }
103
104    /// Get the variant attributes for a status code.
105    fn get_variant_attrs(
106        &self,
107        op: &Operation,
108        status: u16,
109        kind: GeneratedTypeKind,
110    ) -> Vec<TokenStream> {
111        if let Some(TypeOverride::Rename {
112            variant_overrides, ..
113        }) = self.overrides.get(op.method, &op.path, kind)
114            && let Some(ov) = variant_overrides.get(&status)
115        {
116            return ov.attrs.clone();
117        }
118        Vec::new()
119    }
120
121    /// Validate that all variant overrides reference status codes that exist in the spec.
122    /// Returns compile errors for any invalid status codes.
123    fn validate_variant_overrides(&self, op: &Operation, kind: GeneratedTypeKind) -> TokenStream {
124        let Some(TypeOverride::Rename {
125            variant_overrides, ..
126        }) = self.overrides.get(op.method, &op.path, kind)
127        else {
128            return quote! {};
129        };
130
131        // Collect valid status codes from the spec
132        let valid_codes: std::collections::HashSet<u16> = op
133            .responses
134            .iter()
135            .filter_map(|r| match (&r.status_code, kind) {
136                (ResponseStatus::Code(code), GeneratedTypeKind::Ok)
137                    if r.status_code.is_success() =>
138                {
139                    Some(*code)
140                }
141                (ResponseStatus::Code(code), GeneratedTypeKind::Err)
142                    if r.status_code.is_error() =>
143                {
144                    Some(*code)
145                }
146                _ => None,
147            })
148            .collect();
149
150        // Check each override
151        let errors: Vec<_> = variant_overrides
152            .iter()
153            .filter(|(status, _)| !valid_codes.contains(status))
154            .map(|(status, ov)| {
155                let valid_list: Vec<_> = valid_codes.iter().collect();
156                let kind_name = match kind {
157                    GeneratedTypeKind::Ok => "success",
158                    GeneratedTypeKind::Err => "error",
159                    GeneratedTypeKind::Query => "query",
160                };
161                let msg = format!(
162                    "status code {} does not exist in the OpenAPI spec for {} {} (valid {} codes: {:?})",
163                    status, op.method, op.path, kind_name, valid_list
164                );
165                quote_spanned! { ov.name.span() =>
166                    compile_error!(#msg);
167                }
168            })
169            .collect();
170
171        quote! { #(#errors)* }
172    }
173
174    /// Get attributes for an enum. If attrs contains a derive, returns only the attrs.
175    /// Otherwise prepends the default derives.
176    fn get_enum_attrs(&self, op: &Operation, kind: GeneratedTypeKind) -> Vec<TokenStream> {
177        if let Some(TypeOverride::Rename { attrs, .. }) =
178            self.overrides.get(op.method, &op.path, kind)
179        {
180            // Check if any attr is a derive
181            let has_derive = attrs.iter().any(|a| a.to_string().starts_with("#[derive"));
182            if has_derive {
183                return attrs.clone();
184            } else {
185                // Prepend default derives
186                let mut result = vec![self.suffixes.default_derives.clone()];
187                result.extend(attrs.clone());
188                return result;
189            }
190        }
191        Vec::new()
192    }
193
194    /// Generate a response enum (Ok or Err) for an operation.
195    fn generate_response_enum(
196        &self,
197        op: &Operation,
198        op_name: &str,
199        kind: GeneratedTypeKind,
200    ) -> TokenStream {
201        // Skip if replaced
202        if self.overrides.is_replaced(op.method, &op.path, kind) {
203            return quote! {};
204        }
205
206        // Validate variant overrides reference valid status codes
207        let validation_errors = self.validate_variant_overrides(op, kind);
208
209        let suffix = match kind {
210            GeneratedTypeKind::Ok => &self.suffixes.ok_suffix,
211            GeneratedTypeKind::Err => &self.suffixes.err_suffix,
212            GeneratedTypeKind::Query => unreachable!(),
213        };
214
215        let enum_name = self.get_enum_name(op, &format!("{}{}", op_name, suffix), kind);
216
217        // Collect relevant responses based on kind
218        let responses: Vec<_> = op
219            .responses
220            .iter()
221            .filter(|r| match kind {
222                GeneratedTypeKind::Ok => r.status_code.is_success(),
223                GeneratedTypeKind::Err => r.status_code.is_error(),
224                GeneratedTypeKind::Query => false,
225            })
226            .collect();
227
228        // Handle empty responses with a default enum
229        if responses.is_empty() {
230            return self.generate_empty_fallback(&enum_name, kind);
231        }
232
233        let mut errors = Vec::new();
234        let mut inline_definitions = Vec::new();
235
236        // Process each response into variant info
237        let variants: Vec<VariantInfo> = responses
238            .iter()
239            .map(|resp| {
240                let (variant_name, status, is_default) = match &resp.status_code {
241                    ResponseStatus::Code(code) => {
242                        (self.get_variant_name(op, *code, kind), *code, false)
243                    }
244                    ResponseStatus::Default => {
245                        // Only error enums should have Default responses
246                        (format_ident!("Default"), 500, true)
247                    }
248                };
249
250                let variant_attrs = self.get_variant_attrs(op, status, kind);
251                let inner_type_override = self.get_inner_type_name(op, status, kind);
252
253                let body_type = if let Some(schema) = &resp.schema {
254                    // Validate: inner type override only allowed for inline schemas
255                    if let Some(ref inner_name) = inner_type_override
256                        && !TypeGenerator::is_inline_schema(schema)
257                    {
258                        errors.push(quote_spanned! { inner_name.span() =>
259                            compile_error!("inner type name override can only be used with inline schemas, not $ref");
260                        });
261                    }
262
263                    // Use override name or default name hint
264                    let name_hint = inner_type_override
265                        .as_ref()
266                        .map(|i| i.to_string())
267                        .unwrap_or_else(|| format!("{}{}", enum_name, variant_name));
268                    let generated = self
269                        .type_gen
270                        .type_for_schema_with_definitions(schema, &name_hint);
271                    inline_definitions.extend(generated.definitions);
272                    Some(generated.type_ref)
273                } else {
274                    None
275                };
276
277                VariantInfo {
278                    name: variant_name,
279                    body_type,
280                    status,
281                    is_default,
282                    attrs: variant_attrs,
283                }
284            })
285            .collect();
286
287        // Generate variant definitions
288        let variant_defs = variants.iter().map(|v| {
289            let name = &v.name;
290            let attrs = &v.attrs;
291            match (&v.body_type, v.is_default) {
292                (Some(ty), true) => quote! { #(#attrs)* #name(::axum::http::StatusCode, #ty) },
293                (Some(ty), false) => quote! { #(#attrs)* #name(#ty) },
294                (None, true) => quote! { #(#attrs)* #name(::axum::http::StatusCode) },
295                (None, false) => quote! { #(#attrs)* #name },
296            }
297        });
298
299        // Generate IntoResponse match arms
300        let into_response_arms = variants.iter().map(|v| {
301            let name = &v.name;
302            let status_code = status_code_ident(v.status);
303
304            match (&v.body_type, v.is_default) {
305                (Some(_), true) => quote! {
306                    Self::#name(status, body) => {
307                        (status, ::axum::Json(body)).into_response()
308                    }
309                },
310                (Some(_), false) => quote! {
311                    Self::#name(body) => {
312                        (#status_code, ::axum::Json(body)).into_response()
313                    }
314                },
315                (None, true) => quote! {
316                    Self::#name(status) => {
317                        status.into_response()
318                    }
319                },
320                (None, false) => quote! {
321                    Self::#name => {
322                        #status_code.into_response()
323                    }
324                },
325            }
326        });
327
328        let enum_attrs = self.get_enum_attrs(op, kind);
329        let default_derives = &self.suffixes.default_derives;
330
331        // If we have override attrs, use them; otherwise use default derives
332        let attrs_tokens = if enum_attrs.is_empty() {
333            quote! { #default_derives }
334        } else {
335            quote! { #(#enum_attrs)* }
336        };
337
338        // Use the enum name's span for the definition so error messages
339        // point to the user's enum definition, not the macro attribute
340        let enum_span = enum_name.span();
341
342        let enum_def = quote_spanned! { enum_span =>
343            #attrs_tokens
344            pub enum #enum_name {
345                #(#variant_defs,)*
346            }
347        };
348
349        quote! {
350            #validation_errors
351
352            #(#errors)*
353
354            #(#inline_definitions)*
355
356            #enum_def
357
358            impl ::axum::response::IntoResponse for #enum_name {
359                fn into_response(self) -> ::axum::response::Response {
360                    use ::axum::response::IntoResponse;
361                    match self {
362                        #(#into_response_arms)*
363                    }
364                }
365            }
366        }
367    }
368
369    /// Generate a fallback enum when no responses are defined in the spec.
370    /// For Ok: generates a default Status200 variant.
371    /// For Err: returns empty (no error type generated when spec has no errors).
372    fn generate_empty_fallback(
373        &self,
374        enum_name: &syn::Ident,
375        kind: GeneratedTypeKind,
376    ) -> TokenStream {
377        let span = enum_name.span();
378        match kind {
379            GeneratedTypeKind::Ok => quote_spanned! { span =>
380                #[derive(Debug)]
381                pub enum #enum_name {
382                    Status200,
383                }
384
385                impl ::axum::response::IntoResponse for #enum_name {
386                    fn into_response(self) -> ::axum::response::Response {
387                        match self {
388                            Self::Status200 => {
389                                ::axum::http::StatusCode::OK.into_response()
390                            }
391                        }
392                    }
393                }
394            },
395            // No error type generated when the spec defines no error responses
396            GeneratedTypeKind::Err => quote! {},
397            GeneratedTypeKind::Query => unreachable!(),
398        }
399    }
400}
401
402/// Information about a response variant.
403struct VariantInfo {
404    name: syn::Ident,
405    body_type: Option<TokenStream>,
406    status: u16,
407    is_default: bool,
408    attrs: Vec<TokenStream>,
409}
410
411/// Generate a StatusCode expression for a numeric status.
412fn status_code_ident(status: u16) -> TokenStream {
413    // Use well-known constants where possible
414    let name = match status {
415        200 => Some("OK"),
416        201 => Some("CREATED"),
417        202 => Some("ACCEPTED"),
418        204 => Some("NO_CONTENT"),
419        301 => Some("MOVED_PERMANENTLY"),
420        302 => Some("FOUND"),
421        304 => Some("NOT_MODIFIED"),
422        400 => Some("BAD_REQUEST"),
423        401 => Some("UNAUTHORIZED"),
424        403 => Some("FORBIDDEN"),
425        404 => Some("NOT_FOUND"),
426        405 => Some("METHOD_NOT_ALLOWED"),
427        409 => Some("CONFLICT"),
428        410 => Some("GONE"),
429        422 => Some("UNPROCESSABLE_ENTITY"),
430        429 => Some("TOO_MANY_REQUESTS"),
431        500 => Some("INTERNAL_SERVER_ERROR"),
432        501 => Some("NOT_IMPLEMENTED"),
433        502 => Some("BAD_GATEWAY"),
434        503 => Some("SERVICE_UNAVAILABLE"),
435        504 => Some("GATEWAY_TIMEOUT"),
436        _ => None,
437    };
438
439    if let Some(name) = name {
440        let ident = format_ident!("{}", name);
441        quote! { ::axum::http::StatusCode::#ident }
442    } else {
443        quote! { ::axum::http::StatusCode::from_u16(#status).unwrap() }
444    }
445}