use std::collections::HashMap;
use darling::usage::{CollectLifetimes as _, CollectTypeParams as _, GenericsExt as _, Purpose};
use inflections::case::to_snake_case;
use proc_macro2::TokenStream;
use quote::{ToTokens, format_ident, quote, quote_spanned};
use syn::{Ident, TypeParamBound};
use crate::{
parse::kernel::{
DefinedGeneric, KernelBody, KernelFn, KernelParam, KernelReturns, KernelSignature, Launch,
strip_ref,
},
paths::{frontend_type, prelude_type},
};
impl KernelFn {
pub fn to_tokens_mut(&mut self) -> TokenStream {
let vis = &self.vis;
let sig = &self.sig;
let body = match &self.body {
KernelBody::Block(block) => &block.to_tokens(&mut self.context),
KernelBody::Verbatim(tokens) => tokens,
};
let name = &self.full_name;
let cfg_debug = cfg!(debug_symbols) && !self.args.no_debug_symbols.is_present();
let (debug_source, debug_params) = if cfg_debug || self.args.debug_symbols.is_present() {
let debug_source = frontend_type("debug_source_expand");
let cube_debug = frontend_type("CubeDebug");
let src_file = self.args.src_file.as_ref().map(|file| file.value());
let src_file = src_file.or_else(|| {
let span: proc_macro::Span = self.span.unwrap();
let source_path = span.local_file();
let source_file = source_path.as_ref().and_then(|path| path.file_name());
source_file.map(|file| file.to_string_lossy().into())
});
let source_text = match src_file {
Some(file) => quote![include_str!(#file)],
None => quote![""],
};
let debug_source = quote_spanned! {self.span=>
#debug_source(scope, #name, file!(), #source_text, line!(), column!())
};
let debug_params = sig
.runtime_params()
.map(|it| &it.name)
.map(|name| {
let name_str = name.to_string();
quote! [#cube_debug::set_debug_name(&#name, scope, #name_str);]
})
.collect();
(debug_source, debug_params)
} else {
(TokenStream::new(), Vec::new())
};
let body = self
.args
.fast_math
.as_ref()
.map(|value| {
let fast_math = frontend_type("fast_math_expand");
quote![#fast_math(scope, #value, |scope| {#body})]
})
.unwrap_or_else(|| quote![#body]);
let imports = trait_imports();
let mappings = self.sig.define_mappings();
let registers = self
.analysis
.register_types(mappings, quote![scope], false, false);
let out = quote! {
#[allow(unused_mut)]
#vis #sig {
#debug_source;
#(#debug_params)*
#imports;
#registers
#body
}
};
out
}
}
fn trait_imports() -> TokenStream {
let into_runtime = prelude_type("IntoRuntime");
let assign = prelude_type("Assign");
quote! {
use #into_runtime as _;
use #assign as _;
}
}
impl ToTokens for KernelSignature {
fn to_tokens(&self, tokens: &mut TokenStream) {
let scope = prelude_type("Scope");
let cube_type = prelude_type("CubeType");
let name = &self.name;
let generics = &self.generics;
let where_clause = &generics.where_clause;
let return_type = match &self.returns {
KernelReturns::ExpandType(ty) => {
let mut is_mut = false;
let mut is_ref = false;
let ty = strip_ref(ty.clone(), &mut is_ref, &mut is_mut);
quote![<#ty as #cube_type>::ExpandType]
}
KernelReturns::Plain(ty) => quote![#ty],
};
let out = if let Some(receiver) = &self.receiver_arg {
let args = self.parameters.iter().skip(1);
quote! {
fn #name #generics(
#receiver,
scope: &mut #scope,
#(#args),*
) -> #return_type #where_clause
}
} else {
let args = &self.parameters;
quote! {
fn #name #generics(
scope: &mut #scope,
#(#args),*
) -> #return_type #where_clause
}
};
tokens.extend(out);
}
}
impl ToTokens for KernelParam {
fn to_tokens(&self, tokens: &mut TokenStream) {
let name = &self.name;
let ty = &self.normalized_ty;
let mut_ = &self.mut_token;
tokens.extend(quote![#mut_ #name: #ty]);
}
}
impl Launch {
fn kernel_phantom_data(&self) -> Option<TokenStream> {
let generics = self.kernel_generics.clone();
let declared_lifetimes = generics.declared_lifetimes();
let declared_type_params = generics.declared_type_params();
let used_lifetimes = self
.comptime_params()
.map(|param| ¶m.ty)
.collect_lifetimes_cloned(&Purpose::Declare.into(), &declared_lifetimes);
let used_type_params = self
.comptime_params()
.map(|param| ¶m.ty)
.collect_type_params_cloned(&Purpose::Declare.into(), &declared_type_params);
let lifetimes: Vec<_> = declared_lifetimes.difference(&used_lifetimes).collect();
let type_params: Vec<_> = declared_type_params.difference(&used_type_params).collect();
(!lifetimes.is_empty() || !type_params.is_empty())
.then(|| quote![__ty: ::core::marker::PhantomData<(#(#lifetimes,)* #(#type_params),*)>])
}
pub fn compilation_args_def(&self) -> (Vec<TokenStream>, Vec<Ident>) {
let mut tokens = Vec::new();
let mut args = Vec::new();
let launch_arg = prelude_type("LaunchArg");
self.runtime_params().for_each(|input| {
let ty = &input.ty_owned();
let name = &input.name;
tokens.push(quote! {
#name: <#ty as #launch_arg>::CompilationArg
});
args.push(name.clone());
});
(tokens, args)
}
pub fn arg_registers(&self) -> (TokenStream, TokenStream) {
let launch_arg = prelude_type("LaunchArg");
let mut defined = quote! {};
let mut args = quote! {};
self.runtime_params().enumerate().for_each(|(i, input)| {
let ty = &input.ty_owned();
let ident = &input.name;
let var = Ident::new(format!("comp_arg_{i}").as_str(), ident.span());
args.extend(quote! {#var,});
defined.extend(quote! {
let #var = <#ty as #launch_arg>::register(#ident, &mut launcher);
});
});
(
quote! {
#defined
},
args,
)
}
pub fn io_mappings(&self) -> TokenStream {
let launch_arg = prelude_type("LaunchArg");
let mut define = quote! {};
let expand_fn = |ident, expand_name, ty| {
let ty = self.func.analysis.process_ty(&ty);
quote! {
let #ident = <#ty as #launch_arg>::#expand_name(&self.#ident.dynamic_cast(), &mut builder);
}
};
for param in self.runtime_params() {
let expand_name = match param.is_mut {
true => format_ident!("expand_output"),
false => format_ident!("expand"),
};
define.extend(expand_fn(¶m.name, expand_name, param.ty_owned()));
}
quote! {
#define
}
}
fn define_body(&self) -> TokenStream {
let kernel_builder = prelude_type("KernelBuilder");
let io_map = self.io_mappings();
let mut mapping = HashMap::new();
for param in self.func.sig.parameters.iter() {
for define in param.defines.iter() {
match define {
DefinedGeneric::Single(ident) => {
mapping.insert(ident.clone(), (param.name.clone(), None));
}
DefinedGeneric::Multiple(ident, index) => {
mapping.insert(ident.clone(), (param.name.clone(), Some(*index)));
}
}
}
}
let mapping = self.func.sig.define_mappings();
let register_type =
self.func
.analysis
.register_types(mapping, quote![builder.scope], true, true);
let args = self.func.sig.parameters.iter().map(|it| {
let name = &it.name;
match it.is_const {
true => quote![self.#name],
false => quote![#name],
}
});
let generics = self
.func
.analysis
.process_generic_names(&self.func.sig.generics);
quote! {
let mut builder = #kernel_builder::default();
builder.runtime_properties(__R::target_properties());
builder.device_properties(self.client.properties());
#register_type
self.settings.address_type.register(&mut builder.scope);
#io_map
expand #generics(&mut builder.scope, #(#args.clone(),)*);
builder.build(self.settings.clone())
}
}
fn kernel_entrypoint_name(&self) -> TokenStream {
let base_name = self.func.sig.name.to_string();
let type_name = prelude_type("type_name_short_sanitized");
let generics = &self.func.sig.generics;
let suffix_producing_bounds = [
format_ident!("Float"),
format_ident!("Numeric"),
format_ident!("Int"),
format_ident!("Scalar"),
format_ident!("Size"),
];
let mut matching_generics = vec![];
for ty in generics.type_params() {
for bound in &ty.bounds {
let TypeParamBound::Trait(t) = bound else {
continue;
};
let Some(generic_trailing) = t.path.segments.last() else {
continue;
};
if suffix_producing_bounds.contains(&generic_trailing.ident) {
matching_generics.push(ty.ident.clone());
continue;
}
}
}
if matching_generics.is_empty() {
quote! {
#base_name
}
} else {
let mut defines = self.func.sig.define_mappings();
let generic_names = matching_generics.iter().map(|ident| {
let name = match defines.remove(ident) {
Some((name, index)) => match index {
Some(index) => {
quote![#name[#index]]
}
None => quote![#name],
},
None => quote![#type_name::<#ident>();],
};
let ident_snake = to_snake_case(&ident.to_string());
quote! {{
let type_name = #name;
name.push_str(&cubecl::__private::format!("_{}_{type_name}", #ident_snake));
}}
});
quote! (
{
let mut name = cubecl::__private::format!("{}", #base_name);
#(#generic_names)*
name
}
)
}
}
pub fn kernel_definition(&self) -> TokenStream {
if self.args.is_launch() {
let kernel_metadata = prelude_type("KernelMetadata");
let cube_kernel = prelude_type("CubeKernel");
let kernel_settings = prelude_type("KernelSettings");
let compute_client = prelude_type("ComputeClient");
let kernel_definition: syn::Path = prelude_type("KernelDefinition");
let kernel_id = prelude_type("KernelId");
let storage_ty = prelude_type("StorageType");
let kernel_name = self.kernel_name();
let define = self.define_body();
let kernel_doc = format!("{} Kernel", self.func.sig.name);
let (generics, generic_names, where_clause) = self.kernel_generics.split_for_impl();
let const_params: Vec<_> = self.comptime_params().collect();
let param_names = self
.comptime_params()
.map(|param| param.name.clone())
.collect::<Vec<_>>();
let phantom_data = self.kernel_phantom_data();
let phantom_data_init = phantom_data
.as_ref()
.map(|_| quote![__ty: ::core::marker::PhantomData]);
let (compilation_args, args) = self.compilation_args_def();
let info_names = param_names.clone().into_iter().chain(args.clone());
let info_ty_name = format_ident!("{kernel_name}Info");
let info_ty = self.info_ty(&info_ty_name);
let info_generics = generic_names.as_turbofish();
let kernel_source_name = self.kernel_entrypoint_name();
let mut settings = quote![settings.kernel_name(#kernel_source_name)];
let cfg_debug = cfg!(debug_symbols) && !self.args.no_debug_symbols.is_present();
if cfg_debug || self.args.debug_symbols.is_present() {
settings.extend(quote![.debug_symbols()]);
}
if let Some(cluster_dim) = &self.args.cluster_dim {
settings.extend(quote![.cluster_dim(#cluster_dim)]);
}
quote! {
#[doc = #kernel_doc]
pub struct #kernel_name #generics #where_clause {
settings: #kernel_settings,
client: #compute_client<__R>,
#(#compilation_args,)*
#(#const_params,)*
#phantom_data
}
#info_ty
#[allow(clippy::too_many_arguments)]
impl #generics #kernel_name #generic_names #where_clause {
pub fn new(
settings: #kernel_settings,
client: #compute_client<__R>,
#(#compilation_args,)*
#(#const_params),*) -> Self {
Self {
settings: #settings,
client,
#(#args,)*
#(#param_names,)*
#phantom_data_init
}
}
}
impl #generics #kernel_metadata for #kernel_name #generic_names #where_clause {
fn id(&self) -> #kernel_id {
let cube_dim = self.settings.cube_dim.clone();
let address_type = self.settings.address_type;
#kernel_id::new::<Self>()
.address_type(address_type)
.cube_dim(self.settings.cube_dim.clone())
.info(#info_ty_name #info_generics {
#(#info_names: self.#info_names.clone(),)*
#phantom_data_init
})
}
fn address_type(&self) -> #storage_ty {
self.settings.address_type.unsigned_type()
}
}
impl #generics #cube_kernel for #kernel_name #generic_names #where_clause {
fn define(&self) -> #kernel_definition {
#define
}
}
}
} else {
TokenStream::new()
}
}
fn info_ty(&self, name: &Ident) -> proc_macro2::TokenStream {
let const_params: Vec<_> = self.comptime_params().collect();
let param_names = self
.comptime_params()
.map(|param| param.name.clone())
.collect::<Vec<_>>();
let phantom_data = self.kernel_phantom_data();
let phantom_data_init = phantom_data
.as_ref()
.map(|_| quote![__ty: ::core::marker::PhantomData]);
let (compilation_args, args) = self.compilation_args_def();
let info_names = param_names
.clone()
.into_iter()
.chain(args.clone())
.collect::<Vec<_>>();
let kernel_source_name = self.kernel_entrypoint_name();
let mut settings = quote![settings.kernel_name(#kernel_source_name)];
let cfg_debug = cfg!(debug_symbols) && !self.args.no_debug_symbols.is_present();
if cfg_debug || self.args.debug_symbols.is_present() {
settings.extend(quote![.debug_symbols()]);
}
if let Some(cluster_dim) = &self.args.cluster_dim {
settings.extend(quote![.cluster_dim(#cluster_dim)]);
}
let generics = &self.kernel_generics;
let (type_generics_names, impl_generics, where_generics) =
self.kernel_generics.split_for_impl();
let vis = &self.vis;
fn map_vec<'a, T: 'a, F: Fn(&T) -> TokenStream>(
fields: impl Iterator<Item = &'a T>,
func: F,
) -> Vec<TokenStream> {
fields.map(func).collect::<Vec<_>>()
}
let clone = map_vec(info_names.iter(), |name| quote!(#name: self.#name.clone()));
let hash = map_vec(info_names.iter(), |name| quote!(self.#name.hash(state)));
let partial_eq = map_vec(
info_names.iter(),
|name| quote!(self.#name.eq(&other.#name)),
);
let debug = map_vec(info_names.iter(), |name| {
let name_str = name.to_string();
let name_str = name_str.strip_prefix("_").unwrap_or(&name_str);
quote!(.field(#name_str, &self.#name))
});
quote! {
#vis struct #name #generics #where_generics {
#(#compilation_args,)*
#(#const_params,)*
#phantom_data
}
impl #type_generics_names Clone for #name #impl_generics #where_generics {
fn clone(&self) -> Self {
Self {
#(#clone,)*
#phantom_data_init
}
}
}
impl #type_generics_names core::hash::Hash for #name #impl_generics #where_generics {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
#(#hash;)*
}
}
impl #type_generics_names core::cmp::PartialEq for #name #impl_generics #where_generics {
fn eq(&self, other: &Self) -> bool {
#(#partial_eq &&)* true
}
}
impl #type_generics_names core::fmt::Debug for #name #impl_generics #where_generics {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct(stringify!(#name))
#(#debug)*
.finish()
}
}
impl #type_generics_names core::cmp::Eq for #name #impl_generics #where_generics { }
}
}
}