junobuild-macros 0.4.1

Procedural macros for hooking into various events in Juno.
Documentation
use crate::error::MacroError;
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
use serde::Deserialize;
use serde_tokenstream::from_tokenstream;
use std::string::ToString;
use syn::{parse, ItemFn, ReturnType, Signature, Type};

#[derive(Default, Deserialize)]
struct HookAttributes {
    collections: Option<Vec<String>>,
}

#[allow(clippy::enum_variant_names)]
#[derive(Clone)]
pub enum Hook {
    OnSetDoc,
    OnSetManyDocs,
    OnDeleteDoc,
    OnDeleteManyDocs,
    OnDeleteFilteredDocs,
    OnUploadAsset,
    OnDeleteAsset,
    OnDeleteManyAssets,
    OnDeleteFilteredAssets,
    OnInit,
    OnInitSync,
    OnPostUpgrade,
    OnPostUpgradeSync,
    OnInitRandomSeed,
    AssertSetDoc,
    AssertDeleteDoc,
    AssertUploadAsset,
    AssertDeleteAsset,
}

const CONTEXT_PARAM: &str = "context";

fn map_hook_name(hook: Hook) -> String {
    match hook {
        Hook::OnSetDoc => "juno_on_set_doc".to_string(),
        Hook::OnSetManyDocs => "juno_on_set_many_docs".to_string(),
        Hook::OnDeleteDoc => "juno_on_delete_doc".to_string(),
        Hook::OnDeleteManyDocs => "juno_on_delete_many_docs".to_string(),
        Hook::OnDeleteFilteredDocs => "juno_on_delete_filtered_docs".to_string(),
        Hook::OnUploadAsset => "juno_on_upload_asset".to_string(),
        Hook::OnDeleteAsset => "juno_on_delete_asset".to_string(),
        Hook::OnDeleteManyAssets => "juno_on_delete_many_assets".to_string(),
        Hook::OnDeleteFilteredAssets => "juno_on_delete_filtered_assets".to_string(),
        Hook::OnInit => "juno_on_init".to_string(),
        Hook::OnPostUpgrade => "juno_on_post_upgrade".to_string(),
        Hook::OnInitSync => "juno_on_init_sync".to_string(),
        Hook::OnPostUpgradeSync => "juno_on_post_upgrade_sync".to_string(),
        Hook::OnInitRandomSeed => "juno_on_init_random_seed".to_string(),
        Hook::AssertSetDoc => "juno_assert_set_doc".to_string(),
        Hook::AssertDeleteDoc => "juno_assert_delete_doc".to_string(),
        Hook::AssertUploadAsset => "juno_assert_upload_asset".to_string(),
        Hook::AssertDeleteAsset => "juno_assert_delete_asset".to_string(),
    }
}

fn map_hook_collections(hook: Hook) -> Option<String> {
    match hook {
        Hook::OnSetDoc => Some("juno_on_set_doc_collections".to_string()),
        Hook::OnSetManyDocs => Some("juno_on_set_many_docs_collections".to_string()),
        Hook::OnDeleteDoc => Some("juno_on_delete_doc_collections".to_string()),
        Hook::OnDeleteManyDocs => Some("juno_on_delete_many_docs_collections".to_string()),
        Hook::OnDeleteFilteredDocs => Some("juno_on_delete_filtered_docs_collections".to_string()),
        Hook::OnUploadAsset => Some("juno_on_upload_asset_collections".to_string()),
        Hook::OnDeleteAsset => Some("juno_on_delete_asset_collections".to_string()),
        Hook::OnDeleteManyAssets => Some("juno_on_delete_many_assets_collections".to_string()),
        Hook::OnDeleteFilteredAssets => {
            Some("juno_on_delete_filtered_assets_collections".to_string())
        }
        Hook::AssertSetDoc => Some("juno_assert_set_doc_collections".to_string()),
        Hook::AssertDeleteDoc => Some("juno_assert_delete_doc_collections".to_string()),
        Hook::AssertUploadAsset => Some("juno_assert_upload_asset_collections".to_string()),
        Hook::AssertDeleteAsset => Some("juno_assert_delete_asset_collections".to_string()),
        _ => None,
    }
}

