Skip to main content

oxapi_impl/
method.rs

1//! Method transformation for trait methods.
2
3use openapiv3::{ReferenceOr, SchemaKind, Type as OapiType};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6use syn::{FnArg, GenericArgument, PathArguments, Type};
7
8use crate::openapi::Operation;
9use crate::types::TypeGenerator;
10use crate::{GeneratedTypeKind, Generator, TypeOverride};
11
12/// Role of a parameter in a handler function.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum ParamRole {
15    /// Path parameters (e.g., `/items/{id}`)
16    Path,
17    /// Query parameters (e.g., `?foo=bar`)
18    Query,
19    /// Request body
20    Body,
21    /// Other extractors (State, custom extractors, etc.)
22    Other,
23}
24
25/// Transforms user-defined trait methods based on OpenAPI operations.
26pub struct MethodTransformer<'a> {
27    generator: &'a Generator,
28    types_mod: &'a syn::Ident,
29}
30
31impl<'a> MethodTransformer<'a> {
32    pub fn new(generator: &'a Generator, types_mod: &'a syn::Ident) -> Self {
33        Self {
34            generator,
35            types_mod,
36        }
37    }
38
39    /// Resolve a response type (Ok or Err) for an operation, applying any overrides.
40    fn resolve_response_type(
41        &self,
42        op: &Operation,
43        kind: GeneratedTypeKind,
44        default_suffix: &str,
45    ) -> TokenStream {
46        let types_mod = self.types_mod;
47        let overrides = self.generator.type_overrides();
48        let op_name = op.name();
49
50        match overrides.get(op.method, &op.path, kind) {
51            Some(TypeOverride::Rename { name, .. }) => {
52                let ident = format_ident!("{}", name);
53                quote! { #types_mod::#ident }
54            }
55            Some(TypeOverride::Replace(replacement)) => replacement.clone(),
56            None => {
57                let ident = format_ident!("{}{}", op_name, default_suffix);
58                quote! { #types_mod::#ident }
59            }
60        }
61    }
62
63    /// Transform a trait method based on its operation.
64    ///
65    /// The `param_roles` argument contains the role for each parameter (if any explicit
66    /// `#[oxapi(path)]`, `#[oxapi(query)]`, or `#[oxapi(body)]` attribute was found).
67    /// If `None`, inference is used based on HTTP method and parameter position.
68    pub fn transform(
69        &self,
70        method: &syn::TraitItemFn,
71        op: &Operation,
72        param_roles: &[Option<ParamRole>],
73    ) -> syn::Result<TokenStream> {
74        let suffixes = self.generator.response_suffixes();
75
76        // Get Ok type (may be renamed or replaced)
77        let ok_type = self.resolve_response_type(op, GeneratedTypeKind::Ok, &suffixes.ok_suffix);
78
79        // Transform each parameter, filling in type elisions
80        let transformed_params = self.transform_params(method, op, param_roles)?;
81
82        // Determine if async and build return type
83        let is_async = method.sig.asyncness.is_some();
84
85        // Only wrap in Result if the operation has error responses defined
86        let return_type = if op.has_error_responses() {
87            let err_type =
88                self.resolve_response_type(op, GeneratedTypeKind::Err, &suffixes.err_suffix);
89            quote! { ::core::result::Result<#ok_type, #err_type> }
90        } else {
91            // No errors in spec, return Ok type directly
92            ok_type.clone()
93        };
94
95        let method_name = &method.sig.ident;
96
97        // Build the signature
98        if is_async {
99            // Rewrite async fn to fn -> impl Send + Future<Output = ...>
100            Ok(quote! {
101                fn #method_name(#transformed_params) -> impl ::core::marker::Send + ::core::future::Future<Output = #return_type>;
102            })
103        } else {
104            // Sync function
105            Ok(quote! {
106                fn #method_name(#transformed_params) -> #return_type;
107            })
108        }
109    }
110
111    /// Transform the parameters, filling in type elisions.
112    ///
113    /// The `param_roles` argument contains the role for each parameter that has an
114    /// explicit `#[oxapi(...)]` attribute.
115    ///
116    /// **Important**: If ANY parameter has an explicit role attribute, inference is
117    /// disabled for ALL parameters. This means you must either:
118    /// - Use no explicit attrs (rely entirely on type name detection), OR
119    /// - Use explicit attrs on ALL parameters that need type elision
120    ///
121    /// This allows non-standard use cases like adding a body to a GET request.
122    fn transform_params(
123        &self,
124        method: &syn::TraitItemFn,
125        op: &Operation,
126        param_roles: &[Option<ParamRole>],
127    ) -> syn::Result<TokenStream> {
128        let type_gen = self.generator.type_generator();
129
130        let mut all_params: Vec<&syn::PatType> = Vec::new();
131
132        for arg in &method.sig.inputs {
133            match arg {
134                FnArg::Receiver(_) => {
135                    return Err(syn::Error::new_spanned(
136                        arg,
137                        "oxapi trait methods must be static (no self)",
138                    ));
139                }
140                FnArg::Typed(pat_type) => {
141                    all_params.push(pat_type);
142                }
143            }
144        }
145
146        // Check if any parameter has an explicit role attribute.
147        // If so, inference is disabled for ALL parameters - only explicit attrs are used.
148        let has_any_explicit = param_roles.iter().any(|r| r.is_some());
149
150        // Compute roles for all parameters
151        let roles: Vec<ParamRole> = all_params
152            .iter()
153            .enumerate()
154            .map(|(idx, pat_type)| {
155                // Explicit attr always applies
156                if let Some(Some(role)) = param_roles.get(idx) {
157                    return *role;
158                }
159                // If any explicit attr is present, don't infer - use Other
160                if has_any_explicit {
161                    return ParamRole::Other;
162                }
163                // Otherwise detect from type name
164                detect_role_from_type(&pat_type.ty)
165            })
166            .collect();
167
168        // Transform each parameter
169        let mut transformed = Vec::new();
170        for (idx, pat_type) in all_params.iter().enumerate() {
171            let pat = &pat_type.pat;
172            let ty = &pat_type.ty;
173            let role = roles.get(idx).copied().unwrap_or(ParamRole::Other);
174
175            let transformed_ty = self.transform_type_with_role(ty, op, type_gen, role)?;
176            transformed.push(quote! { #pat: #transformed_ty });
177        }
178
179        Ok(quote! { #(#transformed),* })
180    }
181
182    /// Transform a type based on its role, preserving the user's extractor type.
183    ///
184    /// This method only replaces the `_` in a type like `MyExtractor<_>` with the
185    /// inferred type. It does NOT replace the extractor itself with a fully-qualified path.
186    fn transform_type_with_role(
187        &self,
188        ty: &Type,
189        op: &Operation,
190        type_gen: &TypeGenerator,
191        role: ParamRole,
192    ) -> syn::Result<TokenStream> {
193        match ty {
194            Type::Path(type_path) => {
195                // Check if this type has elision that needs to be filled
196                if !has_type_elision(ty) {
197                    // No elision, pass through unchanged
198                    return Ok(quote! { #ty });
199                }
200
201                // Get all segments except the last one (the prefix path)
202                let prefix_segments: Vec<_> = type_path
203                    .path
204                    .segments
205                    .iter()
206                    .take(type_path.path.segments.len().saturating_sub(1))
207                    .collect();
208
209                let last_segment = type_path
210                    .path
211                    .segments
212                    .last()
213                    .ok_or_else(|| syn::Error::new_spanned(ty, "empty type path"))?;
214
215                // Infer the inner type based on the role
216                let inner_type = match role {
217                    ParamRole::Path => type_gen.generate_path_type(op),
218                    ParamRole::Query => self.generate_query_inner_type(op, type_gen),
219                    ParamRole::Body => self.generate_body_inner_type(op, type_gen),
220                    ParamRole::Other => {
221                        // Other role: pass through unchanged (State, custom extractors, etc.)
222                        return Ok(quote! { #ty });
223                    }
224                };
225
226                // Reconstruct the type with the inferred inner type
227                let type_ident = &last_segment.ident;
228                let leading_colon = type_path.path.leading_colon;
229
230                if prefix_segments.is_empty() {
231                    // Simple type like `Path<_>` or `MyExtractor<_>`
232                    Ok(quote! { #leading_colon #type_ident<#inner_type> })
233                } else {
234                    // Qualified path like `axum::extract::Path<_>`
235                    Ok(quote! { #leading_colon #(#prefix_segments ::)* #type_ident<#inner_type> })
236                }
237            }
238            _ => {
239                // Non-path types pass through unchanged
240                Ok(quote! { #ty })
241            }
242        }
243    }
244
245    /// Generate the inner type for a query parameter.
246    fn generate_query_inner_type(&self, op: &Operation, type_gen: &TypeGenerator) -> TokenStream {
247        let types_mod = self.types_mod;
248        let overrides = self.generator.type_overrides();
249
250        // Check for replacement first
251        if let Some(TypeOverride::Replace(replacement)) =
252            overrides.get(op.method, &op.path, GeneratedTypeKind::Query)
253        {
254            return replacement.clone();
255        }
256
257        // Generate a query struct if the operation has query params
258        if let Some((name, _)) = type_gen.generate_query_struct(op, overrides) {
259            quote! { #types_mod::#name }
260        } else {
261            // Fallback to HashMap<String, String> for operations without query params
262            quote! { ::std::collections::HashMap<String, String> }
263        }
264    }
265
266    /// Generate the inner type for a request body.
267    fn generate_body_inner_type(&self, op: &Operation, type_gen: &TypeGenerator) -> TokenStream {
268        let types_mod = self.types_mod;
269
270        if let Some(body) = &op.request_body {
271            if let Some(schema) = &body.schema {
272                let op_name = op.operation_id.as_deref().unwrap_or(&op.path);
273                let body_type = type_gen.request_body_type(body, op_name);
274
275                // Check if this type needs to be prefixed with the types module.
276                let needs_prefix = match schema {
277                    ReferenceOr::Reference { .. } => true,
278                    ReferenceOr::Item(inline) => {
279                        matches!(&inline.schema_kind, SchemaKind::Type(OapiType::Object(_)))
280                    }
281                };
282
283                if needs_prefix {
284                    quote! { #types_mod::#body_type }
285                } else {
286                    body_type
287                }
288            } else {
289                quote! { serde_json::Value }
290            }
291        } else {
292            quote! { serde_json::Value }
293        }
294    }
295}
296
297/// Check if a type contains type elision (`_`).
298fn has_type_elision(ty: &Type) -> bool {
299    match ty {
300        Type::Infer(_) => true,
301        Type::Path(type_path) => {
302            // Check the generic arguments of the last segment
303            if let Some(last) = type_path.path.segments.last()
304                && let PathArguments::AngleBracketed(args) = &last.arguments
305            {
306                return args
307                    .args
308                    .iter()
309                    .any(|arg| matches!(arg, GenericArgument::Type(Type::Infer(_))));
310            }
311            false
312        }
313        _ => false,
314    }
315}
316
317/// Detect the role from the extractor type name.
318/// Recognizes: Path, Query, Json (as simple name or full path like axum::extract::Path).
319fn detect_role_from_type(ty: &Type) -> ParamRole {
320    if let Type::Path(type_path) = ty
321        && let Some(last) = type_path.path.segments.last()
322    {
323        let name = last.ident.to_string();
324        match name.as_str() {
325            "Path" => return ParamRole::Path,
326            "Query" => return ParamRole::Query,
327            "Json" => return ParamRole::Body,
328            _ => {}
329        }
330    }
331    ParamRole::Other
332}