1use proc_macro2::TokenStream;
4use quote::{format_ident, quote, quote_spanned};
5
6use crate::openapi::{Operation, ParsedSpec, ResponseStatus};
7use crate::types::TypeGenerator;
8use crate::{GeneratedTypeKind, ResponseSuffixes, TypeOverride, TypeOverrides};
9
10pub struct ResponseGenerator<'a> {
12 spec: &'a ParsedSpec,
13 type_gen: &'a TypeGenerator,
14 overrides: &'a TypeOverrides,
15 suffixes: &'a ResponseSuffixes,
16}
17
18impl<'a> ResponseGenerator<'a> {
19 pub fn new(
20 spec: &'a ParsedSpec,
21 type_gen: &'a TypeGenerator,
22 overrides: &'a TypeOverrides,
23 suffixes: &'a ResponseSuffixes,
24 ) -> Self {
25 Self {
26 spec,
27 type_gen,
28 overrides,
29 suffixes,
30 }
31 }
32
33 pub fn generate_all(&self) -> TokenStream {
35 let enums: Vec<_> = self
36 .spec
37 .operations()
38 .map(|op| self.generate_for_operation(op))
39 .collect();
40
41 quote! {
42 #(#enums)*
43 }
44 }
45
46 pub fn generate_for_operation(&self, op: &Operation) -> TokenStream {
48 let op_name = op.name();
49
50 let ok_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Ok);
51 let err_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Err);
52
53 quote! {
54 #ok_enum
55 #err_enum
56 }
57 }
58
59 fn get_enum_name(
61 &self,
62 op: &Operation,
63 default_name: &str,
64 kind: GeneratedTypeKind,
65 ) -> syn::Ident {
66 if let Some(TypeOverride::Rename { name, .. }) =
67 self.overrides.get(op.method, &op.path, kind)
68 {
69 name.clone()
70 } else {
71 format_ident!("{}", default_name)
72 }
73 }
74
75 fn get_variant_name(&self, op: &Operation, status: u16, kind: GeneratedTypeKind) -> syn::Ident {
77 if let Some(TypeOverride::Rename {
78 variant_overrides, ..
79 }) = self.overrides.get(op.method, &op.path, kind)
80 && let Some(ov) = variant_overrides.get(&status)
81 {
82 return ov.name.clone();
83 }
84 format_ident!("Status{}", status)
85 }
86
87 fn get_inner_type_name(
89 &self,
90 op: &Operation,
91 status: u16,
92 kind: GeneratedTypeKind,
93 ) -> Option<syn::Ident> {
94 if let Some(TypeOverride::Rename {
95 variant_overrides, ..
96 }) = self.overrides.get(op.method, &op.path, kind)
97 && let Some(ov) = variant_overrides.get(&status)
98 {
99 return ov.inner_type_name.clone();
100 }
101 None
102 }
103
104 fn get_variant_attrs(
106 &self,
107 op: &Operation,
108 status: u16,
109 kind: GeneratedTypeKind,
110 ) -> Vec<TokenStream> {
111 if let Some(TypeOverride::Rename {
112 variant_overrides, ..
113 }) = self.overrides.get(op.method, &op.path, kind)
114 && let Some(ov) = variant_overrides.get(&status)
115 {
116 return ov.attrs.clone();
117 }
118 Vec::new()
119 }
120
121 fn validate_variant_overrides(&self, op: &Operation, kind: GeneratedTypeKind) -> TokenStream {
124 let Some(TypeOverride::Rename {
125 variant_overrides, ..
126 }) = self.overrides.get(op.method, &op.path, kind)
127 else {
128 return quote! {};
129 };
130
131 let valid_codes: std::collections::HashSet<u16> = op
133 .responses
134 .iter()
135 .filter_map(|r| match (&r.status_code, kind) {
136 (ResponseStatus::Code(code), GeneratedTypeKind::Ok)
137 if r.status_code.is_success() =>
138 {
139 Some(*code)
140 }
141 (ResponseStatus::Code(code), GeneratedTypeKind::Err)
142 if r.status_code.is_error() =>
143 {
144 Some(*code)
145 }
146 _ => None,
147 })
148 .collect();
149
150 let errors: Vec<_> = variant_overrides
152 .iter()
153 .filter(|(status, _)| !valid_codes.contains(status))
154 .map(|(status, ov)| {
155 let valid_list: Vec<_> = valid_codes.iter().collect();
156 let kind_name = match kind {
157 GeneratedTypeKind::Ok => "success",
158 GeneratedTypeKind::Err => "error",
159 GeneratedTypeKind::Query => "query",
160 };
161 let msg = format!(
162 "status code {} does not exist in the OpenAPI spec for {} {} (valid {} codes: {:?})",
163 status, op.method, op.path, kind_name, valid_list
164 );
165 quote_spanned! { ov.name.span() =>
166 compile_error!(#msg);
167 }
168 })
169 .collect();
170
171 quote! { #(#errors)* }
172 }
173
174 fn get_enum_attrs(&self, op: &Operation, kind: GeneratedTypeKind) -> Vec<TokenStream> {
177 if let Some(TypeOverride::Rename { attrs, .. }) =
178 self.overrides.get(op.method, &op.path, kind)
179 {
180 let has_derive = attrs.iter().any(|a| a.to_string().starts_with("#[derive"));
182 if has_derive {
183 return attrs.clone();
184 } else {
185 let mut result = vec![self.suffixes.default_derives.clone()];
187 result.extend(attrs.clone());
188 return result;
189 }
190 }
191 Vec::new()
192 }
193
194 fn generate_response_enum(
196 &self,
197 op: &Operation,
198 op_name: &str,
199 kind: GeneratedTypeKind,
200 ) -> TokenStream {
201 if self.overrides.is_replaced(op.method, &op.path, kind) {
203 return quote! {};
204 }
205
206 let validation_errors = self.validate_variant_overrides(op, kind);
208
209 let suffix = match kind {
210 GeneratedTypeKind::Ok => &self.suffixes.ok_suffix,
211 GeneratedTypeKind::Err => &self.suffixes.err_suffix,
212 GeneratedTypeKind::Query => unreachable!(),
213 };
214
215 let enum_name = self.get_enum_name(op, &format!("{}{}", op_name, suffix), kind);
216
217 let responses: Vec<_> = op
219 .responses
220 .iter()
221 .filter(|r| match kind {
222 GeneratedTypeKind::Ok => r.status_code.is_success(),
223 GeneratedTypeKind::Err => r.status_code.is_error(),
224 GeneratedTypeKind::Query => false,
225 })
226 .collect();
227
228 if responses.is_empty() {
230 return self.generate_empty_fallback(&enum_name, kind);
231 }
232
233 let mut errors = Vec::new();
234 let mut inline_definitions = Vec::new();
235
236 let variants: Vec<VariantInfo> = responses
238 .iter()
239 .map(|resp| {
240 let (variant_name, status, is_default) = match &resp.status_code {
241 ResponseStatus::Code(code) => {
242 (self.get_variant_name(op, *code, kind), *code, false)
243 }
244 ResponseStatus::Default => {
245 (format_ident!("Default"), 500, true)
247 }
248 };
249
250 let variant_attrs = self.get_variant_attrs(op, status, kind);
251 let inner_type_override = self.get_inner_type_name(op, status, kind);
252
253 let body_type = if let Some(schema) = &resp.schema {
254 if let Some(ref inner_name) = inner_type_override
256 && !TypeGenerator::is_inline_schema(schema)
257 {
258 errors.push(quote_spanned! { inner_name.span() =>
259 compile_error!("inner type name override can only be used with inline schemas, not $ref");
260 });
261 }
262
263 let name_hint = inner_type_override
265 .as_ref()
266 .map(|i| i.to_string())
267 .unwrap_or_else(|| format!("{}{}", enum_name, variant_name));
268 let generated = self
269 .type_gen
270 .type_for_schema_with_definitions(schema, &name_hint);
271 inline_definitions.extend(generated.definitions);
272 Some(generated.type_ref)
273 } else {
274 None
275 };
276
277 VariantInfo {
278 name: variant_name,
279 body_type,
280 status,
281 is_default,
282 attrs: variant_attrs,
283 }
284 })
285 .collect();
286
287 let variant_defs = variants.iter().map(|v| {
289 let name = &v.name;
290 let attrs = &v.attrs;
291 match (&v.body_type, v.is_default) {
292 (Some(ty), true) => quote! { #(#attrs)* #name(::axum::http::StatusCode, #ty) },
293 (Some(ty), false) => quote! { #(#attrs)* #name(#ty) },
294 (None, true) => quote! { #(#attrs)* #name(::axum::http::StatusCode) },
295 (None, false) => quote! { #(#attrs)* #name },
296 }
297 });
298
299 let into_response_arms = variants.iter().map(|v| {
301 let name = &v.name;
302 let status_code = status_code_ident(v.status);
303
304 match (&v.body_type, v.is_default) {
305 (Some(_), true) => quote! {
306 Self::#name(status, body) => {
307 (status, ::axum::Json(body)).into_response()
308 }
309 },
310 (Some(_), false) => quote! {
311 Self::#name(body) => {
312 (#status_code, ::axum::Json(body)).into_response()
313 }
314 },
315 (None, true) => quote! {
316 Self::#name(status) => {
317 status.into_response()
318 }
319 },
320 (None, false) => quote! {
321 Self::#name => {
322 #status_code.into_response()
323 }
324 },
325 }
326 });
327
328 let enum_attrs = self.get_enum_attrs(op, kind);
329 let default_derives = &self.suffixes.default_derives;
330
331 let attrs_tokens = if enum_attrs.is_empty() {
333 quote! { #default_derives }
334 } else {
335 quote! { #(#enum_attrs)* }
336 };
337
338 let enum_span = enum_name.span();
341
342 let enum_def = quote_spanned! { enum_span =>
343 #attrs_tokens
344 pub enum #enum_name {
345 #(#variant_defs,)*
346 }
347 };
348
349 quote! {
350 #validation_errors
351
352 #(#errors)*
353
354 #(#inline_definitions)*
355
356 #enum_def
357
358 impl ::axum::response::IntoResponse for #enum_name {
359 fn into_response(self) -> ::axum::response::Response {
360 use ::axum::response::IntoResponse;
361 match self {
362 #(#into_response_arms)*
363 }
364 }
365 }
366 }
367 }
368
369 fn generate_empty_fallback(
373 &self,
374 enum_name: &syn::Ident,
375 kind: GeneratedTypeKind,
376 ) -> TokenStream {
377 let span = enum_name.span();
378 match kind {
379 GeneratedTypeKind::Ok => quote_spanned! { span =>
380 #[derive(Debug)]
381 pub enum #enum_name {
382 Status200,
383 }
384
385 impl ::axum::response::IntoResponse for #enum_name {
386 fn into_response(self) -> ::axum::response::Response {
387 match self {
388 Self::Status200 => {
389 ::axum::http::StatusCode::OK.into_response()
390 }
391 }
392 }
393 }
394 },
395 GeneratedTypeKind::Err => quote! {},
397 GeneratedTypeKind::Query => unreachable!(),
398 }
399 }
400}
401
402struct VariantInfo {
404 name: syn::Ident,
405 body_type: Option<TokenStream>,
406 status: u16,
407 is_default: bool,
408 attrs: Vec<TokenStream>,
409}
410
411fn status_code_ident(status: u16) -> TokenStream {
413 let name = match status {
415 200 => Some("OK"),
416 201 => Some("CREATED"),
417 202 => Some("ACCEPTED"),
418 204 => Some("NO_CONTENT"),
419 301 => Some("MOVED_PERMANENTLY"),
420 302 => Some("FOUND"),
421 304 => Some("NOT_MODIFIED"),
422 400 => Some("BAD_REQUEST"),
423 401 => Some("UNAUTHORIZED"),
424 403 => Some("FORBIDDEN"),
425 404 => Some("NOT_FOUND"),
426 405 => Some("METHOD_NOT_ALLOWED"),
427 409 => Some("CONFLICT"),
428 410 => Some("GONE"),
429 422 => Some("UNPROCESSABLE_ENTITY"),
430 429 => Some("TOO_MANY_REQUESTS"),
431 500 => Some("INTERNAL_SERVER_ERROR"),
432 501 => Some("NOT_IMPLEMENTED"),
433 502 => Some("BAD_GATEWAY"),
434 503 => Some("SERVICE_UNAVAILABLE"),
435 504 => Some("GATEWAY_TIMEOUT"),
436 _ => None,
437 };
438
439 if let Some(name) = name {
440 let ident = format_ident!("{}", name);
441 quote! { ::axum::http::StatusCode::#ident }
442 } else {
443 quote! { ::axum::http::StatusCode::from_u16(#status).unwrap() }
444 }
445}