1use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::openapi::{Operation, ParsedSpec, ResponseStatus};
8use crate::types::TypeGenerator;
9
10pub struct ResponseGenerator<'a> {
12 spec: &'a ParsedSpec,
13 type_gen: &'a TypeGenerator,
14}
15
16impl<'a> ResponseGenerator<'a> {
17 pub fn new(spec: &'a ParsedSpec, type_gen: &'a TypeGenerator) -> Self {
18 Self { spec, type_gen }
19 }
20
21 pub fn generate_all(&self) -> TokenStream {
23 let enums: Vec<_> = self
24 .spec
25 .operations()
26 .map(|op| self.generate_for_operation(op))
27 .collect();
28
29 quote! {
30 #(#enums)*
31 }
32 }
33
34 pub fn generate_for_operation(&self, op: &Operation) -> TokenStream {
36 let op_name = op
37 .operation_id
38 .as_deref()
39 .unwrap_or(&op.path)
40 .to_upper_camel_case();
41
42 let ok_enum = self.generate_ok_enum(op, &op_name);
43 let err_enum = self.generate_err_enum(op, &op_name);
44
45 quote! {
46 #ok_enum
47 #err_enum
48 }
49 }
50
51 fn generate_ok_enum(&self, op: &Operation, op_name: &str) -> TokenStream {
53 let enum_name = format_ident!("{}Ok", op_name);
54
55 let success_responses: Vec<_> = op
57 .responses
58 .iter()
59 .filter(|r| r.status_code.is_success())
60 .collect();
61
62 if success_responses.is_empty() {
63 return quote! {
65 #[derive(Debug)]
66 pub enum #enum_name {
67 Status200,
68 }
69
70 impl ::axum::response::IntoResponse for #enum_name {
71 fn into_response(self) -> ::axum::response::Response {
72 match self {
73 Self::Status200 => {
74 ::axum::http::StatusCode::OK.into_response()
75 }
76 }
77 }
78 }
79 };
80 }
81
82 let variants: Vec<_> = success_responses
83 .iter()
84 .map(|resp| {
85 let status = match &resp.status_code {
86 ResponseStatus::Code(code) => *code,
87 ResponseStatus::Default => 200,
88 };
89 let variant_name = format_ident!("Status{}", status);
90
91 if let Some(schema) = &resp.schema {
92 let ty = self
93 .type_gen
94 .type_for_schema(schema, &format!("{}Response{}", op_name, status));
95 (variant_name, Some(ty), status)
96 } else {
97 (variant_name, None, status)
98 }
99 })
100 .collect();
101
102 let variant_defs = variants.iter().map(|(name, ty, _)| {
103 if let Some(ty) = ty {
104 quote! { #name(#ty) }
105 } else {
106 quote! { #name }
107 }
108 });
109
110 let into_response_arms = variants.iter().map(|(name, ty, status)| {
111 let status_code = status_code_ident(*status);
112
113 if ty.is_some() {
114 quote! {
115 Self::#name(body) => {
116 (
117 #status_code,
118 ::axum::Json(body),
119 ).into_response()
120 }
121 }
122 } else {
123 quote! {
124 Self::#name => {
125 #status_code.into_response()
126 }
127 }
128 }
129 });
130
131 quote! {
132 #[derive(Debug)]
133 pub enum #enum_name {
134 #(#variant_defs,)*
135 }
136
137 impl ::axum::response::IntoResponse for #enum_name {
138 fn into_response(self) -> ::axum::response::Response {
139 use ::axum::response::IntoResponse;
140 match self {
141 #(#into_response_arms)*
142 }
143 }
144 }
145 }
146 }
147
148 fn generate_err_enum(&self, op: &Operation, op_name: &str) -> TokenStream {
150 let enum_name = format_ident!("{}Err", op_name);
151
152 let error_responses: Vec<_> = op
154 .responses
155 .iter()
156 .filter(|r| r.status_code.is_error())
157 .collect();
158
159 if error_responses.is_empty() {
160 return quote! {
162 #[derive(Debug)]
163 pub enum #enum_name {
164 Status500(String),
165 }
166
167 impl ::axum::response::IntoResponse for #enum_name {
168 fn into_response(self) -> ::axum::response::Response {
169 use ::axum::response::IntoResponse;
170 match self {
171 Self::Status500(msg) => {
172 (
173 ::axum::http::StatusCode::INTERNAL_SERVER_ERROR,
174 msg,
175 ).into_response()
176 }
177 }
178 }
179 }
180 };
181 }
182
183 let variants: Vec<_> = error_responses
184 .iter()
185 .map(|resp| {
186 let (variant_name, status) = match &resp.status_code {
187 ResponseStatus::Code(code) => (format_ident!("Status{}", code), *code),
188 ResponseStatus::Default => (format_ident!("Default"), 500),
189 };
190
191 if let Some(schema) = &resp.schema {
192 let ty = self
193 .type_gen
194 .type_for_schema(schema, &format!("{}Error{}", op_name, status));
195 (
196 variant_name,
197 Some(ty),
198 status,
199 resp.status_code == ResponseStatus::Default,
200 )
201 } else {
202 (
203 variant_name,
204 None,
205 status,
206 resp.status_code == ResponseStatus::Default,
207 )
208 }
209 })
210 .collect();
211
212 let variant_defs = variants.iter().map(|(name, ty, _, is_default)| {
213 if let Some(ty) = ty {
214 if *is_default {
215 quote! { #name(::axum::http::StatusCode, #ty) }
217 } else {
218 quote! { #name(#ty) }
219 }
220 } else if *is_default {
221 quote! { #name(::axum::http::StatusCode) }
222 } else {
223 quote! { #name }
224 }
225 });
226
227 let into_response_arms = variants.iter().map(|(name, ty, status, is_default)| {
228 let status_code = status_code_ident(*status);
229
230 if *is_default {
231 if ty.is_some() {
232 quote! {
233 Self::#name(status, body) => {
234 (status, ::axum::Json(body)).into_response()
235 }
236 }
237 } else {
238 quote! {
239 Self::#name(status) => {
240 status.into_response()
241 }
242 }
243 }
244 } else if ty.is_some() {
245 quote! {
246 Self::#name(body) => {
247 (
248 #status_code,
249 ::axum::Json(body),
250 ).into_response()
251 }
252 }
253 } else {
254 quote! {
255 Self::#name => {
256 #status_code.into_response()
257 }
258 }
259 }
260 });
261
262 quote! {
263 #[derive(Debug)]
264 pub enum #enum_name {
265 #(#variant_defs,)*
266 }
267
268 impl ::axum::response::IntoResponse for #enum_name {
269 fn into_response(self) -> ::axum::response::Response {
270 use ::axum::response::IntoResponse;
271 match self {
272 #(#into_response_arms)*
273 }
274 }
275 }
276 }
277 }
278}
279
280fn status_code_ident(status: u16) -> TokenStream {
282 let name = match status {
284 200 => Some("OK"),
285 201 => Some("CREATED"),
286 202 => Some("ACCEPTED"),
287 204 => Some("NO_CONTENT"),
288 301 => Some("MOVED_PERMANENTLY"),
289 302 => Some("FOUND"),
290 304 => Some("NOT_MODIFIED"),
291 400 => Some("BAD_REQUEST"),
292 401 => Some("UNAUTHORIZED"),
293 403 => Some("FORBIDDEN"),
294 404 => Some("NOT_FOUND"),
295 405 => Some("METHOD_NOT_ALLOWED"),
296 409 => Some("CONFLICT"),
297 410 => Some("GONE"),
298 422 => Some("UNPROCESSABLE_ENTITY"),
299 429 => Some("TOO_MANY_REQUESTS"),
300 500 => Some("INTERNAL_SERVER_ERROR"),
301 501 => Some("NOT_IMPLEMENTED"),
302 502 => Some("BAD_GATEWAY"),
303 503 => Some("SERVICE_UNAVAILABLE"),
304 504 => Some("GATEWAY_TIMEOUT"),
305 _ => None,
306 };
307
308 if let Some(name) = name {
309 let ident = format_ident!("{}", name);
310 quote! { ::axum::http::StatusCode::#ident }
311 } else {
312 quote! { ::axum::http::StatusCode::from_u16(#status).unwrap() }
313 }
314}