Skip to main content

oxapi_impl/
method.rs

1//! Method transformation for trait methods.
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{FnArg, GenericArgument, PathArguments, Type};
6
7use crate::openapi::Operation;
8use crate::types::TypeGenerator;
9use crate::{GeneratedTypeKind, Generator, TypeOverride};
10
11/// Transforms user-defined trait methods based on OpenAPI operations.
12pub struct MethodTransformer<'a> {
13    generator: &'a Generator,
14    types_mod: &'a syn::Ident,
15}
16
17impl<'a> MethodTransformer<'a> {
18    pub fn new(generator: &'a Generator, types_mod: &'a syn::Ident) -> Self {
19        Self {
20            generator,
21            types_mod,
22        }
23    }
24
25    /// Transform a trait method based on its operation.
26    pub fn transform(&self, method: &syn::TraitItemFn, op: &Operation) -> syn::Result<TokenStream> {
27        let op_name = op.name();
28
29        let types_mod = self.types_mod;
30        let overrides = self.generator.type_overrides();
31
32        let suffixes = self.generator.response_suffixes();
33
34        // Get Ok type (may be renamed or replaced)
35        let ok_type: TokenStream = match overrides.get(op.method, &op.path, GeneratedTypeKind::Ok) {
36            Some(TypeOverride::Rename { name, .. }) => {
37                let ident = format_ident!("{}", name);
38                quote! { #types_mod::#ident }
39            }
40            Some(TypeOverride::Replace(replacement)) => replacement.clone(),
41            None => {
42                let ident = format_ident!("{}{}", op_name, suffixes.ok_suffix);
43                quote! { #types_mod::#ident }
44            }
45        };
46
47        // Transform each parameter, filling in type elisions
48        let transformed_params = self.transform_params(method, op)?;
49
50        // Determine if async and build return type
51        let is_async = method.sig.asyncness.is_some();
52
53        // Only wrap in Result if the operation has error responses defined
54        let return_type = if op.has_error_responses() {
55            // Get Err type (may be renamed or replaced)
56            let err_type: TokenStream =
57                match overrides.get(op.method, &op.path, GeneratedTypeKind::Err) {
58                    Some(TypeOverride::Rename { name, .. }) => {
59                        let ident = format_ident!("{}", name);
60                        quote! { #types_mod::#ident }
61                    }
62                    Some(TypeOverride::Replace(replacement)) => replacement.clone(),
63                    None => {
64                        let ident = format_ident!("{}{}", op_name, suffixes.err_suffix);
65                        quote! { #types_mod::#ident }
66                    }
67                };
68            quote! { ::core::result::Result<#ok_type, #err_type> }
69        } else {
70            // No errors in spec, return Ok type directly
71            ok_type.clone()
72        };
73
74        let method_name = &method.sig.ident;
75
76        // Build the signature
77        if is_async {
78            // Rewrite async fn to fn -> impl Send + Future<Output = ...>
79            Ok(quote! {
80                fn #method_name(#transformed_params) -> impl ::core::marker::Send + ::core::future::Future<Output = #return_type>;
81            })
82        } else {
83            // Sync function
84            Ok(quote! {
85                fn #method_name(#transformed_params) -> #return_type;
86            })
87        }
88    }
89
90    /// Transform the parameters, filling in type elisions.
91    fn transform_params(
92        &self,
93        method: &syn::TraitItemFn,
94        op: &Operation,
95    ) -> syn::Result<TokenStream> {
96        let type_gen = self.generator.type_generator();
97
98        let mut transformed = Vec::new();
99
100        for arg in &method.sig.inputs {
101            match arg {
102                FnArg::Receiver(_) => {
103                    return Err(syn::Error::new_spanned(
104                        arg,
105                        "oxapi trait methods must be static (no self)",
106                    ));
107                }
108                FnArg::Typed(pat_type) => {
109                    let pat = &pat_type.pat;
110                    let ty = &pat_type.ty;
111
112                    // Check if this is an extractor with type elision
113                    let transformed_ty = self.transform_type(ty, op, type_gen)?;
114
115                    transformed.push(quote! { #pat: #transformed_ty });
116                }
117            }
118        }
119
120        Ok(quote! { #(#transformed),* })
121    }
122
123    /// Transform a type, filling in elisions like `Path<_>`.
124    fn transform_type(
125        &self,
126        ty: &Type,
127        op: &Operation,
128        type_gen: &TypeGenerator,
129    ) -> syn::Result<TokenStream> {
130        let types_mod = self.types_mod;
131
132        match ty {
133            Type::Path(type_path) => {
134                let last_segment = type_path
135                    .path
136                    .segments
137                    .last()
138                    .ok_or_else(|| syn::Error::new_spanned(ty, "empty type path"))?;
139
140                let type_name = last_segment.ident.to_string();
141
142                match type_name.as_str() {
143                    "Path" => {
144                        // Fill in path parameter type
145                        let inner = self.get_or_infer_inner(&last_segment.arguments, || {
146                            type_gen.generate_path_type(op)
147                        })?;
148                        Ok(quote! { ::axum::extract::Path<#inner> })
149                    }
150                    "Query" => {
151                        // Fill in query parameter type
152                        let overrides = self.generator.type_overrides();
153                        let inner = self.get_or_infer_inner(&last_segment.arguments, || {
154                            // Check for replacement first
155                            if let Some(TypeOverride::Replace(replacement)) =
156                                overrides.get(op.method, &op.path, GeneratedTypeKind::Query)
157                            {
158                                return replacement.clone();
159                            }
160
161                            // Generate a query struct if needed (may be renamed)
162                            if let Some((name, _)) = type_gen.generate_query_struct(op, overrides) {
163                                quote! { #types_mod::#name }
164                            } else {
165                                quote! { () }
166                            }
167                        })?;
168                        Ok(quote! { ::axum::extract::Query<#inner> })
169                    }
170                    "Json" => {
171                        // Fill in request body type
172                        let inner = self.get_or_infer_inner(&last_segment.arguments, || {
173                            if let Some(body) = &op.request_body {
174                                let op_name = op.operation_id.as_deref().unwrap_or(&op.path);
175                                type_gen.request_body_type(body, op_name)
176                            } else {
177                                quote! { serde_json::Value }
178                            }
179                        })?;
180                        Ok(quote! { ::axum::extract::Json<#inner> })
181                    }
182                    "State" => {
183                        // State type is user-provided, don't modify
184                        Ok(quote! { #ty })
185                    }
186                    _ => {
187                        // Other types pass through unchanged
188                        Ok(quote! { #ty })
189                    }
190                }
191            }
192            _ => {
193                // Non-path types pass through unchanged
194                Ok(quote! { #ty })
195            }
196        }
197    }
198
199    /// Get the inner type from generics, or infer it if it's `_`.
200    fn get_or_infer_inner<F>(&self, args: &PathArguments, infer: F) -> syn::Result<TokenStream>
201    where
202        F: FnOnce() -> TokenStream,
203    {
204        match args {
205            PathArguments::None => {
206                // No generics, infer the type
207                Ok(infer())
208            }
209            PathArguments::AngleBracketed(args) => {
210                if let Some(GenericArgument::Type(Type::Infer(_))) = args.args.first() {
211                    // Type is `_`, infer it
212                    Ok(infer())
213                } else if let Some(GenericArgument::Type(ty)) = args.args.first() {
214                    // Type is explicit, use it
215                    Ok(quote! { #ty })
216                } else {
217                    Err(syn::Error::new_spanned(args, "expected type argument"))
218                }
219            }
220            PathArguments::Parenthesized(_) => Err(syn::Error::new_spanned(
221                args,
222                "unexpected parenthesized arguments",
223            )),
224        }
225    }
226}