tailcall-impl 2.2.0

The procedural macro implementation for the tailcall crate
Documentation
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Error, ImplItemMethod, ItemFn, Signature};

use crate::{
    analyze::{is_simple_self_tail_recursive, is_simple_self_tail_recursive_method},
    loop_lower::{lower_self_tail_loop, lower_self_tail_method_loop},
    rewrite::TailPositionRewriter,
    signature::{function_argument_exprs, helper_signature, method_helper_signature},
};

pub fn apply_fn_tailcall_transform(item_fn: ItemFn) -> TokenStream {
    match TailcallTransform::new(item_fn).expand() {
        Ok(output) => output,
        Err(error) => error.to_compile_error(),
    }
}

pub fn apply_method_tailcall_transform(method: ImplItemMethod) -> TokenStream {
    match TailcallMethodTransform::new(method).expand() {
        Ok(output) => output,
        Err(error) => error.to_compile_error(),
    }
}

struct TailcallTransform {
    item_fn: ItemFn,
}

struct TailcallMethodTransform {
    method: ImplItemMethod,
}

impl TailcallMethodTransform {
    fn new(method: ImplItemMethod) -> Self {
        Self { method }
    }

    fn expand(self) -> Result<TokenStream, Error> {
        let ImplItemMethod {
            attrs,
            vis,
            defaultness,
            sig,
            block,
        } = self.method;

        reject_unsupported_signature(&sig)?;

        let helper_sig = method_helper_signature(&sig)?;
        let helper_fn_ident = &helper_sig.ident;
        let helper_args = function_argument_exprs(&sig)?;
        let original_method = ImplItemMethod {
            attrs: attrs.clone(),
            vis: vis.clone(),
            defaultness,
            sig: sig.clone(),
            block: block.clone(),
        };
        let optimized = is_simple_self_tail_recursive_method(&original_method);
        let wrapper_body = if optimized {
            lower_self_tail_method_loop(&original_method)?
        } else {
            quote! { Self::#helper_fn_ident(#(#helper_args),*).call() }
        };
        let helper_body = if optimized {
            let method_ident = &sig.ident;
            quote! {
                tailcall::runtime::Thunk::value(Self::#method_ident(#(#helper_args),*))
            }
        } else {
            let helper_block = TailPositionRewriter::rewrite(block)?;
            quote! {
                tailcall::runtime::Thunk::bounce(move || #helper_block)
            }
        };

        Ok(quote! {
            #(#attrs)*
            #defaultness #vis #sig {
                #wrapper_body
            }

            #[doc(hidden)]
            #[allow(unused)]
            #[inline(always)]
            #helper_sig {
                #helper_body
            }
        })
    }
}

impl TailcallTransform {
    fn new(item_fn: ItemFn) -> Self {
        Self { item_fn }
    }

