Skip to main content

oxapi_impl/
responses.rs

1//! Response enum generation.
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::openapi::{Operation, ParsedSpec, ResponseStatus};
8use crate::types::TypeGenerator;
9
10/// Generator for response enums.
11pub struct ResponseGenerator<'a> {
12    spec: &'a ParsedSpec,
13    type_gen: &'a TypeGenerator,
14}
15
16impl<'a> ResponseGenerator<'a> {
17    pub fn new(spec: &'a ParsedSpec, type_gen: &'a TypeGenerator) -> Self {
18        Self { spec, type_gen }
19    }
20
21    /// Generate response enums for all operations.
22    pub fn generate_all(&self) -> TokenStream {
23        let enums: Vec<_> = self
24            .spec
25            .operations()
26            .map(|op| self.generate_for_operation(op))
27            .collect();
28
29        quote! {
30            #(#enums)*
31        }
32    }
33
34    /// Generate Ok and Err enums for a single operation.
35    pub fn generate_for_operation(&self, op: &Operation) -> TokenStream {
36        let op_name = op
37            .operation_id
38            .as_deref()
39            .unwrap_or(&op.path)
40            .to_upper_camel_case();
41
42        let ok_enum = self.generate_ok_enum(op, &op_name);
43        let err_enum = self.generate_err_enum(op, &op_name);
44
45        quote! {
46            #ok_enum
47            #err_enum
48        }
49    }
50
51    /// Generate the Ok (success) enum for an operation.
52    fn generate_ok_enum(&self, op: &Operation, op_name: &str) -> TokenStream {
53        let enum_name = format_ident!("{}Ok", op_name);
54
55        // Collect success responses (2xx status codes)
56        let success_responses: Vec<_> = op
57            .responses
58            .iter()
59            .filter(|r| r.status_code.is_success())
60            .collect();
61
62        if success_responses.is_empty() {
63            // No explicit success responses, create a default
64            return quote! {
65                #[derive(Debug)]
66                pub enum #enum_name {
67                    Status200,
68                }
69
70                impl ::axum::response::IntoResponse for #enum_name {
71                    fn into_response(self) -> ::axum::response::Response {
72                        match self {
73                            Self::Status200 => {
74                                ::axum::http::StatusCode::OK.into_response()
75                            }
76                        }
77                    }
78                }
79            };
80        }
81
82        let variants: Vec<_> = success_responses
83            .iter()
84            .map(|resp| {
85                let status = match &resp.status_code {
86                    ResponseStatus::Code(code) => *code,
87                    ResponseStatus::Default => 200,
88                };
89                let variant_name = format_ident!("Status{}", status);
90
91                if let Some(schema) = &resp.schema {
92                    let ty = self
93                        .type_gen
94                        .type_for_schema(schema, &format!("{}Response{}", op_name, status));
95                    (variant_name, Some(ty), status)
96                } else {
97                    (variant_name, None, status)
98                }
99            })
100            .collect();
101
102        let variant_defs = variants.iter().map(|(name, ty, _)| {
103            if let Some(ty) = ty {
104                quote! { #name(#ty) }
105            } else {
106                quote! { #name }
107            }
108        });
109
110        let into_response_arms = variants.iter().map(|(name, ty, status)| {
111            let status_code = status_code_ident(*status);
112
113            if ty.is_some() {
114                quote! {
115                    Self::#name(body) => {
116                        (
117                            #status_code,
118                            ::axum::Json(body),
119                        ).into_response()
120                    }
121                }
122            } else {
123                quote! {
124                    Self::#name => {
125                        #status_code.into_response()
126                    }
127                }
128            }
129        });
130
131        quote! {
132            #[derive(Debug)]
133            pub enum #enum_name {
134                #(#variant_defs,)*
135            }
136
137            impl ::axum::response::IntoResponse for #enum_name {
138                fn into_response(self) -> ::axum::response::Response {
139                    use ::axum::response::IntoResponse;
140                    match self {
141                        #(#into_response_arms)*
142                    }
143                }
144            }
145        }
146    }
147
148    /// Generate the Err (error) enum for an operation.
149    fn generate_err_enum(&self, op: &Operation, op_name: &str) -> TokenStream {
150        let enum_name = format_ident!("{}Err", op_name);
151
152        // Collect error responses (4xx, 5xx, and default)
153        let error_responses: Vec<_> = op
154            .responses
155            .iter()
156            .filter(|r| r.status_code.is_error())
157            .collect();
158
159        if error_responses.is_empty() {
160            // No explicit error responses, create a default
161            return quote! {
162                #[derive(Debug)]
163                pub enum #enum_name {
164                    Status500(String),
165                }
166
167                impl ::axum::response::IntoResponse for #enum_name {
168                    fn into_response(self) -> ::axum::response::Response {
169                        use ::axum::response::IntoResponse;
170                        match self {
171                            Self::Status500(msg) => {
172                                (
173                                    ::axum::http::StatusCode::INTERNAL_SERVER_ERROR,
174                                    msg,
175                                ).into_response()
176                            }
177                        }
178                    }
179                }
180            };
181        }
182
183        let variants: Vec<_> = error_responses
184            .iter()
185            .map(|resp| {
186                let (variant_name, status) = match &resp.status_code {
187                    ResponseStatus::Code(code) => (format_ident!("Status{}", code), *code),
188                    ResponseStatus::Default => (format_ident!("Default"), 500),
189                };
190
191                if let Some(schema) = &resp.schema {
192                    let ty = self
193                        .type_gen
194                        .type_for_schema(schema, &format!("{}Error{}", op_name, status));
195                    (
196                        variant_name,
197                        Some(ty),
198                        status,
199                        resp.status_code == ResponseStatus::Default,
200                    )
201                } else {
202                    (
203                        variant_name,
204                        None,
205                        status,
206                        resp.status_code == ResponseStatus::Default,
207                    )
208                }
209            })
210            .collect();
211
212        let variant_defs = variants.iter().map(|(name, ty, _, is_default)| {
213            if let Some(ty) = ty {
214                if *is_default {
215                    // Default needs to carry the status code
216                    quote! { #name(::axum::http::StatusCode, #ty) }
217                } else {
218                    quote! { #name(#ty) }
219                }
220            } else if *is_default {
221                quote! { #name(::axum::http::StatusCode) }
222            } else {
223                quote! { #name }
224            }
225        });
226
227        let into_response_arms = variants.iter().map(|(name, ty, status, is_default)| {
228            let status_code = status_code_ident(*status);
229
230            if *is_default {
231                if ty.is_some() {
232                    quote! {
233                        Self::#name(status, body) => {
234                            (status, ::axum::Json(body)).into_response()
235                        }
236                    }
237                } else {
238                    quote! {
239                        Self::#name(status) => {
240                            status.into_response()
241                        }
242                    }
243                }
244            } else if ty.is_some() {
245                quote! {
246                    Self::#name(body) => {
247                        (
248                            #status_code,
249                            ::axum::Json(body),
250                        ).into_response()
251                    }
252                }
253            } else {
254                quote! {
255                    Self::#name => {
256                        #status_code.into_response()
257                    }
258                }
259            }
260        });
261
262        quote! {
263            #[derive(Debug)]
264            pub enum #enum_name {
265                #(#variant_defs,)*
266            }
267
268            impl ::axum::response::IntoResponse for #enum_name {
269                fn into_response(self) -> ::axum::response::Response {
270                    use ::axum::response::IntoResponse;
271                    match self {
272                        #(#into_response_arms)*
273                    }
274                }
275            }
276        }
277    }
278}
279
280/// Generate a StatusCode expression for a numeric status.
281fn status_code_ident(status: u16) -> TokenStream {
282    // Use well-known constants where possible
283    let name = match status {
284        200 => Some("OK"),
285        201 => Some("CREATED"),
286        202 => Some("ACCEPTED"),
287        204 => Some("NO_CONTENT"),
288        301 => Some("MOVED_PERMANENTLY"),
289        302 => Some("FOUND"),
290        304 => Some("NOT_MODIFIED"),
291        400 => Some("BAD_REQUEST"),
292        401 => Some("UNAUTHORIZED"),
293        403 => Some("FORBIDDEN"),
294        404 => Some("NOT_FOUND"),
295        405 => Some("METHOD_NOT_ALLOWED"),
296        409 => Some("CONFLICT"),
297        410 => Some("GONE"),
298        422 => Some("UNPROCESSABLE_ENTITY"),
299        429 => Some("TOO_MANY_REQUESTS"),
300        500 => Some("INTERNAL_SERVER_ERROR"),
301        501 => Some("NOT_IMPLEMENTED"),
302        502 => Some("BAD_GATEWAY"),
303        503 => Some("SERVICE_UNAVAILABLE"),
304        504 => Some("GATEWAY_TIMEOUT"),
305        _ => None,
306    };
307
308    if let Some(name) = name {
309        let ident = format_ident!("{}", name);
310        quote! { ::axum::http::StatusCode::#ident }
311    } else {
312        quote! { ::axum::http::StatusCode::from_u16(#status).unwrap() }
313    }
314}