1use heck::AsSnakeCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote, quote_spanned};
6
7use crate::openapi::{Operation, ParsedSpec, ResponseHeader, ResponseStatus};
8use crate::types::TypeGenerator;
9use crate::{GeneratedTypeKind, ResponseSuffixes, TypeOverride, TypeOverrides};
10
11pub struct ResponseGenerator<'a> {
13 spec: &'a ParsedSpec,
14 type_gen: &'a TypeGenerator,
15 overrides: &'a TypeOverrides,
16 suffixes: &'a ResponseSuffixes,
17}
18
19impl<'a> ResponseGenerator<'a> {
20 pub fn new(
21 spec: &'a ParsedSpec,
22 type_gen: &'a TypeGenerator,
23 overrides: &'a TypeOverrides,
24 suffixes: &'a ResponseSuffixes,
25 ) -> Self {
26 Self {
27 spec,
28 type_gen,
29 overrides,
30 suffixes,
31 }
32 }
33
34 fn response_matches_kind(status: &ResponseStatus, kind: GeneratedTypeKind) -> bool {
36 match kind {
37 GeneratedTypeKind::Ok => status.is_success(),
38 GeneratedTypeKind::Err => status.is_error(),
39 GeneratedTypeKind::Query | GeneratedTypeKind::Path => false,
40 }
41 }
42
43 pub fn generate_all(&self) -> TokenStream {
45 let enums: Vec<_> = self
46 .spec
47 .operations()
48 .map(|op| self.generate_for_operation(op))
49 .collect();
50
51 quote! {
52 #(#enums)*
53 }
54 }
55
56 pub fn generate_for_operation(&self, op: &Operation) -> TokenStream {
58 let op_name = op.name();
59
60 let ok_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Ok);
61 let err_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Err);
62
63 quote! {
64 #ok_enum
65 #err_enum
66 }
67 }
68
69 fn get_enum_name(
71 &self,
72 op: &Operation,
73 default_name: &str,
74 kind: GeneratedTypeKind,
75 ) -> syn::Ident {
76 if let Some(TypeOverride::Rename { name, .. }) =
77 self.overrides.get(op.method, &op.path, kind)
78 {
79 name.clone()
80 } else {
81 format_ident!("{}", default_name)
82 }
83 }
84
85 fn get_variant_override<T, F>(
87 &self,
88 op: &Operation,
89 status: u16,
90 kind: GeneratedTypeKind,
91 getter: F,
92 default: T,
93 ) -> T
94 where
95 F: FnOnce(&crate::VariantOverride) -> T,
96 {
97 if let Some(TypeOverride::Rename {
98 variant_overrides, ..
99 }) = self.overrides.get(op.method, &op.path, kind)
100 && let Some(ov) = variant_overrides.get(&status)
101 {
102 return getter(ov);
103 }
104 default
105 }
106
107 fn get_variant_name(&self, op: &Operation, status: u16, kind: GeneratedTypeKind) -> syn::Ident {
109 self.get_variant_override(
110 op,
111 status,
112 kind,
113 |ov| ov.name.clone(),
114 format_ident!("Status{}", status),
115 )
116 }
117
118 fn get_inner_type_name(
120 &self,
121 op: &Operation,
122 status: u16,
123 kind: GeneratedTypeKind,
124 ) -> Option<syn::Ident> {
125 self.get_variant_override(op, status, kind, |ov| ov.inner_type_name.clone(), None)
126 }
127
128 fn get_variant_attrs(
130 &self,
131 op: &Operation,
132 status: u16,
133 kind: GeneratedTypeKind,
134 ) -> Vec<TokenStream> {
135 self.get_variant_override(op, status, kind, |ov| ov.attrs.clone(), Vec::new())
136 }
137
138 fn validate_variant_overrides(&self, op: &Operation, kind: GeneratedTypeKind) -> TokenStream {
141 let Some(TypeOverride::Rename {
142 variant_overrides, ..
143 }) = self.overrides.get(op.method, &op.path, kind)
144 else {
145 return quote! {};
146 };
147
148 let valid_codes: std::collections::HashSet<u16> = op
150 .responses
151 .iter()
152 .filter(|r| Self::response_matches_kind(&r.status_code, kind))
153 .filter_map(|r| match &r.status_code {
154 ResponseStatus::Code(code) => Some(*code),
155 _ => None,
156 })
157 .collect();
158
159 let errors: Vec<_> = variant_overrides
161 .iter()
162 .filter(|(status, _)| !valid_codes.contains(status))
163 .map(|(status, ov)| {
164 let valid_list: Vec<_> = valid_codes.iter().collect();
165 let kind_name = match kind {
166 GeneratedTypeKind::Ok => "success",
167 GeneratedTypeKind::Err => "error",
168 GeneratedTypeKind::Query => "query",
169 GeneratedTypeKind::Path => "path",
170 };
171 let msg = format!(
172 "status code {} does not exist in the OpenAPI spec for {} {} (valid {} codes: {:?})",
173 status, op.method, op.path, kind_name, valid_list
174 );
175 quote_spanned! { ov.name.span() =>
176 compile_error!(#msg);
177 }
178 })
179 .collect();
180
181 quote! { #(#errors)* }
182 }
183
184 fn get_enum_attrs(&self, op: &Operation, kind: GeneratedTypeKind) -> Vec<TokenStream> {
187 if let Some(TypeOverride::Rename { attrs, .. }) =
188 self.overrides.get(op.method, &op.path, kind)
189 {
190 let has_derive = attrs.iter().any(|a| a.to_string().starts_with("#[derive"));
192 if has_derive {
193 return attrs.clone();
194 } else {
195 let mut result = vec![self.suffixes.default_derives.clone()];
197 result.extend(attrs.clone());
198 return result;
199 }
200 }
201 Vec::new()
202 }
203
204 fn collect_variants(
206 &self,
207 op: &Operation,
208 kind: GeneratedTypeKind,
209 enum_name: &syn::Ident,
210 responses: &[&crate::openapi::OperationResponse],
211 ) -> CollectedVariants {
212 let mut errors = Vec::new();
213 let mut inline_definitions = Vec::new();
214
215 let variants = responses
216 .iter()
217 .map(|resp| {
218 let (variant_name, status, is_default) = match &resp.status_code {
219 ResponseStatus::Code(code) => {
220 (self.get_variant_name(op, *code, kind), *code, false)
221 }
222 ResponseStatus::Default => {
223 (format_ident!("Default"), 500, true)
225 }
226 };
227
228 let variant_attrs = self.get_variant_attrs(op, status, kind);
229 let inner_type_override = self.get_inner_type_name(op, status, kind);
230
231 let body_type = if let Some(schema) = &resp.schema {
232 if let Some(ref inner_name) = inner_type_override
234 && !TypeGenerator::is_inline_schema(schema)
235 {
236 errors.push(quote_spanned! { inner_name.span() =>
237 compile_error!("inner type name override can only be used with inline schemas, not $ref");
238 });
239 }
240
241 let name_hint = inner_type_override
243 .as_ref()
244 .map(|i| i.to_string())
245 .unwrap_or_else(|| format!("{}{}", enum_name, variant_name));
246 let generated = self
247 .type_gen
248 .type_for_schema_with_definitions(schema, &name_hint);
249 inline_definitions.extend(generated.definitions);
250 Some(generated.type_ref)
251 } else {
252 None
253 };
254
255 VariantInfo {
256 name: variant_name,
257 body_type,
258 status,
259 is_default,
260 attrs: variant_attrs,
261 headers: resp.headers.clone(),
262 }
263 })
264 .collect();
265
266 CollectedVariants {
267 variants,
268 errors,
269 inline_definitions,
270 }
271 }
272
273 fn generate_response_enum(
275 &self,
276 op: &Operation,
277 op_name: &str,
278 kind: GeneratedTypeKind,
279 ) -> TokenStream {
280 if self.overrides.is_replaced(op.method, &op.path, kind) {
282 return quote! {};
283 }
284
285 let validation_errors = self.validate_variant_overrides(op, kind);
287
288 let suffix = match kind {
289 GeneratedTypeKind::Ok => &self.suffixes.ok_suffix,
290 GeneratedTypeKind::Err => &self.suffixes.err_suffix,
291 GeneratedTypeKind::Query | GeneratedTypeKind::Path => unreachable!(),
292 };
293
294 let enum_name = self.get_enum_name(op, &format!("{}{}", op_name, suffix), kind);
295
296 let responses: Vec<_> = op
298 .responses
299 .iter()
300 .filter(|r| Self::response_matches_kind(&r.status_code, kind))
301 .collect();
302
303 if responses.is_empty() {
305 return self.generate_empty_fallback(&enum_name, kind);
306 }
307
308 let CollectedVariants {
310 variants,
311 errors,
312 inline_definitions,
313 } = self.collect_variants(op, kind, &enum_name, &responses);
314
315 let header_structs: Vec<_> = variants
317 .iter()
318 .filter(|v| !v.headers.is_empty())
319 .map(|v| self.generate_header_struct(&enum_name, v))
320 .collect();
321
322 let variant_defs = variants.iter().map(|v| {
324 let name = &v.name;
325 let attrs = &v.attrs;
326 let has_headers = !v.headers.is_empty();
327
328 match (has_headers, &v.body_type, v.is_default) {
329 (true, Some(ty), true) => {
331 let hs = header_struct_ident(&enum_name, &v.name);
332 quote! { #(#attrs)* #name { headers: #hs, status: ::axum::http::StatusCode, body: #ty } }
333 }
334 (true, Some(ty), false) => {
335 let hs = header_struct_ident(&enum_name, &v.name);
336 quote! { #(#attrs)* #name { headers: #hs, body: #ty } }
337 }
338 (true, None, true) => {
339 let hs = header_struct_ident(&enum_name, &v.name);
340 quote! { #(#attrs)* #name { headers: #hs, status: ::axum::http::StatusCode } }
341 }
342 (true, None, false) => {
343 let hs = header_struct_ident(&enum_name, &v.name);
344 quote! { #(#attrs)* #name { headers: #hs } }
345 }
346 (false, Some(ty), true) => quote! { #(#attrs)* #name(::axum::http::StatusCode, #ty) },
348 (false, Some(ty), false) => quote! { #(#attrs)* #name(#ty) },
349 (false, None, true) => quote! { #(#attrs)* #name(::axum::http::StatusCode) },
350 (false, None, false) => quote! { #(#attrs)* #name },
351 }
352 });
353
354 let into_response_arms = variants.iter().map(|v| {
356 let name = &v.name;
357 let status_code = status_code_ident(v.status);
358 let has_headers = !v.headers.is_empty();
359
360 let insert_headers = if has_headers {
361 self.generate_header_insertions(&v.headers)
362 } else {
363 quote! {}
364 };
365
366 match (has_headers, &v.body_type, v.is_default) {
367 (true, Some(_), true) => quote! {
369 Self::#name { headers, status, body } => {
370 let mut response = (status, ::axum::Json(body)).into_response();
371 #insert_headers
372 response
373 }
374 },
375 (true, Some(_), false) => quote! {
376 Self::#name { headers, body } => {
377 let mut response = (#status_code, ::axum::Json(body)).into_response();
378 #insert_headers
379 response
380 }
381 },
382 (true, None, true) => quote! {
383 Self::#name { headers, status } => {
384 let mut response = status.into_response();
385 #insert_headers
386 response
387 }
388 },
389 (true, None, false) => quote! {
390 Self::#name { headers } => {
391 let mut response = #status_code.into_response();
392 #insert_headers
393 response
394 }
395 },
396 (false, Some(_), true) => quote! {
398 Self::#name(status, body) => {
399 (status, ::axum::Json(body)).into_response()
400 }
401 },
402 (false, Some(_), false) => quote! {
403 Self::#name(body) => {
404 (#status_code, ::axum::Json(body)).into_response()
405 }
406 },
407 (false, None, true) => quote! {
408 Self::#name(status) => {
409 status.into_response()
410 }
411 },
412 (false, None, false) => quote! {
413 Self::#name => {
414 #status_code.into_response()
415 }
416 },
417 }
418 });
419
420 let enum_attrs = self.get_enum_attrs(op, kind);
421 let default_derives = &self.suffixes.default_derives;
422
423 let attrs_tokens = if enum_attrs.is_empty() {
425 quote! { #default_derives }
426 } else {
427 quote! { #(#enum_attrs)* }
428 };
429
430 let enum_span = enum_name.span();
433
434 let enum_def = quote_spanned! { enum_span =>
435 #attrs_tokens
436 pub enum #enum_name {
437 #(#variant_defs,)*
438 }
439 };
440
441 quote! {
442 #validation_errors
443
444 #(#errors)*
445
446 #(#inline_definitions)*
447
448 #(#header_structs)*
449
450 #enum_def
451
452 impl ::axum::response::IntoResponse for #enum_name {
453 fn into_response(self) -> ::axum::response::Response {
454 use ::axum::response::IntoResponse;
455 match self {
456 #(#into_response_arms)*
457 }
458 }
459 }
460 }
461 }
462
463 fn generate_header_struct(&self, enum_name: &syn::Ident, variant: &VariantInfo) -> TokenStream {
465 let struct_name = header_struct_ident(enum_name, &variant.name);
466
467 let fields = variant.headers.iter().map(|h| {
468 let field_name = format_ident!("{}", AsSnakeCase(&h.name).to_string());
469 let inner_type = if let Some(schema) = &h.schema {
470 self.type_gen.type_for_schema(schema, &h.name)
471 } else {
472 quote! { String }
473 };
474
475 let field_type = if h.required {
476 inner_type
477 } else {
478 quote! { Option<#inner_type> }
479 };
480
481 quote! { pub #field_name: #field_type }
482 });
483
484 quote! {
485 #[derive(Debug, Default)]
486 pub struct #struct_name {
487 #(#fields,)*
488 }
489 }
490 }
491
492 fn generate_header_insertions(&self, headers: &[ResponseHeader]) -> TokenStream {
494 let insertions = headers.iter().map(|h| {
495 let field_name = format_ident!("{}", AsSnakeCase(&h.name).to_string());
496 let header_name = h.name.to_lowercase();
497
498 if h.required {
499 quote! {
500 response.headers_mut().insert(
501 ::axum::http::HeaderName::from_static(#header_name),
502 ::axum::http::HeaderValue::from_str(&headers.#field_name.to_string()).unwrap(),
503 );
504 }
505 } else {
506 quote! {
507 if let Some(ref v) = headers.#field_name {
508 response.headers_mut().insert(
509 ::axum::http::HeaderName::from_static(#header_name),
510 ::axum::http::HeaderValue::from_str(&v.to_string()).unwrap(),
511 );
512 }
513 }
514 }
515 });
516
517 quote! { #(#insertions)* }
518 }
519
520 fn generate_empty_fallback(
524 &self,
525 enum_name: &syn::Ident,
526 kind: GeneratedTypeKind,
527 ) -> TokenStream {
528 let span = enum_name.span();
529 match kind {
530 GeneratedTypeKind::Ok => quote_spanned! { span =>
531 #[derive(Debug)]
532 pub enum #enum_name {
533 Status200,
534 }
535
536 impl ::axum::response::IntoResponse for #enum_name {
537 fn into_response(self) -> ::axum::response::Response {
538 match self {
539 Self::Status200 => {
540 ::axum::http::StatusCode::OK.into_response()
541 }
542 }
543 }
544 }
545 },
546 GeneratedTypeKind::Err => quote! {},
548 GeneratedTypeKind::Query | GeneratedTypeKind::Path => unreachable!(),
549 }
550 }
551}
552
553struct VariantInfo {
555 name: syn::Ident,
556 body_type: Option<TokenStream>,
557 status: u16,
558 is_default: bool,
559 attrs: Vec<TokenStream>,
560 headers: Vec<ResponseHeader>,
561}
562
563struct CollectedVariants {
565 variants: Vec<VariantInfo>,
566 errors: Vec<TokenStream>,
567 inline_definitions: Vec<TokenStream>,
568}
569
570const STATUS_CODE_NAMES: &[(u16, &str)] = &[
572 (200, "OK"),
573 (201, "CREATED"),
574 (202, "ACCEPTED"),
575 (204, "NO_CONTENT"),
576 (301, "MOVED_PERMANENTLY"),
577 (302, "FOUND"),
578 (303, "SEE_OTHER"),
579 (304, "NOT_MODIFIED"),
580 (307, "TEMPORARY_REDIRECT"),
581 (308, "PERMANENT_REDIRECT"),
582 (400, "BAD_REQUEST"),
583 (401, "UNAUTHORIZED"),
584 (403, "FORBIDDEN"),
585 (404, "NOT_FOUND"),
586 (405, "METHOD_NOT_ALLOWED"),
587 (409, "CONFLICT"),
588 (410, "GONE"),
589 (422, "UNPROCESSABLE_ENTITY"),
590 (429, "TOO_MANY_REQUESTS"),
591 (500, "INTERNAL_SERVER_ERROR"),
592 (501, "NOT_IMPLEMENTED"),
593 (502, "BAD_GATEWAY"),
594 (503, "SERVICE_UNAVAILABLE"),
595 (504, "GATEWAY_TIMEOUT"),
596];
597
598fn header_struct_ident(enum_name: &syn::Ident, variant_name: &syn::Ident) -> syn::Ident {
600 format_ident!("{}{}Headers", enum_name, variant_name)
601}
602
603fn status_code_ident(status: u16) -> TokenStream {
605 if let Some((_, name)) = STATUS_CODE_NAMES.iter().find(|(code, _)| *code == status) {
607 let ident = format_ident!("{}", name);
608 quote! { ::axum::http::StatusCode::#ident }
609 } else {
610 quote! { ::axum::http::StatusCode::from_u16(#status).unwrap() }
611 }
612}