    fn expand(self) -> Result<TokenStream, Error> {
        let ItemFn {
            attrs,
            vis,
            sig,
            block,
        } = self.item_fn;

        reject_unsupported_signature(&sig)?;

        let helper_sig = helper_signature(&sig);
        let helper_fn_ident = &helper_sig.ident;
        let helper_args = function_argument_exprs(&sig)?;
        let original_item_fn = ItemFn {
            attrs: attrs.clone(),
            vis: vis.clone(),
            sig: sig.clone(),
            block: block.clone(),
        };
        let optimized = is_simple_self_tail_recursive(&original_item_fn);
        let wrapper_body = if optimized {
            lower_self_tail_loop(&original_item_fn)?
        } else {
            quote! { #helper_fn_ident(#(#helper_args),*).call() }
        };
        let helper_body = if optimized {
            let fn_ident = &sig.ident;
            quote! {
                tailcall::runtime::Thunk::value(#fn_ident(#(#helper_args),*))
            }
        } else {
            let helper_block = TailPositionRewriter::rewrite(*block)?;
            quote! {
                tailcall::runtime::Thunk::bounce(move || #helper_block)
            }
        };

        Ok(quote! {
            #(#attrs)*
            #vis #sig {
                #wrapper_body
            }

            #[doc(hidden)]
            #[allow(unused)]
            #[inline(always)]
            #helper_sig {
                #helper_body
            }
        })
    }
}

fn reject_unsupported_signature(sig: &Signature) -> Result<(), Error> {
    if sig.constness.is_some() {
        return Err(Error::new_spanned(
            sig.constness,
            "#[tailcall] does not support const functions",
        ));
    }

    if sig.asyncness.is_some() {
        return Err(Error::new_spanned(
            sig.asyncness,
            "#[tailcall] does not support async functions",
        ));
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use proc_macro2::TokenStream;
    use quote::quote;
    use syn::parse_quote;

    use super::{apply_fn_tailcall_transform, apply_method_tailcall_transform};

    fn assert_expansion_eq(actual: TokenStream, expected: TokenStream) {
        assert_eq!(actual.to_string(), expected.to_string());
    }

    #[test]
    fn expands_runtime_backed_free_function_as_expected() {
        let item_fn: syn::ItemFn = parse_quote! {
            fn is_even(x: u32) -> bool {
                if x == 0 {
                    true
                } else {
                    tailcall::call! { is_odd(x - 1) }
                }
            }
        };

        let actual = apply_fn_tailcall_transform(item_fn);
        let expected = quote! {
            fn is_even(x: u32) -> bool {
                __tailcall_build_is_even_thunk(x).call()
            }

            #[doc(hidden)]
            #[allow(unused)]
            #[inline(always)]
            fn __tailcall_build_is_even_thunk<'tailcall>(x: u32) -> tailcall::runtime::Thunk<'tailcall, bool> {
                tailcall::runtime::Thunk::bounce(move || {
                    if x == 0 {
                        tailcall::runtime::Thunk::value(true)
                    } else {
                        tailcall::call! { is_odd(x - 1) }
                    }
                })
            }
        };

        assert_expansion_eq(actual, expected);
    }

    #[test]
    fn expands_loop_lowered_free_function_as_expected() {
        let item_fn: syn::ItemFn = parse_quote! {
            fn countdown(n: u32) -> u32 {
                if n > 0 {
                    tailcall::call! { countdown(n - 1) }
                } else {
                    0
                }
            }
        };

        let actual = apply_fn_tailcall_transform(item_fn);
        let expected = quote! {
            fn countdown(n: u32) -> u32 {
                let mut n = n;
                loop {
                    if n > 0 {
                        {
                            let __tailcall_next_0 = n - 1;
                            n = __tailcall_next_0;
                            continue;
                        }
                    } else {
                        return 0
                    }
                }
            }

            #[doc(hidden)]
            #[allow(unused)]
            #[inline(always)]
            fn __tailcall_build_countdown_thunk<'tailcall>(n: u32) -> tailcall::runtime::Thunk<'tailcall, u32> {
                tailcall::runtime::Thunk::value(countdown(n))
            }
        };

        assert_expansion_eq(actual, expected);
    }

    #[test]
    fn expands_loop_lowered_method_as_expected() {
        let method: syn::ImplItemMethod = parse_quote! {
            fn countdown(&mut self, n: u32) -> u32 {
                self.steps += 1;

                if n > 0 {
                    tailcall::call! { self.countdown(n - 1) }
                } else {
                    self.steps as u32
                }
            }
        };

        let actual = apply_method_tailcall_transform(method);
        let expected = quote! {
            fn countdown(&mut self, n: u32) -> u32 {
                let __tailcall_self = self;
                let mut n = n;
                loop {
                    __tailcall_self.steps += 1;
                    if n > 0 {
                        {
                            let __tailcall_next_0 = n - 1;
                            n = __tailcall_next_0;
                            continue;
                        }
                    } else {
                        return __tailcall_self.steps as u32
                    }
                }
            }

            #[doc(hidden)]
            #[allow(unused)]
            #[inline(always)]
            fn __tailcall_build_countdown_thunk<'tailcall>(&'tailcall mut self, n: u32) -> tailcall::runtime::Thunk<'tailcall, u32> {
                tailcall::runtime::Thunk::value(Self::countdown(self, n))
            }
        };

        assert_expansion_eq(actual, expected);
    }
}