Skip to main content

cubecl_macros/
lib.rs

1#![allow(clippy::large_enum_variant)]
2
3use core::panic;
4
5use error::error_into_token_stream;
6use generate::autotune::generate_autotune_key;
7use parse::{
8    cube_impl::CubeImpl,
9    cube_trait::{CubeTrait, CubeTraitImpl},
10    helpers::{RemoveHelpers, ReplaceIndices},
11    kernel::{Launch, from_tokens},
12};
13use proc_macro::TokenStream;
14use quote::quote;
15use syn::{Item, visit_mut::VisitMut};
16
17use crate::{
18    generate::{assign::generate_cube_type_mut, into_runtime::generate_into_runtime},
19    parse::{
20        cube_type::generate_cube_type, derive_expand::generate_derive_expand,
21        helpers::ReplaceDefines,
22    },
23};
24
25mod error;
26mod expression;
27mod generate;
28mod operator;
29mod parse;
30mod paths;
31mod scope;
32mod statement;
33
34/// Mark a cube function, trait or implementation for expansion.
35///
36/// # Arguments
37/// * `launch` - generates a function to launch the kernel
38/// * `launch_unchecked` - generates a launch function without checks
39/// * `debug` - panics after generation to print the output to console
40/// * `create_dummy_kernel` - Generates a function to create a kernel without launching it. Used for
41///   testing.
42///
43/// # Trait arguments
44/// * `expand_base_traits` - base traits for the expanded "second half" of a trait with methods.
45/// * `self_type` - the type used for the `self` parameter of the expanded "second half" of a trait
46///   with methods. You shouldn't need to touch this unless you specifically need to dynamically
47///   dispatch an expanded trait.
48///
49/// # Example
50///
51/// ```ignored
52/// # use cubecl_macros::cube;
53/// #[cube]
54/// fn my_addition(a: u32, b: u32) -> u32 {
55///     a + b
56/// }
57/// ```
58#[proc_macro_attribute]
59pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream {
60    match cube_impl(args, input.clone()) {
61        Ok(tokens) => tokens,
62        Err(e) => error_into_token_stream(e, input.into()).into(),
63    }
64}
65
66fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
67    let mut item: Item = syn::parse(input)?;
68    let args = from_tokens(args.into())?;
69
70    let tokens = match item.clone() {
71        Item::Fn(kernel) => {
72            let kernel = Launch::from_item_fn(kernel, args)?;
73            RemoveHelpers.visit_item_mut(&mut item);
74            ReplaceIndices.visit_item_mut(&mut item);
75            ReplaceDefines.visit_item_mut(&mut item);
76
77            return Ok(TokenStream::from(quote! {
78                #[allow(dead_code, clippy::too_many_arguments)]
79                #item
80                #kernel
81            }));
82        }
83        Item::Trait(kernel_trait) => {
84            let is_debug = args.debug.is_present();
85            let expand_trait = CubeTrait::from_item_trait(kernel_trait, args)?;
86
87            let tokens = TokenStream::from(quote! {
88                #expand_trait
89            });
90            if is_debug {
91                panic!("{tokens}");
92            }
93            return Ok(tokens);
94        }
95        Item::Impl(item_impl) => {
96            if item_impl.trait_.is_some() {
97                let mut expand_impl = CubeTraitImpl::from_item_impl(item_impl, &args)?;
98                let expand_impl = expand_impl.to_tokens_mut();
99
100                Ok(TokenStream::from(quote! {
101                    #expand_impl
102                }))
103            } else {
104                let mut expand_impl = CubeImpl::from_item_impl(item_impl, &args)?;
105                let expand_impl = expand_impl.to_tokens_mut();
106
107                Ok(TokenStream::from(quote! {
108                    #expand_impl
109                }))
110            }
111        }
112        item => Err(syn::Error::new_spanned(
113            item,
114            "`#[cube]` is only supported on traits and functions",
115        ))?,
116    };
117
118    if args.debug.is_present() {
119        match tokens {
120            Ok(tokens) => panic!("{tokens}"),
121            Err(err) => panic!("{err}"),
122        };
123    }
124
125    tokens
126}
127
128/// Derive macro to define a cube type that is launched with a kernel
129#[proc_macro_derive(CubeLaunch, attributes(cube, launch))]
130pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
131    gen_cube_type(input, true)
132}
133
134/// Derive macro to define a cube type that is not launched
135#[proc_macro_derive(CubeType, attributes(cube))]
136pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
137    gen_cube_type(input, false)
138}
139
140fn gen_cube_type(input: TokenStream, with_launch: bool) -> TokenStream {
141    let parsed = syn::parse(input);
142
143    let input = match &parsed {
144        Ok(val) => val,
145        Err(err) => return err.to_compile_error().into(),
146    };
147
148    match generate_cube_type(input, with_launch) {
149        Ok(val) => val.into(),
150        Err(err) => err.to_compile_error().into(),
151    }
152}
153
154/// Attribute macro to define a type that can be used as a kernel comptime
155/// argument This derive Debug, Hash, `PartialEq`, Eq, Clone, Copy
156#[proc_macro_attribute]
157pub fn derive_cube_comptime(_metadata: TokenStream, input: TokenStream) -> TokenStream {
158    let input: proc_macro2::TokenStream = input.into();
159    quote! {
160        #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
161        #input
162    }
163    .into()
164}
165
166/// Attribute macro to derive cube traits for existing structs, without redefining that struct.
167#[proc_macro_attribute]
168pub fn derive_expand(metadata: TokenStream, input: TokenStream) -> TokenStream {
169    match generate_derive_expand(input.into(), metadata.into()) {
170        Ok(val) => val.into(),
171        Err(err) => err.to_compile_error().into(),
172    }
173}
174
175/// Mark the contents of this macro as compile time values, turning off all
176/// expansion for this code and using it verbatim
177///
178/// # Example
179/// ```ignored
180/// #use cubecl_macros::cube;
181/// #fn some_rust_function(a: u32) -> u32 {}
182/// #[cube]
183/// fn do_stuff(input: u32) -> u32 {
184///     let comptime_value = comptime! { some_rust_function(3) };
185///     input + comptime_value
186/// }
187/// ```
188#[proc_macro]
189pub fn comptime(input: TokenStream) -> TokenStream {
190    let tokens: proc_macro2::TokenStream = input.into();
191    quote![{ #tokens }].into()
192}
193
194/// Mark the contents of this macro as an intrinsic, turning off all expansion
195/// for this code and calling it with the scope
196///
197/// # Example
198/// ```ignored
199/// #use cubecl_macros::cube;
200/// #[cube]
201/// fn do_stuff(input: u32) -> u32 {
202///     let comptime_value = intrinsic! { |scope| u32::elem_size(scope) };
203///     input + comptime_value
204/// }
205/// ```
206#[proc_macro]
207pub fn intrinsic(_input: TokenStream) -> TokenStream {
208    quote![{ cubecl::unexpanded!() }].into()
209}
210
211/// Makes the function return a compile time value
212/// Useful in a cube trait to have a part of the trait return comptime values
213///
214/// # Example
215/// ```ignored
216/// #use cubecl_macros::cube;
217/// #[cube]
218/// fn do_stuff(#[comptime] input: u32) -> comptime_type!(u32) {
219///     input + 5   
220/// }
221/// ```
222///
223/// TODO: calling a trait method returning `comptime_type` from
224/// within another trait method does not work
225#[proc_macro]
226pub fn comptime_type(input: TokenStream) -> TokenStream {
227    let tokens: proc_macro2::TokenStream = input.into();
228    quote![ #tokens ].into()
229}
230
231/// Insert a literal comment into the kernel source code.
232///
233/// # Example
234/// ```ignored
235/// #use cubecl_macros::cube;
236/// #[cube]
237/// fn do_stuff(input: u32) -> u32 {
238///     comment!("Add five to the input");
239///     input + 5
240/// }
241/// ```
242#[proc_macro]
243pub fn comment(input: TokenStream) -> TokenStream {
244    let tokens: proc_macro2::TokenStream = input.into();
245    quote![{ #tokens }].into()
246}
247
248/// Terminate the execution of the kernel for the current unit.
249///
250/// This terminates the execution of the unit even if nested inside many
251/// functions.
252///
253/// # Example
254/// ```ignored
255/// #use cubecl_macros::cube;
256/// #[cube]
257/// fn stop_if_more_than_ten(input: u32)  {
258///     if input > 10 {
259///         terminate!();
260///     }
261/// }
262/// ```
263#[proc_macro]
264pub fn terminate(input: TokenStream) -> TokenStream {
265    let tokens: proc_macro2::TokenStream = input.into();
266    quote![{ #tokens }].into()
267}
268
269/// Implements display and initialization for autotune keys.
270///
271/// # Helper
272///
273/// Use the `#[autotune(anchor)]` helper attribute to anchor a numerical value.
274/// This groups multiple numerical values into the same bucket.
275///
276/// For now, only an exponential function is supported, and it can be modified
277/// with `exp`. By default, the base is '2' and there are no `min` or `max`
278/// provided.
279///
280/// # Example
281/// ```ignore
282/// #[derive(AutotuneKey, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
283/// pub struct OperationKey {
284///     #[autotune(name = "Batch Size")]
285///     batch_size: usize,
286///     channels: usize,
287///     #[autotune(anchor(exp(min = 16, max = 1024, base = 2)))]
288///     height: usize,
289///     #[autotune(anchor)]
290///     width: usize,
291/// }
292/// ```
293#[proc_macro_derive(AutotuneKey, attributes(autotune))]
294pub fn derive_autotune_key(input: TokenStream) -> TokenStream {
295    let input = syn::parse(input).unwrap();
296    match generate_autotune_key(input) {
297        Ok(tokens) => tokens.into(),
298        Err(e) => e.into_compile_error().into(),
299    }
300}
301
302/// Implements `IntoRuntime` for a `CubeType`
303#[proc_macro_derive(IntoRuntime, attributes(cube))]
304pub fn derive_into_runtime(input: TokenStream) -> TokenStream {
305    let input = syn::parse(input).unwrap();
306    match generate_into_runtime(&input) {
307        Ok(tokens) => tokens.into(),
308        Err(e) => e.into_compile_error().into(),
309    }
310}
311
312/// Implements mutability for a `CubeType`
313#[proc_macro_derive(CubeTypeMut, attributes(cube))]
314pub fn derive_assign(input: TokenStream) -> TokenStream {
315    let input = syn::parse(input).unwrap();
316    match generate_cube_type_mut(&input) {
317        Ok(tokens) => tokens.into(),
318        Err(e) => e.into_compile_error().into(),
319    }
320}