1use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{FnArg, GenericArgument, PathArguments, Type};
6
7use crate::openapi::Operation;
8use crate::types::TypeGenerator;
9use crate::{GeneratedTypeKind, Generator, TypeOverride};
10
11pub struct MethodTransformer<'a> {
13 generator: &'a Generator,
14 types_mod: &'a syn::Ident,
15}
16
17impl<'a> MethodTransformer<'a> {
18 pub fn new(generator: &'a Generator, types_mod: &'a syn::Ident) -> Self {
19 Self {
20 generator,
21 types_mod,
22 }
23 }
24
25 pub fn transform(&self, method: &syn::TraitItemFn, op: &Operation) -> syn::Result<TokenStream> {
27 let op_name = op.name();
28
29 let types_mod = self.types_mod;
30 let overrides = self.generator.type_overrides();
31
32 let suffixes = self.generator.response_suffixes();
33
34 let ok_type: TokenStream = match overrides.get(op.method, &op.path, GeneratedTypeKind::Ok) {
36 Some(TypeOverride::Rename { name, .. }) => {
37 let ident = format_ident!("{}", name);
38 quote! { #types_mod::#ident }
39 }
40 Some(TypeOverride::Replace(replacement)) => replacement.clone(),
41 None => {
42 let ident = format_ident!("{}{}", op_name, suffixes.ok_suffix);
43 quote! { #types_mod::#ident }
44 }
45 };
46
47 let transformed_params = self.transform_params(method, op)?;
49
50 let is_async = method.sig.asyncness.is_some();
52
53 let return_type = if op.has_error_responses() {
55 let err_type: TokenStream =
57 match overrides.get(op.method, &op.path, GeneratedTypeKind::Err) {
58 Some(TypeOverride::Rename { name, .. }) => {
59 let ident = format_ident!("{}", name);
60 quote! { #types_mod::#ident }
61 }
62 Some(TypeOverride::Replace(replacement)) => replacement.clone(),
63 None => {
64 let ident = format_ident!("{}{}", op_name, suffixes.err_suffix);
65 quote! { #types_mod::#ident }
66 }
67 };
68 quote! { ::core::result::Result<#ok_type, #err_type> }
69 } else {
70 ok_type.clone()
72 };
73
74 let method_name = &method.sig.ident;
75
76 if is_async {
78 Ok(quote! {
80 fn #method_name(#transformed_params) -> impl ::core::marker::Send + ::core::future::Future<Output = #return_type>;
81 })
82 } else {
83 Ok(quote! {
85 fn #method_name(#transformed_params) -> #return_type;
86 })
87 }
88 }
89
90 fn transform_params(
92 &self,
93 method: &syn::TraitItemFn,
94 op: &Operation,
95 ) -> syn::Result<TokenStream> {
96 let type_gen = self.generator.type_generator();
97
98 let mut transformed = Vec::new();
99
100 for arg in &method.sig.inputs {
101 match arg {
102 FnArg::Receiver(_) => {
103 return Err(syn::Error::new_spanned(
104 arg,
105 "oxapi trait methods must be static (no self)",
106 ));
107 }
108 FnArg::Typed(pat_type) => {
109 let pat = &pat_type.pat;
110 let ty = &pat_type.ty;
111
112 let transformed_ty = self.transform_type(ty, op, type_gen)?;
114
115 transformed.push(quote! { #pat: #transformed_ty });
116 }
117 }
118 }
119
120 Ok(quote! { #(#transformed),* })
121 }
122
123 fn transform_type(
125 &self,
126 ty: &Type,
127 op: &Operation,
128 type_gen: &TypeGenerator,
129 ) -> syn::Result<TokenStream> {
130 let types_mod = self.types_mod;
131
132 match ty {
133 Type::Path(type_path) => {
134 let last_segment = type_path
135 .path
136 .segments
137 .last()
138 .ok_or_else(|| syn::Error::new_spanned(ty, "empty type path"))?;
139
140 let type_name = last_segment.ident.to_string();
141
142 match type_name.as_str() {
143 "Path" => {
144 let inner = self.get_or_infer_inner(&last_segment.arguments, || {
146 type_gen.generate_path_type(op)
147 })?;
148 Ok(quote! { ::axum::extract::Path<#inner> })
149 }
150 "Query" => {
151 let overrides = self.generator.type_overrides();
153 let inner = self.get_or_infer_inner(&last_segment.arguments, || {
154 if let Some(TypeOverride::Replace(replacement)) =
156 overrides.get(op.method, &op.path, GeneratedTypeKind::Query)
157 {
158 return replacement.clone();
159 }
160
161 if let Some((name, _)) = type_gen.generate_query_struct(op, overrides) {
163 quote! { #types_mod::#name }
164 } else {
165 quote! { () }
166 }
167 })?;
168 Ok(quote! { ::axum::extract::Query<#inner> })
169 }
170 "Json" => {
171 let inner = self.get_or_infer_inner(&last_segment.arguments, || {
173 if let Some(body) = &op.request_body {
174 let op_name = op.operation_id.as_deref().unwrap_or(&op.path);
175 type_gen.request_body_type(body, op_name)
176 } else {
177 quote! { serde_json::Value }
178 }
179 })?;
180 Ok(quote! { ::axum::extract::Json<#inner> })
181 }
182 "State" => {
183 Ok(quote! { #ty })
185 }
186 _ => {
187 Ok(quote! { #ty })
189 }
190 }
191 }
192 _ => {
193 Ok(quote! { #ty })
195 }
196 }
197 }
198
199 fn get_or_infer_inner<F>(&self, args: &PathArguments, infer: F) -> syn::Result<TokenStream>
201 where
202 F: FnOnce() -> TokenStream,
203 {
204 match args {
205 PathArguments::None => {
206 Ok(infer())
208 }
209 PathArguments::AngleBracketed(args) => {
210 if let Some(GenericArgument::Type(Type::Infer(_))) = args.args.first() {
211 Ok(infer())
213 } else if let Some(GenericArgument::Type(ty)) = args.args.first() {
214 Ok(quote! { #ty })
216 } else {
217 Err(syn::Error::new_spanned(args, "expected type argument"))
218 }
219 }
220 PathArguments::Parenthesized(_) => Err(syn::Error::new_spanned(
221 args,
222 "unexpected parenthesized arguments",
223 )),
224 }
225 }
226}