1use proc_macro2::TokenStream;
4use quote::{format_ident, quote, quote_spanned};
5
6use crate::openapi::{Operation, ParsedSpec, ResponseStatus};
7use crate::types::TypeGenerator;
8use crate::{GeneratedTypeKind, ResponseSuffixes, TypeOverride, TypeOverrides};
9
10pub struct ResponseGenerator<'a> {
12 spec: &'a ParsedSpec,
13 type_gen: &'a TypeGenerator,
14 overrides: &'a TypeOverrides,
15 suffixes: &'a ResponseSuffixes,
16}
17
18impl<'a> ResponseGenerator<'a> {
19 pub fn new(
20 spec: &'a ParsedSpec,
21 type_gen: &'a TypeGenerator,
22 overrides: &'a TypeOverrides,
23 suffixes: &'a ResponseSuffixes,
24 ) -> Self {
25 Self {
26 spec,
27 type_gen,
28 overrides,
29 suffixes,
30 }
31 }
32
33 fn response_matches_kind(status: &ResponseStatus, kind: GeneratedTypeKind) -> bool {
35 match kind {
36 GeneratedTypeKind::Ok => status.is_success(),
37 GeneratedTypeKind::Err => status.is_error(),
38 GeneratedTypeKind::Query => false,
39 }
40 }
41
42 pub fn generate_all(&self) -> TokenStream {
44 let enums: Vec<_> = self
45 .spec
46 .operations()
47 .map(|op| self.generate_for_operation(op))
48 .collect();
49
50 quote! {
51 #(#enums)*
52 }
53 }
54
55 pub fn generate_for_operation(&self, op: &Operation) -> TokenStream {
57 let op_name = op.name();
58
59 let ok_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Ok);
60 let err_enum = self.generate_response_enum(op, &op_name, GeneratedTypeKind::Err);
61
62 quote! {
63 #ok_enum
64 #err_enum
65 }
66 }
67
68 fn get_enum_name(
70 &self,
71 op: &Operation,
72 default_name: &str,
73 kind: GeneratedTypeKind,
74 ) -> syn::Ident {
75 if let Some(TypeOverride::Rename { name, .. }) =
76 self.overrides.get(op.method, &op.path, kind)
77 {
78 name.clone()
79 } else {
80 format_ident!("{}", default_name)
81 }
82 }
83
84 fn get_variant_override<T, F>(
86 &self,
87 op: &Operation,
88 status: u16,
89 kind: GeneratedTypeKind,
90 getter: F,
91 default: T,
92 ) -> T
93 where
94 F: FnOnce(&crate::VariantOverride) -> T,
95 {
96 if let Some(TypeOverride::Rename {
97 variant_overrides, ..
98 }) = self.overrides.get(op.method, &op.path, kind)
99 && let Some(ov) = variant_overrides.get(&status)
100 {
101 return getter(ov);
102 }
103 default
104 }
105
106 fn get_variant_name(&self, op: &Operation, status: u16, kind: GeneratedTypeKind) -> syn::Ident {
108 self.get_variant_override(
109 op,
110 status,
111 kind,
112 |ov| ov.name.clone(),
113 format_ident!("Status{}", status),
114 )
115 }
116
117 fn get_inner_type_name(
119 &self,
120 op: &Operation,
121 status: u16,
122 kind: GeneratedTypeKind,
123 ) -> Option<syn::Ident> {
124 self.get_variant_override(op, status, kind, |ov| ov.inner_type_name.clone(), None)
125 }
126
127 fn get_variant_attrs(
129 &self,
130 op: &Operation,
131 status: u16,
132 kind: GeneratedTypeKind,
133 ) -> Vec<TokenStream> {
134 self.get_variant_override(op, status, kind, |ov| ov.attrs.clone(), Vec::new())
135 }
136
137 fn validate_variant_overrides(&self, op: &Operation, kind: GeneratedTypeKind) -> TokenStream {
140 let Some(TypeOverride::Rename {
141 variant_overrides, ..
142 }) = self.overrides.get(op.method, &op.path, kind)
143 else {
144 return quote! {};
145 };
146
147 let valid_codes: std::collections::HashSet<u16> = op
149 .responses
150 .iter()
151 .filter(|r| Self::response_matches_kind(&r.status_code, kind))
152 .filter_map(|r| match &r.status_code {
153 ResponseStatus::Code(code) => Some(*code),
154 _ => None,
155 })
156 .collect();
157
158 let errors: Vec<_> = variant_overrides
160 .iter()
161 .filter(|(status, _)| !valid_codes.contains(status))
162 .map(|(status, ov)| {
163 let valid_list: Vec<_> = valid_codes.iter().collect();
164 let kind_name = match kind {
165 GeneratedTypeKind::Ok => "success",
166 GeneratedTypeKind::Err => "error",
167 GeneratedTypeKind::Query => "query",
168 };
169 let msg = format!(
170 "status code {} does not exist in the OpenAPI spec for {} {} (valid {} codes: {:?})",
171 status, op.method, op.path, kind_name, valid_list
172 );
173 quote_spanned! { ov.name.span() =>
174 compile_error!(#msg);
175 }
176 })
177 .collect();
178
179 quote! { #(#errors)* }
180 }
181
182 fn get_enum_attrs(&self, op: &Operation, kind: GeneratedTypeKind) -> Vec<TokenStream> {
185 if let Some(TypeOverride::Rename { attrs, .. }) =
186 self.overrides.get(op.method, &op.path, kind)
187 {
188 let has_derive = attrs.iter().any(|a| a.to_string().starts_with("#[derive"));
190 if has_derive {
191 return attrs.clone();
192 } else {
193 let mut result = vec![self.suffixes.default_derives.clone()];
195 result.extend(attrs.clone());
196 return result;
197 }
198 }
199 Vec::new()
200 }
201
202 fn generate_response_enum(
204 &self,
205 op: &Operation,
206 op_name: &str,
207 kind: GeneratedTypeKind,
208 ) -> TokenStream {
209 if self.overrides.is_replaced(op.method, &op.path, kind) {
211 return quote! {};
212 }
213
214 let validation_errors = self.validate_variant_overrides(op, kind);
216
217 let suffix = match kind {
218 GeneratedTypeKind::Ok => &self.suffixes.ok_suffix,
219 GeneratedTypeKind::Err => &self.suffixes.err_suffix,
220 GeneratedTypeKind::Query => unreachable!(),
221 };
222
223 let enum_name = self.get_enum_name(op, &format!("{}{}", op_name, suffix), kind);
224
225 let responses: Vec<_> = op
227 .responses
228 .iter()
229 .filter(|r| Self::response_matches_kind(&r.status_code, kind))
230 .collect();
231
232 if responses.is_empty() {
234 return self.generate_empty_fallback(&enum_name, kind);
235 }
236
237 let mut errors = Vec::new();
238 let mut inline_definitions = Vec::new();
239
240 let variants: Vec<VariantInfo> = responses
242 .iter()
243 .map(|resp| {
244 let (variant_name, status, is_default) = match &resp.status_code {
245 ResponseStatus::Code(code) => {
246 (self.get_variant_name(op, *code, kind), *code, false)
247 }
248 ResponseStatus::Default => {
249 (format_ident!("Default"), 500, true)
251 }
252 };
253
254 let variant_attrs = self.get_variant_attrs(op, status, kind);
255 let inner_type_override = self.get_inner_type_name(op, status, kind);
256
257 let body_type = if let Some(schema) = &resp.schema {
258 if let Some(ref inner_name) = inner_type_override
260 && !TypeGenerator::is_inline_schema(schema)
261 {
262 errors.push(quote_spanned! { inner_name.span() =>
263 compile_error!("inner type name override can only be used with inline schemas, not $ref");
264 });
265 }
266
267 let name_hint = inner_type_override
269 .as_ref()
270 .map(|i| i.to_string())
271 .unwrap_or_else(|| format!("{}{}", enum_name, variant_name));
272 let generated = self
273 .type_gen
274 .type_for_schema_with_definitions(schema, &name_hint);
275 inline_definitions.extend(generated.definitions);
276 Some(generated.type_ref)
277 } else {
278 None
279 };
280
281 VariantInfo {
282 name: variant_name,
283 body_type,
284 status,
285 is_default,
286 attrs: variant_attrs,
287 }
288 })
289 .collect();
290
291 let variant_defs = variants.iter().map(|v| {
293 let name = &v.name;
294 let attrs = &v.attrs;
295 match (&v.body_type, v.is_default) {
296 (Some(ty), true) => quote! { #(#attrs)* #name(::axum::http::StatusCode, #ty) },
297 (Some(ty), false) => quote! { #(#attrs)* #name(#ty) },
298 (None, true) => quote! { #(#attrs)* #name(::axum::http::StatusCode) },
299 (None, false) => quote! { #(#attrs)* #name },
300 }
301 });
302
303 let into_response_arms = variants.iter().map(|v| {
305 let name = &v.name;
306 let status_code = status_code_ident(v.status);
307
308 match (&v.body_type, v.is_default) {
309 (Some(_), true) => quote! {
310 Self::#name(status, body) => {
311 (status, ::axum::Json(body)).into_response()
312 }
313 },
314 (Some(_), false) => quote! {
315 Self::#name(body) => {
316 (#status_code, ::axum::Json(body)).into_response()
317 }
318 },
319 (None, true) => quote! {
320 Self::#name(status) => {
321 status.into_response()
322 }
323 },
324 (None, false) => quote! {
325 Self::#name => {
326 #status_code.into_response()
327 }
328 },
329 }
330 });
331
332 let enum_attrs = self.get_enum_attrs(op, kind);
333 let default_derives = &self.suffixes.default_derives;
334
335 let attrs_tokens = if enum_attrs.is_empty() {
337 quote! { #default_derives }
338 } else {
339 quote! { #(#enum_attrs)* }
340 };
341
342 let enum_span = enum_name.span();
345
346 let enum_def = quote_spanned! { enum_span =>
347 #attrs_tokens
348 pub enum #enum_name {
349 #(#variant_defs,)*
350 }
351 };
352
353 quote! {
354 #validation_errors
355
356 #(#errors)*
357
358 #(#inline_definitions)*
359
360 #enum_def
361
362 impl ::axum::response::IntoResponse for #enum_name {
363 fn into_response(self) -> ::axum::response::Response {
364 use ::axum::response::IntoResponse;
365 match self {
366 #(#into_response_arms)*
367 }
368 }
369 }
370 }
371 }
372
373 fn generate_empty_fallback(
377 &self,
378 enum_name: &syn::Ident,
379 kind: GeneratedTypeKind,
380 ) -> TokenStream {
381 let span = enum_name.span();
382 match kind {
383 GeneratedTypeKind::Ok => quote_spanned! { span =>
384 #[derive(Debug)]
385 pub enum #enum_name {
386 Status200,
387 }
388
389 impl ::axum::response::IntoResponse for #enum_name {
390 fn into_response(self) -> ::axum::response::Response {
391 match self {
392 Self::Status200 => {
393 ::axum::http::StatusCode::OK.into_response()
394 }
395 }
396 }
397 }
398 },
399 GeneratedTypeKind::Err => quote! {},
401 GeneratedTypeKind::Query => unreachable!(),
402 }
403 }
404}
405
406struct VariantInfo {
408 name: syn::Ident,
409 body_type: Option<TokenStream>,
410 status: u16,
411 is_default: bool,
412 attrs: Vec<TokenStream>,
413}
414
415const STATUS_CODE_NAMES: &[(u16, &str)] = &[
417 (200, "OK"),
418 (201, "CREATED"),
419 (202, "ACCEPTED"),
420 (204, "NO_CONTENT"),
421 (301, "MOVED_PERMANENTLY"),
422 (302, "FOUND"),
423 (304, "NOT_MODIFIED"),
424 (400, "BAD_REQUEST"),
425 (401, "UNAUTHORIZED"),
426 (403, "FORBIDDEN"),
427 (404, "NOT_FOUND"),
428 (405, "METHOD_NOT_ALLOWED"),
429 (409, "CONFLICT"),
430 (410, "GONE"),
431 (422, "UNPROCESSABLE_ENTITY"),
432 (429, "TOO_MANY_REQUESTS"),
433 (500, "INTERNAL_SERVER_ERROR"),
434 (501, "NOT_IMPLEMENTED"),
435 (502, "BAD_GATEWAY"),
436 (503, "SERVICE_UNAVAILABLE"),
437 (504, "GATEWAY_TIMEOUT"),
438];
439
440fn status_code_ident(status: u16) -> TokenStream {
442 if let Some((_, name)) = STATUS_CODE_NAMES.iter().find(|(code, _)| *code == status) {
444 let ident = format_ident!("{}", name);
445 quote! { ::axum::http::StatusCode::#ident }
446 } else {
447 quote! { ::axum::http::StatusCode::from_u16(#status).unwrap() }
448 }
449}