use crate::{
RemoveHelpers,
expression::{Block, Expression},
parse::helpers::is_define_attribute,
paths::{frontend_type, prelude_type},
scope::Context,
statement::{DefineKind, Pattern, Statement},
};
use core::hash::Hash;
use darling::{FromMeta, ast::NestedMeta, util::Flag};
use inflections::case::to_snake_case;
use proc_macro2::{Span, TokenStream};
use quote::{ToTokens, format_ident, quote};
use std::{collections::HashMap, iter};
use syn::{
ConstParam, Expr, FnArg, Generics, Ident, ItemFn, LitStr, ReturnType, Signature, Token,
TraitItemFn, Type, TypeMacro, TypeParam, Visibility, parse, parse_quote,
punctuated::Punctuated, spanned::Spanned, token::Mut, visit_mut::VisitMut,
};
use super::{desugar::Desugar, helpers::is_comptime_attr, statement::parse_pat};
#[derive(Default, FromMeta, Clone)]
pub(crate) struct KernelArgs {
pub launch: Flag,
pub launch_unchecked: Flag,
pub debug_symbols: Flag,
pub no_debug_symbols: Flag,
pub fast_math: Option<Expr>,
pub debug: Flag,
pub create_dummy_kernel: Flag,
pub expand_only: Flag,
pub cluster_dim: Option<Expr>,
pub src_file: Option<LitStr>,
pub expand_base_traits: Option<String>,
pub explicit_define: Flag,
#[darling(default)]
pub self_type: SelfType,
#[darling(default)]
pub address_type: AddressType,
}
#[derive(Default, FromMeta, PartialEq, Eq, Clone, Copy)]
pub(crate) enum SelfType {
#[default]
Owned,
Ref,
RefMut,
}
#[derive(Default, FromMeta, PartialEq, Eq, Clone, Copy)]
pub(crate) enum AddressType {
#[default]
U32,
U64,
Dynamic,
}
pub fn from_tokens<T: FromMeta>(tokens: TokenStream) -> syn::Result<T> {
let meta = NestedMeta::parse_meta_list(tokens)?;
T::from_list(&meta).map_err(syn::Error::from)
}
impl KernelArgs {
pub fn is_launch(&self) -> bool {
self.launch.is_present() || self.launch_unchecked.is_present()
}
}
#[derive(Clone)]
pub struct GenericArg {
pub expand_ty: syn::Path,
pub marker_ty: syn::Ident,
pub kind: DefineKind,
}
#[derive(Clone)]
pub struct GenericAnalysis {
pub map: HashMap<syn::Ident, GenericArg>,
}
impl GenericAnalysis {
pub fn process_generic_names(&self, ty: &syn::Generics) -> TokenStream {
let mut output = quote![];
if ty.params.is_empty() {
return output;
}
for param in ty.params.iter() {
match param {
syn::GenericParam::Type(TypeParam { ident, .. })
| syn::GenericParam::Const(ConstParam { ident, .. }) => {
if let Some(GenericArg { expand_ty, .. }) = self.map.get(ident) {
output.extend(quote![#expand_ty,]);
} else {
output.extend(quote![#ident,]);
}
}
_ => todo!(),
}
}
quote! {
::<#output>
}
}
pub fn register_types(
&self,
mut name_mapping: HashMap<Ident, (Ident, Option<usize>)>,
scope: TokenStream,
has_self: bool,
launch: bool,
) -> TokenStream {
let mut output = quote![];
let self_ = has_self.then(|| quote![self.]);
for (
ident,
GenericArg {
kind, expand_ty, ..
},
) in self.map.iter()
{
let name = match name_mapping.remove(ident) {
Some((name, index)) => match index {
Some(index) => {
quote! { #self_ #name[#index].into() }
}
None => quote! { #self_ #name.into() },
},
None if !launch => {
continue;
}
None => match kind {
DefineKind::Type => quote![#ident::as_type_native_unchecked().storage_type()],
DefineKind::Size => quote![#ident::value()],
},
};
match kind {
DefineKind::Type => {
output.extend(quote! {
#scope.register_type::<#expand_ty>(#name);
});
}
DefineKind::Size => {
output.extend(quote! {
#scope.register_size::<#expand_ty>(#name);
});
}
}
}
if !name_mapping.is_empty() {
for key in name_mapping.keys() {
let err = syn::Error::new_spanned(
key,
format!("Generic `{key}` isn't defined correctly. Only `Float`, `Int` and `Numeric` generics can be defined with only a single trait bound."),
).into_compile_error();
output.extend(err);
}
}
output
}
pub fn process_ty(&self, ty: &syn::Type) -> syn::Type {
let type_path = match &ty {
Type::Path(type_path) => type_path,
_ => return ty.clone(),
};
let path = &type_path.path;
let mut returned = syn::Path {
leading_colon: path.leading_colon,
segments: syn::punctuated::Punctuated::new(),
};
for pair in path.segments.pairs() {
let segment = pair.value();
let punc = pair.punct();
if let Some(GenericArg { expand_ty, .. }) = self.map.get(&segment.ident) {
returned.segments.extend(expand_ty.segments.clone());
} else {
match &segment.arguments {
syn::PathArguments::AngleBracketed(arg) => {
let mut args = syn::punctuated::Punctuated::new();
arg.args.iter().for_each(|arg| match arg {
syn::GenericArgument::Type(ty) => {
let ty = self.process_ty(ty);
args.push(syn::GenericArgument::Type(ty));
}
_ => args.push_value(arg.clone()),
});
let segment = syn::PathSegment {
ident: segment.ident.clone(),
arguments: syn::PathArguments::AngleBracketed(
syn::AngleBracketedGenericArguments {
colon2_token: arg.colon2_token,
lt_token: arg.lt_token,
args,
gt_token: arg.gt_token,
},
),
};
returned.segments.push_value(segment);
}
_ => returned.segments.push_value((*segment).clone()),
}
}
if let Some(punc) = punc {
returned.segments.push_punct(**punc)
}
}
syn::Type::Path(syn::TypePath {
qself: type_path.qself.clone(),
path: returned,
})
}
pub fn from_generics(generics: &syn::Generics, explicit_defines: bool) -> Self {
let mut map = HashMap::new();
let elem_expand = prelude_type("DynamicScalar");
let size_expand = prelude_type("DynamicSize");
for type_param in generics.type_params() {
if type_param.bounds.len() > 1 {
continue;
}
if let Some(syn::TypeParamBound::Trait(trait_bound)) = type_param.bounds.first()
&& let Some(bound) = trait_bound.path.get_ident()
{
let name = bound.to_string();
let ident = type_param.ident.clone();
let marker_ty = format_ident!("_{ident}");
match name.as_str() {
"Float" | "Int" | "Numeric" | "CubePrimitive" => {
if explicit_defines {
map.insert(
ident.clone(),
GenericArg {
expand_ty: parse_quote!(#ident),
marker_ty,
kind: DefineKind::Type,
},
);
} else {
map.insert(
ident,
GenericArg {
expand_ty: parse_quote!(#elem_expand<#marker_ty>),
marker_ty,
kind: DefineKind::Type,
},
);
}
}
"Size" => {
if explicit_defines {
map.insert(
ident.clone(),
GenericArg {
expand_ty: parse_quote!(#ident),
marker_ty,
kind: DefineKind::Size,
},
);
} else {
map.insert(
type_param.ident.clone(),
GenericArg {
expand_ty: parse_quote!(#size_expand<#marker_ty>),
marker_ty,
kind: DefineKind::Size,
},
);
}
}
_ => {}
};
};
}
Self { map }
}
}
pub struct Launch {
pub args: KernelArgs,
pub vis: Visibility,
pub func: KernelFn,
pub kernel_generics: Generics,
pub launch_generics: Generics,
}
#[derive(Clone)]
pub struct KernelFn {
pub vis: Visibility,
pub sig: KernelSignature,
pub body: KernelBody,
pub full_name: String,
pub span: Span,
pub context: Context,
pub args: KernelArgs,
pub analysis: GenericAnalysis,
}
#[allow(clippy::large_enum_variant)]
#[derive(Clone)]
pub enum KernelBody {
Block(Block),
Verbatim(TokenStream),
}
#[derive(Clone, Debug)]
pub struct KernelSignature {
pub name: Ident,
pub parameters: Vec<KernelParam>,
pub returns: KernelReturns,
pub generics: Generics,
pub receiver_arg: Option<FnArg>,
}
impl KernelSignature {
pub fn runtime_params(&self) -> impl Iterator<Item = &KernelParam> {
self.parameters.iter().filter(|it| !it.is_const)
}
pub fn define_mappings(&self) -> HashMap<Ident, (Ident, Option<usize>)> {
let mut mapping = HashMap::new();
for param in self.parameters.iter() {
for define in param.defines.iter() {
match define {
DefinedGeneric::Single(ident) => {
mapping.insert(ident.clone(), (param.name.clone(), None));
}
DefinedGeneric::Multiple(ident, index) => {
mapping.insert(ident.clone(), (param.name.clone(), Some(*index)));
}
}
}
}
mapping
}
}
#[derive(Clone, Debug)]
pub enum KernelReturns {
ExpandType(Type),
Plain(Type),
}
impl KernelReturns {
pub fn ty(&self) -> Type {
match self {
KernelReturns::ExpandType(ty) => ty.clone(),
KernelReturns::Plain(ty) => ty.clone(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum DefinedGeneric {
Single(Ident),
Multiple(Ident, usize),
}
impl DefinedGeneric {
pub fn contains_ident(&self, ident_input: &Ident) -> bool {
self.ident() == ident_input
}
pub fn ident(&self) -> &Ident {
match self {
DefinedGeneric::Single(ident, ..) => ident,
DefinedGeneric::Multiple(ident, ..) => ident,
}
}
}
#[derive(Clone, Debug)]
pub struct KernelParam {
pub name: Ident,
pub ty: Type,
pub normalized_ty: Type,
pub defines: Vec<DefinedGeneric>,
pub is_const: bool,
pub is_mut: bool,
pub is_ref: bool,
pub mut_token: Option<Mut>,
}
impl KernelParam {
pub fn from_param(param: FnArg, args: &KernelArgs, has_body: bool) -> syn::Result<Self> {
let param = match param {
FnArg::Typed(param) => param,
FnArg::Receiver(param) => {
let mut is_ref = false;
let mut is_mut = false;
let normalized_ty =
normalize_kernel_ty(*param.ty.clone(), false, &mut is_ref, &mut is_mut);
let normalized_ty = match args.self_type {
SelfType::Owned => normalized_ty,
SelfType::Ref => parse_quote!(&#normalized_ty),
SelfType::RefMut => parse_quote!(&mut #normalized_ty),
};
is_mut = param.mutability.is_some();
is_ref = param.reference.is_some();
let mut_token = if has_body && is_mut {
let span = param.span();
Some(Token)
} else {
None
};
return Ok(KernelParam {
name: Ident::new("self", param.span()),
ty: *param.ty,
normalized_ty,
defines: Vec::new(),
is_const: false,
is_mut,
is_ref,
mut_token,
});
}
};
let Pattern {
ident,
mut is_ref,
mut is_mut,
..
} = parse_pat(*param.pat.clone())?;
let mut is_const = false;
let mut defines = Vec::new();
for attr in param.attrs.iter() {
if is_comptime_attr(attr) {
is_const = true;
}
if is_define_attribute(attr) {
match attr.parse_args::<Ident>() {
Ok(ident) => {
defines.push(DefinedGeneric::Single(ident));
}
Err(_) => {
let list = attr.meta.require_list().expect("Wrong syntax.");
let tokens = list.tokens.to_string();
let names = tokens.split(",");
for (i, name) in names.enumerate() {
let ident = Ident::new(name.trim(), attr.span());
defines.push(DefinedGeneric::Multiple(ident, i));
}
}
};
is_const = true;
}
}
let ty = *param.ty.clone();
let normalized_ty = normalize_kernel_ty(*param.ty, is_const, &mut is_ref, &mut is_mut);
let mut_token = if has_body && is_mut {
let span = ident.span();
Some(Token)
} else {
None
};
Ok(Self {
name: ident,
ty,
defines,
normalized_ty,
is_const,
is_mut,
is_ref,
mut_token,
})
}
pub fn ty_owned(&self) -> Type {
strip_ref(self.ty.clone(), &mut false, &mut false)
}
pub fn plain_normalized_self(&mut self) {
if let Type::Path(pat) = &self.ty
&& pat
.path
.get_ident()
.filter(|ident| *ident == "Self")
.is_some()
{
self.normalized_ty = self.ty.clone();
}
}
}
impl KernelSignature {
pub fn from_signature(sig: Signature, args: &KernelArgs, has_body: bool) -> syn::Result<Self> {
let name = sig.ident;
let mut generics = sig.generics;
let returns = match sig.output {
syn::ReturnType::Default => KernelReturns::ExpandType(parse_quote![()]),
syn::ReturnType::Type(_, ty) => match *ty.clone() {
Type::Macro(TypeMacro { mac }) => {
if mac.path.is_ident("comptime_type") {
let inner_type = parse::<Type>(mac.tokens.into())
.expect("Interior of comptime_type macro should be a valid type.");
KernelReturns::Plain(inner_type)
} else {
panic!("Only comptime_type macro supported on return type")
}
}
_ => KernelReturns::ExpandType(*ty),
},
};
let sig_params = sig
.inputs
.into_iter()
.map(|it| KernelParam::from_param(it, args, has_body))
.collect::<Result<Vec<_>, _>>()?;
let manually_defined_params = sig_params
.iter()
.flat_map(|it| it.defines.iter().map(|it| it.ident()))
.collect::<Vec<_>>();
let define_params = generics
.type_params()
.filter(|it| !manually_defined_params.contains(&&it.ident))
.filter(|it| {
it.attrs.iter().any(is_define_attribute)
|| (args.is_launch() && it.bounds.to_token_stream().to_string() == "Size")
})
.map(|ty_param| {
let type_ = prelude_type("Type");
let ident = &ty_param.ident;
let name = format_ident!("_{}", to_snake_case(&ident.to_string()));
let is_size = ty_param.bounds.to_token_stream().to_string() == "Size";
let ty = match is_size {
true => quote![usize],
false => quote![#type_],
};
KernelParam::from_param(parse_quote!(#[define(#ident)] #name: #ty), args, has_body)
})
.collect::<Result<Vec<_>, _>>()?;
let parameters = define_params
.into_iter()
.chain(sig_params)
.collect::<Vec<_>>();
let receiver_arg = if parameters.iter().any(|it| it.name == "self") {
Some(match args.self_type {
SelfType::Owned => parse_quote!(self),
SelfType::Ref => parse_quote!(&self),
SelfType::RefMut => parse_quote!(&mut self),
})
} else {
None
};
RemoveHelpers.visit_generics_mut(&mut generics);
Ok(KernelSignature {
generics,
name,
parameters,
returns,
receiver_arg,
})
}
pub fn from_trait_fn(function: TraitItemFn, args: &KernelArgs) -> syn::Result<Self> {
Self::from_signature(function.sig, args, false)
}
pub fn plain_self(&mut self) {
if let Type::Path(pat) = self.returns.ty()
&& pat.path.is_ident("Self")
{
self.returns = KernelReturns::Plain(self.returns.ty());
}
for param in self.parameters.iter_mut() {
if let Type::Path(pat) = ¶m.ty
&& pat.path.is_ident("Self")
{
param.normalized_ty = parse_quote!(Self);
}
}
}
}
impl KernelFn {
pub fn from_sig_and_block(
vis: Visibility,
sig: Signature,
mut block: syn::Block,
full_name: String,
args: &KernelArgs,
) -> syn::Result<Self> {
let cfg_debug = cfg!(debug_symbols) && !args.no_debug_symbols.is_present();
let debug_symbols = cfg_debug || args.debug_symbols.is_present();
let span = Span::call_site();
let sig = KernelSignature::from_signature(sig, args, true)?;
let analysis =
GenericAnalysis::from_generics(&sig.generics, args.explicit_define.is_present());
let mut context = Context::new(sig.returns.ty(), debug_symbols);
context.extend(sig.parameters.clone());
Desugar.visit_block_mut(&mut block);
let (mut block, _) = context.in_scope(|ctx| Block::from_block(block, ctx))?;
Self::patch_mut_owned_inputs(&mut block, &sig);
Ok(KernelFn {
vis,
sig,
body: KernelBody::Block(block),
full_name,
span,
context,
analysis,
args: args.clone(),
})
}
fn patch_mut_owned_inputs(block: &mut Block, sig: &KernelSignature) {
let mut mappings = Vec::new();
let into_mut = frontend_type("IntoMut");
for s in sig.parameters.iter() {
if !s.is_ref && s.is_mut {
let name = s.name.clone();
let expression = Expression::Verbatim {
tokens: quote! {
let mut #name = #into_mut::into_mut(#name, scope);
},
};
let stmt = Statement::Expression {
expression: Box::new(expression),
terminated: false,
};
mappings.push(stmt);
}
}
if !mappings.is_empty() {
mappings.append(&mut block.inner);
block.inner = mappings;
}
}
}
impl Launch {
pub fn from_item_fn(function: ItemFn, args: KernelArgs) -> syn::Result<Self> {
let runtime = prelude_type("Runtime");
let ret = function.sig.output.clone();
let vis = function.vis;
let full_name = function.sig.ident.to_string();
let mut func = KernelFn::from_sig_and_block(
Visibility::Public(parse_quote![pub]),
function.sig,
*function.block,
full_name,
&args,
)?;
if args.is_launch()
&& let ReturnType::Type(arrow, ty) = &ret
{
let mut ts = arrow.to_token_stream();
ts.extend(ty.into_token_stream());
return Err(syn::Error::new_spanned(
ts,
format!(
"This is a launch kernel and cannot have a return type. Remove `-> {}`. Use mutable output arguments instead in order to get values out from kernels.",
ty.into_token_stream()
),
));
}
let mut kernel_generics = func.sig.generics.clone();
kernel_generics.params.clear();
let explicit_define = args.explicit_define.is_present();
for param in func.sig.generics.params.iter_mut() {
let is_defined = |ident| {
func.sig
.parameters
.iter()
.any(|p| p.defines.iter().any(|d| d.contains_ident(ident)))
};
match param.clone() {
syn::GenericParam::Type(TypeParam { ident, .. })
if is_defined(&ident) && !explicit_define => {}
param => {
kernel_generics.params.push(param);
}
}
}
kernel_generics.params.push(parse_quote![__R: #runtime]);
let mut launch_generics = kernel_generics.clone();
launch_generics.params =
Punctuated::from_iter(iter::once(parse_quote!['kernel]).chain(launch_generics.params));
Ok(Launch {
args,
vis,
func,
launch_generics,
kernel_generics,
})
}
}
fn normalize_kernel_ty(ty: Type, is_const: bool, is_ref: &mut bool, is_mut: &mut bool) -> Type {
let ty = strip_ref(ty, is_ref, is_mut);
let cube_type = prelude_type("CubeType");
if is_const {
ty
} else {
parse_quote![<#ty as #cube_type>::ExpandType]
}
}
pub fn strip_ref(ty: Type, is_ref: &mut bool, is_mut: &mut bool) -> Type {
match ty {
Type::Reference(reference) => {
*is_ref = true;
*is_mut = *is_mut || reference.mutability.is_some();
*reference.elem
}
ty => ty,
}
}