1#![allow(clippy::large_enum_variant)]
2
3use core::panic;
4
5use error::error_into_token_stream;
6use generate::autotune::generate_autotune_key;
7use parse::{
8 cube_impl::CubeImpl,
9 cube_trait::{CubeTrait, CubeTraitImpl},
10 helpers::{RemoveHelpers, ReplaceIndices},
11 kernel::{Launch, from_tokens},
12};
13use proc_macro::TokenStream;
14use quote::quote;
15use syn::{Item, visit_mut::VisitMut};
16
17use crate::{
18 generate::{assign::generate_cube_type_mut, into_runtime::generate_into_runtime},
19 parse::{
20 cube_type::generate_cube_type, derive_expand::generate_derive_expand,
21 helpers::ReplaceDefines,
22 },
23};
24
25mod error;
26mod expression;
27mod generate;
28mod operator;
29mod parse;
30mod paths;
31mod scope;
32mod statement;
33
34#[proc_macro_attribute]
59pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream {
60 match cube_impl(args, input.clone()) {
61 Ok(tokens) => tokens,
62 Err(e) => error_into_token_stream(e, input.into()).into(),
63 }
64}
65
66fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
67 let mut item: Item = syn::parse(input)?;
68 let args = from_tokens(args.into())?;
69
70 let tokens = match item.clone() {
71 Item::Fn(kernel) => {
72 let kernel = Launch::from_item_fn(kernel, args)?;
73 RemoveHelpers.visit_item_mut(&mut item);
74 ReplaceIndices.visit_item_mut(&mut item);
75 ReplaceDefines.visit_item_mut(&mut item);
76
77 return Ok(TokenStream::from(quote! {
78 #[allow(dead_code, clippy::too_many_arguments)]
79 #item
80 #kernel
81 }));
82 }
83 Item::Trait(kernel_trait) => {
84 let is_debug = args.debug.is_present();
85 let expand_trait = CubeTrait::from_item_trait(kernel_trait, args)?;
86
87 let tokens = TokenStream::from(quote! {
88 #expand_trait
89 });
90 if is_debug {
91 panic!("{tokens}");
92 }
93 return Ok(tokens);
94 }
95 Item::Impl(item_impl) => {
96 if item_impl.trait_.is_some() {
97 let mut expand_impl = CubeTraitImpl::from_item_impl(item_impl, &args)?;
98 let expand_impl = expand_impl.to_tokens_mut();
99
100 Ok(TokenStream::from(quote! {
101 #expand_impl
102 }))
103 } else {
104 let mut expand_impl = CubeImpl::from_item_impl(item_impl, &args)?;
105 let expand_impl = expand_impl.to_tokens_mut();
106
107 Ok(TokenStream::from(quote! {
108 #expand_impl
109 }))
110 }
111 }
112 item => Err(syn::Error::new_spanned(
113 item,
114 "`#[cube]` is only supported on traits and functions",
115 ))?,
116 };
117
118 if args.debug.is_present() {
119 match tokens {
120 Ok(tokens) => panic!("{tokens}"),
121 Err(err) => panic!("{err}"),
122 };
123 }
124
125 tokens
126}
127
128#[proc_macro_derive(CubeLaunch, attributes(cube, launch))]
130pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
131 gen_cube_type(input, true)
132}
133
134#[proc_macro_derive(CubeType, attributes(cube))]
136pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
137 gen_cube_type(input, false)
138}
139
140fn gen_cube_type(input: TokenStream, with_launch: bool) -> TokenStream {
141 let parsed = syn::parse(input);
142
143 let input = match &parsed {
144 Ok(val) => val,
145 Err(err) => return err.to_compile_error().into(),
146 };
147
148 match generate_cube_type(input, with_launch) {
149 Ok(val) => val.into(),
150 Err(err) => err.to_compile_error().into(),
151 }
152}
153
154#[proc_macro_attribute]
157pub fn derive_cube_comptime(_metadata: TokenStream, input: TokenStream) -> TokenStream {
158 let input: proc_macro2::TokenStream = input.into();
159 quote! {
160 #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
161 #input
162 }
163 .into()
164}
165
166#[proc_macro_attribute]
168pub fn derive_expand(metadata: TokenStream, input: TokenStream) -> TokenStream {
169 match generate_derive_expand(input.into(), metadata.into()) {
170 Ok(val) => val.into(),
171 Err(err) => err.to_compile_error().into(),
172 }
173}
174
175#[proc_macro]
189pub fn comptime(input: TokenStream) -> TokenStream {
190 let tokens: proc_macro2::TokenStream = input.into();
191 quote![{ #tokens }].into()
192}
193
194#[proc_macro]
207pub fn intrinsic(_input: TokenStream) -> TokenStream {
208 quote![{ cubecl::unexpanded!() }].into()
209}
210
211#[proc_macro]
226pub fn comptime_type(input: TokenStream) -> TokenStream {
227 let tokens: proc_macro2::TokenStream = input.into();
228 quote![ #tokens ].into()
229}
230
231#[proc_macro]
243pub fn comment(input: TokenStream) -> TokenStream {
244 let tokens: proc_macro2::TokenStream = input.into();
245 quote![{ #tokens }].into()
246}
247
248#[proc_macro]
264pub fn terminate(input: TokenStream) -> TokenStream {
265 let tokens: proc_macro2::TokenStream = input.into();
266 quote![{ #tokens }].into()
267}
268
269#[proc_macro_derive(AutotuneKey, attributes(autotune))]
294pub fn derive_autotune_key(input: TokenStream) -> TokenStream {
295 let input = syn::parse(input).unwrap();
296 match generate_autotune_key(input) {
297 Ok(tokens) => tokens.into(),
298 Err(e) => e.into_compile_error().into(),
299 }
300}
301
302#[proc_macro_derive(IntoRuntime, attributes(cube))]
304pub fn derive_into_runtime(input: TokenStream) -> TokenStream {
305 let input = syn::parse(input).unwrap();
306 match generate_into_runtime(&input) {
307 Ok(tokens) => tokens.into(),
308 Err(e) => e.into_compile_error().into(),
309 }
310}
311
312#[proc_macro_derive(CubeTypeMut, attributes(cube))]
314pub fn derive_assign(input: TokenStream) -> TokenStream {
315 let input = syn::parse(input).unwrap();
316 match generate_cube_type_mut(&input) {
317 Ok(tokens) => tokens.into(),
318 Err(e) => e.into_compile_error().into(),
319 }
320}