#![allow(clippy::large_enum_variant)]
use core::panic;
use error::error_into_token_stream;
use generate::autotune::generate_autotune_key;
use parse::{
cube_impl::CubeImpl,
cube_trait::{CubeTrait, CubeTraitImpl},
helpers::{RemoveHelpers, ReplaceIndices},
kernel::{Launch, from_tokens},
};
use proc_macro::TokenStream;
use quote::quote;
use syn::{Item, visit_mut::VisitMut};
use crate::{
generate::{assign::generate_cube_type_mut, into_runtime::generate_into_runtime},
parse::{
cube_type::generate_cube_type, derive_expand::generate_derive_expand,
helpers::ReplaceDefines,
},
};
mod error;
mod expression;
mod generate;
mod operator;
mod parse;
mod paths;
mod scope;
mod statement;
#[proc_macro_attribute]
pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream {
match cube_impl(args, input.clone()) {
Ok(tokens) => tokens,
Err(e) => error_into_token_stream(e, input.into()).into(),
}
}
fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
let mut item: Item = syn::parse(input)?;
let args = from_tokens(args.into())?;
let tokens = match item.clone() {
Item::Fn(kernel) => {
let kernel = Launch::from_item_fn(kernel, args)?;
RemoveHelpers.visit_item_mut(&mut item);
ReplaceIndices.visit_item_mut(&mut item);
ReplaceDefines.visit_item_mut(&mut item);
return Ok(TokenStream::from(quote! {
#[allow(dead_code, clippy::too_many_arguments)]
#item
#kernel
}));
}
Item::Trait(kernel_trait) => {
let is_debug = args.debug.is_present();
let expand_trait = CubeTrait::from_item_trait(kernel_trait, args)?;
let tokens = TokenStream::from(quote! {
#expand_trait
});
if is_debug {
panic!("{tokens}");
}
return Ok(tokens);
}
Item::Impl(item_impl) => {
if item_impl.trait_.is_some() {
let mut expand_impl = CubeTraitImpl::from_item_impl(item_impl, &args)?;
let expand_impl = expand_impl.to_tokens_mut();
Ok(TokenStream::from(quote! {
#expand_impl
}))
} else {
let mut expand_impl = CubeImpl::from_item_impl(item_impl, &args)?;
let expand_impl = expand_impl.to_tokens_mut();
Ok(TokenStream::from(quote! {
#expand_impl
}))
}
}
item => Err(syn::Error::new_spanned(
item,
"`#[cube]` is only supported on traits and functions",
))?,
};
if args.debug.is_present() {
match tokens {
Ok(tokens) => panic!("{tokens}"),
Err(err) => panic!("{err}"),
};
}
tokens
}
#[proc_macro_derive(CubeLaunch, attributes(cube, launch))]
pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
gen_cube_type(input, true)
}
#[proc_macro_derive(CubeType, attributes(cube))]
pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
gen_cube_type(input, false)
}
fn gen_cube_type(input: TokenStream, with_launch: bool) -> TokenStream {
let parsed = syn::parse(input);
let input = match &parsed {
Ok(val) => val,
Err(err) => return err.to_compile_error().into(),
};
match generate_cube_type(input, with_launch) {
Ok(val) => val.into(),
Err(err) => err.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn derive_cube_comptime(_metadata: TokenStream, input: TokenStream) -> TokenStream {
let input: proc_macro2::TokenStream = input.into();
quote! {
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
#input
}
.into()
}
#[proc_macro_attribute]
pub fn derive_expand(metadata: TokenStream, input: TokenStream) -> TokenStream {
match generate_derive_expand(input.into(), metadata.into()) {
Ok(val) => val.into(),
Err(err) => err.to_compile_error().into(),
}
}
#[proc_macro]
pub fn comptime(input: TokenStream) -> TokenStream {
let tokens: proc_macro2::TokenStream = input.into();
quote![{ #tokens }].into()
}
#[proc_macro]
pub fn intrinsic(_input: TokenStream) -> TokenStream {
quote![{ cubecl::unexpanded!() }].into()
}
#[proc_macro]
pub fn comptime_type(input: TokenStream) -> TokenStream {
let tokens: proc_macro2::TokenStream = input.into();
quote![ #tokens ].into()
}
#[proc_macro]
pub fn comment(input: TokenStream) -> TokenStream {
let tokens: proc_macro2::TokenStream = input.into();
quote![{ #tokens }].into()
}
#[proc_macro]
pub fn terminate(input: TokenStream) -> TokenStream {
let tokens: proc_macro2::TokenStream = input.into();
quote![{ #tokens }].into()
}
#[proc_macro_derive(AutotuneKey, attributes(autotune))]
pub fn derive_autotune_key(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
match generate_autotune_key(input) {
Ok(tokens) => tokens.into(),
Err(e) => e.into_compile_error().into(),
}
}
#[proc_macro_derive(IntoRuntime, attributes(cube))]
pub fn derive_into_runtime(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
match generate_into_runtime(&input) {
Ok(tokens) => tokens.into(),
Err(e) => e.into_compile_error().into(),
}
}
#[proc_macro_derive(CubeTypeMut, attributes(cube))]
pub fn derive_assign(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
match generate_cube_type_mut(&input) {
Ok(tokens) => tokens.into(),
Err(e) => e.into_compile_error().into(),
}
}