cubecl_macros/
lib.rs

1#![cfg_attr(nightly, feature(proc_macro_span))]
2#![allow(clippy::large_enum_variant)]
3
4use core::panic;
5
6use darling::FromDeriveInput;
7use error::error_into_token_stream;
8use generate::autotune::generate_autotune_key;
9use parse::{
10    cube_impl::CubeImpl,
11    cube_trait::{CubeTrait, CubeTraitImpl},
12    cube_type::CubeType,
13    helpers::{RemoveHelpers, ReplaceIndices},
14    kernel::{Launch, from_tokens},
15};
16use proc_macro::TokenStream;
17use quote::quote;
18use syn::{Item, visit_mut::VisitMut};
19
20mod error;
21mod expression;
22mod generate;
23mod operator;
24mod parse;
25mod paths;
26mod scope;
27mod statement;
28
29/// Mark a cube function, trait or implementation for expansion.
30///
31/// # Arguments
32/// * `launch` - generates a function to launch the kernel
33/// * `launch_unchecked` - generates a launch function without checks
34/// * `debug` - panics after generation to print the output to console
35/// * `create_dummy_kernel` - Generates a function to create a kernel without launching it. Used for testing.
36///
37/// # Example
38///
39/// ```ignored
40/// # use cubecl_macros::cube;
41/// #[cube]
42/// fn my_addition(a: u32, b: u32) -> u32 {
43///     a + b
44/// }
45/// ```
46#[proc_macro_attribute]
47pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream {
48    match cube_impl(args, input.clone()) {
49        Ok(tokens) => tokens,
50        Err(e) => error_into_token_stream(e, input.into()).into(),
51    }
52}
53
54fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
55    let mut item: Item = syn::parse(input)?;
56    let args = from_tokens(args.into())?;
57
58    let tokens = match item.clone() {
59        Item::Fn(kernel) => {
60            let kernel = Launch::from_item_fn(kernel, args)?;
61            RemoveHelpers.visit_item_mut(&mut item);
62            ReplaceIndices.visit_item_mut(&mut item);
63
64            return Ok(TokenStream::from(quote! {
65                #[allow(dead_code, clippy::too_many_arguments)]
66                #item
67                #kernel
68            }));
69        }
70        Item::Trait(kernel_trait) => {
71            let expand_trait = CubeTrait::from_item_trait(kernel_trait)?;
72
73            Ok(TokenStream::from(quote! {
74                #expand_trait
75            }))
76        }
77        Item::Impl(item_impl) => {
78            if item_impl.trait_.is_some() {
79                let mut expand_impl = CubeTraitImpl::from_item_impl(
80                    item_impl,
81                    args.src_file,
82                    args.debug_symbols.is_present(),
83                )?;
84                let expand_impl = expand_impl.to_tokens_mut();
85
86                Ok(TokenStream::from(quote! {
87                    #expand_impl
88                }))
89            } else {
90                let mut expand_impl = CubeImpl::from_item_impl(
91                    item_impl,
92                    args.src_file,
93                    args.debug_symbols.is_present(),
94                )?;
95                let expand_impl = expand_impl.to_tokens_mut();
96
97                Ok(TokenStream::from(quote! {
98                    #expand_impl
99                }))
100            }
101        }
102        item => Err(syn::Error::new_spanned(
103            item,
104            "`#[cube]` is only supported on traits and functions",
105        ))?,
106    };
107
108    if args.debug.is_present() {
109        match tokens {
110            Ok(tokens) => panic!("{tokens}"),
111            Err(err) => panic!("{err}"),
112        };
113    }
114
115    tokens
116}
117
118/// Derive macro to define a cube type that is launched with a kernel
119#[proc_macro_derive(CubeLaunch, attributes(expand, cube))]
120pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
121    gen_cube_type(input, true)
122}
123
124/// Derive macro to define a cube type that is not launched
125#[proc_macro_derive(CubeType, attributes(expand, cube))]
126pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
127    gen_cube_type(input, false)
128}
129
130fn gen_cube_type(input: TokenStream, with_launch: bool) -> TokenStream {
131    let parsed = syn::parse(input);
132
133    let input = match &parsed {
134        Ok(val) => val,
135        Err(err) => return err.to_compile_error().into(),
136    };
137
138    let cube_type = match CubeType::from_derive_input(input) {
139        Ok(val) => val,
140        Err(err) => return err.write_errors().into(),
141    };
142
143    cube_type.generate(with_launch).into()
144}
145
146/// Attribute macro to define a type that can be used as a kernel comptime argument
147/// This derive Debug, Hash, PartialEq, Eq, Clone, Copy
148#[proc_macro_attribute]
149pub fn derive_cube_comptime(_metadata: TokenStream, input: TokenStream) -> TokenStream {
150    let input: proc_macro2::TokenStream = input.into();
151    quote! {
152        #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
153        #input
154    }
155    .into()
156}
157
158/// Mark the contents of this macro as compile time values, turning off all expansion for this code
159/// and using it verbatim
160///
161/// # Example
162/// ```ignored
163/// #use cubecl_macros::cube;
164/// #fn some_rust_function(a: u32) -> u32 {}
165/// #[cube]
166/// fn do_stuff(input: u32) -> u32 {
167///     let comptime_value = comptime! { some_rust_function(3) };
168///     input + comptime_value
169/// }
170/// ```
171#[proc_macro]
172pub fn comptime(input: TokenStream) -> TokenStream {
173    let tokens: proc_macro2::TokenStream = input.into();
174    quote![{ #tokens }].into()
175}
176
177/// Mark the contents of this macro as an intrinsic, turning off all expansion for this code
178/// and calling it with the scope
179///
180/// # Example
181/// ```ignored
182/// #use cubecl_macros::cube;
183/// #[cube]
184/// fn do_stuff(input: u32) -> u32 {
185///     let comptime_value = intrinsic! { |scope| u32::elem_size(scope) };
186///     input + comptime_value
187/// }
188/// ```
189#[proc_macro]
190pub fn intrinsic(_input: TokenStream) -> TokenStream {
191    quote![{ cubecl::unexpanded!() }].into()
192}
193
194/// Makes the function return a compile time value
195/// Useful in a cube trait to have a part of the trait return comptime values
196///
197/// # Example
198/// ```ignored
199/// #use cubecl_macros::cube;
200/// #[cube]
201/// fn do_stuff(#[comptime] input: u32) -> comptime_type!(u32) {
202///     input + 5   
203/// }
204/// ```
205///
206/// TODO: calling a trait method returning comptime_type from
207/// within another trait method does not work
208#[proc_macro]
209pub fn comptime_type(input: TokenStream) -> TokenStream {
210    let tokens: proc_macro2::TokenStream = input.into();
211    quote![ #tokens ].into()
212}
213
214/// Insert a literal comment into the kernel source code.
215///
216/// # Example
217/// ```ignored
218/// #use cubecl_macros::cube;
219/// #[cube]
220/// fn do_stuff(input: u32) -> u32 {
221///     comment!("Add five to the input");
222///     input + 5
223/// }
224/// ```
225#[proc_macro]
226pub fn comment(input: TokenStream) -> TokenStream {
227    let tokens: proc_macro2::TokenStream = input.into();
228    quote![{ #tokens }].into()
229}
230
231/// Terminate the execution of the kernel for the current unit.
232///
233/// This terminates the execution of the unit even if nested inside many functions.
234///
235/// # Example
236/// ```ignored
237/// #use cubecl_macros::cube;
238/// #[cube]
239/// fn stop_if_more_than_ten(input: u32)  {
240///     if input > 10 {
241///         terminate!();
242///     }
243/// }
244/// ```
245#[proc_macro]
246pub fn terminate(input: TokenStream) -> TokenStream {
247    let tokens: proc_macro2::TokenStream = input.into();
248    quote![{ #tokens }].into()
249}
250
251/// Implements display and initialization for autotune keys.
252///
253/// # Helper
254///
255/// Use the `#[autotune(anchor)]` helper attribute to anchor a numerical value.
256/// This groups multiple numerical values into the same bucket.
257///
258/// For now, only an exponential function is supported, and it can be modified with `exp`.
259/// By default, the base is '2' and there are no `min` or `max` provided.
260///
261/// # Example
262/// ```ignore
263/// #[derive(AutotuneKey, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
264/// pub struct OperationKey {
265///     #[autotune(name = "Batch Size")]
266///     batch_size: usize,
267///     channels: usize,
268///     #[autotune(anchor(exp(min = 16, max = 1024, base = 2)))]
269///     height: usize,
270///     #[autotune(anchor)]
271///     width: usize,
272/// }
273/// ```
274#[proc_macro_derive(AutotuneKey, attributes(autotune))]
275pub fn derive_autotune_key(input: TokenStream) -> TokenStream {
276    let input = syn::parse(input).unwrap();
277    match generate_autotune_key(input) {
278        Ok(tokens) => tokens.into(),
279        Err(e) => e.into_compile_error().into(),
280    }
281}