cubecl_macros/lib.rs
1#![allow(clippy::large_enum_variant)]
2
3use core::panic;
4
5use darling::FromDeriveInput;
6use error::error_into_token_stream;
7use generate::autotune::generate_autotune_key;
8use parse::{
9 cube_impl::CubeImpl,
10 cube_trait::{CubeTrait, CubeTraitImpl},
11 cube_type::CubeType,
12 helpers::{RemoveHelpers, ReplaceIndices},
13 kernel::{Launch, from_tokens},
14};
15use proc_macro::TokenStream;
16use quote::quote;
17use syn::{Item, visit_mut::VisitMut};
18
19mod error;
20mod expression;
21mod generate;
22mod operator;
23mod parse;
24mod paths;
25mod scope;
26mod statement;
27
28/// Mark a cube function, trait or implementation for expansion.
29///
30/// # Arguments
31/// * `launch` - generates a function to launch the kernel
32/// * `launch_unchecked` - generates a launch function without checks
33/// * `debug` - panics after generation to print the output to console
34/// * `create_dummy_kernel` - Generates a function to create a kernel without launching it. Used for
35/// testing.
36///
37/// # Trait arguments
38/// * `expand_base_traits` - base traits for the expanded "second half" of a trait with methods.
39/// * `self_type` - the type used for the `self` parameter of the expanded "second half" of a trait
40/// with methods. You shouldn't need to touch this unless you specifically need to dynamically
41/// dispatch an expanded trait.
42///
43/// # Example
44///
45/// ```ignored
46/// # use cubecl_macros::cube;
47/// #[cube]
48/// fn my_addition(a: u32, b: u32) -> u32 {
49/// a + b
50/// }
51/// ```
52#[proc_macro_attribute]
53pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream {
54 match cube_impl(args, input.clone()) {
55 Ok(tokens) => tokens,
56 Err(e) => error_into_token_stream(e, input.into()).into(),
57 }
58}
59
60fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
61 let mut item: Item = syn::parse(input)?;
62 let args = from_tokens(args.into())?;
63
64 let tokens = match item.clone() {
65 Item::Fn(kernel) => {
66 let kernel = Launch::from_item_fn(kernel, args)?;
67 RemoveHelpers.visit_item_mut(&mut item);
68 ReplaceIndices.visit_item_mut(&mut item);
69
70 return Ok(TokenStream::from(quote! {
71 #[allow(dead_code, clippy::too_many_arguments)]
72 #item
73 #kernel
74 }));
75 }
76 Item::Trait(kernel_trait) => {
77 let expand_trait = CubeTrait::from_item_trait(kernel_trait, args)?;
78
79 return Ok(TokenStream::from(quote! {
80 #expand_trait
81 }));
82 }
83 Item::Impl(item_impl) => {
84 if item_impl.trait_.is_some() {
85 let mut expand_impl = CubeTraitImpl::from_item_impl(item_impl, &args)?;
86 let expand_impl = expand_impl.to_tokens_mut();
87
88 Ok(TokenStream::from(quote! {
89 #expand_impl
90 }))
91 } else {
92 let mut expand_impl = CubeImpl::from_item_impl(item_impl, &args)?;
93 let expand_impl = expand_impl.to_tokens_mut();
94
95 Ok(TokenStream::from(quote! {
96 #expand_impl
97 }))
98 }
99 }
100 item => Err(syn::Error::new_spanned(
101 item,
102 "`#[cube]` is only supported on traits and functions",
103 ))?,
104 };
105
106 if args.debug.is_present() {
107 match tokens {
108 Ok(tokens) => panic!("{tokens}"),
109 Err(err) => panic!("{err}"),
110 };
111 }
112
113 tokens
114}
115
116/// Derive macro to define a cube type that is launched with a kernel
117#[proc_macro_derive(CubeLaunch, attributes(expand, cube))]
118pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
119 gen_cube_type(input, true)
120}
121
122/// Derive macro to define a cube type that is not launched
123#[proc_macro_derive(CubeType, attributes(expand, cube))]
124pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
125 gen_cube_type(input, false)
126}
127
128fn gen_cube_type(input: TokenStream, with_launch: bool) -> TokenStream {
129 let parsed = syn::parse(input);
130
131 let input = match &parsed {
132 Ok(val) => val,
133 Err(err) => return err.to_compile_error().into(),
134 };
135
136 let cube_type = match CubeType::from_derive_input(input) {
137 Ok(val) => val,
138 Err(err) => return err.write_errors().into(),
139 };
140
141 cube_type.generate(with_launch).into()
142}
143
144/// Attribute macro to define a type that can be used as a kernel comptime
145/// argument This derive Debug, Hash, PartialEq, Eq, Clone, Copy
146#[proc_macro_attribute]
147pub fn derive_cube_comptime(_metadata: TokenStream, input: TokenStream) -> TokenStream {
148 let input: proc_macro2::TokenStream = input.into();
149 quote! {
150 #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
151 #input
152 }
153 .into()
154}
155
156/// Mark the contents of this macro as compile time values, turning off all
157/// expansion for this code and using it verbatim
158///
159/// # Example
160/// ```ignored
161/// #use cubecl_macros::cube;
162/// #fn some_rust_function(a: u32) -> u32 {}
163/// #[cube]
164/// fn do_stuff(input: u32) -> u32 {
165/// let comptime_value = comptime! { some_rust_function(3) };
166/// input + comptime_value
167/// }
168/// ```
169#[proc_macro]
170pub fn comptime(input: TokenStream) -> TokenStream {
171 let tokens: proc_macro2::TokenStream = input.into();
172 quote![{ #tokens }].into()
173}
174
175/// Mark the contents of this macro as an intrinsic, turning off all expansion
176/// for this code and calling it with the scope
177///
178/// # Example
179/// ```ignored
180/// #use cubecl_macros::cube;
181/// #[cube]
182/// fn do_stuff(input: u32) -> u32 {
183/// let comptime_value = intrinsic! { |scope| u32::elem_size(scope) };
184/// input + comptime_value
185/// }
186/// ```
187#[proc_macro]
188pub fn intrinsic(_input: TokenStream) -> TokenStream {
189 quote![{ cubecl::unexpanded!() }].into()
190}
191
192/// Makes the function return a compile time value
193/// Useful in a cube trait to have a part of the trait return comptime values
194///
195/// # Example
196/// ```ignored
197/// #use cubecl_macros::cube;
198/// #[cube]
199/// fn do_stuff(#[comptime] input: u32) -> comptime_type!(u32) {
200/// input + 5
201/// }
202/// ```
203///
204/// TODO: calling a trait method returning comptime_type from
205/// within another trait method does not work
206#[proc_macro]
207pub fn comptime_type(input: TokenStream) -> TokenStream {
208 let tokens: proc_macro2::TokenStream = input.into();
209 quote![ #tokens ].into()
210}
211
212/// Insert a literal comment into the kernel source code.
213///
214/// # Example
215/// ```ignored
216/// #use cubecl_macros::cube;
217/// #[cube]
218/// fn do_stuff(input: u32) -> u32 {
219/// comment!("Add five to the input");
220/// input + 5
221/// }
222/// ```
223#[proc_macro]
224pub fn comment(input: TokenStream) -> TokenStream {
225 let tokens: proc_macro2::TokenStream = input.into();
226 quote![{ #tokens }].into()
227}
228
229/// Terminate the execution of the kernel for the current unit.
230///
231/// This terminates the execution of the unit even if nested inside many
232/// functions.
233///
234/// # Example
235/// ```ignored
236/// #use cubecl_macros::cube;
237/// #[cube]
238/// fn stop_if_more_than_ten(input: u32) {
239/// if input > 10 {
240/// terminate!();
241/// }
242/// }
243/// ```
244#[proc_macro]
245pub fn terminate(input: TokenStream) -> TokenStream {
246 let tokens: proc_macro2::TokenStream = input.into();
247 quote![{ #tokens }].into()
248}
249
250/// Implements display and initialization for autotune keys.
251///
252/// # Helper
253///
254/// Use the `#[autotune(anchor)]` helper attribute to anchor a numerical value.
255/// This groups multiple numerical values into the same bucket.
256///
257/// For now, only an exponential function is supported, and it can be modified
258/// with `exp`. By default, the base is '2' and there are no `min` or `max`
259/// provided.
260///
261/// # Example
262/// ```ignore
263/// #[derive(AutotuneKey, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
264/// pub struct OperationKey {
265/// #[autotune(name = "Batch Size")]
266/// batch_size: usize,
267/// channels: usize,
268/// #[autotune(anchor(exp(min = 16, max = 1024, base = 2)))]
269/// height: usize,
270/// #[autotune(anchor)]
271/// width: usize,
272/// }
273/// ```
274#[proc_macro_derive(AutotuneKey, attributes(autotune))]
275pub fn derive_autotune_key(input: TokenStream) -> TokenStream {
276 let input = syn::parse(input).unwrap();
277 match generate_autotune_key(input) {
278 Ok(tokens) => tokens.into(),
279 Err(e) => e.into_compile_error().into(),
280 }
281}