use crate::{
expression::{Block, Expression},
paths::{frontend_type, prelude_type},
scope::Context,
statement::{Pattern, Statement},
};
use darling::{FromMeta, ast::NestedMeta, util::Flag};
use proc_macro2::{Span, TokenStream};
use quote::{ToTokens, quote};
use std::{collections::HashMap, iter};
use syn::{
Expr, FnArg, Generics, Ident, ItemFn, LitStr, ReturnType, Signature, TraitItemFn, Type,
TypeMacro, Visibility, parse, parse_quote, punctuated::Punctuated, spanned::Spanned,
visit_mut::VisitMut,
};
use super::{desugar::Desugar, helpers::is_comptime_attr, statement::parse_pat};
#[derive(Default, FromMeta, Clone)]
pub(crate) struct KernelArgs {
pub launch: Flag,
pub launch_unchecked: Flag,
pub debug_symbols: Flag,
pub no_debug_symbols: Flag,
pub fast_math: Option<Expr>,
pub debug: Flag,
pub create_dummy_kernel: Flag,
pub cluster_dim: Option<Expr>,
pub src_file: Option<LitStr>,
pub expand_base_traits: Option<String>,
#[darling(default)]
pub self_type: SelfType,
#[darling(default)]
pub address_type: AddressType,
}
#[derive(Default, FromMeta, PartialEq, Eq, Clone, Copy)]
pub(crate) enum SelfType {
#[default]
Owned,
Ref,
RefMut,
}
#[derive(Default, FromMeta, PartialEq, Eq, Clone, Copy)]
pub(crate) enum AddressType {
#[default]
U32,
U64,
Dynamic,
}
pub fn from_tokens<T: FromMeta>(tokens: TokenStream) -> syn::Result<T> {
let meta = NestedMeta::parse_meta_list(tokens)?;
T::from_list(&meta).map_err(syn::Error::from)
}
impl KernelArgs {
pub fn is_launch(&self) -> bool {
self.launch.is_present() || self.launch_unchecked.is_present()
}
}
pub struct GenericAnalysis {
pub map: HashMap<syn::Ident, syn::PathSegment>,
}
impl GenericAnalysis {
pub fn process_generics(&self, ty: &syn::Generics) -> TokenStream {
let mut output = quote![];
if ty.params.is_empty() {
return output;
}
for param in ty.params.pairs() {
match param.value() {
syn::GenericParam::Type(type_param) => {
if let Some(ty) = self.map.get(&type_param.ident) {
output.extend(quote![#ty,]);
} else {
let ident = &type_param.ident;
output.extend(quote![#ident,]);
}
}
_ => todo!(),
}
}
quote! {
::<#output>
}
}
pub fn register_types(
&self,
mut name_mapping: HashMap<Ident, (Ident, Option<usize>)>,
) -> TokenStream {
let mut output = quote![];
for (name, ty) in self.map.iter() {
let name = match name_mapping.remove(name) {
Some((name, index)) => match index {
Some(index) => {
quote! { self.#name[#index].into() }
}
None => quote! { self.#name.into() },
},
None => quote! {#name::as_type_native_unchecked()},
};
output.extend(quote! {
builder
.scope
.register_type::<#ty>(#name);
});
}
if !name_mapping.is_empty() {
for key in name_mapping.keys() {
let err = syn::Error::new_spanned(
key,
format!("Generic `{key}` isn't defined correctly. Only `Float`, `Int` and `Numeric` generics can be defined with only a single trait bound."),
).into_compile_error();
output.extend(err);
}
}
output
}
pub fn process_ty(&self, ty: &syn::Type) -> syn::Type {
let type_path = match &ty {
Type::Path(type_path) => type_path,
_ => return ty.clone(),
};
let path = &type_path.path;
let mut returned = syn::Path {
leading_colon: path.leading_colon,
segments: syn::punctuated::Punctuated::new(),
};
for pair in path.segments.pairs() {
let segment = pair.value();
let punc = pair.punct();
if let Some(segment) = self.map.get(&segment.ident) {
returned.segments.push_value(segment.clone());
} else {
match &segment.arguments {
syn::PathArguments::AngleBracketed(arg) => {
let mut args = syn::punctuated::Punctuated::new();
arg.args.iter().for_each(|arg| match arg {
syn::GenericArgument::Type(ty) => {
let ty = self.process_ty(ty);
args.push(syn::GenericArgument::Type(ty));
}
_ => args.push_value(arg.clone()),
});
let segment = syn::PathSegment {
ident: segment.ident.clone(),
arguments: syn::PathArguments::AngleBracketed(
syn::AngleBracketedGenericArguments {
colon2_token: arg.colon2_token,
lt_token: arg.lt_token,
args,
gt_token: arg.gt_token,
},
),
};
returned.segments.push_value(segment);
}
_ => returned.segments.push_value((*segment).clone()),
}
}
if let Some(punc) = punc {
returned.segments.push_punct(**punc)
}
}
syn::Type::Path(syn::TypePath {
qself: type_path.qself.clone(),
path: returned,
})
}
pub fn from_generics(generics: &syn::Generics) -> Self {
let mut map = HashMap::new();
for param in generics.params.pairs() {
let type_param = if let syn::GenericParam::Type(type_param) = param.value() {
type_param
} else {
continue;
};
if type_param.bounds.len() > 1 {
continue;
}
if let Some(syn::TypeParamBound::Trait(trait_bound)) = type_param.bounds.first()
&& let Some(bound) = trait_bound.path.get_ident()
{
let name = bound.to_string();
let index = map.len() as u8;
match name.as_str() {
"Float" => {
map.insert(type_param.ident.clone(), parse_quote!(FloatExpand<#index>));
}
"Int" => {
map.insert(type_param.ident.clone(), parse_quote!(IntExpand<#index>));
}
"Numeric" => {
map.insert(
type_param.ident.clone(),
parse_quote!(NumericExpand<#index>),
);
}
_ => {}
};
};
}
Self { map }
}
}
pub struct Launch {
pub args: KernelArgs,
pub vis: Visibility,
pub func: KernelFn,
pub kernel_generics: Generics,
pub launch_generics: Generics,
pub analysis: GenericAnalysis,
}
#[derive(Clone)]
pub struct KernelFn {
pub vis: Visibility,
pub sig: KernelSignature,
pub body: KernelBody,
pub full_name: String,
pub span: Span,
pub context: Context,
pub args: KernelArgs,
}
#[allow(clippy::large_enum_variant)]
#[derive(Clone)]
pub enum KernelBody {
Block(Block),
Verbatim(TokenStream),
}
#[derive(Clone, Debug)]
pub struct KernelSignature {
pub name: Ident,
pub parameters: Vec<KernelParam>,
pub returns: KernelReturns,
pub generics: Generics,
pub receiver_arg: Option<FnArg>,
}
impl KernelSignature {
pub fn runtime_params(&self) -> impl Iterator<Item = &KernelParam> {
self.parameters.iter().filter(|it| !it.is_const)
}
}
#[derive(Clone, Debug)]
pub enum KernelReturns {
ExpandType(Type),
Plain(Type),
}
impl KernelReturns {
pub fn ty(&self) -> Type {
match self {
KernelReturns::ExpandType(ty) => ty.clone(),
KernelReturns::Plain(ty) => ty.clone(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum DefinedGeneric {
Single(Ident),
Multiple(Ident, usize),
}
impl DefinedGeneric {
pub fn contains_ident(&self, ident_input: &Ident) -> bool {
match self {
DefinedGeneric::Single(ident) => ident == ident_input,
DefinedGeneric::Multiple(ident, _) => ident == ident_input,
}
}
}
#[derive(Clone, Debug)]
pub struct KernelParam {
pub name: Ident,
pub ty: Type,
pub normalized_ty: Type,
pub defines: Vec<DefinedGeneric>,
pub is_const: bool,
pub is_mut: bool,
pub is_ref: bool,
}
impl KernelParam {
pub fn from_param(param: FnArg, args: &KernelArgs) -> syn::Result<Self> {
let param = match param {
FnArg::Typed(param) => param,
FnArg::Receiver(param) => {
let mut is_ref = false;
let mut is_mut = false;
let normalized_ty =
normalize_kernel_ty(*param.ty.clone(), false, &mut is_ref, &mut is_mut);
let normalized_ty = match args.self_type {
SelfType::Owned => normalized_ty,
SelfType::Ref => parse_quote!(&#normalized_ty),
SelfType::RefMut => parse_quote!(&mut #normalized_ty),
};
is_mut = param.mutability.is_some();
is_ref = param.reference.is_some();
return Ok(KernelParam {
name: Ident::new("self", param.span()),
ty: *param.ty,
normalized_ty,
defines: Vec::new(),
is_const: false,
is_mut,
is_ref,
});
}
};
let Pattern {
ident,
mut is_ref,
mut is_mut,
..
} = parse_pat(*param.pat.clone())?;
let mut is_const = false;
let mut defines = Vec::new();
for attr in param.attrs.iter() {
if is_comptime_attr(attr) {
is_const = true;
}
if attr.path().is_ident("define") {
match attr.parse_args::<Ident>() {
Ok(ident) => {
defines.push(DefinedGeneric::Single(ident));
}
Err(_) => {
let list = attr.meta.require_list().expect("Wrong syntax.");
let tokens = list.tokens.to_string();
let names = tokens.split(",");
for (i, name) in names.enumerate() {
let ident = Ident::new(name.trim(), attr.span());
defines.push(DefinedGeneric::Multiple(ident, i));
}
}
};
is_const = true;
}
}
let ty = *param.ty.clone();
let normalized_ty = normalize_kernel_ty(*param.ty, is_const, &mut is_ref, &mut is_mut);
Ok(Self {
name: ident,
ty,
defines,
normalized_ty,
is_const,
is_mut,
is_ref,
})
}
pub fn ty_owned(&self) -> Type {
strip_ref(self.ty.clone(), &mut false, &mut false)
}
pub fn plain_normalized_self(&mut self) {
if let Type::Path(pat) = &self.ty
&& pat
.path
.get_ident()
.filter(|ident| *ident == "Self")
.is_some()
{
self.normalized_ty = self.ty.clone();
}
}
}
impl KernelSignature {
pub fn from_signature(sig: Signature, args: &KernelArgs) -> syn::Result<Self> {
let name = sig.ident;
let generics = sig.generics;
let returns = match sig.output {
syn::ReturnType::Default => KernelReturns::ExpandType(parse_quote![()]),
syn::ReturnType::Type(_, ty) => match *ty.clone() {
Type::Macro(TypeMacro { mac }) => {
if mac.path.is_ident("comptime_type") {
let inner_type = parse::<Type>(mac.tokens.into())
.expect("Interior of comptime_type macro should be a valid type.");
KernelReturns::Plain(inner_type)
} else {
panic!("Only comptime_type macro supported on return type")
}
}
_ => KernelReturns::ExpandType(*ty),
},
};
let parameters = sig
.inputs
.into_iter()
.map(|it| KernelParam::from_param(it, args))
.collect::<Result<Vec<_>, _>>()?;
let receiver_arg = if parameters.iter().any(|it| it.name == "self") {
Some(match args.self_type {
SelfType::Owned => parse_quote!(self),
SelfType::Ref => parse_quote!(&self),
SelfType::RefMut => parse_quote!(&mut self),
})
} else {
None
};
Ok(KernelSignature {
generics,
name,
parameters,
returns,
receiver_arg,
})
}
pub fn from_trait_fn(function: TraitItemFn, args: &KernelArgs) -> syn::Result<Self> {
Self::from_signature(function.sig, args)
}
pub fn plain_self(&mut self) {
if let Type::Path(pat) = self.returns.ty()
&& pat.path.is_ident("Self")
{
self.returns = KernelReturns::Plain(self.returns.ty());
}
for param in self.parameters.iter_mut() {
if let Type::Path(pat) = ¶m.ty
&& pat.path.is_ident("Self")
{
param.normalized_ty = parse_quote!(Self);
}
}
}
}
impl KernelFn {
pub fn from_sig_and_block(
vis: Visibility,
sig: Signature,
mut block: syn::Block,
full_name: String,
args: &KernelArgs,
) -> syn::Result<Self> {
let cfg_debug = cfg!(debug_symbols) && !args.no_debug_symbols.is_present();
let debug_symbols = cfg_debug || args.debug_symbols.is_present();
let span = Span::call_site();
let sig = KernelSignature::from_signature(sig, args)?;
let mut context = Context::new(sig.returns.ty(), debug_symbols);
context.extend(sig.parameters.clone());
Desugar.visit_block_mut(&mut block);
let (mut block, _) = context.in_scope(|ctx| Block::from_block(block, ctx))?;
Self::patch_mut_owned_inputs(&mut block, &sig);
Ok(KernelFn {
vis,
sig,
body: KernelBody::Block(block),
full_name,
span,
context,
args: args.clone(),
})
}
fn patch_mut_owned_inputs(block: &mut Block, sig: &KernelSignature) {
let mut mappings = Vec::new();
let into_mut = frontend_type("IntoMut");
for s in sig.parameters.iter() {
if !s.is_ref && s.is_mut {
let name = s.name.clone();
let expression = Expression::Verbatim {
tokens: quote! {
let mut #name = #into_mut::into_mut(#name, scope);
},
};
let stmt = Statement::Expression {
expression: Box::new(expression),
terminated: false,
};
mappings.push(stmt);
}
}
if !mappings.is_empty() {
mappings.append(&mut block.inner);
block.inner = mappings;
}
}
}
impl Launch {
pub fn from_item_fn(function: ItemFn, args: KernelArgs) -> syn::Result<Self> {
let runtime = prelude_type("Runtime");
let ret = function.sig.output.clone();
let vis = function.vis;
let full_name = function.sig.ident.to_string();
let func = KernelFn::from_sig_and_block(
Visibility::Public(parse_quote![pub]),
function.sig,
*function.block,
full_name,
&args,
)?;
if args.is_launch()
&& let ReturnType::Type(arrow, ty) = &ret
{
let mut ts = arrow.to_token_stream();
ts.extend(ty.into_token_stream());
return Err(syn::Error::new_spanned(
ts,
format!(
"This is a launch kernel and cannot have a return type. Remove `-> {}`. Use mutable output arguments instead in order to get values out from kernels.",
ty.into_token_stream()
),
));
}
let mut kernel_generics = func.sig.generics.clone();
kernel_generics.params.clear();
for param in func.sig.generics.params.iter() {
if let syn::GenericParam::Type(tp) = param
&& func
.sig
.parameters
.iter()
.any(|p| p.defines.iter().any(|d| d.contains_ident(&tp.ident)))
{
continue;
};
kernel_generics.params.push(param.clone());
}
kernel_generics.params.push(parse_quote![__R: #runtime]);
let mut launch_generics = kernel_generics.clone();
launch_generics.params =
Punctuated::from_iter(iter::once(parse_quote!['kernel]).chain(launch_generics.params));
let analysis = GenericAnalysis::from_generics(&func.sig.generics);
Ok(Launch {
args,
vis,
func,
launch_generics,
kernel_generics,
analysis,
})
}
}
fn normalize_kernel_ty(ty: Type, is_const: bool, is_ref: &mut bool, is_mut: &mut bool) -> Type {
let ty = strip_ref(ty, is_ref, is_mut);
let cube_type = prelude_type("CubeType");
if is_const {
ty
} else {
parse_quote![<#ty as #cube_type>::ExpandType]
}
}
pub fn strip_ref(ty: Type, is_ref: &mut bool, is_mut: &mut bool) -> Type {
match ty {
Type::Reference(reference) => {
*is_ref = true;
*is_mut = *is_mut || reference.mutability.is_some();
*reference.elem
}
ty => ty,
}
}