1extern crate proc_macro;
7
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::{
11 parse_macro_input, ItemFn, Type, ReturnType, GenericArgument, PathArguments,
12 parse::Parse, parse::ParseStream, Error, Result as SynResult,
13 visit_mut::{self, VisitMut}, Expr, Ident, Lit,
14 spanned::Spanned,
15};
16
17fn to_pascal_case(s: &str) -> String {
19 let mut pascal = String::new();
20 let mut capitalize = true;
21 for c in s.chars() {
22 if c == '_' {
23 capitalize = true;
24 } else if capitalize {
25 pascal.push(c.to_ascii_uppercase());
26 capitalize = false;
27 } else {
28 pascal.push(c);
29 }
30 }
31 pascal
32}
33
34struct PurityCheckVisitor {
36 errors: Vec<Error>,
37}
38
39impl VisitMut for PurityCheckVisitor {
40 fn visit_expr_mut(&mut self, i: &mut Expr) {
41 match i {
42 Expr::Unsafe(e) => {
43 self.errors.push(Error::new(
44 e.span(),
45 "impure `unsafe` block found in function marked as `pure`",
46 ));
47 }
48 Expr::Macro(e) => {
49 if e.mac.path.is_ident("asm") {
50 self.errors.push(Error::new(
51 e.span(),
52 "impure inline assembly found in function marked as `pure`",
53 ));
54 }
55 }
56 Expr::MethodCall(e) => {
57 self.errors.push(Error::new(
58 e.span(), "method calls are not supported in pure functions"
59 ));
60 }
61 Expr::Call(call_expr) => {
62 if let Expr::Path(expr_path) = &*call_expr.func {
63 let path = &expr_path.path;
64 if let Some(segment) = path.segments.last() {
65 if segment.ident == "Ok" || segment.ident == "Err" {
66 visit_mut::visit_expr_call_mut(self, call_expr);
67 return;
68 }
69 }
70
71 let mut zst_path = path.clone();
72 if let Some(last_segment) = zst_path.segments.last_mut() {
73 let ident_str = last_segment.ident.to_string();
74 let pascal_case_ident = to_pascal_case(&ident_str);
75 last_segment.ident = Ident::new(&pascal_case_ident, last_segment.ident.span());
76
77 let _fn_name = path.segments.last().unwrap().ident.to_string();
79
80 visit_mut::visit_expr_call_mut(self, call_expr);
82
83 let new_node = syn::parse_quote!({
84 {
85 let _ = || {
87 fn _assert_pure_function<T: crate::traits::IsPure>(_: T) {}
88 _assert_pure_function(#zst_path);
89 };
90 #call_expr
91 }
92 });
93 *i = new_node;
94 return; }
96 } else {
97 self.errors.push(Error::new_spanned(&call_expr.func, "closures and other complex function call expressions are not supported in pure functions"));
98 }
99 }
100 _ => {}
101 }
102
103 visit_mut::visit_expr_mut(self, i);
105 }
106}
107
108#[proc_macro_attribute]
115pub fn pure(_args: TokenStream, item: TokenStream) -> TokenStream {
116 let mut input_fn = parse_macro_input!(item as ItemFn);
117
118 let mut visitor = PurityCheckVisitor { errors: vec![] };
119
120 let mut new_body_box = input_fn.block.clone();
122 visitor.visit_block_mut(&mut new_body_box);
123
124 if !visitor.errors.is_empty() {
125 let combined_errors = visitor.errors.into_iter().reduce(|mut a, b| {
126 a.combine(b);
127 a
128 });
129 if let Some(errors) = combined_errors {
130 return errors.to_compile_error().into();
131 }
132 }
133
134 input_fn.block = new_body_box;
136
137 let fn_name_str = input_fn.sig.ident.to_string();
139 let zst_name = Ident::new(&to_pascal_case(&fn_name_str), input_fn.sig.ident.span());
140
141 let expanded = quote! {
142 #input_fn
143
144 #[doc(hidden)]
145 struct #zst_name;
146 #[doc(hidden)]
147 impl crate::traits::IsPure for #zst_name {}
148 };
149
150 TokenStream::from(expanded)
151}
152
153struct AttributeArgs {
155 strategy_type: Type,
156}
157
158impl Parse for AttributeArgs {
159 fn parse(input: ParseStream) -> SynResult<Self> {
160 let strategy_type: Type = input.parse()?;
161 Ok(AttributeArgs { strategy_type })
162 }
163}
164
165fn extract_result_types(return_type: &Type) -> SynResult<(Type, Type)> {
167 if let Type::Path(type_path) = return_type {
168 if let Some(segment) = type_path.path.segments.last() {
169 if segment.ident == "Result" {
170 if let PathArguments::AngleBracketed(args) = &segment.arguments {
171 if args.args.len() == 2 {
172 if let (
173 GenericArgument::Type(ok_type),
174 GenericArgument::Type(err_type)
175 ) = (&args.args[0], &args.args[1]) {
176 return Ok((ok_type.clone(), err_type.clone()));
177 }
178 }
179 }
180 }
181 }
182 }
183
184 Err(Error::new_spanned(
185 return_type,
186 "Expected function to return Result<T, E>"
187 ))
188}
189
190#[proc_macro_attribute]
230pub fn error_strategy(args: TokenStream, item: TokenStream) -> TokenStream {
231 let input_fn = parse_macro_input!(item as ItemFn);
232 let args = parse_macro_input!(args as AttributeArgs);
233
234 let strategy_type = args.strategy_type;
235 let fn_name = &input_fn.sig.ident;
236 let fn_vis = &input_fn.vis;
237 let fn_inputs = &input_fn.sig.inputs;
238 let fn_body = &input_fn.block;
239 let fn_asyncness = &input_fn.sig.asyncness;
240 let fn_generics = &input_fn.sig.generics;
241 let where_clause = &input_fn.sig.generics.where_clause;
242
243 let (ok_type, err_type) = match &input_fn.sig.output {
245 ReturnType::Type(_, ty) => {
246 match extract_result_types(ty) {
247 Ok(types) => types,
248 Err(e) => return e.to_compile_error().into(),
249 }
250 }
251 ReturnType::Default => {
252 return Error::new_spanned(
253 &input_fn.sig,
254 "Function must return Result<T, E>"
255 ).to_compile_error().into();
256 }
257 };
258
259 let original_impl_name = syn::Ident::new(
261 &format!("{}_original_impl", fn_name),
262 fn_name.span()
263 );
264
265 let strategy_name = quote!(#strategy_type).to_string();
267
268 let param_names: Vec<_> = input_fn.sig.inputs.iter().filter_map(|arg| {
270 if let syn::FnArg::Typed(pat_type) = arg {
271 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
272 Some(&pat_ident.ident)
273 } else {
274 None
275 }
276 } else {
277 None
278 }
279 }).collect();
280
281 let function_call = if fn_asyncness.is_some() {
283 quote! { #original_impl_name(#(#param_names),*).await }
285 } else {
286 quote! { #original_impl_name(#(#param_names),*) }
288 };
289
290 let expanded = quote! {
291 #[doc(hidden)]
292 #fn_asyncness fn #original_impl_name #fn_generics (#fn_inputs) -> Result<#ok_type, #err_type> #where_clause
293 #fn_body
294
295 #fn_vis #fn_asyncness fn #fn_name #fn_generics (#fn_inputs) -> crate::PipexResult<#ok_type, #err_type> #where_clause {
296 let result = #function_call;
297 crate::PipexResult::new(result, #strategy_name)
298 }
299 };
300
301 TokenStream::from(expanded)
302}
303
304struct MemoizedArgs {
306 capacity: Option<usize>,
307}
308
309impl Parse for MemoizedArgs {
310 fn parse(input: ParseStream) -> SynResult<Self> {
311 let mut capacity = None;
312
313 while !input.is_empty() {
315 let lookahead = input.lookahead1();
316 if lookahead.peek(syn::Ident) {
317 let ident: Ident = input.parse()?;
318 if ident == "capacity" {
319 input.parse::<syn::Token![=]>()?;
320 let lit: Lit = input.parse()?;
321 if let Lit::Int(lit_int) = lit {
322 capacity = Some(lit_int.base10_parse()?);
323 } else {
324 return Err(Error::new_spanned(lit, "capacity must be an integer"));
325 }
326 } else {
327 return Err(Error::new_spanned(ident, "unknown attribute argument"));
328 }
329
330 if input.peek(syn::Token![,]) {
332 input.parse::<syn::Token![,]>()?;
333 }
334 } else {
335 return Err(lookahead.error());
336 }
337 }
338
339 Ok(MemoizedArgs { capacity })
340 }
341}
342
343impl Default for MemoizedArgs {
344 fn default() -> Self {
345 Self { capacity: Some(1000) } }
347}
348
349#[proc_macro_attribute]
387pub fn memoized(args: TokenStream, item: TokenStream) -> TokenStream {
388 let args = if args.is_empty() {
389 MemoizedArgs::default()
390 } else {
391 parse_macro_input!(args as MemoizedArgs)
392 };
393
394 let input_fn = parse_macro_input!(item as ItemFn);
395
396 let fn_name = &input_fn.sig.ident;
398 let fn_vis = &input_fn.vis;
399 let fn_inputs = &input_fn.sig.inputs;
400 let fn_output = &input_fn.sig.output;
401 let fn_generics = &input_fn.sig.generics;
402 let where_clause = &input_fn.sig.generics.where_clause;
403 let fn_asyncness = &input_fn.sig.asyncness;
404
405 let cache_name = Ident::new(&format!("{}_CACHE", fn_name.to_string().to_uppercase()), fn_name.span());
407
408 let param_names: Vec<_> = input_fn.sig.inputs.iter().filter_map(|arg| {
410 if let syn::FnArg::Typed(pat_type) = arg {
411 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
412 Some(&pat_ident.ident)
413 } else {
414 None
415 }
416 } else {
417 None
418 }
419 }).collect();
420
421 let original_fn_name = Ident::new(&format!("{}_original", fn_name), fn_name.span());
423
424 let capacity = args.capacity.unwrap_or(1000);
426
427 let return_type = match &input_fn.sig.output {
429 ReturnType::Default => quote! { () },
430 ReturnType::Type(_, ty) => quote! { #ty },
431 };
432
433 let key_type = if param_names.is_empty() {
435 quote! { () }
436 } else {
437 let param_types: Vec<_> = input_fn.sig.inputs.iter().filter_map(|arg| {
438 if let syn::FnArg::Typed(pat_type) = arg {
439 Some(&pat_type.ty)
440 } else {
441 None
442 }
443 }).collect();
444
445 if param_types.len() == 1 {
446 quote! { #(#param_types)* }
447 } else {
448 quote! { (#(#param_types),*) }
449 }
450 };
451
452 let key_creation = if param_names.is_empty() {
454 quote! { () }
455 } else if param_names.len() == 1 {
456 let param = ¶m_names[0];
457 quote! { #param.clone() }
458 } else {
459 quote! { (#(#param_names.clone()),*) }
460 };
461
462 let fn_call = if fn_asyncness.is_some() {
464 quote! { #original_fn_name(#(#param_names),*).await }
465 } else {
466 quote! { #original_fn_name(#(#param_names),*) }
467 };
468
469 let fn_body = &input_fn.block;
470
471 let expanded = quote! {
472 #fn_asyncness fn #original_fn_name #fn_generics (#fn_inputs) #fn_output #where_clause
474 #fn_body
475
476 #fn_vis #fn_asyncness fn #fn_name #fn_generics (#fn_inputs) #fn_output #where_clause {
478 #[cfg(feature = "memoization")]
479 {
480 use std::sync::Arc;
481
482 static #cache_name: crate::once_cell::sync::Lazy<crate::dashmap::DashMap<#key_type, #return_type>> = crate::once_cell::sync::Lazy::new(|| {
484 crate::dashmap::DashMap::with_capacity(#capacity)
485 });
486
487 let cache = &#cache_name;
488
489 let key = #key_creation;
490
491 if let Some(cached_result) = cache.get(&key) {
493 return cached_result.clone();
494 }
495
496 let result = #fn_call;
498
499 if cache.len() < #capacity {
501 cache.insert(key, result.clone());
502 }
503
504 result
505 }
506
507 #[cfg(not(feature = "memoization"))]
508 {
509 #fn_call
511 }
512 }
513 };
514
515 TokenStream::from(expanded)
516}
517
518
519
520