#![recursion_limit = "128"]
extern crate proc_macro;
extern crate proc_macro2;
#[macro_use]
extern crate syn;
#[macro_use]
extern crate quote;
use inflector::Inflector;
use proc_macro2::Span;
use quote::ToTokens;
use syn::{
fold::Fold,
parse::{Error, Parse, ParseStream, Result},
parse2,
punctuated::Punctuated,
spanned::Spanned,
ArgCaptured, Expr, FnArg, FnDecl, Ident, Item, Lit, Meta, Pat, ReturnType, Type,
};
#[derive(Debug)]
struct Parameters {
cache: CacheValue,
start: usize,
}
#[derive(Debug)]
enum CacheValue {
Auto,
NoLimit,
Defined(usize),
}
impl Parse for Parameters {
fn parse(input: ParseStream) -> Result<Self> {
let mut params = Parameters {
cache: CacheValue::Auto,
start: 0,
};
let metas: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input).unwrap();
for meta in metas.into_iter() {
if let Meta::NameValue(nvp) = meta {
match &nvp.ident.to_string()[..] {
"cache" => match nvp.lit {
Lit::Int(i) => params.cache = CacheValue::Defined(i.value() as usize),
Lit::Str(s) => match &s.value()[..] {
"auto" => params.cache = CacheValue::Auto,
"no-limit" => params.cache = CacheValue::NoLimit,
_ => Err(Error::new(s.span(), "expected integer or `auto`"))?,
},
other => Err(Error::new(other.span(), "expected integer or `auto`"))?,
},
"start" => {
if let Lit::Int(i) = nvp.lit {
params.start = i.value() as usize;
} else {
Err(Error::new(nvp.lit.span(), "expected integer"))?
}
}
_ => Err(Error::new(nvp.ident.span(), "expected `cache` or `start`"))?,
}
} else {
Err(Error::new(meta.span(), "expected parameter assignment"))?
}
}
Ok(params)
}
}
struct FunctionCallUpdater<'a> {
func_ident: &'a Ident,
func_decl: &'a FnDecl,
determined_cache_size: Option<usize>,
}
impl Fold for FunctionCallUpdater<'_> {
fn fold_expr(&mut self, outer_node: Expr) -> Expr {
if let Expr::Call(node) = &outer_node {
if let Expr::Path(path) = &*node.func {
if path.path.segments.last().unwrap().value().ident == *self.func_ident {
let arg = node.args.last().unwrap();
let arg = arg.value();
if let Expr::Binary(binary) = arg {
if let Expr::Path(path) = &*binary.left {
let fn_param = self.func_decl.inputs.first().unwrap();
let fn_param = fn_param.value();
if let FnArg::Captured(ArgCaptured {
pat: Pat::Ident(ident),
..
}) = fn_param
{
if path.path.segments.first().unwrap().value().ident == ident.ident {
if let Expr::Lit(lit) = &*binary.right {
if let Lit::Int(int) = &lit.lit {
if let Some(current) = self.determined_cache_size {
self.determined_cache_size = Some(current.max(int.value() as usize));
}
}
}
}
} else {
panic!()
}
}
} else {
self.determined_cache_size = None;
}
return parse_quote!(
self.get(#arg)
);
}
}
syn::fold::fold_expr(self, outer_node)
} else {
syn::fold::fold_expr(self, outer_node)
}
}
}
#[proc_macro_attribute]
pub fn reciter(attr: proc_macro::TokenStream, input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let params = parse_macro_input!(attr as Parameters);
let input: proc_macro2::TokenStream = input.into();
let item: Item = parse2(input.clone()).unwrap();
let func = if let Item::Fn(func) = item {
func
} else {
return Error::new(item.span(), "expected function").to_compile_error().into();
};
let func_ident = &func.ident;
let func_inputs = &func.decl.inputs;
let iterator_ident = Ident::new(&(func_ident.to_string().to_pascal_case() + "Iterator"), Span::call_site());
let ty = if let ReturnType::Type(_, ty) = &func.decl.output {
ty
} else {
return Error::new(func.span(), "function needs to have a return type").to_compile_error().into();
}
.into_token_stream();
if func.decl.inputs.len() == 1 {
if let FnArg::Captured(arg) = func.decl.inputs.first().unwrap().value() {
if let Type::Path(typath) = &arg.ty {
if typath.path.segments.first().unwrap().value().ident != "usize" {
return Error::new(typath.span(), format!("expected usize, found {}", typath.into_token_stream()))
.to_compile_error()
.into();
}
}
} else {
unreachable!()
}
} else {
return Error::new(func.span(), "expected exactly 1 parameter").to_compile_error().into();
}
let mut fcu = FunctionCallUpdater {
func_ident: &func.ident,
func_decl: &func.decl,
determined_cache_size: Some(0),
};
let func_body = fcu.fold_block(*func.block);
let cache_size = match params.cache {
CacheValue::Defined(int) => Some(int),
CacheValue::Auto => {
if let Some(value) = fcu.determined_cache_size {
Some(value)
} else {
panic!("Couldn't figure out cache size automatically");
}
}
CacheValue::NoLimit => None,
};
let start = params.start;
let tokens = if let Some(cache_size) = cache_size {
if cache_size > 32 {
panic!("cache > 32 doesn't work at the momeent because Default is only implemented for arrays up to the size of 32")
}
quote!(
struct #iterator_ident {
counter : usize,
cache : [#ty;#cache_size],
cache_cursor : usize
}
impl #iterator_ident {
fn new() -> Self {
#iterator_ident {
counter : #start,
cache : <[#ty;#cache_size]>::default(), cache_cursor : 0
}
}
fn recursive (&mut self, #func_inputs ) -> #ty {
#func_body
}
fn get(&mut self, n : usize) -> #ty {
let back = self.counter - n;
if back == 0 {
self.recursive(n)
} else {
if back > #cache_size {
self.recursive(n)
} else {
let mut ind = if self.cache_cursor >= back {
self.cache_cursor - back
} else {
self.cache_cursor + #cache_size - back
};
self.cache[ind].clone() }
}
}
fn cache_write(&mut self, item : #ty) {
self.cache[self.cache_cursor] = item;
self.cache_cursor += 1;
if self.cache_cursor == #cache_size {
self.cache_cursor = 0;
}
}
}
impl Iterator for #iterator_ident {
type Item = #ty;
fn next(&mut self) -> std::option::Option<Self::Item> {
let ret = self.get(self.counter);
if #cache_size > 0 {self.cache_write(ret.clone());}
self.counter+=1;
Some(ret)
}
}
)
.into()
} else {
quote!(
struct #iterator_ident {
counter : usize,
cache : Vec<#ty>,
}
impl #iterator_ident {
fn new() -> Self {
#iterator_ident {
counter : #start,
cache : Vec::new(),
}
}
fn recursive (&mut self, #func_inputs ) -> #ty {
#func_body
}
fn get(&mut self, n : usize) -> #ty {
if self.counter == n {
self.recursive(n)
} else {
self.cache[n].clone() }
}
pub fn calculated_values(&self) -> &Vec<#ty> {
&self.cache
}
}
impl Iterator for #iterator_ident {
type Item = #ty;
fn next(&mut self) -> std::option::Option<Self::Item> {
let ret = self.get(self.counter);
self.cache.push(ret.clone());
self.counter+=1;
Some(ret)
}
}
)
.into()
};
tokens
}