reciter 0.1.2

Macro that allows converting a recursive function into an Iterator, which uses a cache
Documentation
#![recursion_limit = "128"]

extern crate proc_macro;
extern crate proc_macro2;
#[macro_use]
extern crate syn;
#[macro_use]
extern crate quote;

use inflector::Inflector;
use proc_macro2::Span;
use quote::ToTokens;
use syn::{
    fold::Fold,
    parse::{Error, Parse, ParseStream, Result},
    parse2,
    punctuated::Punctuated,
    spanned::Spanned,
    ArgCaptured, Expr, FnArg, FnDecl, Ident, Item, Lit, Meta, Pat, ReturnType, Type,
};

#[derive(Debug)]
struct Parameters {
    cache: CacheValue,
    start: usize,
}

#[derive(Debug)]
enum CacheValue {
    Auto,
    NoLimit,
    Defined(usize),
}

impl Parse for Parameters {
    fn parse(input: ParseStream) -> Result<Self> {
        let mut params = Parameters {
            cache: CacheValue::Auto,
            start: 0,
        };

        let metas: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input).unwrap();

        for meta in metas.into_iter() {
            if let Meta::NameValue(nvp) = meta {
                match &nvp.ident.to_string()[..] {
                    "cache" => match nvp.lit {
                        Lit::Int(i) => params.cache = CacheValue::Defined(i.value() as usize),
                        Lit::Str(s) => match &s.value()[..] {
                            "auto" => params.cache = CacheValue::Auto,
                            "no-limit" => params.cache = CacheValue::NoLimit,
                            _ => Err(Error::new(s.span(), "expected integer or `auto`"))?,
                        },
                        other => Err(Error::new(other.span(), "expected integer or `auto`"))?,
                    },
                    "start" => {
                        if let Lit::Int(i) = nvp.lit {
                            params.start = i.value() as usize;
                        } else {
                            Err(Error::new(nvp.lit.span(), "expected integer"))?
                        }
                    }
                    _ => Err(Error::new(nvp.ident.span(), "expected `cache` or `start`"))?,
                }
            } else {
                Err(Error::new(meta.span(), "expected parameter assignment"))?
            }
        }

        Ok(params)
    }
}

struct FunctionCallUpdater<'a> {
    func_ident: &'a Ident,
    func_decl: &'a FnDecl,
    determined_cache_size: Option<usize>,
}

impl Fold for FunctionCallUpdater<'_> {
    fn fold_expr(&mut self, outer_node: Expr) -> Expr {
        if let Expr::Call(node) = &outer_node {
            if let Expr::Path(path) = &*node.func {
                if path.path.segments.last().unwrap().value().ident == *self.func_ident {
                    let arg = node.args.last().unwrap();
                    let arg = arg.value();
                    if let Expr::Binary(binary) = arg {
                        if let Expr::Path(path) = &*binary.left {
                            let fn_param = self.func_decl.inputs.first().unwrap();
                            let fn_param = fn_param.value();
                            if let FnArg::Captured(ArgCaptured {
                                pat: Pat::Ident(ident),
                                ..
                            }) = fn_param
                            {
                                if path.path.segments.first().unwrap().value().ident == ident.ident {
                                    if let Expr::Lit(lit) = &*binary.right {
                                        if let Lit::Int(int) = &lit.lit {
                                            if let Some(current) = self.determined_cache_size {
                                                self.determined_cache_size = Some(current.max(int.value() as usize));
                                            }
                                        }
                                    }
                                }
                            } else {
                                panic!()
                            }
                        }
                    } else {
                        self.determined_cache_size = None;
                    }
                    return parse_quote!(
                        self.get(#arg)
                    );
                }
            }
            syn::fold::fold_expr(self, outer_node)
        } else {
            syn::fold::fold_expr(self, outer_node)
        }
    }
}

