arcis-internal-expr-macro 0.9.0

Internal helper macro for expression handling in Arcis.
Documentation
use crate::fold::TypeReplacer;
use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote, ToTokens};
use std::collections::HashMap;
use syn::{
    fold::Fold,
    spanned::Spanned,
    Expr,
    GenericArgument,
    Lit,
    LitInt,
    PathArguments,
    Type,
    TypeArray,
    TypePath,
    TypeTuple,
};

/// A trait to automatically derive some functions used to build apply and proc on Expr.
/// For instance, let's say we have a variant:
/// `FieldExpr::SumOfLinearCombinationOfMatrixProducts(Vec<(Vec<Vec<T>>, Vec<Vec<T>>, F)>)`
/// It produces:
/// ``
///     fn vec_vec_vec_t_vec_vec_t_f(
///         &mut self,
///         val: Vec<(Vec<Vec<T>>, Vec<Vec<T>>, F)>,
///     ) -> Vec<(Vec<Vec<NewT>>, Vec<Vec<NewT>>, F)> {
///         val.into_iter()
///             .map(|x| {
///                 (
///                     x
///                         .0
///                         .into_iter()
///                         .map(|x| { x.into_iter().map(|x| { self.t(x) }).collect() })
///                         .collect(),
///                     x
///                         .1
///                         .into_iter()
///                         .map(|x| { x.into_iter().map(|x| { self.t(x) }).collect() })
///                         .collect(),
///                     x.2,
///                 )
///             })
///             .collect()
///     }
/// ``
/// by automatic iteration.
pub trait ExprDefaultImpl {
    /// Creates a default implementation that works on type `Self`.
    fn expr_default_impl(
        &self,
        generic_finder: &HashMap<Ident, Ident>,
        fn_name_by_type: &HashMap<Type, Ident>,
        has_no_generics: bool,
        top_level: bool,
        is_apply: bool,
        val_name: TokenStream,
    ) -> Option<TokenStream>;
}

impl ExprDefaultImpl for TypePath {
    fn expr_default_impl(
        &self,
        generic_finder: &HashMap<Ident, Ident>,
        fn_name_by_type: &HashMap<Type, Ident>,
        has_no_generics: bool,
        top_level: bool,
        is_apply: bool,
        val_name: TokenStream,
    ) -> Option<TokenStream> {
        let has_no_generics = has_no_generics || {
            let mut replacer = TypeReplacer::new(generic_finder);
            let replaced = replacer.fold_type_path(self.clone());
            *self == replaced
        };
        if has_no_generics {
            if is_apply {
                Some(quote! {
                    #val_name
                })
            } else {
                Some(quote! {})
            }
        } else if !top_level && generic_finder.contains_key(&self.path.segments[0].ident) {
            let fn_name = fn_name_by_type.get(&Type::Path(self.clone()))?;

            Some(quote! {
                self.#fn_name(#val_name)
            })
        } else if self.path.segments[0].ident == "Vec" {
            let loop_var = format_ident!("x");
            let PathArguments::AngleBracketed(b) = &self.path.segments[0].arguments else {
                None?
            };
            let GenericArgument::Type(new_ty) = &b.args[0] else {
                None?
            };
            let inner = new_ty.expr_default_impl(
                generic_finder,
                fn_name_by_type,
                has_no_generics,
                false,
                is_apply,
                loop_var.to_token_stream(),
            )?;
            if is_apply {
                Some(quote! {
                    #val_name.into_iter().map(|#loop_var| {
                        #inner
                    }).collect()
                })
            } else {
                Some(quote! {
                    #val_name.iter().for_each(|#loop_var| {
                        #inner;
                    })
                })
            }
        } else {
            None
        }
    }
}

impl ExprDefaultImpl for Type {
    fn expr_default_impl(
        &self,
        generic_finder: &HashMap<Ident, Ident>,
        fn_name_by_type: &HashMap<Type, Ident>,
        has_no_generics: bool,
        top_level: bool,
        is_apply: bool,
        val_name: TokenStream,
    ) -> Option<TokenStream> {
        match self {
            Type::Tuple(y) => y.expr_default_impl(
                generic_finder,
                fn_name_by_type,
                has_no_generics,
                false,
                is_apply,
                val_name,
            ),
            Type::Path(new_ty) => new_ty.expr_default_impl(
                generic_finder,
                fn_name_by_type,
                has_no_generics,
                top_level,
                is_apply,
                val_name,
            ),
            Type::Array(new_ty) => new_ty.expr_default_impl(
                generic_finder,
                fn_name_by_type,
                has_no_generics,
                false,
                is_apply,
                val_name,
            ),
            y => {
                println!("unknown type: {y:?}");
                None?
            }
        }
    }
}

impl ExprDefaultImpl for TypeTuple {
    fn expr_default_impl(
        &self,
        generic_finder: &HashMap<Ident, Ident>,
        fn_name_by_type: &HashMap<Type, Ident>,
        has_no_generics: bool,
        _top_level: bool,
        is_apply: bool,
        val_name: TokenStream,
    ) -> Option<TokenStream> {
        let inner = self
            .elems
            .iter()
            .enumerate()
            .map(|(i, x)| {
                let local_idx = LitInt::new(&format!("{i}"), val_name.span());
                let local_var = if is_apply {
                    quote! { #val_name.#local_idx }
                } else {
                    quote! { &#val_name.#local_idx }
                };
                x.expr_default_impl(
                    generic_finder,
                    fn_name_by_type,
                    has_no_generics,
                    false,
                    is_apply,
                    local_var,
                )
                .map(|x1| {
                    if i == 0 {
                        x1
                    } else {
                        quote! {, #x1}
                    }
                })
            })
            .collect::<Option<TokenStream>>()?;
        Some(quote! {
         (#inner)
        })
    }
}

impl ExprDefaultImpl for TypeArray {
    fn expr_default_impl(
        &self,
        generic_finder: &HashMap<Ident, Ident>,
        fn_name_by_type: &HashMap<Type, Ident>,
        has_no_generics: bool,
        _top_level: bool,
        is_apply: bool,
        val_name: TokenStream,
    ) -> Option<TokenStream> {
        let Expr::Lit(len) = &self.len else { None? };
        let Lit::Int(len) = &len.lit else { None? };
        let len = len.base10_parse::<usize>().ok()?;
        if is_apply {
            let inner = (0..len)
                .map(|i| {
                    let local_val = quote! {#val_name[#i]};
                    let inner = self.elem.expr_default_impl(
                        generic_finder,
                        fn_name_by_type,
                        has_no_generics,
                        false,
                        is_apply,
                        local_val,
                    );
                    inner.map(|x| {
                        if i == 0 {
                            quote! {#x}
                        } else {
                            quote! {, #x}
                        }
                    })
                })
                .collect::<Option<TokenStream>>()?;
            Some(quote! {
                [#inner]
            })
        } else {
            let local_val = quote! {x};
            let inner = self.elem.expr_default_impl(
                generic_finder,
                fn_name_by_type,
                has_no_generics,
                false,
                is_apply,
                local_val.clone(),
            );
            Some(quote! {
                #val_name.iter().for_each(|#local_val| {#inner});
            })
        }
    }
}