cubecl_macros/lib.rs
1#![allow(clippy::large_enum_variant)]
2
3use core::panic;
4
5use darling::FromDeriveInput;
6use error::error_into_token_stream;
7use generate::autotune::generate_autotune_key;
8use parse::{
9 cube_impl::CubeImpl,
10 cube_trait::{CubeTrait, CubeTraitImpl},
11 cube_type::CubeType,
12 helpers::{RemoveHelpers, ReplaceIndices},
13 kernel::{Launch, from_tokens},
14};
15use proc_macro::TokenStream;
16use quote::quote;
17use syn::{Item, visit_mut::VisitMut};
18
19mod error;
20mod expression;
21mod generate;
22mod operator;
23mod parse;
24mod paths;
25mod scope;
26mod statement;
27
28/// Mark a cube function, trait or implementation for expansion.
29///
30/// # Arguments
31/// * `launch` - generates a function to launch the kernel
32/// * `launch_unchecked` - generates a launch function without checks
33/// * `debug` - panics after generation to print the output to console
34/// * `create_dummy_kernel` - Generates a function to create a kernel without launching it. Used for
35/// testing.
36///
37/// # Trait arguments
38/// * `expand_base_traits` - base traits for the expanded "second half" of a trait with methods.
39/// * `self_type` - the type used for the `self` parameter of the expanded "second half" of a trait
40/// with methods. You shouldn't need to touch this unless you specifically need to dynamically
41/// dispatch an expanded trait.
42///
43/// # Example
44///
45/// ```ignored
46/// # use cubecl_macros::cube;
47/// #[cube]
48/// fn my_addition(a: u32, b: u32) -> u32 {
49/// a + b
50/// }
51/// ```
52#[proc_macro_attribute]
53pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream {
54 match cube_impl(args, input.clone()) {
55 Ok(tokens) => tokens,
56 Err(e) => error_into_token_stream(e, input.into()).into(),
57 }
58}
59
60fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
61 let mut item: Item = syn::parse(input)?;
62 let args = from_tokens(args.into())?;
63
64 let tokens = match item.clone() {
65 Item::Fn(kernel) => {
66 let kernel = Launch::from_item_fn(kernel, args)?;
67 RemoveHelpers.visit_item_mut(&mut item);
68 ReplaceIndices.visit_item_mut(&mut item);
69
70 return Ok(TokenStream::from(quote! {
71 #[allow(dead_code, clippy::too_many_arguments)]
72 #item
73 #kernel
74 }));
75 }
76 Item::Trait(kernel_trait) => {
77 let is_debug = args.debug.is_present();
78 let expand_trait = CubeTrait::from_item_trait(kernel_trait, args)?;
79
80 let tokens = TokenStream::from(quote! {
81 #expand_trait
82 });
83 if is_debug {
84 panic!("{tokens}");
85 }
86 return Ok(tokens);
87 }
88 Item::Impl(item_impl) => {
89 if item_impl.trait_.is_some() {
90 let mut expand_impl = CubeTraitImpl::from_item_impl(item_impl, &args)?;
91 let expand_impl = expand_impl.to_tokens_mut();
92
93 Ok(TokenStream::from(quote! {
94 #expand_impl
95 }))
96 } else {
97 let mut expand_impl = CubeImpl::from_item_impl(item_impl, &args)?;
98 let expand_impl = expand_impl.to_tokens_mut();
99
100 Ok(TokenStream::from(quote! {
101 #expand_impl
102 }))
103 }
104 }
105 item => Err(syn::Error::new_spanned(
106 item,
107 "`#[cube]` is only supported on traits and functions",
108 ))?,
109 };
110
111 if args.debug.is_present() {
112 match tokens {
113 Ok(tokens) => panic!("{tokens}"),
114 Err(err) => panic!("{err}"),
115 };
116 }
117
118 tokens
119}
120
121/// Derive macro to define a cube type that is launched with a kernel
122#[proc_macro_derive(CubeLaunch, attributes(expand, cube))]
123pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
124 gen_cube_type(input, true)
125}
126
127/// Derive macro to define a cube type that is not launched
128#[proc_macro_derive(CubeType, attributes(expand, cube))]
129pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
130 gen_cube_type(input, false)
131}
132
133fn gen_cube_type(input: TokenStream, with_launch: bool) -> TokenStream {
134 let parsed = syn::parse(input);
135
136 let input = match &parsed {
137 Ok(val) => val,
138 Err(err) => return err.to_compile_error().into(),
139 };
140
141 let cube_type = match CubeType::from_derive_input(input) {
142 Ok(val) => val,
143 Err(err) => return err.write_errors().into(),
144 };
145
146 cube_type.generate(with_launch).into()
147}
148
149/// Attribute macro to define a type that can be used as a kernel comptime
150/// argument This derive Debug, Hash, PartialEq, Eq, Clone, Copy
151#[proc_macro_attribute]
152pub fn derive_cube_comptime(_metadata: TokenStream, input: TokenStream) -> TokenStream {
153 let input: proc_macro2::TokenStream = input.into();
154 quote! {
155 #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
156 #input
157 }
158 .into()
159}
160
161/// Mark the contents of this macro as compile time values, turning off all
162/// expansion for this code and using it verbatim
163///
164/// # Example
165/// ```ignored
166/// #use cubecl_macros::cube;
167/// #fn some_rust_function(a: u32) -> u32 {}
168/// #[cube]
169/// fn do_stuff(input: u32) -> u32 {
170/// let comptime_value = comptime! { some_rust_function(3) };
171/// input + comptime_value
172/// }
173/// ```
174#[proc_macro]
175pub fn comptime(input: TokenStream) -> TokenStream {
176 let tokens: proc_macro2::TokenStream = input.into();
177 quote![{ #tokens }].into()
178}
179
180/// Mark the contents of this macro as an intrinsic, turning off all expansion
181/// for this code and calling it with the scope
182///
183/// # Example
184/// ```ignored
185/// #use cubecl_macros::cube;
186/// #[cube]
187/// fn do_stuff(input: u32) -> u32 {
188/// let comptime_value = intrinsic! { |scope| u32::elem_size(scope) };
189/// input + comptime_value
190/// }
191/// ```
192#[proc_macro]
193pub fn intrinsic(_input: TokenStream) -> TokenStream {
194 quote![{ cubecl::unexpanded!() }].into()
195}
196
197/// Makes the function return a compile time value
198/// Useful in a cube trait to have a part of the trait return comptime values
199///
200/// # Example
201/// ```ignored
202/// #use cubecl_macros::cube;
203/// #[cube]
204/// fn do_stuff(#[comptime] input: u32) -> comptime_type!(u32) {
205/// input + 5
206/// }
207/// ```
208///
209/// TODO: calling a trait method returning comptime_type from
210/// within another trait method does not work
211#[proc_macro]
212pub fn comptime_type(input: TokenStream) -> TokenStream {
213 let tokens: proc_macro2::TokenStream = input.into();
214 quote![ #tokens ].into()
215}
216
217/// Insert a literal comment into the kernel source code.
218///
219/// # Example
220/// ```ignored
221/// #use cubecl_macros::cube;
222/// #[cube]
223/// fn do_stuff(input: u32) -> u32 {
224/// comment!("Add five to the input");
225/// input + 5
226/// }
227/// ```
228#[proc_macro]
229pub fn comment(input: TokenStream) -> TokenStream {
230 let tokens: proc_macro2::TokenStream = input.into();
231 quote![{ #tokens }].into()
232}
233
234/// Terminate the execution of the kernel for the current unit.
235///
236/// This terminates the execution of the unit even if nested inside many
237/// functions.
238///
239/// # Example
240/// ```ignored
241/// #use cubecl_macros::cube;
242/// #[cube]
243/// fn stop_if_more_than_ten(input: u32) {
244/// if input > 10 {
245/// terminate!();
246/// }
247/// }
248/// ```
249#[proc_macro]
250pub fn terminate(input: TokenStream) -> TokenStream {
251 let tokens: proc_macro2::TokenStream = input.into();
252 quote![{ #tokens }].into()
253}
254
255/// Implements display and initialization for autotune keys.
256///
257/// # Helper
258///
259/// Use the `#[autotune(anchor)]` helper attribute to anchor a numerical value.
260/// This groups multiple numerical values into the same bucket.
261///
262/// For now, only an exponential function is supported, and it can be modified
263/// with `exp`. By default, the base is '2' and there are no `min` or `max`
264/// provided.
265///
266/// # Example
267/// ```ignore
268/// #[derive(AutotuneKey, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
269/// pub struct OperationKey {
270/// #[autotune(name = "Batch Size")]
271/// batch_size: usize,
272/// channels: usize,
273/// #[autotune(anchor(exp(min = 16, max = 1024, base = 2)))]
274/// height: usize,
275/// #[autotune(anchor)]
276/// width: usize,
277/// }
278/// ```
279#[proc_macro_derive(AutotuneKey, attributes(autotune))]
280pub fn derive_autotune_key(input: TokenStream) -> TokenStream {
281 let input = syn::parse(input).unwrap();
282 match generate_autotune_key(input) {
283 Ok(tokens) => tokens.into(),
284 Err(e) => e.into_compile_error().into(),
285 }
286}