/// Converts a recursive function into an Iterator, which uses a cache
/// ### Parameters
/// #### cache
/// ##### value: "auto" | "no-limit" | integer
/// Sets the size of the cache the Iterator is using.
///
/// "no-limit" additionally implements a public method for the Iterator
/// `pub fn calculated_values(&self) -> &Vec<#ty>`
/// which returns a reference to the cache
/// #### start
/// ##### value: integer
/// Sets the counter of the Iterator at the start (most of the time 0 or 1)
#[proc_macro_attribute]
pub fn reciter(attr: proc_macro::TokenStream, input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let params = parse_macro_input!(attr as Parameters);

    let input: proc_macro2::TokenStream = input.into();

    let item: Item = parse2(input.clone()).unwrap();

    let func = if let Item::Fn(func) = item {
        func
    } else {
        return Error::new(item.span(), "expected function").to_compile_error().into();
    };

    let func_ident = &func.ident;
    let func_inputs = &func.decl.inputs;

    let iterator_ident = Ident::new(&(func_ident.to_string().to_pascal_case() + "Iterator"), Span::call_site());
    let ty = if let ReturnType::Type(_, ty) = &func.decl.output {
        ty
    } else {
        return Error::new(func.span(), "function needs to have a return type").to_compile_error().into();
    }
    .into_token_stream();

    if func.decl.inputs.len() == 1 {
        if let FnArg::Captured(arg) = func.decl.inputs.first().unwrap().value() {
            if let Type::Path(typath) = &arg.ty {
                if typath.path.segments.first().unwrap().value().ident != "usize" {
                    return Error::new(typath.span(), format!("expected usize, found {}", typath.into_token_stream()))
                        .to_compile_error()
                        .into();
                }
            }
        } else {
            unreachable!()
        }
    } else {
        return Error::new(func.span(), "expected exactly 1 parameter").to_compile_error().into();
    }

    let mut fcu = FunctionCallUpdater {
        func_ident: &func.ident,
        func_decl: &func.decl,
        determined_cache_size: Some(0),
    };

    let func_body = fcu.fold_block(*func.block);

    let cache_size = match params.cache {
        CacheValue::Defined(int) => Some(int),
        CacheValue::Auto => {
            if let Some(value) = fcu.determined_cache_size {
                Some(value)
            } else {
                panic!("Couldn't figure out cache size automatically");
            }
        }
        CacheValue::NoLimit => None,
    };
    let start = params.start;

    let tokens = if let Some(cache_size) = cache_size {
        if cache_size > 32 {
            panic!("cache > 32 doesn't work at the momeent because Default is only implemented for arrays up to the size of 32")
        }
        quote!(
            struct #iterator_ident {
                counter : usize,
                cache : [#ty;#cache_size],
                cache_cursor : usize
            }

            impl #iterator_ident {
                fn new() -> Self {
                   #iterator_ident {
                        counter : #start,
                        cache : <[#ty;#cache_size]>::default(), //TODO: check if impl Default
                        cache_cursor : 0
                    }
                }

                fn recursive (&mut self, #func_inputs ) -> #ty {
                    #func_body
                }

                fn get(&mut self, n : usize) -> #ty {
                    let back = self.counter - n;

                    if back == 0 {
                        self.recursive(n)
                    } else {
                        if back > #cache_size {
                            //panic!("cache too small")
                            self.recursive(n)
                        } else {
                            let mut ind = if self.cache_cursor >= back {
                                self.cache_cursor - back
                            } else {
                                self.cache_cursor + #cache_size - back
                            };
                            self.cache[ind].clone() //TODO: check if impl Clone
                        }
                    }
                }

                fn cache_write(&mut self, item : #ty) {
                    self.cache[self.cache_cursor] = item;
                    self.cache_cursor += 1;
                    if self.cache_cursor == #cache_size {
                        self.cache_cursor = 0;
                    }
                }
            }

            impl Iterator for #iterator_ident {
                type Item = #ty;
                fn next(&mut self) -> std::option::Option<Self::Item> {
                    let ret = self.get(self.counter);
                    if #cache_size > 0 {self.cache_write(ret.clone());}
                    self.counter+=1;
                    Some(ret)
                }
            }
        )
        .into()
    } else {
        quote!(
            struct #iterator_ident {
                counter : usize,
                cache : Vec<#ty>,
            }

            impl #iterator_ident {
                fn new() -> Self {
                   #iterator_ident {
                        counter : #start,
                        cache : Vec::new(),
                    }
                }

                fn recursive (&mut self, #func_inputs ) -> #ty {
                    #func_body
                }

                fn get(&mut self, n : usize) -> #ty {
                    if self.counter == n {
                        self.recursive(n)
                    } else {
                        self.cache[n].clone() //TODO: check if impl Clone
                    }
                }

                pub fn calculated_values(&self) -> &Vec<#ty> {
                    &self.cache
                }
            }

            impl Iterator for #iterator_ident {
                type Item = #ty;
                fn next(&mut self) -> std::option::Option<Self::Item> {
                    let ret = self.get(self.counter);
                    self.cache.push(ret.clone());
                    self.counter+=1;
                    Some(ret)
                }
            }
        )
        .into()
    };
    tokens
}