1use core::panic;
2
3use darling::FromDeriveInput;
4use error::error_into_token_stream;
5use generate::autotune::{generate_autotune_key, generate_autotune_set};
6use parse::{
7 cube_impl::CubeImpl,
8 cube_trait::{CubeTrait, CubeTraitImpl},
9 cube_type::CubeType,
10 helpers::{RemoveHelpers, ReplaceIndices},
11 kernel::{from_tokens, Launch},
12};
13use proc_macro::TokenStream;
14use quote::quote;
15use syn::{visit_mut::VisitMut, Item};
16
17mod error;
18mod expression;
19mod generate;
20mod operator;
21mod parse;
22mod paths;
23mod scope;
24mod statement;
25
26#[proc_macro_attribute]
44pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream {
45 match cube_impl(args, input.clone()) {
46 Ok(tokens) => tokens,
47 Err(e) => error_into_token_stream(e, input.into()).into(),
48 }
49}
50
51fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
52 let mut item: Item = syn::parse(input)?;
53 let args = from_tokens(args.into())?;
54
55 let tokens = match item.clone() {
56 Item::Fn(kernel) => {
57 let kernel = Launch::from_item_fn(kernel, args)?;
58 RemoveHelpers.visit_item_mut(&mut item);
59 ReplaceIndices.visit_item_mut(&mut item);
60
61 return Ok(TokenStream::from(quote! {
62 #[allow(dead_code, clippy::too_many_arguments)]
63 #item
64 #kernel
65 }));
66 }
67 Item::Trait(kernel_trait) => {
68 let expand_trait = CubeTrait::from_item_trait(kernel_trait)?;
69
70 Ok(TokenStream::from(quote! {
71 #expand_trait
72 }))
73 }
74 Item::Impl(item_impl) => {
75 if item_impl.trait_.is_some() {
76 let mut expand_impl = CubeTraitImpl::from_item_impl(item_impl)?;
77 let expand_impl = expand_impl.to_tokens_mut();
78
79 Ok(TokenStream::from(quote! {
80 #expand_impl
81 }))
82 } else {
83 let mut expand_impl = CubeImpl::from_item_impl(item_impl)?;
84 let expand_impl = expand_impl.to_tokens_mut();
85
86 Ok(TokenStream::from(quote! {
87 #expand_impl
88 }))
89 }
90 }
91 item => Err(syn::Error::new_spanned(
92 item,
93 "`#[cube]` is only supported on traits and functions",
94 ))?,
95 };
96
97 if args.debug.is_present() {
98 match tokens {
99 Ok(tokens) => panic!("{tokens}"),
100 Err(err) => panic!("{err}"),
101 };
102 }
103
104 tokens
105}
106
107#[proc_macro_derive(CubeLaunch, attributes(expand, cube))]
109pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
110 gen_cube_type(input, true)
112}
113
114#[proc_macro_derive(CubeType, attributes(expand, cube))]
116pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
117 gen_cube_type(input, false)
118}
119
120fn gen_cube_type(input: TokenStream, with_launch: bool) -> TokenStream {
121 let parsed = syn::parse(input);
122
123 let input = match &parsed {
124 Ok(val) => val,
125 Err(err) => return err.to_compile_error().into(),
126 };
127
128 let cube_type = match CubeType::from_derive_input(input) {
129 Ok(val) => val,
130 Err(err) => return err.write_errors().into(),
131 };
132
133 cube_type.generate(with_launch).into()
134}
135
136#[proc_macro]
150pub fn comptime(input: TokenStream) -> TokenStream {
151 let tokens: proc_macro2::TokenStream = input.into();
152 quote![{ #tokens }].into()
153}
154
155#[proc_macro]
167pub fn comment(input: TokenStream) -> TokenStream {
168 let tokens: proc_macro2::TokenStream = input.into();
169 quote![{ #tokens }].into()
170}
171
172#[proc_macro_derive(AutotuneKey, attributes(autotune))]
193pub fn derive_autotune_key(input: TokenStream) -> TokenStream {
194 let input = syn::parse(input).unwrap();
195 match generate_autotune_key(input) {
196 Ok(tokens) => tokens.into(),
197 Err(e) => e.into_compile_error().into(),
198 }
199}
200
201#[proc_macro_attribute]
226pub fn tune(args: TokenStream, input: TokenStream) -> TokenStream {
227 match autotune_set_impl(args, input.clone()) {
228 Ok(tokens) => tokens,
229 Err(e) => error_into_token_stream(e, input.into()).into(),
230 }
231}
232
233fn autotune_set_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
234 let item = syn::parse(input)?;
235 let args = from_tokens(args.into())?;
236 Ok(generate_autotune_set(item, args)?.into())
237}