fn map_hook_type(hook: &Hook) -> Option<String> {
    match hook {
        Hook::OnSetDoc => Some("OnSetDocContext".to_string()),
        Hook::OnSetManyDocs => Some("OnSetManyDocsContext".to_string()),
        Hook::OnDeleteDoc => Some("OnDeleteDocContext".to_string()),
        Hook::OnDeleteManyDocs => Some("OnDeleteManyDocsContext".to_string()),
        Hook::OnDeleteFilteredDocs => Some("OnDeleteFilteredDocsContext".to_string()),
        Hook::OnUploadAsset => Some("OnUploadAssetContext".to_string()),
        Hook::OnDeleteAsset => Some("OnDeleteAssetContext".to_string()),
        Hook::OnDeleteManyAssets => Some("OnDeleteManyAssetsContext".to_string()),
        Hook::OnDeleteFilteredAssets => Some("OnDeleteFilteredAssetsContext".to_string()),
        Hook::AssertSetDoc => Some("AssertSetDocContext".to_string()),
        Hook::AssertDeleteDoc => Some("AssertDeleteDocContext".to_string()),
        Hook::AssertUploadAsset => Some("AssertUploadAssetContext".to_string()),
        Hook::AssertDeleteAsset => Some("AssertDeleteAssetContext".to_string()),
        _ => None,
    }
}

pub fn hook_macro(hook: Hook, attr: TokenStream, item: TokenStream) -> TokenStream {
    parse_hook(&hook, attr, item).map_or_else(|e| e.to_error(), Into::into)
}

fn parse_hook(hook: &Hook, attr: TokenStream, item: TokenStream) -> Result<TokenStream, String> {
    let ast = parse::<ItemFn>(item).map_err(|_| "Expected a function to register the hooks")?;

    let signature = &ast.sig;

    let hook_fn = Ident::new(&map_hook_name(hook.clone()), proc_macro2::Span::call_site());

    match hook {
        Hook::OnPostUpgrade | Hook::OnInit | Hook::OnInitRandomSeed => {
            parse_lifecycle_hook(&ast, signature, &hook_fn)
        }
        Hook::OnPostUpgradeSync | Hook::OnInitSync => {
            parse_lifecycle_sync_hook(&ast, signature, &hook_fn)
        }
        _ => parse_doc_hook(&ast, signature, &hook_fn, hook, attr),
    }
}

