Skip to main content

oxapi_impl/
method.rs

1//! Method transformation for trait methods.
2
3use openapiv3::{ReferenceOr, SchemaKind, Type as OapiType};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6use syn::{FnArg, GenericArgument, PathArguments, Type};
7
8use crate::openapi::Operation;
9use crate::types::TypeGenerator;
10use crate::{GeneratedTypeKind, Generator, TypeOverride};
11
12/// Transforms user-defined trait methods based on OpenAPI operations.
13pub struct MethodTransformer<'a> {
14    generator: &'a Generator,
15    types_mod: &'a syn::Ident,
16}
17
18impl<'a> MethodTransformer<'a> {
19    pub fn new(generator: &'a Generator, types_mod: &'a syn::Ident) -> Self {
20        Self {
21            generator,
22            types_mod,
23        }
24    }
25
26    /// Transform a trait method based on its operation.
27    pub fn transform(&self, method: &syn::TraitItemFn, op: &Operation) -> syn::Result<TokenStream> {
28        let op_name = op.name();
29
30        let types_mod = self.types_mod;
31        let overrides = self.generator.type_overrides();
32
33        let suffixes = self.generator.response_suffixes();
34
35        // Get Ok type (may be renamed or replaced)
36        let ok_type: TokenStream = match overrides.get(op.method, &op.path, GeneratedTypeKind::Ok) {
37            Some(TypeOverride::Rename { name, .. }) => {
38                let ident = format_ident!("{}", name);
39                quote! { #types_mod::#ident }
40            }
41            Some(TypeOverride::Replace(replacement)) => replacement.clone(),
42            None => {
43                let ident = format_ident!("{}{}", op_name, suffixes.ok_suffix);
44                quote! { #types_mod::#ident }
45            }
46        };
47
48        // Transform each parameter, filling in type elisions
49        let transformed_params = self.transform_params(method, op)?;
50
51        // Determine if async and build return type
52        let is_async = method.sig.asyncness.is_some();
53
54        // Only wrap in Result if the operation has error responses defined
55        let return_type = if op.has_error_responses() {
56            // Get Err type (may be renamed or replaced)
57            let err_type: TokenStream =
58                match overrides.get(op.method, &op.path, GeneratedTypeKind::Err) {
59                    Some(TypeOverride::Rename { name, .. }) => {
60                        let ident = format_ident!("{}", name);
61                        quote! { #types_mod::#ident }
62                    }
63                    Some(TypeOverride::Replace(replacement)) => replacement.clone(),
64                    None => {
65                        let ident = format_ident!("{}{}", op_name, suffixes.err_suffix);
66                        quote! { #types_mod::#ident }
67                    }
68                };
69            quote! { ::core::result::Result<#ok_type, #err_type> }
70        } else {
71            // No errors in spec, return Ok type directly
72            ok_type.clone()
73        };
74
75        let method_name = &method.sig.ident;
76
77        // Build the signature
78        if is_async {
79            // Rewrite async fn to fn -> impl Send + Future<Output = ...>
80            Ok(quote! {
81                fn #method_name(#transformed_params) -> impl ::core::marker::Send + ::core::future::Future<Output = #return_type>;
82            })
83        } else {
84            // Sync function
85            Ok(quote! {
86                fn #method_name(#transformed_params) -> #return_type;
87            })
88        }
89    }
90
91    /// Transform the parameters, filling in type elisions.
92    fn transform_params(
93        &self,
94        method: &syn::TraitItemFn,
95        op: &Operation,
96    ) -> syn::Result<TokenStream> {
97        let type_gen = self.generator.type_generator();
98
99        let mut transformed = Vec::new();
100
101        for arg in &method.sig.inputs {
102            match arg {
103                FnArg::Receiver(_) => {
104                    return Err(syn::Error::new_spanned(
105                        arg,
106                        "oxapi trait methods must be static (no self)",
107                    ));
108                }
109                FnArg::Typed(pat_type) => {
110                    let pat = &pat_type.pat;
111                    let ty = &pat_type.ty;
112
113                    // Check if this is an extractor with type elision
114                    let transformed_ty = self.transform_type(ty, op, type_gen)?;
115
116                    transformed.push(quote! { #pat: #transformed_ty });
117                }
118            }
119        }
120
121        Ok(quote! { #(#transformed),* })
122    }
123
124    /// Transform a type, filling in elisions like `Path<_>`.
125    fn transform_type(
126        &self,
127        ty: &Type,
128        op: &Operation,
129        type_gen: &TypeGenerator,
130    ) -> syn::Result<TokenStream> {
131        let types_mod = self.types_mod;
132
133        match ty {
134            Type::Path(type_path) => {
135                let last_segment = type_path
136                    .path
137                    .segments
138                    .last()
139                    .ok_or_else(|| syn::Error::new_spanned(ty, "empty type path"))?;
140
141                let type_name = last_segment.ident.to_string();
142
143                match type_name.as_str() {
144                    "Path" => {
145                        // Fill in path parameter type
146                        let inner = self.get_or_infer_inner(&last_segment.arguments, || {
147                            type_gen.generate_path_type(op)
148                        })?;
149                        Ok(quote! { ::axum::extract::Path<#inner> })
150                    }
151                    "Query" => {
152                        // Fill in query parameter type
153                        let overrides = self.generator.type_overrides();
154                        let inner = self.get_or_infer_inner(&last_segment.arguments, || {
155                            // Check for replacement first
156                            if let Some(TypeOverride::Replace(replacement)) =
157                                overrides.get(op.method, &op.path, GeneratedTypeKind::Query)
158                            {
159                                return replacement.clone();
160                            }
161
162                            // Generate a query struct if needed (may be renamed)
163                            if let Some((name, _)) = type_gen.generate_query_struct(op, overrides) {
164                                quote! { #types_mod::#name }
165                            } else {
166                                quote! { () }
167                            }
168                        })?;
169                        Ok(quote! { ::axum::extract::Query<#inner> })
170                    }
171                    "Json" => {
172                        // Fill in request body type
173                        let inner = self.get_or_infer_inner(&last_segment.arguments, || {
174                            if let Some(body) = &op.request_body {
175                                if let Some(schema) = &body.schema {
176                                    let op_name = op.operation_id.as_deref().unwrap_or(&op.path);
177                                    let body_type = type_gen.request_body_type(body, op_name);
178
179                                    // Check if this type needs to be prefixed with the types module.
180                                    // References to component schemas and inline objects need prefixing
181                                    // because their types are generated in the types module.
182                                    let needs_prefix = match schema {
183                                        ReferenceOr::Reference { .. } => true,
184                                        ReferenceOr::Item(inline) => matches!(
185                                            &inline.schema_kind,
186                                            SchemaKind::Type(OapiType::Object(_))
187                                        ),
188                                    };
189
190                                    if needs_prefix {
191                                        quote! { #types_mod::#body_type }
192                                    } else {
193                                        body_type
194                                    }
195                                } else {
196                                    quote! { serde_json::Value }
197                                }
198                            } else {
199                                quote! { serde_json::Value }
200                            }
201                        })?;
202                        Ok(quote! { ::axum::extract::Json<#inner> })
203                    }
204                    "State" => {
205                        // State type is user-provided, don't modify
206                        Ok(quote! { #ty })
207                    }
208                    _ => {
209                        // Other types pass through unchanged
210                        Ok(quote! { #ty })
211                    }
212                }
213            }
214            _ => {
215                // Non-path types pass through unchanged
216                Ok(quote! { #ty })
217            }
218        }
219    }
220
221    /// Get the inner type from generics, or infer it if it's `_`.
222    fn get_or_infer_inner<F>(&self, args: &PathArguments, infer: F) -> syn::Result<TokenStream>
223    where
224        F: FnOnce() -> TokenStream,
225    {
226        match args {
227            PathArguments::None => {
228                // No generics, infer the type
229                Ok(infer())
230            }
231            PathArguments::AngleBracketed(args) => {
232                if let Some(GenericArgument::Type(Type::Infer(_))) = args.args.first() {
233                    // Type is `_`, infer it
234                    Ok(infer())
235                } else if let Some(GenericArgument::Type(ty)) = args.args.first() {
236                    // Type is explicit, use it
237                    Ok(quote! { #ty })
238                } else {
239                    Err(syn::Error::new_spanned(args, "expected type argument"))
240                }
241            }
242            PathArguments::Parenthesized(_) => Err(syn::Error::new_spanned(
243                args,
244                "unexpected parenthesized arguments",
245            )),
246        }
247    }
248}