Skip to main content

oxapi_impl/
method.rs

1//! Method transformation for trait methods.
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6use syn::{FnArg, GenericArgument, PathArguments, Type};
7
8use crate::Generator;
9use crate::openapi::Operation;
10use crate::types::TypeGenerator;
11
12/// Transforms user-defined trait methods based on OpenAPI operations.
13pub struct MethodTransformer<'a> {
14    generator: &'a Generator,
15    types_mod: &'a syn::Ident,
16}
17
18impl<'a> MethodTransformer<'a> {
19    pub fn new(generator: &'a Generator, types_mod: &'a syn::Ident) -> Self {
20        Self {
21            generator,
22            types_mod,
23        }
24    }
25
26    /// Transform a trait method based on its operation.
27    pub fn transform(&self, method: &syn::TraitItemFn, op: &Operation) -> syn::Result<TokenStream> {
28        let op_name = op
29            .operation_id
30            .as_deref()
31            .unwrap_or(&op.path)
32            .to_upper_camel_case();
33
34        let types_mod = self.types_mod;
35        let ok_type = format_ident!("{}Ok", op_name);
36        let err_type = format_ident!("{}Err", op_name);
37
38        // Transform each parameter, filling in type elisions
39        let transformed_params = self.transform_params(method, op)?;
40
41        // Determine if async and build return type
42        let is_async = method.sig.asyncness.is_some();
43
44        let return_type = quote! {
45            ::core::result::Result<#types_mod::#ok_type, #types_mod::#err_type>
46        };
47
48        let method_name = &method.sig.ident;
49
50        // Build the signature
51        if is_async {
52            // Rewrite async fn to fn -> impl Send + Future<Output = ...>
53            Ok(quote! {
54                fn #method_name(#transformed_params) -> impl ::core::marker::Send + ::core::future::Future<Output = #return_type>;
55            })
56        } else {
57            // Sync function
58            Ok(quote! {
59                fn #method_name(#transformed_params) -> #return_type;
60            })
61        }
62    }
63
64    /// Transform the parameters, filling in type elisions.
65    fn transform_params(
66        &self,
67        method: &syn::TraitItemFn,
68        op: &Operation,
69    ) -> syn::Result<TokenStream> {
70        let type_gen = self.generator.type_generator();
71
72        let mut transformed = Vec::new();
73
74        for arg in &method.sig.inputs {
75            match arg {
76                FnArg::Receiver(_) => {
77                    return Err(syn::Error::new_spanned(
78                        arg,
79                        "oxapi trait methods must be static (no self)",
80                    ));
81                }
82                FnArg::Typed(pat_type) => {
83                    let pat = &pat_type.pat;
84                    let ty = &pat_type.ty;
85
86                    // Check if this is an extractor with type elision
87                    let transformed_ty = self.transform_type(ty, op, type_gen)?;
88
89                    transformed.push(quote! { #pat: #transformed_ty });
90                }
91            }
92        }
93
94        Ok(quote! { #(#transformed),* })
95    }
96
97    /// Transform a type, filling in elisions like `Path<_>`.
98    fn transform_type(
99        &self,
100        ty: &Type,
101        op: &Operation,
102        type_gen: &TypeGenerator,
103    ) -> syn::Result<TokenStream> {
104        let types_mod = self.types_mod;
105
106        match ty {
107            Type::Path(type_path) => {
108                let last_segment = type_path
109                    .path
110                    .segments
111                    .last()
112                    .ok_or_else(|| syn::Error::new_spanned(ty, "empty type path"))?;
113
114                let type_name = last_segment.ident.to_string();
115
116                match type_name.as_str() {
117                    "Path" => {
118                        // Fill in path parameter type
119                        let inner = self.get_or_infer_inner(&last_segment.arguments, || {
120                            type_gen.generate_path_type(op)
121                        })?;
122                        Ok(quote! { ::axum::extract::Path<#inner> })
123                    }
124                    "Query" => {
125                        // Fill in query parameter type
126                        let inner = self.get_or_infer_inner(&last_segment.arguments, || {
127                            // Generate a query struct if needed
128                            if let Some((name, _)) = type_gen.generate_query_struct(op) {
129                                quote! { #types_mod::#name }
130                            } else {
131                                quote! { () }
132                            }
133                        })?;
134                        Ok(quote! { ::axum::extract::Query<#inner> })
135                    }
136                    "Json" => {
137                        // Fill in request body type
138                        let inner = self.get_or_infer_inner(&last_segment.arguments, || {
139                            if let Some(body) = &op.request_body {
140                                let op_name = op.operation_id.as_deref().unwrap_or(&op.path);
141                                type_gen.request_body_type(body, op_name)
142                            } else {
143                                quote! { serde_json::Value }
144                            }
145                        })?;
146                        Ok(quote! { ::axum::extract::Json<#inner> })
147                    }
148                    "State" => {
149                        // State type is user-provided, don't modify
150                        Ok(quote! { #ty })
151                    }
152                    _ => {
153                        // Other types pass through unchanged
154                        Ok(quote! { #ty })
155                    }
156                }
157            }
158            _ => {
159                // Non-path types pass through unchanged
160                Ok(quote! { #ty })
161            }
162        }
163    }
164
165    /// Get the inner type from generics, or infer it if it's `_`.
166    fn get_or_infer_inner<F>(&self, args: &PathArguments, infer: F) -> syn::Result<TokenStream>
167    where
168        F: FnOnce() -> TokenStream,
169    {
170        match args {
171            PathArguments::None => {
172                // No generics, infer the type
173                Ok(infer())
174            }
175            PathArguments::AngleBracketed(args) => {
176                if let Some(GenericArgument::Type(Type::Infer(_))) = args.args.first() {
177                    // Type is `_`, infer it
178                    Ok(infer())
179                } else if let Some(GenericArgument::Type(ty)) = args.args.first() {
180                    // Type is explicit, use it
181                    Ok(quote! { #ty })
182                } else {
183                    Err(syn::Error::new_spanned(args, "expected type argument"))
184                }
185            }
186            PathArguments::Parenthesized(_) => Err(syn::Error::new_spanned(
187                args,
188                "unexpected parenthesized arguments",
189            )),
190        }
191    }
192}