Skip to main content

oxapi_impl/
method.rs

1//! Method transformation for trait methods.
2
3use openapiv3::{ReferenceOr, SchemaKind, Type as OapiType};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6use syn::{FnArg, GenericArgument, PathArguments, Type};
7
8use crate::openapi::Operation;
9use crate::types::TypeGenerator;
10use crate::{GeneratedTypeKind, Generator, TypeOverride};
11
12/// Role of a parameter in a handler function.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ParamRole {
15    /// Path parameters (e.g., `/items/{id}`)
16    Path,
17    /// Query parameters (e.g., `?foo=bar`)
18    /// The optional `unknown_field` specifies a field name for capturing unknown query params
19    /// via `#[serde(flatten)]` with a `HashMap<String, String>`.
20    Query {
21        unknown_field: Option<proc_macro2::Ident>,
22    },
23    /// Request body
24    Body,
25    /// Other extractors (State, custom extractors, etc.)
26    Other,
27}
28
29/// Transforms user-defined trait methods based on OpenAPI operations.
30pub struct MethodTransformer<'a> {
31    generator: &'a Generator,
32    types_mod: &'a syn::Ident,
33}
34
35impl<'a> MethodTransformer<'a> {
36    pub fn new(generator: &'a Generator, types_mod: &'a syn::Ident) -> Self {
37        Self {
38            generator,
39            types_mod,
40        }
41    }
42
43    /// Resolve a response type (Ok or Err) for an operation, applying any overrides.
44    fn resolve_response_type(
45        &self,
46        op: &Operation,
47        kind: GeneratedTypeKind,
48        default_suffix: &str,
49    ) -> TokenStream {
50        let types_mod = self.types_mod;
51        let overrides = self.generator.type_overrides();
52        let op_name = op.name();
53
54        match overrides.get(op.method, &op.path, kind) {
55            Some(TypeOverride::Rename { name, .. }) => {
56                let ident = format_ident!("{}", name);
57                quote! { #types_mod::#ident }
58            }
59            Some(TypeOverride::Replace(replacement)) => replacement.clone(),
60            None => {
61                let ident = format_ident!("{}{}", op_name, default_suffix);
62                quote! { #types_mod::#ident }
63            }
64        }
65    }
66
67    /// Transform a trait method based on its operation.
68    ///
69    /// The `param_roles` argument contains the role for each parameter (if any explicit
70    /// `#[oxapi(path)]`, `#[oxapi(query)]`, or `#[oxapi(body)]` attribute was found).
71    /// If `None`, inference is used based on HTTP method and parameter position.
72    pub fn transform(
73        &self,
74        method: &syn::TraitItemFn,
75        op: &Operation,
76        param_roles: &[Option<ParamRole>],
77    ) -> syn::Result<TokenStream> {
78        let suffixes = self.generator.response_suffixes();
79
80        // Get Ok type (may be renamed or replaced)
81        let ok_type = self.resolve_response_type(op, GeneratedTypeKind::Ok, &suffixes.ok_suffix);
82
83        // Transform each parameter, filling in type elisions
84        let transformed_params = self.transform_params(method, op, param_roles)?;
85
86        // Determine if async and build return type
87        let is_async = method.sig.asyncness.is_some();
88
89        // Only wrap in Result if the operation has error responses defined
90        let return_type = if op.has_error_responses() {
91            let err_type =
92                self.resolve_response_type(op, GeneratedTypeKind::Err, &suffixes.err_suffix);
93            quote! { ::core::result::Result<#ok_type, #err_type> }
94        } else {
95            // No errors in spec, return Ok type directly
96            ok_type.clone()
97        };
98
99        let method_name = &method.sig.ident;
100
101        // Build the signature
102        if is_async {
103            // Rewrite async fn to fn -> impl Send + Future<Output = ...>
104            Ok(quote! {
105                fn #method_name(#transformed_params) -> impl ::core::marker::Send + ::core::future::Future<Output = #return_type>;
106            })
107        } else {
108            // Sync function
109            Ok(quote! {
110                fn #method_name(#transformed_params) -> #return_type;
111            })
112        }
113    }
114
115    /// Transform the parameters, filling in type elisions.
116    ///
117    /// The `param_roles` argument contains the role for each parameter that has an
118    /// explicit `#[oxapi(...)]` attribute.
119    ///
120    /// **Important**: If ANY parameter has an explicit role attribute, inference is
121    /// disabled for ALL parameters. This means you must either:
122    /// - Use no explicit attrs (rely entirely on type name detection), OR
123    /// - Use explicit attrs on ALL parameters that need type elision
124    ///
125    /// This allows non-standard use cases like adding a body to a GET request.
126    fn transform_params(
127        &self,
128        method: &syn::TraitItemFn,
129        op: &Operation,
130        param_roles: &[Option<ParamRole>],
131    ) -> syn::Result<TokenStream> {
132        let type_gen = self.generator.type_generator();
133
134        let mut all_params: Vec<&syn::PatType> = Vec::new();
135
136        for arg in &method.sig.inputs {
137            match arg {
138                FnArg::Receiver(_) => {
139                    return Err(syn::Error::new_spanned(
140                        arg,
141                        "oxapi trait methods must be static (no self)",
142                    ));
143                }
144                FnArg::Typed(pat_type) => {
145                    all_params.push(pat_type);
146                }
147            }
148        }
149
150        // Check if any parameter has an explicit role attribute.
151        // If so, inference is disabled for ALL parameters - only explicit attrs are used.
152        let has_any_explicit = param_roles.iter().any(|r| r.is_some());
153
154        // Compute roles for all parameters
155        let roles: Vec<ParamRole> = all_params
156            .iter()
157            .enumerate()
158            .map(|(idx, pat_type)| {
159                // Explicit attr always applies
160                if let Some(Some(role)) = param_roles.get(idx) {
161                    return role.clone();
162                }
163                // If any explicit attr is present, don't infer - use Other
164                if has_any_explicit {
165                    return ParamRole::Other;
166                }
167                // Otherwise detect from type name
168                detect_role_from_type(&pat_type.ty)
169            })
170            .collect();
171
172        // Transform each parameter
173        let mut transformed = Vec::new();
174        for (idx, pat_type) in all_params.iter().enumerate() {
175            let pat = &pat_type.pat;
176            let ty = &pat_type.ty;
177            let role = roles.get(idx).cloned().unwrap_or(ParamRole::Other);
178
179            let transformed_ty = self.transform_type_with_role(ty, op, type_gen, role)?;
180            transformed.push(quote! { #pat: #transformed_ty });
181        }
182
183        Ok(quote! { #(#transformed),* })
184    }
185
186    /// Transform a type based on its role, preserving the user's extractor type.
187    ///
188    /// This method only replaces the `_` in a type like `MyExtractor<_>` with the
189    /// inferred type. It does NOT replace the extractor itself with a fully-qualified path.
190    fn transform_type_with_role(
191        &self,
192        ty: &Type,
193        op: &Operation,
194        type_gen: &TypeGenerator,
195        role: ParamRole,
196    ) -> syn::Result<TokenStream> {
197        match ty {
198            Type::Path(type_path) => {
199                // Check if this type has elision that needs to be filled
200                if !has_type_elision(ty) {
201                    // No elision, pass through unchanged
202                    return Ok(quote! { #ty });
203                }
204
205                // Get all segments except the last one (the prefix path)
206                let prefix_segments: Vec<_> = type_path
207                    .path
208                    .segments
209                    .iter()
210                    .take(type_path.path.segments.len().saturating_sub(1))
211                    .collect();
212
213                let last_segment = type_path
214                    .path
215                    .segments
216                    .last()
217                    .ok_or_else(|| syn::Error::new_spanned(ty, "empty type path"))?;
218
219                // Infer the inner type based on the role
220                let inner_type = match &role {
221                    ParamRole::Path => self.generate_path_inner_type(op, type_gen),
222                    ParamRole::Query { unknown_field } => {
223                        self.generate_query_inner_type(op, type_gen, unknown_field.clone())
224                    }
225                    ParamRole::Body => self.generate_body_inner_type(op, type_gen),
226                    ParamRole::Other => {
227                        // Other role: pass through unchanged (State, custom extractors, etc.)
228                        return Ok(quote! { #ty });
229                    }
230                };
231
232                // Reconstruct the type with the inferred inner type
233                let type_ident = &last_segment.ident;
234                let leading_colon = type_path.path.leading_colon;
235
236                if prefix_segments.is_empty() {
237                    // Simple type like `Path<_>` or `MyExtractor<_>`
238                    Ok(quote! { #leading_colon #type_ident<#inner_type> })
239                } else {
240                    // Qualified path like `axum::extract::Path<_>`
241                    Ok(quote! { #leading_colon #(#prefix_segments ::)* #type_ident<#inner_type> })
242                }
243            }
244            _ => {
245                // Non-path types pass through unchanged
246                Ok(quote! { #ty })
247            }
248        }
249    }
250
251    /// Generate the inner type for a path parameter.
252    fn generate_path_inner_type(&self, op: &Operation, type_gen: &TypeGenerator) -> TokenStream {
253        let types_mod = self.types_mod;
254        let overrides = self.generator.type_overrides();
255
256        // Check for replacement first
257        if let Some(TypeOverride::Replace(replacement)) =
258            overrides.get(op.method, &op.path, GeneratedTypeKind::Path)
259        {
260            return replacement.clone();
261        }
262
263        // Generate a path struct if the operation has path params
264        if let Some((name, _)) = type_gen.generate_path_struct(op, overrides) {
265            quote! { #types_mod::#name }
266        } else {
267            // No path params - return unit type
268            quote! { () }
269        }
270    }
271
272    /// Generate the inner type for a query parameter.
273    ///
274    /// If `unknown_field` is provided, the generated query struct will include an extra
275    /// `#[serde(flatten)] pub <name>: HashMap<String, String>` field.
276    fn generate_query_inner_type(
277        &self,
278        op: &Operation,
279        type_gen: &TypeGenerator,
280        unknown_field: Option<proc_macro2::Ident>,
281    ) -> TokenStream {
282        let types_mod = self.types_mod;
283        let overrides = self.generator.type_overrides();
284
285        // Check for replacement first
286        if let Some(TypeOverride::Replace(replacement)) =
287            overrides.get(op.method, &op.path, GeneratedTypeKind::Query)
288        {
289            return replacement.clone();
290        }
291
292        // Generate a query struct if the operation has query params
293        if let Some((name, _)) =
294            type_gen.generate_query_struct(op, overrides, unknown_field.as_ref())
295        {
296            quote! { #types_mod::#name }
297        } else {
298            // Fallback to HashMap<String, String> for operations without query params
299            quote! { ::std::collections::HashMap<String, String> }
300        }
301    }
302
303    /// Generate the inner type for a request body.
304    fn generate_body_inner_type(&self, op: &Operation, type_gen: &TypeGenerator) -> TokenStream {
305        let types_mod = self.types_mod;
306
307        if let Some(body) = &op.request_body {
308            if let Some(schema) = &body.schema {
309                let op_name = op.operation_id.as_deref().unwrap_or(&op.path);
310                let body_type = type_gen.request_body_type(body, op_name);
311
312                // Check if this type needs to be prefixed with the types module.
313                let needs_prefix = match schema {
314                    ReferenceOr::Reference { .. } => true,
315                    ReferenceOr::Item(inline) => {
316                        matches!(&inline.schema_kind, SchemaKind::Type(OapiType::Object(_)))
317                    }
318                };
319
320                if needs_prefix {
321                    quote! { #types_mod::#body_type }
322                } else {
323                    body_type
324                }
325            } else {
326                quote! { serde_json::Value }
327            }
328        } else {
329            quote! { serde_json::Value }
330        }
331    }
332}
333
334/// Check if a type contains type elision (`_`).
335fn has_type_elision(ty: &Type) -> bool {
336    match ty {
337        Type::Infer(_) => true,
338        Type::Path(type_path) => {
339            // Check the generic arguments of the last segment
340            if let Some(last) = type_path.path.segments.last()
341                && let PathArguments::AngleBracketed(args) = &last.arguments
342            {
343                return args
344                    .args
345                    .iter()
346                    .any(|arg| matches!(arg, GenericArgument::Type(Type::Infer(_))));
347            }
348            false
349        }
350        _ => false,
351    }
352}
353
354/// Detect the role from the extractor type name.
355/// Recognizes: Path, Query, Json (as simple name or full path like axum::extract::Path).
356fn detect_role_from_type(ty: &Type) -> ParamRole {
357    if let Type::Path(type_path) = ty
358        && let Some(last) = type_path.path.segments.last()
359    {
360        let name = last.ident.to_string();
361        match name.as_str() {
362            "Path" => return ParamRole::Path,
363            "Query" => {
364                return ParamRole::Query {
365                    unknown_field: None,
366                };
367            }
368            "Json" => return ParamRole::Body,
369            _ => {}
370        }
371    }
372    ParamRole::Other
373}