cubecl_macros_internal/
lib.rs

1use generate::{
2    op_args::generate_op_args,
3    operation::{generate_opcode, generate_operation},
4};
5use proc_macro::TokenStream;
6use type_hash::type_hash_impl;
7
8mod generate;
9mod parse;
10mod type_hash;
11
12/// *Internal macro*
13///
14/// Generates an implementation of `OperationArgs` for this type. All fields must implement
15/// `FromArgList`.
16#[doc(hidden)]
17#[proc_macro_derive(OperationArgs, attributes(args))]
18pub fn derive_operation_args(input: TokenStream) -> TokenStream {
19    let input = syn::parse(input).unwrap();
20    match generate_op_args(input) {
21        Ok(tokens) => tokens.into(),
22        Err(e) => e.into_compile_error().into(),
23    }
24}
25
26/// Generates reflection info for an operation. Generates an opcode enum and an implementation of
27/// `OperationReflect` that deconstructs and reconstructs the typed version. All variant fields must
28/// implement `OperationArgs`, or `OperationReflect` if the variant is nested. Uses the `operation`
29/// helper attribute.
30///
31/// # Arguments
32///
33/// * `opcode_name` - the name of the generated opcode enum (required)
34/// * `pure` - marks this entire operation as pure
35/// * `commutative` - marks this entire operation as commutative
36///
37/// # Variant arguments
38///
39/// * `pure` - Marks this variant as pure
40/// * `commutative` - Marks this variant as commutative
41///
42#[doc(hidden)]
43#[proc_macro_derive(OperationReflect, attributes(operation))]
44pub fn derive_operation(input: TokenStream) -> TokenStream {
45    let input = syn::parse(input).unwrap();
46    match generate_operation(input) {
47        Ok(tokens) => tokens.into(),
48        Err(e) => e.into_compile_error().into(),
49    }
50}
51
52/// Generates an opcode enum for an operation, without implementation `OperationReflect`. Allows for
53/// manual implementation.
54///
55/// Use `self.__match_opcode()` to get the opcode for an operation.
56///
57/// # Arguments
58///
59/// * `opcode_name` - the name of the generated opcode enum (required)
60#[doc(hidden)]
61#[proc_macro_derive(OperationCode, attributes(operation))]
62pub fn derive_opcode(input: TokenStream) -> TokenStream {
63    let input = syn::parse(input).unwrap();
64    match generate_opcode(input) {
65        Ok(tokens) => tokens.into(),
66        Err(e) => e.into_compile_error().into(),
67    }
68}
69
70#[proc_macro_derive(TypeHash, attributes(type_hash))]
71pub fn derive_type_hash(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
72    let input = syn::parse(input).unwrap();
73    type_hash_impl(input).into()
74}