Skip to main content

serverless_fn_macro/
lib.rs

1//! Procedural macros for the serverless-fn crate.
2//!
3//! This crate provides the `#[serverless]` attribute macro for marking
4//! functions as serverless functions.
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{FnArg, Pat, PatType, TypePath};
9
10/// Attribute macro for marking a function as a serverless function.
11///
12/// This macro generates:
13/// - The original function implementation
14/// - A client stub for remote calls (when `remote_call` feature is enabled)
15/// - A wrapper function for local calls (when `local_call` feature is enabled)
16/// - Server registration code for automatic function discovery
17///
18/// # Example
19///
20/// ```rust,ignore
21/// use serverless_fn::{serverless, ServerlessError};
22///
23/// #[serverless]
24/// pub async fn hello(name: String) -> Result<String, ServerlessError> {
25///     Ok(format!("Hello, {}!", name))
26/// }
27/// ```
28#[proc_macro_attribute]
29pub fn serverless(_args: TokenStream, input: TokenStream) -> TokenStream {
30    let input_fn = syn::parse_macro_input!(input as syn::ItemFn);
31    let sig = &input_fn.sig;
32    let vis = &input_fn.vis;
33    let block = &input_fn.block;
34    let attrs = &input_fn.attrs;
35
36    let fn_name = &sig.ident;
37    let fn_name_str = fn_name.to_string();
38
39    // Extract function arguments
40    let args: Vec<(syn::Ident, syn::Type)> = sig
41        .inputs
42        .iter()
43        .filter_map(|arg| {
44            if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
45                if let Pat::Ident(pat_ident) = pat.as_ref() {
46                    Some((pat_ident.ident.clone(), ty.as_ref().clone()))
47                } else {
48                    None
49                }
50            } else {
51                None
52            }
53        })
54        .collect();
55
56    let arg_names: Vec<&syn::Ident> = args.iter().map(|(name, _)| name).collect();
57    let arg_types: Vec<&syn::Type> = args.iter().map(|(_, ty)| ty).collect();
58
59    // Generate input struct name
60    let input_struct_name = syn::Ident::new(&format!("__{}_input", fn_name), fn_name.span());
61
62    // Get return type and success type (Ok type in Result)
63    let (return_type, success_type) = match &sig.output {
64        syn::ReturnType::Type(_, ty) => {
65            let return_ty = ty.as_ref().clone();
66            let success_ty = extract_ok_type(&return_ty).unwrap_or_else(|| return_ty.clone());
67            (return_ty, success_ty)
68        }
69        syn::ReturnType::Default => {
70            let unit_ty: syn::Type = syn::parse_str("()").expect("Failed to parse unit type");
71            (unit_ty.clone(), unit_ty)
72        }
73    };
74
75    // Server implementation: rename original function
76    let server_fn_name = syn::Ident::new(&format!("__{}_impl", fn_name), fn_name.span());
77
78    // Generate registration code
79    let registrar_struct_name =
80        syn::Ident::new(&format!("__{}_registrar", fn_name), fn_name.span());
81    let path_static = syn::Ident::new(
82        &format!("__{}_PATH", fn_name.to_string().to_uppercase()),
83        fn_name.span(),
84    );
85
86    let config = RegistrarConfig {
87        registrar_struct_name: &registrar_struct_name,
88        path_static: &path_static,
89        fn_name_str: &fn_name_str,
90        server_fn_name: &server_fn_name,
91        input_struct_name: &input_struct_name,
92        arg_names: &arg_names,
93        arg_types: &arg_types,
94        _success_type: &success_type,
95    };
96
97    let registrar_impl = generate_registrar_impl(&config);
98
99    let output = quote! {
100        // Server implementation (internal function with renamed name)
101        #(#attrs)*
102        #vis async fn #server_fn_name(#(#arg_names: #arg_types),*) -> #return_type #block
103
104        // Client stub for remote calls
105        #[cfg(all(feature = "remote_call", not(feature = "local_call")))]
106        #(#attrs)*
107        #vis async fn #fn_name(#(#arg_names: #arg_types),*) -> #return_type {
108            use serverless_fn::transport::{get_default_transport, Transport};
109            use serverless_fn::serializer::{get_default_serializer, Serializer};
110            use serverless_fn::error::ServerlessError;
111            use serverless_fn::config::Config;
112            use serde::{Serialize, Deserialize};
113
114            #[allow(non_camel_case_types, missing_docs)]
115            #[derive(Serialize, Deserialize)]
116            struct #input_struct_name {
117                #(pub #arg_names: #arg_types,)*
118            }
119
120            let input = #input_struct_name {
121                #(#arg_names,)*
122            };
123
124            let config = Config::from_env();
125            let serializer = get_default_serializer();
126            let serialized_input = serializer.serialize(&input).map_err(ServerlessError::from)?;
127
128            let transport = get_default_transport(config.timeout(), config.retries());
129            let response_bytes = transport.call(#fn_name_str, serialized_input, None).await
130                .map_err(|e| ServerlessError::RemoteExecution(e.to_string()))?;
131
132            let output: #success_type = get_default_serializer()
133                .deserialize(&response_bytes)
134                .map_err(ServerlessError::from)?;
135
136            Ok(output)
137        }
138
139        // Wrapper function for local calls
140        #[cfg(any(feature = "local_call", not(feature = "remote_call")))]
141        #(#attrs)*
142        #vis async fn #fn_name(#(#arg_names: #arg_types),*) -> #return_type {
143            #server_fn_name(#(#arg_names),*).await
144        }
145
146        #registrar_impl
147    };
148
149    output.into()
150}
151
152/// Extracts the Ok type from a Result type.
153fn extract_ok_type(ty: &syn::Type) -> Option<syn::Type> {
154    if let syn::Type::Path(TypePath { path, .. }) = ty
155        && let Some(segment) = path.segments.first()
156        && segment.ident == "Result"
157        && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
158        && let Some(syn::GenericArgument::Type(ty)) = args.args.first()
159    {
160        return Some(ty.clone());
161    }
162    None
163}
164
165/// Configuration for generating registrar implementation.
166struct RegistrarConfig<'a> {
167    registrar_struct_name: &'a syn::Ident,
168    path_static: &'a syn::Ident,
169    fn_name_str: &'a str,
170    server_fn_name: &'a syn::Ident,
171    input_struct_name: &'a syn::Ident,
172    arg_names: &'a [&'a syn::Ident],
173    arg_types: &'a [&'a syn::Type],
174    _success_type: &'a syn::Type,
175}
176
177/// Generates the registrar implementation for server registration.
178fn generate_registrar_impl(config: &RegistrarConfig<'_>) -> proc_macro2::TokenStream {
179    let RegistrarConfig {
180        registrar_struct_name,
181        path_static,
182        fn_name_str,
183        server_fn_name,
184        input_struct_name,
185        arg_names,
186        arg_types,
187        _success_type,
188    } = config;
189
190    quote! {
191        // Static path for this function
192        static #path_static: &str = ::std::concat!("/", #fn_name_str);
193
194        // Define a unique registrar struct for this function
195        #[allow(non_camel_case_types, missing_docs)]
196        struct #registrar_struct_name;
197
198        impl ::serverless_fn::server::FunctionRegistry for #registrar_struct_name {
199            fn function_name(&self) -> &'static str {
200                #fn_name_str
201            }
202
203            fn function_path(&self) -> &'static str {
204                #path_static
205            }
206
207            fn register(&self, server: &mut ::serverless_fn::server::FunctionServer) {
208                use serde::{Deserialize, Serialize};
209                use serverless_fn::serializer::get_default_serializer;
210                use serverless_fn::error::ServerlessError;
211                use axum::extract::Json;
212                use axum::http::StatusCode;
213                use axum::body::Bytes;
214
215                #[derive(Serialize, Deserialize)]
216                #[allow(non_camel_case_types, missing_docs)]
217                struct #input_struct_name {
218                    #(pub #arg_names: #arg_types,)*
219                }
220
221                server.register_http_route(#path_static, move |body: Bytes| async move {
222                    let serializer = get_default_serializer();
223                    let input: #input_struct_name = serializer.deserialize(
224                        &body.to_vec()
225                    ).map_err(|e| {
226                        (StatusCode::BAD_REQUEST, e.to_string())
227                    })?;
228
229                    let result = #server_fn_name(#(input.#arg_names),*).await;
230
231                    match result {
232                        Ok(value) => {
233                            let response_bytes = serializer.serialize(&value)
234                                .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
235                            Ok::<_, (StatusCode, String)>(axum::response::Response::builder()
236                                .status(StatusCode::OK)
237                                .header("content-type", "application/octet-stream")
238                                .body(axum::body::Body::from(response_bytes))
239                                .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?)
240                        }
241                        Err(e) => Err((
242                            StatusCode::INTERNAL_SERVER_ERROR,
243                            e.to_string()
244                        )),
245                    }
246                });
247            }
248        }
249
250        // Submit the registrar to inventory
251        ::inventory::submit! {
252            &#registrar_struct_name as &'static dyn ::serverless_fn::server::FunctionRegistry
253        }
254    }
255}