cubecl_macros/
lib.rs

1#![allow(clippy::large_enum_variant)]
2
3use core::panic;
4
5use darling::FromDeriveInput;
6use error::error_into_token_stream;
7use generate::autotune::generate_autotune_key;
8use parse::{
9    cube_impl::CubeImpl,
10    cube_trait::{CubeTrait, CubeTraitImpl},
11    cube_type::CubeType,
12    helpers::{RemoveHelpers, ReplaceIndices},
13    kernel::{Launch, from_tokens},
14};
15use proc_macro::TokenStream;
16use quote::quote;
17use syn::{Item, visit_mut::VisitMut};
18
19mod error;
20mod expression;
21mod generate;
22mod operator;
23mod parse;
24mod paths;
25mod scope;
26mod statement;
27
28/// Mark a cube function, trait or implementation for expansion.
29///
30/// # Arguments
31/// * `launch` - generates a function to launch the kernel
32/// * `launch_unchecked` - generates a launch function without checks
33/// * `debug` - panics after generation to print the output to console
34/// * `create_dummy_kernel` - Generates a function to create a kernel without launching it. Used for
35///   testing.
36///
37/// # Trait arguments
38/// * `expand_base_traits` - base traits for the expanded "second half" of a trait with methods.
39/// * `self_type` - the type used for the `self` parameter of the expanded "second half" of a trait
40///   with methods. You shouldn't need to touch this unless you specifically need to dynamically
41///   dispatch an expanded trait.
42///
43/// # Example
44///
45/// ```ignored
46/// # use cubecl_macros::cube;
47/// #[cube]
48/// fn my_addition(a: u32, b: u32) -> u32 {
49///     a + b
50/// }
51/// ```
52#[proc_macro_attribute]
53pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream {
54    match cube_impl(args, input.clone()) {
55        Ok(tokens) => tokens,
56        Err(e) => error_into_token_stream(e, input.into()).into(),
57    }
58}
59
60fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
61    let mut item: Item = syn::parse(input)?;
62    let args = from_tokens(args.into())?;
63
64    let tokens = match item.clone() {
65        Item::Fn(kernel) => {
66            let kernel = Launch::from_item_fn(kernel, args)?;
67            RemoveHelpers.visit_item_mut(&mut item);
68            ReplaceIndices.visit_item_mut(&mut item);
69
70            return Ok(TokenStream::from(quote! {
71                #[allow(dead_code, clippy::too_many_arguments)]
72                #item
73                #kernel
74            }));
75        }
76        Item::Trait(kernel_trait) => {
77            let expand_trait = CubeTrait::from_item_trait(kernel_trait, args)?;
78
79            return Ok(TokenStream::from(quote! {
80                #expand_trait
81            }));
82        }
83        Item::Impl(item_impl) => {
84            if item_impl.trait_.is_some() {
85                let mut expand_impl = CubeTraitImpl::from_item_impl(item_impl, &args)?;
86                let expand_impl = expand_impl.to_tokens_mut();
87
88                Ok(TokenStream::from(quote! {
89                    #expand_impl
90                }))
91            } else {
92                let mut expand_impl = CubeImpl::from_item_impl(item_impl, &args)?;
93                let expand_impl = expand_impl.to_tokens_mut();
94
95                Ok(TokenStream::from(quote! {
96                    #expand_impl
97                }))
98            }
99        }
100        item => Err(syn::Error::new_spanned(
101            item,
102            "`#[cube]` is only supported on traits and functions",
103        ))?,
104    };
105
106    if args.debug.is_present() {
107        match tokens {
108            Ok(tokens) => panic!("{tokens}"),
109            Err(err) => panic!("{err}"),
110        };
111    }
112
113    tokens
114}
115
116/// Derive macro to define a cube type that is launched with a kernel
117#[proc_macro_derive(CubeLaunch, attributes(expand, cube))]
118pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
119    gen_cube_type(input, true)
120}
121
122/// Derive macro to define a cube type that is not launched
123#[proc_macro_derive(CubeType, attributes(expand, cube))]
124pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
125    gen_cube_type(input, false)
126}
127
128fn gen_cube_type(input: TokenStream, with_launch: bool) -> TokenStream {
129    let parsed = syn::parse(input);
130
131    let input = match &parsed {
132        Ok(val) => val,
133        Err(err) => return err.to_compile_error().into(),
134    };
135
136    let cube_type = match CubeType::from_derive_input(input) {
137        Ok(val) => val,
138        Err(err) => return err.write_errors().into(),
139    };
140
141    cube_type.generate(with_launch).into()
142}
143
144/// Attribute macro to define a type that can be used as a kernel comptime
145/// argument This derive Debug, Hash, PartialEq, Eq, Clone, Copy
146#[proc_macro_attribute]
147pub fn derive_cube_comptime(_metadata: TokenStream, input: TokenStream) -> TokenStream {
148    let input: proc_macro2::TokenStream = input.into();
149    quote! {
150        #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
151        #input
152    }
153    .into()
154}
155
156/// Mark the contents of this macro as compile time values, turning off all
157/// expansion for this code and using it verbatim
158///
159/// # Example
160/// ```ignored
161/// #use cubecl_macros::cube;
162/// #fn some_rust_function(a: u32) -> u32 {}
163/// #[cube]
164/// fn do_stuff(input: u32) -> u32 {
165///     let comptime_value = comptime! { some_rust_function(3) };
166///     input + comptime_value
167/// }
168/// ```
169#[proc_macro]
170pub fn comptime(input: TokenStream) -> TokenStream {
171    let tokens: proc_macro2::TokenStream = input.into();
172    quote![{ #tokens }].into()
173}
174
175/// Mark the contents of this macro as an intrinsic, turning off all expansion
176/// for this code and calling it with the scope
177///
178/// # Example
179/// ```ignored
180/// #use cubecl_macros::cube;
181/// #[cube]
182/// fn do_stuff(input: u32) -> u32 {
183///     let comptime_value = intrinsic! { |scope| u32::elem_size(scope) };
184///     input + comptime_value
185/// }
186/// ```
187#[proc_macro]
188pub fn intrinsic(_input: TokenStream) -> TokenStream {
189    quote![{ cubecl::unexpanded!() }].into()
190}
191
192/// Makes the function return a compile time value
193/// Useful in a cube trait to have a part of the trait return comptime values
194///
195/// # Example
196/// ```ignored
197/// #use cubecl_macros::cube;
198/// #[cube]
199/// fn do_stuff(#[comptime] input: u32) -> comptime_type!(u32) {
200///     input + 5   
201/// }
202/// ```
203///
204/// TODO: calling a trait method returning comptime_type from
205/// within another trait method does not work
206#[proc_macro]
207pub fn comptime_type(input: TokenStream) -> TokenStream {
208    let tokens: proc_macro2::TokenStream = input.into();
209    quote![ #tokens ].into()
210}
211
212/// Insert a literal comment into the kernel source code.
213///
214/// # Example
215/// ```ignored
216/// #use cubecl_macros::cube;
217/// #[cube]
218/// fn do_stuff(input: u32) -> u32 {
219///     comment!("Add five to the input");
220///     input + 5
221/// }
222/// ```
223#[proc_macro]
224pub fn comment(input: TokenStream) -> TokenStream {
225    let tokens: proc_macro2::TokenStream = input.into();
226    quote![{ #tokens }].into()
227}
228
229/// Terminate the execution of the kernel for the current unit.
230///
231/// This terminates the execution of the unit even if nested inside many
232/// functions.
233///
234/// # Example
235/// ```ignored
236/// #use cubecl_macros::cube;
237/// #[cube]
238/// fn stop_if_more_than_ten(input: u32)  {
239///     if input > 10 {
240///         terminate!();
241///     }
242/// }
243/// ```
244#[proc_macro]
245pub fn terminate(input: TokenStream) -> TokenStream {
246    let tokens: proc_macro2::TokenStream = input.into();
247    quote![{ #tokens }].into()
248}
249
250/// Implements display and initialization for autotune keys.
251///
252/// # Helper
253///
254/// Use the `#[autotune(anchor)]` helper attribute to anchor a numerical value.
255/// This groups multiple numerical values into the same bucket.
256///
257/// For now, only an exponential function is supported, and it can be modified
258/// with `exp`. By default, the base is '2' and there are no `min` or `max`
259/// provided.
260///
261/// # Example
262/// ```ignore
263/// #[derive(AutotuneKey, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
264/// pub struct OperationKey {
265///     #[autotune(name = "Batch Size")]
266///     batch_size: usize,
267///     channels: usize,
268///     #[autotune(anchor(exp(min = 16, max = 1024, base = 2)))]
269///     height: usize,
270///     #[autotune(anchor)]
271///     width: usize,
272/// }
273/// ```
274#[proc_macro_derive(AutotuneKey, attributes(autotune))]
275pub fn derive_autotune_key(input: TokenStream) -> TokenStream {
276    let input = syn::parse(input).unwrap();
277    match generate_autotune_key(input) {
278        Ok(tokens) => tokens.into(),
279        Err(e) => e.into_compile_error().into(),
280    }
281}