cubecl_macros/
lib.rs

1use core::panic;
2
3use darling::FromDeriveInput;
4use error::error_into_token_stream;
5use generate::autotune::{generate_autotune_key, generate_autotune_set};
6use parse::{
7    cube_impl::CubeImpl,
8    cube_trait::{CubeTrait, CubeTraitImpl},
9    cube_type::CubeType,
10    helpers::{RemoveHelpers, ReplaceIndices},
11    kernel::{from_tokens, Launch},
12};
13use proc_macro::TokenStream;
14use quote::quote;
15use syn::{visit_mut::VisitMut, Item};
16
17mod error;
18mod expression;
19mod generate;
20mod operator;
21mod parse;
22mod paths;
23mod scope;
24mod statement;
25
26/// Mark a cube function, trait or implementation for expansion.
27///
28/// # Arguments
29/// * `launch` - generates a function to launch the kernel
30/// * `launch_unchecked` - generates a launch function without checks
31/// * `debug` - panics after generation to print the output to console
32/// * `create_dummy_kernel` - Generates a function to create a kernel without launching it. Used for testing.
33///
34/// # Example
35///
36/// ```ignored
37/// # use cubecl_macros::cube;
38/// #[cube]
39/// fn my_addition(a: u32, b: u32) -> u32 {
40///     a + b
41/// }
42/// ```
43#[proc_macro_attribute]
44pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream {
45    match cube_impl(args, input.clone()) {
46        Ok(tokens) => tokens,
47        Err(e) => error_into_token_stream(e, input.into()).into(),
48    }
49}
50
51fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
52    let mut item: Item = syn::parse(input)?;
53    let args = from_tokens(args.into())?;
54
55    let tokens = match item.clone() {
56        Item::Fn(kernel) => {
57            let kernel = Launch::from_item_fn(kernel, args)?;
58            RemoveHelpers.visit_item_mut(&mut item);
59            ReplaceIndices.visit_item_mut(&mut item);
60
61            return Ok(TokenStream::from(quote! {
62                #[allow(dead_code, clippy::too_many_arguments)]
63                #item
64                #kernel
65            }));
66        }
67        Item::Trait(kernel_trait) => {
68            let expand_trait = CubeTrait::from_item_trait(kernel_trait)?;
69
70            Ok(TokenStream::from(quote! {
71                #expand_trait
72            }))
73        }
74        Item::Impl(item_impl) => {
75            if item_impl.trait_.is_some() {
76                let mut expand_impl = CubeTraitImpl::from_item_impl(item_impl)?;
77                let expand_impl = expand_impl.to_tokens_mut();
78
79                Ok(TokenStream::from(quote! {
80                    #expand_impl
81                }))
82            } else {
83                let mut expand_impl = CubeImpl::from_item_impl(item_impl)?;
84                let expand_impl = expand_impl.to_tokens_mut();
85
86                Ok(TokenStream::from(quote! {
87                    #expand_impl
88                }))
89            }
90        }
91        item => Err(syn::Error::new_spanned(
92            item,
93            "`#[cube]` is only supported on traits and functions",
94        ))?,
95    };
96
97    if args.debug.is_present() {
98        match tokens {
99            Ok(tokens) => panic!("{tokens}"),
100            Err(err) => panic!("{err}"),
101        };
102    }
103
104    tokens
105}
106
107/// Derive macro to define a cube type that is launched with a kernel
108#[proc_macro_derive(CubeLaunch, attributes(expand, cube))]
109pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
110    // panic!("{gen}");
111    gen_cube_type(input, true)
112}
113
114/// Derive macro to define a cube type that is not launched
115#[proc_macro_derive(CubeType, attributes(expand, cube))]
116pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
117    gen_cube_type(input, false)
118}
119
120fn gen_cube_type(input: TokenStream, with_launch: bool) -> TokenStream {
121    let parsed = syn::parse(input);
122
123    let input = match &parsed {
124        Ok(val) => val,
125        Err(err) => return err.to_compile_error().into(),
126    };
127
128    let cube_type = match CubeType::from_derive_input(input) {
129        Ok(val) => val,
130        Err(err) => return err.write_errors().into(),
131    };
132
133    cube_type.generate(with_launch).into()
134}
135
136/// Mark the contents of this macro as compile time values, turning off all expansion for this code
137/// and using it verbatim
138///
139/// # Example
140/// ```ignored
141/// #use cubecl_macros::cube;
142/// #fn some_rust_function(a: u32) -> u32 {}
143/// #[cube]
144/// fn do_stuff(input: u32) -> u32 {
145///     let comptime_value = comptime! { some_rust_function(3) };
146///     input + comptime_value
147/// }
148/// ```
149#[proc_macro]
150pub fn comptime(input: TokenStream) -> TokenStream {
151    let tokens: proc_macro2::TokenStream = input.into();
152    quote![{ #tokens }].into()
153}
154
155/// Insert a literal comment into the kernel source code.
156///
157/// # Example
158/// ```ignored
159/// #use cubecl_macros::cube;
160/// #[cube]
161/// fn do_stuff(input: u32) -> u32 {
162///     comment!("Add five to the input");
163///     input + 5
164/// }
165/// ```
166#[proc_macro]
167pub fn comment(input: TokenStream) -> TokenStream {
168    let tokens: proc_macro2::TokenStream = input.into();
169    quote![{ #tokens }].into()
170}
171
172/// Implements display and initialization for autotune keys.
173///
174/// # Helper
175///
176/// Use the `#[autotune]` helper attribute to anchor fields to the next power of two, or rename
177/// the fields for the display implementation.
178///
179/// # Example
180/// ```ignore
181/// #[derive(AutotuneKey, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
182/// pub struct OperationKey {
183///     #[autotune(name = "Batch Size")]
184///     batch_size: usize,
185///     channels: usize,
186///     #[autotune(anchor(max = 1024))]
187///     height: usize,
188///     #[autotune(anchor)]
189///     width: usize,
190/// }
191/// ```
192#[proc_macro_derive(AutotuneKey, attributes(autotune))]
193pub fn derive_autotune_key(input: TokenStream) -> TokenStream {
194    let input = syn::parse(input).unwrap();
195    match generate_autotune_key(input) {
196        Ok(tokens) => tokens.into(),
197        Err(e) => e.into_compile_error().into(),
198    }
199}
200
201/// Crates a tuning set with a specific signature. Should return a tuple of benchmark inputs.
202///
203/// # Arguments
204///
205/// * `name` - the name of the generated operations struct (default: `PascalCaseFnName`)
206/// * `key` - the name of the key input parameter (default: `key`)
207/// * `create_key` - path to function that creates the key. If not specified, `new` must be implemented manually.
208/// * `should_run` - path to override function for the `should_run` function of the set.
209/// * `operations` - ordered list of operations returned by this tune set
210///
211/// # Example
212///
213/// ```ignore
214/// #[tune(create_key = key_from_input, operations(operation_1, operation_2))]
215/// pub fn my_operations(key: MyKey, input: JitTensor<f32, 4>) -> JitTensor<f32, 4> {
216///     let bench_input = random_tensor_like(input, -1.0, 1.0);
217///     
218///     (bench_input)
219/// }
220///
221/// fn key_from_input(input: &JitTensor<f32, 4>) -> MyKey {
222///     MyKey::new(input.shape.dims)
223/// }
224/// ```
225#[proc_macro_attribute]
226pub fn tune(args: TokenStream, input: TokenStream) -> TokenStream {
227    match autotune_set_impl(args, input.clone()) {
228        Ok(tokens) => tokens,
229        Err(e) => error_into_token_stream(e, input.into()).into(),
230    }
231}
232
233fn autotune_set_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
234    let item = syn::parse(input)?;
235    let args = from_tokens(args.into())?;
236    Ok(generate_autotune_set(item, args)?.into())
237}