pipex_macros/
lib.rs

1//! # Pipex Macros
2//! 
3//! Procedural macros for the pipex crate, providing error handling strategies
4//! and pipeline decorators for async and sync functions.
5
6extern crate proc_macro;
7
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::{
11    parse_macro_input, ItemFn, Type, ReturnType, GenericArgument, PathArguments,
12    parse::Parse, parse::ParseStream, Error, Result as SynResult,
13    visit_mut::{self, VisitMut}, Expr, Ident, Lit,
14    spanned::Spanned,
15};
16
17/// Converts a string from snake_case to PascalCase.
18fn to_pascal_case(s: &str) -> String {
19    let mut pascal = String::new();
20    let mut capitalize = true;
21    for c in s.chars() {
22        if c == '_' {
23            capitalize = true;
24        } else if capitalize {
25            pascal.push(c.to_ascii_uppercase());
26            capitalize = false;
27        } else {
28            pascal.push(c);
29        }
30    }
31    pascal
32}
33
34/// A visitor that recursively checks a function for impurity and injects purity checks for function calls.
35struct PurityCheckVisitor {
36    errors: Vec<Error>,
37}
38
39impl VisitMut for PurityCheckVisitor {
40    fn visit_expr_mut(&mut self, i: &mut Expr) {
41        match i {
42            Expr::Unsafe(e) => {
43                self.errors.push(Error::new(
44                    e.span(),
45                    "impure `unsafe` block found in function marked as `pure`",
46                ));
47            }
48            Expr::Macro(e) => {
49                if e.mac.path.is_ident("asm") {
50                    self.errors.push(Error::new(
51                        e.span(),
52                        "impure inline assembly found in function marked as `pure`",
53                    ));
54                }
55            }
56            Expr::MethodCall(e) => {
57                self.errors.push(Error::new(
58                    e.span(), "method calls are not supported in pure functions"
59                ));
60            }
61            Expr::Call(call_expr) => {
62                if let Expr::Path(expr_path) = &*call_expr.func {
63                    let path = &expr_path.path;
64                    if let Some(segment) = path.segments.last() {
65                        if segment.ident == "Ok" || segment.ident == "Err" {
66                            visit_mut::visit_expr_call_mut(self, call_expr);
67                            return;
68                        }
69                    }
70
71                    let mut zst_path = path.clone();
72                    if let Some(last_segment) = zst_path.segments.last_mut() {
73                        let ident_str = last_segment.ident.to_string();
74                        let pascal_case_ident = to_pascal_case(&ident_str);
75                        last_segment.ident = Ident::new(&pascal_case_ident, last_segment.ident.span());
76
77                        // Create a simple compile-time purity check
78                        let _fn_name = path.segments.last().unwrap().ident.to_string();
79                        
80                        // Manually recurse before we replace the node
81                        visit_mut::visit_expr_call_mut(self, call_expr);
82                        
83                        let new_node = syn::parse_quote!({
84                            {
85                                // Compile-time purity check - generates a clear error if function isn't pure
86                                let _ = || {
87                                    fn _assert_pure_function<T: crate::traits::IsPure>(_: T) {}
88                                    _assert_pure_function(#zst_path);
89                                };
90                                #call_expr
91                            }
92                        });
93                        *i = new_node;
94                        return; // Return to avoid visiting the new node and causing infinite recursion
95                    }
96                } else {
97                    self.errors.push(Error::new_spanned(&call_expr.func, "closures and other complex function call expressions are not supported in pure functions"));
98                }
99            }
100            _ => {}
101        }
102        
103        // Default recursion for all other expression types.
104        visit_mut::visit_expr_mut(self, i);
105    }
106}
107
108/// The `pure` attribute macro.
109///
110/// This macro checks if a function is "pure". A function is pure if:
111/// 1. It contains no `unsafe` blocks or inline assembly (`asm!`).
112/// 2. It does not call any methods.
113/// 3. It only calls other functions that are themselves marked `#[pure]`.
114#[proc_macro_attribute]
115pub fn pure(_args: TokenStream, item: TokenStream) -> TokenStream {
116    let mut input_fn = parse_macro_input!(item as ItemFn);
117
118    let mut visitor = PurityCheckVisitor { errors: vec![] };
119
120    // Clone the function body's box, and visit the block inside.
121    let mut new_body_box = input_fn.block.clone();
122    visitor.visit_block_mut(&mut new_body_box);
123
124    if !visitor.errors.is_empty() {
125        let combined_errors = visitor.errors.into_iter().reduce(|mut a, b| {
126            a.combine(b);
127            a
128        });
129        if let Some(errors) = combined_errors {
130            return errors.to_compile_error().into();
131        }
132    }
133
134    // Replace the old body with the new one containing the checks.
135    input_fn.block = new_body_box;
136
137    // Generate the ZST and IsPure impl.
138    let fn_name_str = input_fn.sig.ident.to_string();
139    let zst_name = Ident::new(&to_pascal_case(&fn_name_str), input_fn.sig.ident.span());
140
141    let expanded = quote! {
142        #input_fn
143
144        #[doc(hidden)]
145        struct #zst_name;
146        #[doc(hidden)]
147        impl crate::traits::IsPure for #zst_name {}
148    };
149
150    TokenStream::from(expanded)
151}
152
153/// Parser for attribute arguments
154struct AttributeArgs {
155    strategy_type: Type,
156}
157
158impl Parse for AttributeArgs {
159    fn parse(input: ParseStream) -> SynResult<Self> {
160        let strategy_type: Type = input.parse()?;
161        Ok(AttributeArgs { strategy_type })
162    }
163}
164
165/// Extract the inner types from Result<T, E>
166fn extract_result_types(return_type: &Type) -> SynResult<(Type, Type)> {
167    if let Type::Path(type_path) = return_type {
168        if let Some(segment) = type_path.path.segments.last() {
169            if segment.ident == "Result" {
170                if let PathArguments::AngleBracketed(args) = &segment.arguments {
171                    if args.args.len() == 2 {
172                        if let (
173                            GenericArgument::Type(ok_type),
174                            GenericArgument::Type(err_type)
175                        ) = (&args.args[0], &args.args[1]) {
176                            return Ok((ok_type.clone(), err_type.clone()));
177                        }
178                    }
179                }
180            }
181        }
182    }
183    
184    Err(Error::new_spanned(
185        return_type,
186        "Expected function to return Result<T, E>"
187    ))
188}
189
190/// The `error_strategy` attribute macro
191/// 
192/// This macro transforms a function that returns `Result<T, E>` into one that 
193/// returns `PipexResult<T, E>`, allowing the pipex library to apply the specified
194/// error handling strategy. Works with both sync and async functions.
195/// 
196/// # Arguments
197/// 
198/// * `strategy` - The error handling strategy type (e.g., `IgnoreHandler`, `CollectHandler`)
199/// 
200/// # Examples
201/// 
202/// ```rust,ignore
203/// use pipex_macros::error_strategy;
204/// 
205/// // Async function
206/// #[error_strategy(IgnoreHandler)]
207/// async fn process_item_async(x: i32) -> Result<i32, String> {
208///     if x % 2 == 0 {
209///         Ok(x * 2)
210///     } else {
211///         Err("Odd number".to_string())
212///     }
213/// }
214/// 
215/// // Sync function  
216/// #[error_strategy(CollectHandler)]
217/// fn process_item_sync(x: i32) -> Result<i32, String> {
218///     if x % 2 == 0 {
219///         Ok(x * 2)
220///     } else {
221///         Err("Odd number".to_string())
222///     }
223/// }
224/// ```
225/// 
226/// The generated function will automatically wrap the result in a `PipexResult`
227/// with the specified strategy name, allowing the pipeline to handle errors
228/// according to the strategy.
229#[proc_macro_attribute]
230pub fn error_strategy(args: TokenStream, item: TokenStream) -> TokenStream {
231    let input_fn = parse_macro_input!(item as ItemFn);
232    let args = parse_macro_input!(args as AttributeArgs);
233    
234    let strategy_type = args.strategy_type;
235    let fn_name = &input_fn.sig.ident;
236    let fn_vis = &input_fn.vis;
237    let fn_inputs = &input_fn.sig.inputs;
238    let fn_body = &input_fn.block;
239    let fn_asyncness = &input_fn.sig.asyncness;
240    let fn_generics = &input_fn.sig.generics;
241    let where_clause = &input_fn.sig.generics.where_clause;
242    
243    // Extract return type and validate it's Result<T, E>
244    let (ok_type, err_type) = match &input_fn.sig.output {
245        ReturnType::Type(_, ty) => {
246            match extract_result_types(ty) {
247                Ok(types) => types,
248                Err(e) => return e.to_compile_error().into(),
249            }
250        }
251        ReturnType::Default => {
252            return Error::new_spanned(
253                &input_fn.sig,
254                "Function must return Result<T, E>"
255            ).to_compile_error().into();
256        }
257    };
258    
259    // Create hidden function name for original implementation
260    let original_impl_name = syn::Ident::new(
261        &format!("{}_original_impl", fn_name),
262        fn_name.span()
263    );
264    
265    // Use the strategy type name as the strategy identifier
266    let strategy_name = quote!(#strategy_type).to_string();
267    
268    // Extract parameter names for the function call
269    let param_names: Vec<_> = input_fn.sig.inputs.iter().filter_map(|arg| {
270        if let syn::FnArg::Typed(pat_type) = arg {
271            if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
272                Some(&pat_ident.ident)
273            } else {
274                None
275            }
276        } else {
277            None
278        }
279    }).collect();
280    
281    // Generate different code based on whether function is async or sync
282    let function_call = if fn_asyncness.is_some() {
283        // Async function - use .await
284        quote! { #original_impl_name(#(#param_names),*).await }
285    } else {
286        // Sync function - no .await
287        quote! { #original_impl_name(#(#param_names),*) }
288    };
289    
290    let expanded = quote! {
291        #[doc(hidden)]
292        #fn_asyncness fn #original_impl_name #fn_generics (#fn_inputs) -> Result<#ok_type, #err_type> #where_clause
293        #fn_body
294        
295        #fn_vis #fn_asyncness fn #fn_name #fn_generics (#fn_inputs) -> crate::PipexResult<#ok_type, #err_type> #where_clause {
296            let result = #function_call;
297            crate::PipexResult::new(result, #strategy_name)
298        }
299    };
300    
301    TokenStream::from(expanded)
302}
303
304/// Memoization configuration for the `#[memoized]` attribute
305struct MemoizedArgs {
306    capacity: Option<usize>,
307}
308
309impl Parse for MemoizedArgs {
310    fn parse(input: ParseStream) -> SynResult<Self> {
311        let mut capacity = None;
312        
313        // Parse optional arguments like: capacity = 1000
314        while !input.is_empty() {
315            let lookahead = input.lookahead1();
316            if lookahead.peek(syn::Ident) {
317                let ident: Ident = input.parse()?;
318                if ident == "capacity" {
319                    input.parse::<syn::Token![=]>()?;
320                    let lit: Lit = input.parse()?;
321                    if let Lit::Int(lit_int) = lit {
322                        capacity = Some(lit_int.base10_parse()?);
323                    } else {
324                        return Err(Error::new_spanned(lit, "capacity must be an integer"));
325                    }
326                } else {
327                    return Err(Error::new_spanned(ident, "unknown attribute argument"));
328                }
329                
330                // Handle optional comma
331                if input.peek(syn::Token![,]) {
332                    input.parse::<syn::Token![,]>()?;
333                }
334            } else {
335                return Err(lookahead.error());
336            }
337        }
338        
339        Ok(MemoizedArgs { capacity })
340    }
341}
342
343impl Default for MemoizedArgs {
344    fn default() -> Self {
345        Self { capacity: Some(1000) } // Default capacity
346    }
347}
348
349/// The `memoized` attribute macro for automatic function memoization.
350///
351/// This macro provides automatic memoization for functions, caching results based on input parameters.
352/// It's designed to work perfectly with `#[pure]` functions since pure functions are safe to memoize.
353///
354/// # Features
355/// - Thread-safe caching using DashMap
356/// - Configurable cache capacity
357/// - Automatic cache key generation from function parameters
358/// - Zero-cost abstraction when memoization feature is disabled
359///
360/// # Arguments
361/// - `capacity` (optional): Maximum number of entries to cache (default: 1000)
362///
363/// # Requirements
364/// - Function parameters must implement `Clone + std::hash::Hash + Eq`
365/// - Return type must implement `Clone`
366/// - Requires the "memoization" feature to be enabled
367///
368/// # Examples
369///
370/// ```rust,ignore
371/// use pipex_macros::{pure, memoized};
372///
373/// #[pure]
374/// #[memoized]
375/// fn fibonacci(n: u64) -> u64 {
376///     if n <= 1 { n } else { fibonacci(n-1) + fibonacci(n-2) }
377/// }
378///
379/// #[pure]
380/// #[memoized(capacity = 500)]
381/// fn expensive_calculation(x: i32, y: i32) -> i32 {
382///     // Some expensive computation
383///     x * y + (x ^ y)
384/// }
385/// ```
386#[proc_macro_attribute]
387pub fn memoized(args: TokenStream, item: TokenStream) -> TokenStream {
388    let args = if args.is_empty() {
389        MemoizedArgs::default()
390    } else {
391        parse_macro_input!(args as MemoizedArgs)
392    };
393    
394    let input_fn = parse_macro_input!(item as ItemFn);
395    
396    // Extract function information
397    let fn_name = &input_fn.sig.ident;
398    let fn_vis = &input_fn.vis;
399    let fn_inputs = &input_fn.sig.inputs;
400    let fn_output = &input_fn.sig.output;
401    let fn_generics = &input_fn.sig.generics;
402    let where_clause = &input_fn.sig.generics.where_clause;
403    let fn_asyncness = &input_fn.sig.asyncness;
404    
405    // Generate cache name
406    let cache_name = Ident::new(&format!("{}_CACHE", fn_name.to_string().to_uppercase()), fn_name.span());
407    
408    // Extract parameter names and types for key generation
409    let param_names: Vec<_> = input_fn.sig.inputs.iter().filter_map(|arg| {
410        if let syn::FnArg::Typed(pat_type) = arg {
411            if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
412                Some(&pat_ident.ident)
413            } else {
414                None
415            }
416        } else {
417            None
418        }
419    }).collect();
420    
421    // Generate original function name
422    let original_fn_name = Ident::new(&format!("{}_original", fn_name), fn_name.span());
423    
424    // Determine cache capacity
425    let capacity = args.capacity.unwrap_or(1000);
426    
427    // Extract return type for cache value
428    let return_type = match &input_fn.sig.output {
429        ReturnType::Default => quote! { () },
430        ReturnType::Type(_, ty) => quote! { #ty },
431    };
432    
433    // Generate cache key type - tuple of all parameter types
434    let key_type = if param_names.is_empty() {
435        quote! { () }
436    } else {
437        let param_types: Vec<_> = input_fn.sig.inputs.iter().filter_map(|arg| {
438            if let syn::FnArg::Typed(pat_type) = arg {
439                Some(&pat_type.ty)
440            } else {
441                None
442            }
443        }).collect();
444        
445        if param_types.len() == 1 {
446            quote! { #(#param_types)* }
447        } else {
448            quote! { (#(#param_types),*) }
449        }
450    };
451    
452    // Generate cache key creation
453    let key_creation = if param_names.is_empty() {
454        quote! { () }
455    } else if param_names.len() == 1 {
456        let param = &param_names[0];
457        quote! { #param.clone() }
458    } else {
459        quote! { (#(#param_names.clone()),*) }
460    };
461    
462    // Generate function call
463    let fn_call = if fn_asyncness.is_some() {
464        quote! { #original_fn_name(#(#param_names),*).await }
465    } else {
466        quote! { #original_fn_name(#(#param_names),*) }
467    };
468    
469    let fn_body = &input_fn.block;
470    
471    let expanded = quote! {
472        // Original function implementation
473        #fn_asyncness fn #original_fn_name #fn_generics (#fn_inputs) #fn_output #where_clause
474        #fn_body
475        
476        // Memoized wrapper function
477        #fn_vis #fn_asyncness fn #fn_name #fn_generics (#fn_inputs) #fn_output #where_clause {
478            #[cfg(feature = "memoization")]
479            {
480                use std::sync::Arc;
481                
482                // Thread-safe cache using DashMap
483                static #cache_name: crate::once_cell::sync::Lazy<crate::dashmap::DashMap<#key_type, #return_type>> = crate::once_cell::sync::Lazy::new(|| {
484                    crate::dashmap::DashMap::with_capacity(#capacity)
485                });
486                
487                let cache = &#cache_name;
488                
489                let key = #key_creation;
490                
491                // Check cache first
492                if let Some(cached_result) = cache.get(&key) {
493                    return cached_result.clone();
494                }
495                
496                // Compute result and cache it
497                let result = #fn_call;
498                
499                // Only cache if we haven't exceeded capacity
500                if cache.len() < #capacity {
501                    cache.insert(key, result.clone());
502                }
503                
504                result
505            }
506            
507            #[cfg(not(feature = "memoization"))]
508            {
509                // When memoization is disabled, just call the original function
510                #fn_call
511            }
512        }
513    };
514    
515    TokenStream::from(expanded)
516}
517
518
519
520// No tests in proc macro crate - they can't use the macros defined here
521// Tests will be in the main pipex crate or integration tests