fn parse_doc_hook(
    ast: &ItemFn,
    signature: &Signature,
    hook_fn: &Ident,
    hook: &Hook,
    attr: TokenStream,
) -> Result<TokenStream, String> {
    let hook_collections_fn = match &map_hook_collections(hook.clone()) {
        Some(hook_collections) => Ident::new(hook_collections, proc_macro2::Span::call_site()),
        None => return Err("Hook collection function cannot be None.".to_string()),
    };

    let hook_param = Ident::new(CONTEXT_PARAM, proc_macro2::Span::call_site());

    let hook_param_type = match &map_hook_type(hook) {
        Some(hook_type) => Ident::new(hook_type, proc_macro2::Span::call_site()),
        None => return Err("Hook type cannot be None.".to_string()),
    };

    let converted_attr: proc_macro2::TokenStream = attr.into();
    let attrs = from_tokenstream::<HookAttributes>(&converted_attr)
        .map_err(|_| "Expected valid attributes to register the hooks")?;

    let collections_tokens = if let Some(collections) = attrs.collections {
        let tokens = collections.iter().map(|col| quote! { #col.to_string() });
        quote! { Some(vec![#(#tokens,)*]) }
    } else {
        quote! { None }
    };

    let hook_body = match hook {
        Hook::AssertSetDoc
        | Hook::AssertDeleteDoc
        | Hook::AssertUploadAsset
        | Hook::AssertDeleteAsset => {
            parse_assert_hook(signature, hook_fn, &hook_param, &hook_param_type)
        }
        _ => parse_on_hook(signature, hook_fn, &hook_param, &hook_param_type),
    };

    let result = quote! {
        #ast

        #[no_mangle]
        pub extern "Rust" fn #hook_collections_fn() -> Option<Vec<String>> {
            #collections_tokens
        }

        #hook_body
    };

    Ok(result.into())
}

fn parse_on_hook(
    signature: &Signature,
    hook_fn: &Ident,
    hook_param: &Ident,
    hook_param_type: &Ident,
) -> proc_macro2::TokenStream {
    let func_name = &signature.ident;
    let is_async = signature.asyncness.is_some();

    let hook_return = parse_hook_return(signature);

    let function_call = if is_async {
        quote! { #func_name(#hook_param).await }
    } else {
        quote! { #func_name(#hook_param) }
    };

    quote! {
        #[no_mangle]
        pub extern "Rust" fn #hook_fn(#hook_param: #hook_param_type) {
            ic_cdk::futures::spawn(async {
                let result = #function_call;
                #hook_return
            });
        }
    }
}

fn parse_hook_return(signature: &Signature) -> proc_macro2::TokenStream {
    let return_length = match &signature.output {
        ReturnType::Default => 0,
        ReturnType::Type(_, ty) => match ty.as_ref() {
            Type::Tuple(tuple) => tuple.elems.len(),
            _ => 1,
        },
    };

    if return_length == 1 {
        quote! {
            match result {
                Ok(_) => {}
                Err(e) => {
                    let error = format!("{}", e);
                    ic_cdk::trap(&error);
                }
            }
        }
    } else {
        quote! {}
    }
}

fn parse_assert_hook(
    signature: &Signature,
    hook_fn: &Ident,
    hook_param: &Ident,
    hook_param_type: &Ident,
) -> proc_macro2::TokenStream {
    let func_name = &signature.ident;

    let function_call = quote! { #func_name(#hook_param) };

    quote! {
        #[no_mangle]
        pub extern "Rust" fn #hook_fn(#hook_param: #hook_param_type) -> Result<(), String> {
            #function_call
        }
    }
}

fn parse_lifecycle_hook(
    ast: &ItemFn,
    signature: &Signature,
    hook_fn: &Ident,
) -> Result<TokenStream, String> {
    let hook_body = parse_lifecycle_hook_body(signature, hook_fn);

    let result = quote! {
        #ast

        #hook_body
    };

    Ok(result.into())
}

fn parse_lifecycle_hook_body(signature: &Signature, hook_fn: &Ident) -> proc_macro2::TokenStream {
    let hook_return = parse_hook_return(signature);

    let func_name = &signature.ident;

    let function_call = quote! { #func_name() };

    quote! {
        #[no_mangle]
        pub extern "Rust" fn #hook_fn() {
            let result = #function_call;
            #hook_return
        }
    }
}

fn parse_lifecycle_sync_hook(
    ast: &ItemFn,
    signature: &Signature,
    hook_fn: &Ident,
) -> Result<TokenStream, String> {
    let hook_body = parse_lifecycle_hook_sync_body(signature, hook_fn);

    let result = quote! {
        #ast

        #hook_body
    };

    Ok(result.into())
}

fn parse_lifecycle_hook_sync_body(
    signature: &Signature,
    hook_fn: &Ident,
) -> proc_macro2::TokenStream {
    let func_name = &signature.ident;

    let function_call = quote! { #func_name() };

    quote! {
        #[no_mangle]
        pub extern "Rust" fn #hook_fn() {
            #function_call;
        }
    }
}