1use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6use syn::{FnArg, GenericArgument, PathArguments, Type};
7
8use crate::Generator;
9use crate::openapi::Operation;
10use crate::types::TypeGenerator;
11
12pub struct MethodTransformer<'a> {
14 generator: &'a Generator,
15 types_mod: &'a syn::Ident,
16}
17
18impl<'a> MethodTransformer<'a> {
19 pub fn new(generator: &'a Generator, types_mod: &'a syn::Ident) -> Self {
20 Self {
21 generator,
22 types_mod,
23 }
24 }
25
26 pub fn transform(&self, method: &syn::TraitItemFn, op: &Operation) -> syn::Result<TokenStream> {
28 let op_name = op
29 .operation_id
30 .as_deref()
31 .unwrap_or(&op.path)
32 .to_upper_camel_case();
33
34 let types_mod = self.types_mod;
35 let ok_type = format_ident!("{}Ok", op_name);
36 let err_type = format_ident!("{}Err", op_name);
37
38 let transformed_params = self.transform_params(method, op)?;
40
41 let is_async = method.sig.asyncness.is_some();
43
44 let return_type = quote! {
45 ::core::result::Result<#types_mod::#ok_type, #types_mod::#err_type>
46 };
47
48 let method_name = &method.sig.ident;
49
50 if is_async {
52 Ok(quote! {
54 fn #method_name(#transformed_params) -> impl ::core::marker::Send + ::core::future::Future<Output = #return_type>;
55 })
56 } else {
57 Ok(quote! {
59 fn #method_name(#transformed_params) -> #return_type;
60 })
61 }
62 }
63
64 fn transform_params(
66 &self,
67 method: &syn::TraitItemFn,
68 op: &Operation,
69 ) -> syn::Result<TokenStream> {
70 let type_gen = self.generator.type_generator();
71
72 let mut transformed = Vec::new();
73
74 for arg in &method.sig.inputs {
75 match arg {
76 FnArg::Receiver(_) => {
77 return Err(syn::Error::new_spanned(
78 arg,
79 "oxapi trait methods must be static (no self)",
80 ));
81 }
82 FnArg::Typed(pat_type) => {
83 let pat = &pat_type.pat;
84 let ty = &pat_type.ty;
85
86 let transformed_ty = self.transform_type(ty, op, type_gen)?;
88
89 transformed.push(quote! { #pat: #transformed_ty });
90 }
91 }
92 }
93
94 Ok(quote! { #(#transformed),* })
95 }
96
97 fn transform_type(
99 &self,
100 ty: &Type,
101 op: &Operation,
102 type_gen: &TypeGenerator,
103 ) -> syn::Result<TokenStream> {
104 let types_mod = self.types_mod;
105
106 match ty {
107 Type::Path(type_path) => {
108 let last_segment = type_path
109 .path
110 .segments
111 .last()
112 .ok_or_else(|| syn::Error::new_spanned(ty, "empty type path"))?;
113
114 let type_name = last_segment.ident.to_string();
115
116 match type_name.as_str() {
117 "Path" => {
118 let inner = self.get_or_infer_inner(&last_segment.arguments, || {
120 type_gen.generate_path_type(op)
121 })?;
122 Ok(quote! { ::axum::extract::Path<#inner> })
123 }
124 "Query" => {
125 let inner = self.get_or_infer_inner(&last_segment.arguments, || {
127 if let Some((name, _)) = type_gen.generate_query_struct(op) {
129 quote! { #types_mod::#name }
130 } else {
131 quote! { () }
132 }
133 })?;
134 Ok(quote! { ::axum::extract::Query<#inner> })
135 }
136 "Json" => {
137 let inner = self.get_or_infer_inner(&last_segment.arguments, || {
139 if let Some(body) = &op.request_body {
140 let op_name = op.operation_id.as_deref().unwrap_or(&op.path);
141 type_gen.request_body_type(body, op_name)
142 } else {
143 quote! { serde_json::Value }
144 }
145 })?;
146 Ok(quote! { ::axum::extract::Json<#inner> })
147 }
148 "State" => {
149 Ok(quote! { #ty })
151 }
152 _ => {
153 Ok(quote! { #ty })
155 }
156 }
157 }
158 _ => {
159 Ok(quote! { #ty })
161 }
162 }
163 }
164
165 fn get_or_infer_inner<F>(&self, args: &PathArguments, infer: F) -> syn::Result<TokenStream>
167 where
168 F: FnOnce() -> TokenStream,
169 {
170 match args {
171 PathArguments::None => {
172 Ok(infer())
174 }
175 PathArguments::AngleBracketed(args) => {
176 if let Some(GenericArgument::Type(Type::Infer(_))) = args.args.first() {
177 Ok(infer())
179 } else if let Some(GenericArgument::Type(ty)) = args.args.first() {
180 Ok(quote! { #ty })
182 } else {
183 Err(syn::Error::new_spanned(args, "expected type argument"))
184 }
185 }
186 PathArguments::Parenthesized(_) => Err(syn::Error::new_spanned(
187 args,
188 "unexpected parenthesized arguments",
189 )),
190 }
191 }
192}