cubecl-macros 0.10.0-pre.3

Procedural macros for CubeCL
Documentation
use ident_case::RenameRule;
use proc_macro2::TokenStream;
use quote::{ToTokens, format_ident, quote};
use syn::{Ident, parse_quote};

use crate::{
    parse::kernel::{AddressType, GenericArg, KernelParam, Launch},
    paths::{core_type, prelude_type},
};

impl ToTokens for Launch {
    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
        let vis = &self.vis;

        let name = &self.func.sig.name;
        let launch = self.launch();
        let launch_unchecked = self.launch_unchecked();
        let aliases = self.create_type_alias();
        let dummy = self.create_dummy_kernel();
        let kernel = self.kernel_definition();
        let mut func = self.func.clone();
        func.sig.name = format_ident!("expand");
        let func = func.to_tokens_mut();

        let out = quote! {
            #vis mod #name {
                use super::*;

                #aliases

                #[allow(unused, clippy::all)]
                #func

                #kernel
                #launch
                #launch_unchecked
                #dummy
            }
        };

        if self.args.debug.is_present() {
            let file = syn::parse_file(&out.to_string()).unwrap();
            let tokens = prettyplease::unparse(&file);
            panic!("{tokens}");
        }
        tokens.extend(out);
    }
}

impl Launch {
    fn launch(&self) -> TokenStream {
        if self.args.launch.is_present() {
            let compute_client = prelude_type("ComputeClient");
            let cube_count = prelude_type("CubeCount");
            let cube_dim = prelude_type("CubeDim");
            let address_type = prelude_type("AddressType");

            let kernel_doc = format!(
                "Launch the kernel [{}()] on the given runtime",
                self.func.sig.name
            );
            let generics = &self.launch_generics;
            let args = self.launch_args();
            let body = self.launch_body();

            let address_type = match self.args.address_type {
                AddressType::Dynamic => quote![__address_type: #address_type,],
                _ => quote![],
            };

            quote! {
                #[allow(clippy::too_many_arguments)]
                #[doc = #kernel_doc]
                pub fn launch #generics(
                    __client: &#compute_client<__R>,
                    __cube_count: #cube_count,
                    __cube_dim: #cube_dim,
                    #address_type
                    #(#args),*
                ) {
                    #body
                    launcher.launch(__cube_count, __kernel, __client)
                }
            }
        } else {
            TokenStream::new()
        }
    }

    fn launch_unchecked(&self) -> TokenStream {
        if self.args.launch_unchecked.is_present() {
            let compute_client = prelude_type("ComputeClient");
            let cube_count = prelude_type("CubeCount");
            let cube_dim = prelude_type("CubeDim");
            let address_type = prelude_type("AddressType");

            let kernel_doc = format!(
                "Launch the kernel [{}()] on the given runtime without bound checks.\n\n\
                 # Safety\n\n\
                 The kernel must not:\n\
                 - Contain any out of bounds reads or writes. Doing so is immediate UB.\n\
                 - Contain any loops that never terminate. These may be optimized away entirely or cause\n\
                   other unpredictable behaviour.",
                self.func.sig.name
            );
            let generics = &self.launch_generics;
            let args = self.launch_args();
            let body = self.launch_body();

            let address_type = match self.args.address_type {
                AddressType::Dynamic => quote![__address_type: #address_type,],
                _ => quote![],
            };

            quote! {
                #[allow(clippy::too_many_arguments)]
                #[doc = #kernel_doc]
                pub unsafe fn launch_unchecked #generics(
                    __client: &#compute_client<__R>,
                    __cube_count: #cube_count,
                    __cube_dim: #cube_dim,
                    #address_type
                    #(#args),*
                ) {
                    #body
                    launcher.launch_unchecked(__cube_count, __kernel, __client)
                }
            }
        } else {
            TokenStream::new()
        }
    }

