1#![cfg_attr(nightly, feature(proc_macro_span))]
2#![allow(clippy::large_enum_variant)]
3
4use core::panic;
5
6use darling::FromDeriveInput;
7use error::error_into_token_stream;
8use generate::autotune::generate_autotune_key;
9use parse::{
10 cube_impl::CubeImpl,
11 cube_trait::{CubeTrait, CubeTraitImpl},
12 cube_type::CubeType,
13 helpers::{RemoveHelpers, ReplaceIndices},
14 kernel::{Launch, from_tokens},
15};
16use proc_macro::TokenStream;
17use quote::quote;
18use syn::{Item, visit_mut::VisitMut};
19
20mod error;
21mod expression;
22mod generate;
23mod operator;
24mod parse;
25mod paths;
26mod scope;
27mod statement;
28
29#[proc_macro_attribute]
47pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream {
48 match cube_impl(args, input.clone()) {
49 Ok(tokens) => tokens,
50 Err(e) => error_into_token_stream(e, input.into()).into(),
51 }
52}
53
54fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
55 let mut item: Item = syn::parse(input)?;
56 let args = from_tokens(args.into())?;
57
58 let tokens = match item.clone() {
59 Item::Fn(kernel) => {
60 let kernel = Launch::from_item_fn(kernel, args)?;
61 RemoveHelpers.visit_item_mut(&mut item);
62 ReplaceIndices.visit_item_mut(&mut item);
63
64 return Ok(TokenStream::from(quote! {
65 #[allow(dead_code, clippy::too_many_arguments)]
66 #item
67 #kernel
68 }));
69 }
70 Item::Trait(kernel_trait) => {
71 let expand_trait = CubeTrait::from_item_trait(kernel_trait)?;
72
73 Ok(TokenStream::from(quote! {
74 #expand_trait
75 }))
76 }
77 Item::Impl(item_impl) => {
78 if item_impl.trait_.is_some() {
79 let mut expand_impl = CubeTraitImpl::from_item_impl(
80 item_impl,
81 args.src_file,
82 args.debug_symbols.is_present(),
83 )?;
84 let expand_impl = expand_impl.to_tokens_mut();
85
86 Ok(TokenStream::from(quote! {
87 #expand_impl
88 }))
89 } else {
90 let mut expand_impl = CubeImpl::from_item_impl(
91 item_impl,
92 args.src_file,
93 args.debug_symbols.is_present(),
94 )?;
95 let expand_impl = expand_impl.to_tokens_mut();
96
97 Ok(TokenStream::from(quote! {
98 #expand_impl
99 }))
100 }
101 }
102 item => Err(syn::Error::new_spanned(
103 item,
104 "`#[cube]` is only supported on traits and functions",
105 ))?,
106 };
107
108 if args.debug.is_present() {
109 match tokens {
110 Ok(tokens) => panic!("{tokens}"),
111 Err(err) => panic!("{err}"),
112 };
113 }
114
115 tokens
116}
117
118#[proc_macro_derive(CubeLaunch, attributes(expand, cube))]
120pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
121 gen_cube_type(input, true)
122}
123
124#[proc_macro_derive(CubeType, attributes(expand, cube))]
126pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
127 gen_cube_type(input, false)
128}
129
130fn gen_cube_type(input: TokenStream, with_launch: bool) -> TokenStream {
131 let parsed = syn::parse(input);
132
133 let input = match &parsed {
134 Ok(val) => val,
135 Err(err) => return err.to_compile_error().into(),
136 };
137
138 let cube_type = match CubeType::from_derive_input(input) {
139 Ok(val) => val,
140 Err(err) => return err.write_errors().into(),
141 };
142
143 cube_type.generate(with_launch).into()
144}
145
146#[proc_macro_attribute]
149pub fn derive_cube_comptime(_metadata: TokenStream, input: TokenStream) -> TokenStream {
150 let input: proc_macro2::TokenStream = input.into();
151 quote! {
152 #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
153 #input
154 }
155 .into()
156}
157
158#[proc_macro]
172pub fn comptime(input: TokenStream) -> TokenStream {
173 let tokens: proc_macro2::TokenStream = input.into();
174 quote![{ #tokens }].into()
175}
176
177#[proc_macro]
190pub fn intrinsic(_input: TokenStream) -> TokenStream {
191 quote![{ cubecl::unexpanded!() }].into()
192}
193
194#[proc_macro]
209pub fn comptime_type(input: TokenStream) -> TokenStream {
210 let tokens: proc_macro2::TokenStream = input.into();
211 quote![ #tokens ].into()
212}
213
214#[proc_macro]
226pub fn comment(input: TokenStream) -> TokenStream {
227 let tokens: proc_macro2::TokenStream = input.into();
228 quote![{ #tokens }].into()
229}
230
231#[proc_macro]
246pub fn terminate(input: TokenStream) -> TokenStream {
247 let tokens: proc_macro2::TokenStream = input.into();
248 quote![{ #tokens }].into()
249}
250
251#[proc_macro_derive(AutotuneKey, attributes(autotune))]
275pub fn derive_autotune_key(input: TokenStream) -> TokenStream {
276 let input = syn::parse(input).unwrap();
277 match generate_autotune_key(input) {
278 Ok(tokens) => tokens.into(),
279 Err(e) => e.into_compile_error().into(),
280 }
281}