#![doc = include_str!("../README.md")]
mod tests;
use proc_macro2::TokenStream;
use proc_macro_error::{abort, SpanRange};
use quote::quote;
use std::str::FromStr;
use strum::{Display, EnumString};
use syn::fold::Fold;
use syn::WhereClause;
use syn::{
parse2, parse_quote, parse_str, punctuated::Punctuated, token::Comma, Block, FnArg,
GenericArgument, GenericParam, Generics, Ident, ItemFn, Lifetime, Pat, PatIdent, PatType,
PathArguments, Signature, Stmt, Type, TypePath, WherePredicate,
};
pub fn anyinput_core(args: TokenStream, input: TokenStream) -> TokenStream {
if !args.is_empty() {
abort!(args, "anyinput does not take any arguments.")
}
let old_item_fn = match parse2::<ItemFn>(input) {
Ok(syntax_tree) => syntax_tree,
Err(error) => return error.to_compile_error(),
};
let new_item_fn = transform_fn(old_item_fn);
quote!(#new_item_fn)
}
pub fn anyinput_core_sample(args: TokenStream, input: TokenStream) -> TokenStream {
if !args.is_empty() {
abort!(args, "anyinput does not take any arguments.")
}
let old_item_fn = match parse2::<ItemFn>(input) {
Ok(syntax_tree) => syntax_tree,
Err(error) => return error.to_compile_error(),
};
let new_item_fn = transform_fn_sample(old_item_fn);
quote!(#new_item_fn)
}
fn transform_fn_sample(_item_fn: ItemFn) -> ItemFn {
println!("input code : {}", quote!(#_item_fn));
println!("input syntax: {:?}", _item_fn);
parse_quote! {
fn hello_world() {
println!("Hello, world!");
}
}
}
fn transform_fn(item_fn: ItemFn) -> ItemFn {
let mut suffix_iter = simple_suffix_iter_factory();
let delta_fn_arg_new = |fn_arg| DeltaFnArg::new(fn_arg, &mut suffix_iter);
item_fn
.sig
.inputs
.iter()
.map(delta_fn_arg_new)
.fold(ItemFnAcc::init(&item_fn), ItemFnAcc::fold)
.to_item_fn()
}
struct ItemFnAcc<'a> {
old_fn: &'a ItemFn,
fn_args: Punctuated<FnArg, Comma>,
generic_params: Punctuated<GenericParam, Comma>,
where_predicates: Punctuated<WherePredicate, Comma>,
stmts: Vec<Stmt>,
}
impl ItemFnAcc<'_> {
fn init(item_fn: &ItemFn) -> ItemFnAcc {
ItemFnAcc {
old_fn: item_fn,
fn_args: Punctuated::<FnArg, Comma>::new(),
generic_params: item_fn.sig.generics.params.clone(),
where_predicates: ItemFnAcc::extract_where_predicates(item_fn),
stmts: item_fn.block.stmts.clone(),
}
}
fn extract_where_predicates(item_fn: &ItemFn) -> Punctuated<WherePredicate, Comma> {
if let Some(WhereClause { predicates, .. }) = &item_fn.sig.generics.where_clause {
predicates.clone()
} else {
parse_quote!()
}
}
fn fold(mut self, delta: DeltaFnArg) -> Self {
self.fn_args.push(delta.fn_arg);
self.generic_params.extend(delta.generic_params);
self.where_predicates.extend(delta.where_predicates);
for (index, element) in delta.stmt.into_iter().enumerate() {
self.stmts.insert(index, element);
}
self
}
fn to_item_fn(&self) -> ItemFn {
ItemFn {
sig: Signature {
generics: self.to_generics(),
inputs: self.fn_args.clone(),
..self.old_fn.sig.clone()
},
block: Box::new(Block {
stmts: self.stmts.clone(),
..*self.old_fn.block.clone()
}),
..self.old_fn.clone()
}
}
fn to_generics(&self) -> Generics {
Generics {
lt_token: parse_quote!(<),
params: self.generic_params.clone(),
gt_token: parse_quote!(>),
where_clause: self.to_where_clause(),
}
}
fn to_where_clause(&self) -> Option<WhereClause> {
if self.where_predicates.is_empty() {
None
} else {
Some(WhereClause {
where_token: parse_quote!(where),
predicates: self.where_predicates.clone(),
})
}
}
}
fn simple_suffix_iter_factory() -> impl Iterator<Item = String> + 'static {
(0usize..).into_iter().map(|i| format!("{i}"))
}
#[derive(Debug, Clone, EnumString, Display)]
#[allow(clippy::enum_variant_names)]
enum Special {
AnyArray,
AnyString,
AnyPath,
AnyIter,
AnyNdArray,
}
impl Special {
fn special_to_where_predicate(
&self,
generic: &TypePath, maybe_sub_type: Option<Type>,
maybe_lifetime: Option<Lifetime>,
span_range: &SpanRange,
) -> WherePredicate {
match &self {
Special::AnyString => {
if maybe_sub_type.is_some() {
abort!(span_range,"AnyString should not have a generic parameter, so 'AnyString', not 'AnyString<_>'.")
};
if maybe_lifetime.is_some() {
abort!(span_range, "AnyString should not have a lifetime.")
};
parse_quote! {
#generic : AsRef<str>
}
}
Special::AnyPath => {
if maybe_sub_type.is_some() {
abort!(span_range,"AnyPath should not have a generic parameter, so 'AnyPath', not 'AnyPath<_>'.")
};
if maybe_lifetime.is_some() {
abort!(span_range, "AnyPath should not have a lifetime.")
};
parse_quote! {
#generic : AsRef<std::path::Path>
}
}
Special::AnyArray => {
let sub_type = match maybe_sub_type {
Some(sub_type) => sub_type,
None => {
abort!(span_range,"AnyArray expects a generic parameter, for example, AnyArray<usize> or AnyArray<AnyString>.")
}
};
if maybe_lifetime.is_some() {
abort!(span_range, "AnyArray should not have a lifetime.")
};
parse_quote! {
#generic : AsRef<[#sub_type]>
}
}
Special::AnyIter => {
let sub_type = match maybe_sub_type {
Some(sub_type) => sub_type,
None => {
abort!(span_range,"AnyIter expects a generic parameter, for example, AnyIter<usize> or AnyIter<AnyString>.")
}
};
if maybe_lifetime.is_some() {
abort!(span_range, "AnyIter should not have a lifetime.")
};
parse_quote! {
#generic : IntoIterator<Item = #sub_type>
}
}
Special::AnyNdArray => {
let sub_type = match maybe_sub_type {
Some(sub_type) => sub_type,
None => {
abort!(span_range,"AnyNdArray expects a generic parameter, for example, AnyNdArray<usize> or AnyNdArray<AnyString>.")
}
};
let lifetime =
maybe_lifetime.expect("Internal error: AnyNdArray should be given a lifetime.");
parse_quote! {
#generic: Into<ndarray::ArrayView1<#lifetime, #sub_type>>
}
}
}
}
fn ident_to_stmt(&self, name: &Ident) -> Stmt {
match &self {
Special::AnyArray | Special::AnyString | Special::AnyPath => {
parse_quote! {
let #name = #name.as_ref();
}
}
Special::AnyIter => {
parse_quote! {
let #name = #name.into_iter();
}
}
Special::AnyNdArray => {
parse_quote! {
let #name = #name.into();
}
}
}
}
fn should_add_lifetime(&self) -> bool {
match self {
Special::AnyArray | Special::AnyString | Special::AnyPath | Special::AnyIter => false,
Special::AnyNdArray => true,
}
}
fn maybe_new(type_path: &TypePath, span_range: &SpanRange) -> Option<(Special, Option<Type>)> {
if type_path.qself.is_none() {
if let Some(segment) = first_and_only(type_path.path.segments.iter()) {
if let Ok(special) = Special::from_str(segment.ident.to_string().as_ref()) {
let maybe_sub_type =
Special::create_maybe_sub_type(&segment.arguments, span_range);
return Some((special, maybe_sub_type));
}
}
}
None
}
fn create_maybe_sub_type(args: &PathArguments, span_range: &SpanRange) -> Option<Type> {
match args {
PathArguments::None => None,
PathArguments::AngleBracketed(ref args) => {
let arg = first_and_only(args.args.iter()).unwrap_or_else(|| {
abort!(span_range, "Expected at exactly one generic parameter.")
});
if let GenericArgument::Type(sub_type2) = arg {
Some(sub_type2.clone())
} else {
abort!(span_range, "Expected generic parameter to be a type.")
}
}
PathArguments::Parenthesized(_) => {
abort!(span_range, "Expected <..> generic parameter.")
}
}
}
fn to_snake_case(&self) -> String {
let mut snake_case_string = String::new();
for (index, ch) in self.to_string().chars().enumerate() {
if index > 0 && ch.is_uppercase() {
snake_case_string.push('_');
}
snake_case_string.push(ch.to_ascii_lowercase());
}
snake_case_string
}
}
#[derive(Debug)]
struct DeltaFnArg {
fn_arg: FnArg,
generic_params: Vec<GenericParam>,
where_predicates: Vec<WherePredicate>,
stmt: Option<Stmt>,
}
impl DeltaFnArg {
fn new(fn_arg: &FnArg, suffix_iter: &mut impl Iterator<Item = String>) -> DeltaFnArg {
if let Some((pat_ident, pat_type)) = DeltaFnArg::is_normal_fn_arg(fn_arg) {
DeltaFnArg::replace_any_specials(pat_type.clone(), pat_ident, suffix_iter)
} else {
DeltaFnArg {
fn_arg: fn_arg.clone(),
generic_params: vec![],
where_predicates: vec![],
stmt: None,
}
}
}
fn is_normal_fn_arg(fn_arg: &FnArg) -> Option<(&PatIdent, &PatType)> {
if let FnArg::Typed(pat_type) = fn_arg {
if let Pat::Ident(pat_ident) = &*pat_type.pat {
if let Type::Path(_) = &*pat_type.ty {
return Some((pat_ident, pat_type));
}
}
}
None
}
#[allow(clippy::ptr_arg)]
fn replace_any_specials(
old_pat_type: PatType,
pat_ident: &PatIdent,
suffix_iter: &mut impl Iterator<Item = String>,
) -> DeltaFnArg {
let mut delta_pat_type = DeltaPatType::new(suffix_iter);
let new_pat_type = delta_pat_type.fold_pat_type(old_pat_type);
DeltaFnArg {
fn_arg: FnArg::Typed(new_pat_type),
stmt: delta_pat_type.generate_any_stmt(pat_ident),
generic_params: delta_pat_type.generic_params,
where_predicates: delta_pat_type.where_predicates,
}
}
}
struct DeltaPatType<'a> {
generic_params: Vec<GenericParam>,
where_predicates: Vec<WherePredicate>,
suffix_iter: &'a mut dyn Iterator<Item = String>,
last_special: Option<Special>,
}
impl Fold for DeltaPatType<'_> {
fn fold_type_path(&mut self, type_path_old: TypePath) -> TypePath {
let span_range = SpanRange::from_tokens(&type_path_old);
let type_path_middle = syn::fold::fold_type_path(self, type_path_old);
if let Some((special, maybe_sub_types)) = Special::maybe_new(&type_path_middle, &span_range)
{
self.last_special = Some(special.clone()); self.create_and_define_generic(special, maybe_sub_types, &span_range)
} else {
self.last_special = None;
type_path_middle
}
}
}
impl<'a> DeltaPatType<'a> {
fn new(suffix_iter: &'a mut dyn Iterator<Item = String>) -> Self {
DeltaPatType {
generic_params: vec![],
where_predicates: vec![],
suffix_iter,
last_special: None,
}
}
fn generate_any_stmt(&self, pat_ident: &PatIdent) -> Option<Stmt> {
if let Some(special) = &self.last_special {
let stmt = special.ident_to_stmt(&pat_ident.ident);
Some(stmt)
} else {
None
}
}
fn create_and_define_generic(
&mut self,
special: Special,
maybe_sub_type: Option<Type>,
span_range: &SpanRange,
) -> TypePath {
let generic = self.create_generic(&special); let maybe_lifetime = self.create_maybe_lifetime(&special);
let where_predicate = special.special_to_where_predicate(
&generic,
maybe_sub_type,
maybe_lifetime,
span_range,
);
let generic_param: GenericParam = parse_quote!(#generic);
self.generic_params.push(generic_param);
self.where_predicates.push(where_predicate);
generic
}
fn create_maybe_lifetime(&mut self, special: &Special) -> Option<Lifetime> {
if special.should_add_lifetime() {
let lifetime = self.create_lifetime(special);
let generic_param: GenericParam = parse_quote!(#lifetime);
self.generic_params.push(generic_param);
Some(lifetime)
} else {
None
}
}
fn create_generic(&mut self, special: &Special) -> TypePath {
let suffix = self.create_suffix();
let generic_name = format!("{}{}", &special, suffix);
parse_str(&generic_name).expect("Internal error: failed to parse generic name")
}
fn create_lifetime(&mut self, special: &Special) -> Lifetime {
let lifetime_name = format!("'{}{}", special.to_snake_case(), self.create_suffix());
parse_str(&lifetime_name).expect("Internal error: failed to parse lifetime name")
}
fn create_suffix(&mut self) -> String {
self.suffix_iter
.next()
.expect("Internal error: ran out of generic suffixes")
}
}
fn first_and_only<T, I: Iterator<Item = T>>(mut iter: I) -> Option<T> {
let first = iter.next()?;
if iter.next().is_some() {
None
} else {
Some(first)
}
}