anyinput_core/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod tests;
4
5use proc_macro2::TokenStream;
6use proc_macro_error::{abort, SpanRange};
7use quote::quote;
8use std::str::FromStr;
9use strum::{Display, EnumString};
10use syn::fold::Fold;
11use syn::WhereClause;
12use syn::{
13    parse2, parse_quote, parse_str, punctuated::Punctuated, token::Comma, Block, FnArg,
14    GenericArgument, GenericParam, Generics, Ident, ItemFn, Lifetime, Pat, PatIdent, PatType,
15    PathArguments, Signature, Stmt, Type, TypePath, WherePredicate,
16};
17
18pub fn anyinput_core(args: TokenStream, input: TokenStream) -> TokenStream {
19    if !args.is_empty() {
20        abort!(args, "anyinput does not take any arguments.")
21    }
22
23    // proc_marco2 version of "parse_macro_input!(input as ItemFn)"
24    let old_item_fn = match parse2::<ItemFn>(input) {
25        Ok(syntax_tree) => syntax_tree,
26        Err(error) => return error.to_compile_error(),
27    };
28
29    let new_item_fn = transform_fn(old_item_fn);
30
31    quote!(#new_item_fn)
32}
33
34pub fn anyinput_core_sample(args: TokenStream, input: TokenStream) -> TokenStream {
35    if !args.is_empty() {
36        abort!(args, "anyinput does not take any arguments.")
37    }
38
39    // proc_marco2 version of "parse_macro_input!(input as ItemFn)"
40    let old_item_fn = match parse2::<ItemFn>(input) {
41        Ok(syntax_tree) => syntax_tree,
42        Err(error) => return error.to_compile_error(),
43    };
44
45    let new_item_fn = transform_fn_sample(old_item_fn);
46
47    quote!(#new_item_fn)
48}
49
50fn transform_fn_sample(_item_fn: ItemFn) -> ItemFn {
51    println!("input code  : {}", quote!(#_item_fn));
52    println!("input syntax: {:?}", _item_fn);
53    parse_quote! {
54        fn hello_world() {
55            println!("Hello, world!");
56        }
57    }
58}
59
60fn transform_fn(item_fn: ItemFn) -> ItemFn {
61    let mut suffix_iter = simple_suffix_iter_factory();
62    let delta_fn_arg_new = |fn_arg| DeltaFnArg::new(fn_arg, &mut suffix_iter);
63
64    // Transform each old argument of the function, accumulating: the new argument, new generics, wheres, and statements
65    // Then, turn the accumulation into a new function.
66    item_fn
67        .sig
68        .inputs
69        .iter()
70        .map(delta_fn_arg_new)
71        .fold(ItemFnAcc::init(&item_fn), ItemFnAcc::fold)
72        .to_item_fn()
73}
74
75struct ItemFnAcc<'a> {
76    old_fn: &'a ItemFn,
77    fn_args: Punctuated<FnArg, Comma>,
78    generic_params: Punctuated<GenericParam, Comma>,
79    where_predicates: Punctuated<WherePredicate, Comma>,
80    stmts: Vec<Stmt>,
81}
82
83impl ItemFnAcc<'_> {
84    fn init(item_fn: &ItemFn) -> ItemFnAcc {
85        // Start with 1. no function arguments, 2. the old function's generics, wheres, and statements
86        ItemFnAcc {
87            old_fn: item_fn,
88            fn_args: Punctuated::<FnArg, Comma>::new(),
89            generic_params: item_fn.sig.generics.params.clone(),
90            where_predicates: ItemFnAcc::extract_where_predicates(item_fn),
91            stmts: item_fn.block.stmts.clone(),
92        }
93    }
94
95    // Even if the where clause is None, we still need to return an empty Punctuated
96    fn extract_where_predicates(item_fn: &ItemFn) -> Punctuated<WherePredicate, Comma> {
97        if let Some(WhereClause { predicates, .. }) = &item_fn.sig.generics.where_clause {
98            predicates.clone()
99        } else {
100            parse_quote!()
101        }
102    }
103
104    fn fold(mut self, delta: DeltaFnArg) -> Self {
105        self.fn_args.push(delta.fn_arg);
106        self.generic_params.extend(delta.generic_params);
107        self.where_predicates.extend(delta.where_predicates);
108        for (index, element) in delta.stmt.into_iter().enumerate() {
109            self.stmts.insert(index, element);
110        }
111        self
112    }
113
114    // Use Rust's struct update syntax (https://www.reddit.com/r/rust/comments/pchp8h/media_struct_update_syntax_in_rust/)
115    fn to_item_fn(&self) -> ItemFn {
116        ItemFn {
117            sig: Signature {
118                generics: self.to_generics(),
119                inputs: self.fn_args.clone(),
120                ..self.old_fn.sig.clone()
121            },
122            block: Box::new(Block {
123                stmts: self.stmts.clone(),
124                ..*self.old_fn.block.clone()
125            }),
126            ..self.old_fn.clone()
127        }
128    }
129
130    fn to_generics(&self) -> Generics {
131        Generics {
132            lt_token: parse_quote!(<),
133            params: self.generic_params.clone(),
134            gt_token: parse_quote!(>),
135            where_clause: self.to_where_clause(),
136        }
137    }
138
139    fn to_where_clause(&self) -> Option<WhereClause> {
140        if self.where_predicates.is_empty() {
141            None
142        } else {
143            Some(WhereClause {
144                where_token: parse_quote!(where),
145                predicates: self.where_predicates.clone(),
146            })
147        }
148    }
149}
150
151// Define a generator for suffixes of generic types. "0", "1", "2", ...
152// This is used to create unique names for generic types.
153// Could switch to one based on UUIDs, but this is easier to read.
154fn simple_suffix_iter_factory() -> impl Iterator<Item = String> + 'static {
155    (0usize..).map(|i| format!("{i}"))
156}
157
158// Define the Specials and their properties.
159#[derive(Debug, Clone, EnumString, Display)]
160#[allow(clippy::enum_variant_names)]
161enum Special {
162    AnyArray,
163    AnyString,
164    AnyPath,
165    AnyIter,
166    AnyNdArray,
167}
168
169impl Special {
170    fn special_to_where_predicate(
171        &self,
172        generic: &TypePath, // for example: AnyArray0
173        maybe_sub_type: Option<Type>,
174        maybe_lifetime: Option<Lifetime>,
175        span_range: &SpanRange,
176    ) -> WherePredicate {
177        match &self {
178            Special::AnyString => {
179                if maybe_sub_type.is_some() {
180                    abort!(span_range,"AnyString should not have a generic parameter, so 'AnyString', not 'AnyString<_>'.")
181                };
182                if maybe_lifetime.is_some() {
183                    abort!(span_range, "AnyString should not have a lifetime.")
184                };
185                parse_quote! {
186                    #generic : AsRef<str>
187                }
188            }
189            Special::AnyPath => {
190                if maybe_sub_type.is_some() {
191                    abort!(span_range,"AnyPath should not have a generic parameter, so 'AnyPath', not 'AnyPath<_>'.")
192                };
193                if maybe_lifetime.is_some() {
194                    abort!(span_range, "AnyPath should not have a lifetime.")
195                };
196                parse_quote! {
197                    #generic : AsRef<std::path::Path>
198                }
199            }
200            Special::AnyArray => {
201                let sub_type = match maybe_sub_type {
202                    Some(sub_type) => sub_type,
203                    None => {
204                        abort!(span_range,"AnyArray expects a generic parameter, for example, AnyArray<usize> or AnyArray<AnyString>.")
205                    }
206                };
207                if maybe_lifetime.is_some() {
208                    abort!(span_range, "AnyArray should not have a lifetime.")
209                };
210                parse_quote! {
211                    #generic : AsRef<[#sub_type]>
212                }
213            }
214            Special::AnyIter => {
215                let sub_type = match maybe_sub_type {
216                    Some(sub_type) => sub_type,
217                    None => {
218                        abort!(span_range,"AnyIter expects a generic parameter, for example, AnyIter<usize> or AnyIter<AnyString>.")
219                    }
220                };
221                if maybe_lifetime.is_some() {
222                    abort!(span_range, "AnyIter should not have a lifetime.")
223                };
224                parse_quote! {
225                    #generic : IntoIterator<Item = #sub_type>
226                }
227            }
228            Special::AnyNdArray => {
229                let sub_type = match maybe_sub_type {
230                    Some(sub_type) => sub_type,
231                    None => {
232                        abort!(span_range,"AnyNdArray expects a generic parameter, for example, AnyNdArray<usize> or AnyNdArray<AnyString>.")
233                    }
234                };
235                let lifetime =
236                    maybe_lifetime.expect("Internal error: AnyNdArray should be given a lifetime.");
237                parse_quote! {
238                    #generic: Into<ndarray::ArrayView1<#lifetime, #sub_type>>
239                }
240            }
241        }
242    }
243
244    fn ident_to_stmt(&self, name: &Ident) -> Stmt {
245        match &self {
246            Special::AnyArray | Special::AnyString | Special::AnyPath => {
247                parse_quote! {
248                    let #name = #name.as_ref();
249                }
250            }
251            Special::AnyIter => {
252                parse_quote! {
253                    let #name = #name.into_iter();
254                }
255            }
256            Special::AnyNdArray => {
257                parse_quote! {
258                    let #name = #name.into();
259                }
260            }
261        }
262    }
263
264    fn should_add_lifetime(&self) -> bool {
265        match self {
266            Special::AnyArray | Special::AnyString | Special::AnyPath | Special::AnyIter => false,
267            Special::AnyNdArray => true,
268        }
269    }
270
271    fn maybe_new(type_path: &TypePath, span_range: &SpanRange) -> Option<(Special, Option<Type>)> {
272        // A special type path has exactly one segment and a name from the Special enum.
273        if type_path.qself.is_none() {
274            if let Some(segment) = first_and_only(type_path.path.segments.iter()) {
275                if let Ok(special) = Special::from_str(segment.ident.to_string().as_ref()) {
276                    let maybe_sub_type =
277                        Special::create_maybe_sub_type(&segment.arguments, span_range);
278                    return Some((special, maybe_sub_type));
279                }
280            }
281        }
282        None
283    }
284
285    fn create_maybe_sub_type(args: &PathArguments, span_range: &SpanRange) -> Option<Type> {
286        match args {
287            PathArguments::None => None,
288            PathArguments::AngleBracketed(ref args) => {
289                let arg = first_and_only(args.args.iter()).unwrap_or_else(|| {
290                    abort!(span_range, "Expected at exactly one generic parameter.")
291                });
292                if let GenericArgument::Type(sub_type2) = arg {
293                    Some(sub_type2.clone())
294                } else {
295                    abort!(span_range, "Expected generic parameter to be a type.")
296                }
297            }
298            PathArguments::Parenthesized(_) => {
299                abort!(span_range, "Expected <..> generic parameter.")
300            }
301        }
302    }
303
304    // Utility that turns camel case into snake case.
305    // For example, "AnyString" -> "any_string".
306    fn to_snake_case(&self) -> String {
307        let mut snake_case_string = String::new();
308        for (index, ch) in self.to_string().chars().enumerate() {
309            if index > 0 && ch.is_uppercase() {
310                snake_case_string.push('_');
311            }
312            snake_case_string.push(ch.to_ascii_lowercase());
313        }
314        snake_case_string
315    }
316}
317
318#[derive(Debug)]
319// The new function input, any statements to add, and any new generic definitions.
320struct DeltaFnArg {
321    fn_arg: FnArg,
322    generic_params: Vec<GenericParam>,
323    where_predicates: Vec<WherePredicate>,
324    stmt: Option<Stmt>,
325}
326
327impl DeltaFnArg {
328    // If a function argument contains a special type(s), re-write it/them.
329    fn new(fn_arg: &FnArg, suffix_iter: &mut impl Iterator<Item = String>) -> DeltaFnArg {
330        // If the function input is normal (not self, not a macro, etc) ...
331        if let Some((pat_ident, pat_type)) = DeltaFnArg::is_normal_fn_arg(fn_arg) {
332            // Replace any specials in the type with generics.
333            DeltaFnArg::replace_any_specials(pat_type.clone(), pat_ident, suffix_iter)
334        } else {
335            // if input is not normal, return it unchanged.
336            DeltaFnArg {
337                fn_arg: fn_arg.clone(),
338                generic_params: vec![],
339                where_predicates: vec![],
340                stmt: None,
341            }
342        }
343    }
344
345    // A function argument is normal if it is not self, not a macro, etc.
346    fn is_normal_fn_arg(fn_arg: &FnArg) -> Option<(&PatIdent, &PatType)> {
347        if let FnArg::Typed(pat_type) = fn_arg {
348            if let Pat::Ident(pat_ident) = &*pat_type.pat {
349                if let Type::Path(_) = &*pat_type.ty {
350                    return Some((pat_ident, pat_type));
351                }
352            }
353        }
354        None
355    }
356
357    // Search type and its (sub)subtypes for specials starting at the deepest level.
358    // When one is found, replace it with a generic.
359    // Finally, return the new type and a list of the generic definitions.
360    // Also, if the top-level type was special, return the special type.
361    #[allow(clippy::ptr_arg)]
362    fn replace_any_specials(
363        old_pat_type: PatType,
364        pat_ident: &PatIdent,
365        suffix_iter: &mut impl Iterator<Item = String>,
366    ) -> DeltaFnArg {
367        let mut delta_pat_type = DeltaPatType::new(suffix_iter);
368        let new_pat_type = delta_pat_type.fold_pat_type(old_pat_type);
369
370        // Return the new function input, any statements to add, and any new generic definitions.
371        DeltaFnArg {
372            fn_arg: FnArg::Typed(new_pat_type),
373            stmt: delta_pat_type.generate_any_stmt(pat_ident),
374            generic_params: delta_pat_type.generic_params,
375            where_predicates: delta_pat_type.where_predicates,
376        }
377    }
378}
379
380struct DeltaPatType<'a> {
381    generic_params: Vec<GenericParam>,
382    where_predicates: Vec<WherePredicate>,
383    suffix_iter: &'a mut dyn Iterator<Item = String>,
384    last_special: Option<Special>,
385}
386
387impl Fold for DeltaPatType<'_> {
388    fn fold_type_path(&mut self, type_path_old: TypePath) -> TypePath {
389        let span_range = SpanRange::from_tokens(&type_path_old); // used by abort!
390
391        // Apply "fold" recursively to process specials in subtypes, for example, Vec<AnyString>.
392        let type_path_middle = syn::fold::fold_type_path(self, type_path_old);
393
394        // If this type is special, replace it with a generic.
395        if let Some((special, maybe_sub_types)) = Special::maybe_new(&type_path_middle, &span_range)
396        {
397            self.last_special = Some(special.clone()); // remember the special found (used for stmt generation)
398            self.create_and_define_generic(special, maybe_sub_types, &span_range)
399        } else {
400            self.last_special = None;
401            type_path_middle
402        }
403    }
404}
405
406impl<'a> DeltaPatType<'a> {
407    fn new(suffix_iter: &'a mut dyn Iterator<Item = String>) -> Self {
408        DeltaPatType {
409            generic_params: vec![],
410            where_predicates: vec![],
411            suffix_iter,
412            last_special: None,
413        }
414    }
415
416    // If the top-level type is a special, add a statement to convert
417    // from its generic type to to a concrete type.
418    // For example,  "let x = x.into_iter();" for AnyIter.
419    fn generate_any_stmt(&self, pat_ident: &PatIdent) -> Option<Stmt> {
420        if let Some(special) = &self.last_special {
421            let stmt = special.ident_to_stmt(&pat_ident.ident);
422            Some(stmt)
423        } else {
424            None
425        }
426    }
427
428    // Define the generic type, for example, "AnyString3: AsRef<str>", and remember the definition.
429    fn create_and_define_generic(
430        &mut self,
431        special: Special,
432        maybe_sub_type: Option<Type>,
433        span_range: &SpanRange,
434    ) -> TypePath {
435        let generic = self.create_generic(&special); // for example, "AnyString3"
436        let maybe_lifetime = self.create_maybe_lifetime(&special);
437        let where_predicate = special.special_to_where_predicate(
438            &generic,
439            maybe_sub_type,
440            maybe_lifetime,
441            span_range,
442        );
443        let generic_param: GenericParam = parse_quote!(#generic);
444        self.generic_params.push(generic_param);
445        self.where_predicates.push(where_predicate);
446        generic
447    }
448
449    // create a lifetime if needed, for example, Some('any_nd_array_3) or None
450    fn create_maybe_lifetime(&mut self, special: &Special) -> Option<Lifetime> {
451        if special.should_add_lifetime() {
452            let lifetime = self.create_lifetime(special);
453            let generic_param: GenericParam = parse_quote!(#lifetime);
454            self.generic_params.push(generic_param);
455
456            Some(lifetime)
457        } else {
458            None
459        }
460    }
461
462    // Create a new generic type, for example, "AnyString3"
463    fn create_generic(&mut self, special: &Special) -> TypePath {
464        let suffix = self.create_suffix();
465        let generic_name = format!("{}{}", &special, suffix);
466        parse_str(&generic_name).expect("Internal error: failed to parse generic name")
467    }
468
469    // Create a new lifetime, for example, "'any_nd_array_4"
470    fn create_lifetime(&mut self, special: &Special) -> Lifetime {
471        let lifetime_name = format!("'{}{}", special.to_snake_case(), self.create_suffix());
472        parse_str(&lifetime_name).expect("Internal error: failed to parse lifetime name")
473    }
474
475    // Create a new suffix, for example, "4"
476    fn create_suffix(&mut self) -> String {
477        self.suffix_iter
478            .next()
479            .expect("Internal error: ran out of generic suffixes")
480    }
481}
482
483// Utility that tells if an iterator contains exactly one element.
484fn first_and_only<T, I: Iterator<Item = T>>(mut iter: I) -> Option<T> {
485    let first = iter.next()?;
486    if iter.next().is_some() {
487        None
488    } else {
489        Some(first)
490    }
491}
492
493// todo later could nested .as_ref(), .into_iter(), and .into() be replaced with a single method or macro?
494// todo later do something interesting with 2d ndarray/views
495// todo later when does the std lib use where clauses? Is there an informal rule? Should there be an option?