cubecl_macros/
lib.rs

1#![allow(clippy::large_enum_variant)]
2
3use core::panic;
4
5use darling::FromDeriveInput;
6use error::error_into_token_stream;
7use generate::autotune::generate_autotune_key;
8use parse::{
9    cube_impl::CubeImpl,
10    cube_trait::{CubeTrait, CubeTraitImpl},
11    cube_type::CubeType,
12    helpers::{RemoveHelpers, ReplaceIndices},
13    kernel::{Launch, from_tokens},
14};
15use proc_macro::TokenStream;
16use quote::quote;
17use syn::{Item, visit_mut::VisitMut};
18
19mod error;
20mod expression;
21mod generate;
22mod operator;
23mod parse;
24mod paths;
25mod scope;
26mod statement;
27
28/// Mark a cube function, trait or implementation for expansion.
29///
30/// # Arguments
31/// * `launch` - generates a function to launch the kernel
32/// * `launch_unchecked` - generates a launch function without checks
33/// * `debug` - panics after generation to print the output to console
34/// * `create_dummy_kernel` - Generates a function to create a kernel without launching it. Used for
35///   testing.
36///
37/// # Trait arguments
38/// * `expand_base_traits` - base traits for the expanded "second half" of a trait with methods.
39/// * `self_type` - the type used for the `self` parameter of the expanded "second half" of a trait
40///   with methods. You shouldn't need to touch this unless you specifically need to dynamically
41///   dispatch an expanded trait.
42///
43/// # Example
44///
45/// ```ignored
46/// # use cubecl_macros::cube;
47/// #[cube]
48/// fn my_addition(a: u32, b: u32) -> u32 {
49///     a + b
50/// }
51/// ```
52#[proc_macro_attribute]
53pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream {
54    match cube_impl(args, input.clone()) {
55        Ok(tokens) => tokens,
56        Err(e) => error_into_token_stream(e, input.into()).into(),
57    }
58}
59
60fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
61    let mut item: Item = syn::parse(input)?;
62    let args = from_tokens(args.into())?;
63
64    let tokens = match item.clone() {
65        Item::Fn(kernel) => {
66            let kernel = Launch::from_item_fn(kernel, args)?;
67            RemoveHelpers.visit_item_mut(&mut item);
68            ReplaceIndices.visit_item_mut(&mut item);
69
70            return Ok(TokenStream::from(quote! {
71                #[allow(dead_code, clippy::too_many_arguments)]
72                #item
73                #kernel
74            }));
75        }
76        Item::Trait(kernel_trait) => {
77            let is_debug = args.debug.is_present();
78            let expand_trait = CubeTrait::from_item_trait(kernel_trait, args)?;
79
80            let tokens = TokenStream::from(quote! {
81                #expand_trait
82            });
83            if is_debug {
84                panic!("{tokens}");
85            }
86            return Ok(tokens);
87        }
88        Item::Impl(item_impl) => {
89            if item_impl.trait_.is_some() {
90                let mut expand_impl = CubeTraitImpl::from_item_impl(item_impl, &args)?;
91                let expand_impl = expand_impl.to_tokens_mut();
92
93                Ok(TokenStream::from(quote! {
94                    #expand_impl
95                }))
96            } else {
97                let mut expand_impl = CubeImpl::from_item_impl(item_impl, &args)?;
98                let expand_impl = expand_impl.to_tokens_mut();
99
100                Ok(TokenStream::from(quote! {
101                    #expand_impl
102                }))
103            }
104        }
105        item => Err(syn::Error::new_spanned(
106            item,
107            "`#[cube]` is only supported on traits and functions",
108        ))?,
109    };
110
111    if args.debug.is_present() {
112        match tokens {
113            Ok(tokens) => panic!("{tokens}"),
114            Err(err) => panic!("{err}"),
115        };
116    }
117
118    tokens
119}
120
121/// Derive macro to define a cube type that is launched with a kernel
122#[proc_macro_derive(CubeLaunch, attributes(expand, cube))]
123pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
124    gen_cube_type(input, true)
125}
126
127/// Derive macro to define a cube type that is not launched
128#[proc_macro_derive(CubeType, attributes(expand, cube))]
129pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
130    gen_cube_type(input, false)
131}
132
133fn gen_cube_type(input: TokenStream, with_launch: bool) -> TokenStream {
134    let parsed = syn::parse(input);
135
136    let input = match &parsed {
137        Ok(val) => val,
138        Err(err) => return err.to_compile_error().into(),
139    };
140
141    let cube_type = match CubeType::from_derive_input(input) {
142        Ok(val) => val,
143        Err(err) => return err.write_errors().into(),
144    };
145
146    cube_type.generate(with_launch).into()
147}
148
149/// Attribute macro to define a type that can be used as a kernel comptime
150/// argument This derive Debug, Hash, PartialEq, Eq, Clone, Copy
151#[proc_macro_attribute]
152pub fn derive_cube_comptime(_metadata: TokenStream, input: TokenStream) -> TokenStream {
153    let input: proc_macro2::TokenStream = input.into();
154    quote! {
155        #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
156        #input
157    }
158    .into()
159}
160
161/// Mark the contents of this macro as compile time values, turning off all
162/// expansion for this code and using it verbatim
163///
164/// # Example
165/// ```ignored
166/// #use cubecl_macros::cube;
167/// #fn some_rust_function(a: u32) -> u32 {}
168/// #[cube]
169/// fn do_stuff(input: u32) -> u32 {
170///     let comptime_value = comptime! { some_rust_function(3) };
171///     input + comptime_value
172/// }
173/// ```
174#[proc_macro]
175pub fn comptime(input: TokenStream) -> TokenStream {
176    let tokens: proc_macro2::TokenStream = input.into();
177    quote![{ #tokens }].into()
178}
179
180/// Mark the contents of this macro as an intrinsic, turning off all expansion
181/// for this code and calling it with the scope
182///
183/// # Example
184/// ```ignored
185/// #use cubecl_macros::cube;
186/// #[cube]
187/// fn do_stuff(input: u32) -> u32 {
188///     let comptime_value = intrinsic! { |scope| u32::elem_size(scope) };
189///     input + comptime_value
190/// }
191/// ```
192#[proc_macro]
193pub fn intrinsic(_input: TokenStream) -> TokenStream {
194    quote![{ cubecl::unexpanded!() }].into()
195}
196
197/// Makes the function return a compile time value
198/// Useful in a cube trait to have a part of the trait return comptime values
199///
200/// # Example
201/// ```ignored
202/// #use cubecl_macros::cube;
203/// #[cube]
204/// fn do_stuff(#[comptime] input: u32) -> comptime_type!(u32) {
205///     input + 5   
206/// }
207/// ```
208///
209/// TODO: calling a trait method returning comptime_type from
210/// within another trait method does not work
211#[proc_macro]
212pub fn comptime_type(input: TokenStream) -> TokenStream {
213    let tokens: proc_macro2::TokenStream = input.into();
214    quote![ #tokens ].into()
215}
216
217/// Insert a literal comment into the kernel source code.
218///
219/// # Example
220/// ```ignored
221/// #use cubecl_macros::cube;
222/// #[cube]
223/// fn do_stuff(input: u32) -> u32 {
224///     comment!("Add five to the input");
225///     input + 5
226/// }
227/// ```
228#[proc_macro]
229pub fn comment(input: TokenStream) -> TokenStream {
230    let tokens: proc_macro2::TokenStream = input.into();
231    quote![{ #tokens }].into()
232}
233
234/// Terminate the execution of the kernel for the current unit.
235///
236/// This terminates the execution of the unit even if nested inside many
237/// functions.
238///
239/// # Example
240/// ```ignored
241/// #use cubecl_macros::cube;
242/// #[cube]
243/// fn stop_if_more_than_ten(input: u32)  {
244///     if input > 10 {
245///         terminate!();
246///     }
247/// }
248/// ```
249#[proc_macro]
250pub fn terminate(input: TokenStream) -> TokenStream {
251    let tokens: proc_macro2::TokenStream = input.into();
252    quote![{ #tokens }].into()
253}
254
255/// Implements display and initialization for autotune keys.
256///
257/// # Helper
258///
259/// Use the `#[autotune(anchor)]` helper attribute to anchor a numerical value.
260/// This groups multiple numerical values into the same bucket.
261///
262/// For now, only an exponential function is supported, and it can be modified
263/// with `exp`. By default, the base is '2' and there are no `min` or `max`
264/// provided.
265///
266/// # Example
267/// ```ignore
268/// #[derive(AutotuneKey, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
269/// pub struct OperationKey {
270///     #[autotune(name = "Batch Size")]
271///     batch_size: usize,
272///     channels: usize,
273///     #[autotune(anchor(exp(min = 16, max = 1024, base = 2)))]
274///     height: usize,
275///     #[autotune(anchor)]
276///     width: usize,
277/// }
278/// ```
279#[proc_macro_derive(AutotuneKey, attributes(autotune))]
280pub fn derive_autotune_key(input: TokenStream) -> TokenStream {
281    let input = syn::parse(input).unwrap();
282    match generate_autotune_key(input) {
283        Ok(tokens) => tokens.into(),
284        Err(e) => e.into_compile_error().into(),
285    }
286}