cubecl_macros/lib.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
use core::panic;
use darling::FromDeriveInput;
use error::error_into_token_stream;
use generate::autotune::{generate_autotune_key, generate_autotune_set};
use parse::{
cube_impl::CubeImpl,
cube_trait::{CubeTrait, CubeTraitImpl},
cube_type::CubeType,
helpers::{RemoveHelpers, ReplaceIndices},
kernel::{from_tokens, Launch},
};
use proc_macro::TokenStream;
use quote::quote;
use syn::{visit_mut::VisitMut, Item};
mod error;
mod expression;
mod generate;
mod operator;
mod parse;
mod paths;
mod scope;
mod statement;
/// Mark a cube function, trait or implementation for expansion.
///
/// # Arguments
/// * `launch` - generates a function to launch the kernel
/// * `launch_unchecked` - generates a launch function without checks
/// * `debug` - panics after generation to print the output to console
/// * `create_dummy_kernel` - Generates a function to create a kernel without launching it. Used for testing.
///
/// # Example
///
/// ```ignored
/// # use cubecl_macros::cube;
/// #[cube]
/// fn my_addition(a: u32, b: u32) -> u32 {
/// a + b
/// }
/// ```
#[proc_macro_attribute]
pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream {
match cube_impl(args, input.clone()) {
Ok(tokens) => tokens,
Err(e) => error_into_token_stream(e, input.into()).into(),
}
}
fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
let mut item: Item = syn::parse(input)?;
let args = from_tokens(args.into())?;
let tokens = match item.clone() {
Item::Fn(kernel) => {
let kernel = Launch::from_item_fn(kernel, args)?;
RemoveHelpers.visit_item_mut(&mut item);
ReplaceIndices.visit_item_mut(&mut item);
return Ok(TokenStream::from(quote! {
#[allow(dead_code, clippy::too_many_arguments)]
#item
#kernel
}));
}
Item::Trait(kernel_trait) => {
let expand_trait = CubeTrait::from_item_trait(kernel_trait)?;
Ok(TokenStream::from(quote! {
#expand_trait
}))
}
Item::Impl(item_impl) => {
if item_impl.trait_.is_some() {
let mut expand_impl = CubeTraitImpl::from_item_impl(item_impl)?;
let expand_impl = expand_impl.to_tokens_mut();
Ok(TokenStream::from(quote! {
#expand_impl
}))
} else {
let mut expand_impl = CubeImpl::from_item_impl(item_impl)?;
let expand_impl = expand_impl.to_tokens_mut();
Ok(TokenStream::from(quote! {
#expand_impl
}))
}
}
item => Err(syn::Error::new_spanned(
item,
"`#[cube]` is only supported on traits and functions",
))?,
};
if args.debug.is_present() {
match tokens {
Ok(tokens) => panic!("{tokens}"),
Err(err) => panic!("{err}"),
};
}
tokens
}
/// Derive macro to define a cube type that is launched with a kernel
#[proc_macro_derive(CubeLaunch, attributes(expand))]
pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
gen_cube_type(input, true)
}
/// Derive macro to define a cube type that is not launched
#[proc_macro_derive(CubeType, attributes(expand))]
pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
gen_cube_type(input, false)
}
fn gen_cube_type(input: TokenStream, with_launch: bool) -> TokenStream {
let parsed = syn::parse(input);
let input = match &parsed {
Ok(val) => val,
Err(err) => return err.to_compile_error().into(),
};
let cube_type = match CubeType::from_derive_input(input) {
Ok(val) => val,
Err(err) => return err.write_errors().into(),
};
cube_type.generate(with_launch).into()
}
/// Mark the contents of this macro as compile time values, turning off all expansion for this code
/// and using it verbatim
///
/// # Example
/// ```ignored
/// #use cubecl_macros::cube;
/// #fn some_rust_function(a: u32) -> u32 {}
/// #[cube]
/// fn do_stuff(input: u32) -> u32 {
/// let comptime_value = comptime! { some_rust_function(3) };
/// input + comptime_value
/// }
/// ```
#[proc_macro]
pub fn comptime(input: TokenStream) -> TokenStream {
let tokens: proc_macro2::TokenStream = input.into();
quote![{ #tokens }].into()
}
/// Implements display and initialization for autotune keys.
///
/// # Helper
///
/// Use the `#[autotune]` helper attribute to anchor fields to the next power of two, or rename
/// the fields for the display implementation.
///
/// # Example
/// ```ignore
/// #[derive(AutotuneKey, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
/// pub struct OperationKey {
/// #[autotune(name = "Batch Size")]
/// batch_size: usize,
/// channels: usize,
/// #[autotune(anchor(max = 1024))]
/// height: usize,
/// #[autotune(anchor)]
/// width: usize,
/// }
/// ```
#[proc_macro_derive(AutotuneKey, attributes(autotune))]
pub fn derive_autotune_key(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
match generate_autotune_key(input) {
Ok(tokens) => tokens.into(),
Err(e) => e.into_compile_error().into(),
}
}
/// Crates a tuning set with a specific signature. Should return a tuple of benchmark inputs.
///
/// # Arguments
///
/// * `name` - the name of the generated operations struct (default: `PascalCaseFnName`)
/// * `key` - the name of the key input parameter (default: `key`)
/// * `create_key` - path to function that creates the key. If not specified, `new` must be implemented manually.
/// * `should_run` - path to override function for the `should_run` function of the set.
/// * `operations` - ordered list of operations returned by this tune set
///
/// # Example
///
/// ```ignore
/// #[tune(create_key = key_from_input, operations(operation_1, operation_2))]
/// pub fn my_operations(key: MyKey, input: JitTensor<f32, 4>) -> JitTensor<f32, 4> {
/// let bench_input = random_tensor_like(input, -1.0, 1.0);
///
/// (bench_input)
/// }
///
/// fn key_from_input(input: &JitTensor<f32, 4>) -> MyKey {
/// MyKey::new(input.shape.dims)
/// }
/// ```
#[proc_macro_attribute]
pub fn tune(args: TokenStream, input: TokenStream) -> TokenStream {
match autotune_set_impl(args, input.clone()) {
Ok(tokens) => tokens,
Err(e) => error_into_token_stream(e, input.into()).into(),
}
}
fn autotune_set_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
let item = syn::parse(input)?;
let args = from_tokens(args.into())?;
Ok(generate_autotune_set(item, args)?.into())
}