1use openapiv3::{ReferenceOr, SchemaKind, Type as OapiType};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6use syn::{FnArg, GenericArgument, PathArguments, Type};
7
8use crate::openapi::Operation;
9use crate::types::TypeGenerator;
10use crate::{GeneratedTypeKind, Generator, TypeOverride};
11
12pub struct MethodTransformer<'a> {
14 generator: &'a Generator,
15 types_mod: &'a syn::Ident,
16}
17
18impl<'a> MethodTransformer<'a> {
19 pub fn new(generator: &'a Generator, types_mod: &'a syn::Ident) -> Self {
20 Self {
21 generator,
22 types_mod,
23 }
24 }
25
26 pub fn transform(&self, method: &syn::TraitItemFn, op: &Operation) -> syn::Result<TokenStream> {
28 let op_name = op.name();
29
30 let types_mod = self.types_mod;
31 let overrides = self.generator.type_overrides();
32
33 let suffixes = self.generator.response_suffixes();
34
35 let ok_type: TokenStream = match overrides.get(op.method, &op.path, GeneratedTypeKind::Ok) {
37 Some(TypeOverride::Rename { name, .. }) => {
38 let ident = format_ident!("{}", name);
39 quote! { #types_mod::#ident }
40 }
41 Some(TypeOverride::Replace(replacement)) => replacement.clone(),
42 None => {
43 let ident = format_ident!("{}{}", op_name, suffixes.ok_suffix);
44 quote! { #types_mod::#ident }
45 }
46 };
47
48 let transformed_params = self.transform_params(method, op)?;
50
51 let is_async = method.sig.asyncness.is_some();
53
54 let return_type = if op.has_error_responses() {
56 let err_type: TokenStream =
58 match overrides.get(op.method, &op.path, GeneratedTypeKind::Err) {
59 Some(TypeOverride::Rename { name, .. }) => {
60 let ident = format_ident!("{}", name);
61 quote! { #types_mod::#ident }
62 }
63 Some(TypeOverride::Replace(replacement)) => replacement.clone(),
64 None => {
65 let ident = format_ident!("{}{}", op_name, suffixes.err_suffix);
66 quote! { #types_mod::#ident }
67 }
68 };
69 quote! { ::core::result::Result<#ok_type, #err_type> }
70 } else {
71 ok_type.clone()
73 };
74
75 let method_name = &method.sig.ident;
76
77 if is_async {
79 Ok(quote! {
81 fn #method_name(#transformed_params) -> impl ::core::marker::Send + ::core::future::Future<Output = #return_type>;
82 })
83 } else {
84 Ok(quote! {
86 fn #method_name(#transformed_params) -> #return_type;
87 })
88 }
89 }
90
91 fn transform_params(
93 &self,
94 method: &syn::TraitItemFn,
95 op: &Operation,
96 ) -> syn::Result<TokenStream> {
97 let type_gen = self.generator.type_generator();
98
99 let mut transformed = Vec::new();
100
101 for arg in &method.sig.inputs {
102 match arg {
103 FnArg::Receiver(_) => {
104 return Err(syn::Error::new_spanned(
105 arg,
106 "oxapi trait methods must be static (no self)",
107 ));
108 }
109 FnArg::Typed(pat_type) => {
110 let pat = &pat_type.pat;
111 let ty = &pat_type.ty;
112
113 let transformed_ty = self.transform_type(ty, op, type_gen)?;
115
116 transformed.push(quote! { #pat: #transformed_ty });
117 }
118 }
119 }
120
121 Ok(quote! { #(#transformed),* })
122 }
123
124 fn transform_type(
126 &self,
127 ty: &Type,
128 op: &Operation,
129 type_gen: &TypeGenerator,
130 ) -> syn::Result<TokenStream> {
131 let types_mod = self.types_mod;
132
133 match ty {
134 Type::Path(type_path) => {
135 let last_segment = type_path
136 .path
137 .segments
138 .last()
139 .ok_or_else(|| syn::Error::new_spanned(ty, "empty type path"))?;
140
141 let type_name = last_segment.ident.to_string();
142
143 match type_name.as_str() {
144 "Path" => {
145 let inner = self.get_or_infer_inner(&last_segment.arguments, || {
147 type_gen.generate_path_type(op)
148 })?;
149 Ok(quote! { ::axum::extract::Path<#inner> })
150 }
151 "Query" => {
152 let overrides = self.generator.type_overrides();
154 let inner = self.get_or_infer_inner(&last_segment.arguments, || {
155 if let Some(TypeOverride::Replace(replacement)) =
157 overrides.get(op.method, &op.path, GeneratedTypeKind::Query)
158 {
159 return replacement.clone();
160 }
161
162 if let Some((name, _)) = type_gen.generate_query_struct(op, overrides) {
164 quote! { #types_mod::#name }
165 } else {
166 quote! { () }
167 }
168 })?;
169 Ok(quote! { ::axum::extract::Query<#inner> })
170 }
171 "Json" => {
172 let inner = self.get_or_infer_inner(&last_segment.arguments, || {
174 if let Some(body) = &op.request_body {
175 if let Some(schema) = &body.schema {
176 let op_name = op.operation_id.as_deref().unwrap_or(&op.path);
177 let body_type = type_gen.request_body_type(body, op_name);
178
179 let needs_prefix = match schema {
183 ReferenceOr::Reference { .. } => true,
184 ReferenceOr::Item(inline) => matches!(
185 &inline.schema_kind,
186 SchemaKind::Type(OapiType::Object(_))
187 ),
188 };
189
190 if needs_prefix {
191 quote! { #types_mod::#body_type }
192 } else {
193 body_type
194 }
195 } else {
196 quote! { serde_json::Value }
197 }
198 } else {
199 quote! { serde_json::Value }
200 }
201 })?;
202 Ok(quote! { ::axum::extract::Json<#inner> })
203 }
204 "State" => {
205 Ok(quote! { #ty })
207 }
208 _ => {
209 Ok(quote! { #ty })
211 }
212 }
213 }
214 _ => {
215 Ok(quote! { #ty })
217 }
218 }
219 }
220
221 fn get_or_infer_inner<F>(&self, args: &PathArguments, infer: F) -> syn::Result<TokenStream>
223 where
224 F: FnOnce() -> TokenStream,
225 {
226 match args {
227 PathArguments::None => {
228 Ok(infer())
230 }
231 PathArguments::AngleBracketed(args) => {
232 if let Some(GenericArgument::Type(Type::Infer(_))) = args.args.first() {
233 Ok(infer())
235 } else if let Some(GenericArgument::Type(ty)) = args.args.first() {
236 Ok(quote! { #ty })
238 } else {
239 Err(syn::Error::new_spanned(args, "expected type argument"))
240 }
241 }
242 PathArguments::Parenthesized(_) => Err(syn::Error::new_spanned(
243 args,
244 "unexpected parenthesized arguments",
245 )),
246 }
247 }
248}