    fn launch_body(&self) -> TokenStream {
        let kernel_launcher = prelude_type("KernelLauncher");

        let mappings = self.func.sig.define_mappings();
        let generic_registers =
            self.func
                .analysis
                .register_types(mappings, quote![scope], false, true);

        let settings = self.configure_settings();
        let kernel_name = self.kernel_name();
        let kernel_generics = self.kernel_generics.split_for_impl();
        let kernel_generics = kernel_generics.1.as_turbofish();
        let comptime_args = self.comptime_params().map(|it| &it.name);
        let (registers, args) = self.arg_registers();

        quote! {
            #settings

            let mut launcher = #kernel_launcher::<__R>::new(__settings.clone());
            launcher.with_scope(|scope| {
                scope.device_properties(__client.properties());
                #generic_registers
            });

            #registers
            let __kernel = #kernel_name #kernel_generics::new(__settings, __client.clone(), #args #(#comptime_args),*);
        }
    }

    fn configure_settings(&self) -> TokenStream {
        let kernel_settings = prelude_type("KernelSettings");
        let addr_ty = prelude_type("AddressType");
        let address_type = match self.args.address_type {
            AddressType::U32 => quote![#addr_ty::U32],
            AddressType::U64 => quote![#addr_ty::U64],
            AddressType::Dynamic => quote![__address_type],
        };

        quote! {
            let mut __settings = #kernel_settings::default()
                .cube_dim(__cube_dim).address_type(#address_type);
        }
    }

    fn create_type_alias(&self) -> TokenStream {
        let mut aliases = quote! {};
        if !self.func.args.explicit_define.is_present() {
            for (
                name,
                GenericArg {
                    expand_ty,
                    marker_ty,
                    ..
                },
            ) in self.func.analysis.map.iter()
            {
                aliases.extend(quote! {
                    pub struct #marker_ty;
                    /// Type to be used as a generic for launch kernel argument.
                    pub type #name = #expand_ty;
                });
            }
        }

        aliases
    }
    fn create_dummy_kernel(&self) -> TokenStream {
        if self.args.create_dummy_kernel.is_present() {
            let cube_count = prelude_type("CubeCount");
            let cube_dim = prelude_type("CubeDim");
            let address_type = prelude_type("AddressType");

            let kernel_doc = format!(
                "Launch the kernel [{}()] on the given runtime",
                self.func.sig.name
            );
            let generics = &self.launch_generics;
            let (_, generic_names, _) = self.kernel_generics.split_for_impl();

            let settings = self.configure_settings();
            let kernel_name = self.kernel_name();
            let comptime_args = self.launch_args();
            let comptime_names = self.comptime_params().map(|it| &it.name);
            let (compilation_args, args) = self.arg_registers();

            let address_type = match self.args.address_type {
                AddressType::Dynamic => quote![__address_type: #address_type,],
                _ => quote![],
            };

            quote! {
                #[allow(clippy::too_many_arguments)]
                #[doc = #kernel_doc]
                pub fn create_dummy_kernel #generics(
                    __cube_count: #cube_count,
                    __cube_dim: #cube_dim,
                    #address_type
                    #(#comptime_args),*
                ) -> #kernel_name #generic_names {
                    #settings
                    #compilation_args

                    #kernel_name::new(__settings, #args #(#comptime_names),*)
                }
            }
        } else {
            TokenStream::new()
        }
    }

    pub fn runtime_params(&self) -> impl Iterator<Item = &KernelParam> {
        self.func.sig.runtime_params()
    }

    fn launch_args(&self) -> Vec<KernelParam> {
        let mut args = self.func.sig.parameters.clone();
        let runtime_arg = core_type("RuntimeArg");
        for arg in args.iter_mut().filter(|it| !it.is_const) {
            let ty = arg.ty_owned();
            arg.normalized_ty = parse_quote![#runtime_arg<#ty, __R>];
            arg.mut_token = None;
        }
        args
    }

    pub fn kernel_name(&self) -> Ident {
        let kernel_name = RenameRule::PascalCase.apply_to_field(self.func.sig.name.to_string());
        format_ident!("{kernel_name}")
    }

    pub fn comptime_params(&self) -> impl Iterator<Item = &KernelParam> {
        self.func
            .sig
            .parameters
            .iter()
            .filter(|param| param.is_const)
    }
}