1use proc_macro2::TokenStream;
4use quote::{format_ident, quote, quote_spanned};
5
6use crate::openapi::{Operation, ParsedSpec, ResponseStatus};
7use crate::types::TypeGenerator;
8use crate::{GeneratedTypeKind, ResponseSuffixes, TypeOverride, TypeOverrides};
9
10pub struct ResponseGenerator<'a> {
12 spec: &'a ParsedSpec,
13 type_gen: &'a TypeGenerator,
14 overrides: &'a TypeOverrides,
15 suffixes: &'a ResponseSuffixes,
16}
17
18impl<'a> ResponseGenerator<'a> {
19 pub fn new(
20 spec: &'a ParsedSpec,
21 type_gen: &'a TypeGenerator,
22 overrides: &'a TypeOverrides,
23 suffixes: &'a ResponseSuffixes,
24 ) -> Self {
25 Self {
26 spec,
27 type_gen,
28 overrides,
29 suffixes,
30 }
31 }
32
33 fn response_matches_kind(status: &ResponseStatus, kind: GeneratedTypeKind) -> bool {
35 match kind {
36 GeneratedTypeKind::Ok => status.is_success(),
37 GeneratedTypeKind::Err => status.is_error(),
38 GeneratedTypeKind::Query | GeneratedTypeKind::Path => false,
39 }
40 }
41
42 pub fn generate_all(&self) -> TokenStream {
44 let enums: Vec<_> = self
45 .spec
46 .operations()
47 .map(|op| self.generate_for_operation(op))
48 .collect();
49
50 quote! {
51 #(#enums)*
52 }
53 }
54
55 pub fn generate_for_operation(&self, op: &Operation) -> TokenStream {
57 let op_name = op.name();
58
59 let ok_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Ok);
60 let err_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Err);
61
62 quote! {
63 #ok_enum
64 #err_enum
65 }
66 }
67
68 fn get_enum_name(
70 &self,
71 op: &Operation,
72 default_name: &str,
73 kind: GeneratedTypeKind,
74 ) -> syn::Ident {
75 if let Some(TypeOverride::Rename { name, .. }) =
76 self.overrides.get(op.method, &op.path, kind)
77 {
78 name.clone()
79 } else {
80 format_ident!("{}", default_name)
81 }
82 }
83
84 fn get_variant_override<T, F>(
86 &self,
87 op: &Operation,
88 status: u16,
89 kind: GeneratedTypeKind,
90 getter: F,
91 default: T,
92 ) -> T
93 where
94 F: FnOnce(&crate::VariantOverride) -> T,
95 {
96 if let Some(TypeOverride::Rename {
97 variant_overrides, ..
98 }) = self.overrides.get(op.method, &op.path, kind)
99 && let Some(ov) = variant_overrides.get(&status)
100 {
101 return getter(ov);
102 }
103 default
104 }
105
106 fn get_variant_name(&self, op: &Operation, status: u16, kind: GeneratedTypeKind) -> syn::Ident {
108 self.get_variant_override(
109 op,
110 status,
111 kind,
112 |ov| ov.name.clone(),
113 format_ident!("Status{}", status),
114 )
115 }
116
117 fn get_inner_type_name(
119 &self,
120 op: &Operation,
121 status: u16,
122 kind: GeneratedTypeKind,
123 ) -> Option<syn::Ident> {
124 self.get_variant_override(op, status, kind, |ov| ov.inner_type_name.clone(), None)
125 }
126
127 fn get_variant_attrs(
129 &self,
130 op: &Operation,
131 status: u16,
132 kind: GeneratedTypeKind,
133 ) -> Vec<TokenStream> {
134 self.get_variant_override(op, status, kind, |ov| ov.attrs.clone(), Vec::new())
135 }
136
137 fn validate_variant_overrides(&self, op: &Operation, kind: GeneratedTypeKind) -> TokenStream {
140 let Some(TypeOverride::Rename {
141 variant_overrides, ..
142 }) = self.overrides.get(op.method, &op.path, kind)
143 else {
144 return quote! {};
145 };
146
147 let valid_codes: std::collections::HashSet<u16> = op
149 .responses
150 .iter()
151 .filter(|r| Self::response_matches_kind(&r.status_code, kind))
152 .filter_map(|r| match &r.status_code {
153 ResponseStatus::Code(code) => Some(*code),
154 _ => None,
155 })
156 .collect();
157
158 let errors: Vec<_> = variant_overrides
160 .iter()
161 .filter(|(status, _)| !valid_codes.contains(status))
162 .map(|(status, ov)| {
163 let valid_list: Vec<_> = valid_codes.iter().collect();
164 let kind_name = match kind {
165 GeneratedTypeKind::Ok => "success",
166 GeneratedTypeKind::Err => "error",
167 GeneratedTypeKind::Query => "query",
168 GeneratedTypeKind::Path => "path",
169 };
170 let msg = format!(
171 "status code {} does not exist in the OpenAPI spec for {} {} (valid {} codes: {:?})",
172 status, op.method, op.path, kind_name, valid_list
173 );
174 quote_spanned! { ov.name.span() =>
175 compile_error!(#msg);
176 }
177 })
178 .collect();
179
180 quote! { #(#errors)* }
181 }
182
183 fn get_enum_attrs(&self, op: &Operation, kind: GeneratedTypeKind) -> Vec<TokenStream> {
186 if let Some(TypeOverride::Rename { attrs, .. }) =
187 self.overrides.get(op.method, &op.path, kind)
188 {
189 let has_derive = attrs.iter().any(|a| a.to_string().starts_with("#[derive"));
191 if has_derive {
192 return attrs.clone();
193 } else {
194 let mut result = vec![self.suffixes.default_derives.clone()];
196 result.extend(attrs.clone());
197 return result;
198 }
199 }
200 Vec::new()
201 }
202
203 fn collect_variants(
205 &self,
206 op: &Operation,
207 kind: GeneratedTypeKind,
208 enum_name: &syn::Ident,
209 responses: &[&crate::openapi::OperationResponse],
210 ) -> CollectedVariants {
211 let mut errors = Vec::new();
212 let mut inline_definitions = Vec::new();
213
214 let variants = responses
215 .iter()
216 .map(|resp| {
217 let (variant_name, status, is_default) = match &resp.status_code {
218 ResponseStatus::Code(code) => {
219 (self.get_variant_name(op, *code, kind), *code, false)
220 }
221 ResponseStatus::Default => {
222 (format_ident!("Default"), 500, true)
224 }
225 };
226
227 let variant_attrs = self.get_variant_attrs(op, status, kind);
228 let inner_type_override = self.get_inner_type_name(op, status, kind);
229
230 let body_type = if let Some(schema) = &resp.schema {
231 if let Some(ref inner_name) = inner_type_override
233 && !TypeGenerator::is_inline_schema(schema)
234 {
235 errors.push(quote_spanned! { inner_name.span() =>
236 compile_error!("inner type name override can only be used with inline schemas, not $ref");
237 });
238 }
239
240 let name_hint = inner_type_override
242 .as_ref()
243 .map(|i| i.to_string())
244 .unwrap_or_else(|| format!("{}{}", enum_name, variant_name));
245 let generated = self
246 .type_gen
247 .type_for_schema_with_definitions(schema, &name_hint);
248 inline_definitions.extend(generated.definitions);
249 Some(generated.type_ref)
250 } else {
251 None
252 };
253
254 VariantInfo {
255 name: variant_name,
256 body_type,
257 status,
258 is_default,
259 attrs: variant_attrs,
260 }
261 })
262 .collect();
263
264 CollectedVariants {
265 variants,
266 errors,
267 inline_definitions,
268 }
269 }
270
271 fn generate_response_enum(
273 &self,
274 op: &Operation,
275 op_name: &str,
276 kind: GeneratedTypeKind,
277 ) -> TokenStream {
278 if self.overrides.is_replaced(op.method, &op.path, kind) {
280 return quote! {};
281 }
282
283 let validation_errors = self.validate_variant_overrides(op, kind);
285
286 let suffix = match kind {
287 GeneratedTypeKind::Ok => &self.suffixes.ok_suffix,
288 GeneratedTypeKind::Err => &self.suffixes.err_suffix,
289 GeneratedTypeKind::Query | GeneratedTypeKind::Path => unreachable!(),
290 };
291
292 let enum_name = self.get_enum_name(op, &format!("{}{}", op_name, suffix), kind);
293
294 let responses: Vec<_> = op
296 .responses
297 .iter()
298 .filter(|r| Self::response_matches_kind(&r.status_code, kind))
299 .collect();
300
301 if responses.is_empty() {
303 return self.generate_empty_fallback(&enum_name, kind);
304 }
305
306 let CollectedVariants {
308 variants,
309 errors,
310 inline_definitions,
311 } = self.collect_variants(op, kind, &enum_name, &responses);
312
313 let variant_defs = variants.iter().map(|v| {
315 let name = &v.name;
316 let attrs = &v.attrs;
317 match (&v.body_type, v.is_default) {
318 (Some(ty), true) => quote! { #(#attrs)* #name(::axum::http::StatusCode, #ty) },
319 (Some(ty), false) => quote! { #(#attrs)* #name(#ty) },
320 (None, true) => quote! { #(#attrs)* #name(::axum::http::StatusCode) },
321 (None, false) => quote! { #(#attrs)* #name },
322 }
323 });
324
325 let into_response_arms = variants.iter().map(|v| {
327 let name = &v.name;
328 let status_code = status_code_ident(v.status);
329
330 match (&v.body_type, v.is_default) {
331 (Some(_), true) => quote! {
332 Self::#name(status, body) => {
333 (status, ::axum::Json(body)).into_response()
334 }
335 },
336 (Some(_), false) => quote! {
337 Self::#name(body) => {
338 (#status_code, ::axum::Json(body)).into_response()
339 }
340 },
341 (None, true) => quote! {
342 Self::#name(status) => {
343 status.into_response()
344 }
345 },
346 (None, false) => quote! {
347 Self::#name => {
348 #status_code.into_response()
349 }
350 },
351 }
352 });
353
354 let enum_attrs = self.get_enum_attrs(op, kind);
355 let default_derives = &self.suffixes.default_derives;
356
357 let attrs_tokens = if enum_attrs.is_empty() {
359 quote! { #default_derives }
360 } else {
361 quote! { #(#enum_attrs)* }
362 };
363
364 let enum_span = enum_name.span();
367
368 let enum_def = quote_spanned! { enum_span =>
369 #attrs_tokens
370 pub enum #enum_name {
371 #(#variant_defs,)*
372 }
373 };
374
375 quote! {
376 #validation_errors
377
378 #(#errors)*
379
380 #(#inline_definitions)*
381
382 #enum_def
383
384 impl ::axum::response::IntoResponse for #enum_name {
385 fn into_response(self) -> ::axum::response::Response {
386 use ::axum::response::IntoResponse;
387 match self {
388 #(#into_response_arms)*
389 }
390 }
391 }
392 }
393 }
394
395 fn generate_empty_fallback(
399 &self,
400 enum_name: &syn::Ident,
401 kind: GeneratedTypeKind,
402 ) -> TokenStream {
403 let span = enum_name.span();
404 match kind {
405 GeneratedTypeKind::Ok => quote_spanned! { span =>
406 #[derive(Debug)]
407 pub enum #enum_name {
408 Status200,
409 }
410
411 impl ::axum::response::IntoResponse for #enum_name {
412 fn into_response(self) -> ::axum::response::Response {
413 match self {
414 Self::Status200 => {
415 ::axum::http::StatusCode::OK.into_response()
416 }
417 }
418 }
419 }
420 },
421 GeneratedTypeKind::Err => quote! {},
423 GeneratedTypeKind::Query | GeneratedTypeKind::Path => unreachable!(),
424 }
425 }
426}
427
428struct VariantInfo {
430 name: syn::Ident,
431 body_type: Option<TokenStream>,
432 status: u16,
433 is_default: bool,
434 attrs: Vec<TokenStream>,
435}
436
437struct CollectedVariants {
439 variants: Vec<VariantInfo>,
440 errors: Vec<TokenStream>,
441 inline_definitions: Vec<TokenStream>,
442}
443
444const STATUS_CODE_NAMES: &[(u16, &str)] = &[
446 (200, "OK"),
447 (201, "CREATED"),
448 (202, "ACCEPTED"),
449 (204, "NO_CONTENT"),
450 (301, "MOVED_PERMANENTLY"),
451 (302, "FOUND"),
452 (304, "NOT_MODIFIED"),
453 (400, "BAD_REQUEST"),
454 (401, "UNAUTHORIZED"),
455 (403, "FORBIDDEN"),
456 (404, "NOT_FOUND"),
457 (405, "METHOD_NOT_ALLOWED"),
458 (409, "CONFLICT"),
459 (410, "GONE"),
460 (422, "UNPROCESSABLE_ENTITY"),
461 (429, "TOO_MANY_REQUESTS"),
462 (500, "INTERNAL_SERVER_ERROR"),
463 (501, "NOT_IMPLEMENTED"),
464 (502, "BAD_GATEWAY"),
465 (503, "SERVICE_UNAVAILABLE"),
466 (504, "GATEWAY_TIMEOUT"),
467];
468
469fn status_code_ident(status: u16) -> TokenStream {
471 if let Some((_, name)) = STATUS_CODE_NAMES.iter().find(|(code, _)| *code == status) {
473 let ident = format_ident!("{}", name);
474 quote! { ::axum::http::StatusCode::#ident }
475 } else {
476 quote! { ::axum::http::StatusCode::from_u16(#status).unwrap() }
477 }
478}