onnx_ir_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5/// Derive macro for generating node builders
6///
7/// Automatically generates a builder with methods for constructing node inputs/outputs.
8///
9/// # Example
10/// ```ignore
11/// #[derive(Debug, Clone, NodeBuilder)]
12/// pub struct AddNode {
13///     pub name: String,
14///     pub inputs: Vec<Argument>,
15///     pub outputs: Vec<Argument>,
16/// }
17/// ```
18///
19/// Generates `AddNodeBuilder` with:
20/// - `new(name)` - Create builder
21/// - `input_tensor(name, rank, dtype)` - Add tensor input (dynamic, no static shape)
22/// - `input_tensor_shape(name, shape, dtype)` - Add tensor input with static shape
23/// - `input_scalar(name, dtype)` - Add scalar input
24/// - `input_shape(name)` - Add shape input
25/// - `output_tensor(name, rank, dtype)` - Add output tensor
26/// - `output_scalar(name, dtype)` - Add scalar output
27/// - `output_shape(name)` - Add shape output
28/// - `config(config)` - Set config (if node has a config field)
29/// - `build()` - Build the node
30#[proc_macro_derive(NodeBuilder)]
31pub fn node_builder_derive(input: TokenStream) -> TokenStream {
32    let input = parse_macro_input!(input as DeriveInput);
33
34    let node_name = &input.ident;
35    let builder_name = syn::Ident::new(&format!("{}Builder", node_name), node_name.span());
36
37    // Check if the struct has a config field
38    let has_config = if let Data::Struct(data) = &input.data {
39        if let Fields::Named(fields) = &data.fields {
40            fields
41                .named
42                .iter()
43                .any(|f| f.ident.as_ref().map(|i| i == "config").unwrap_or(false))
44        } else {
45            false
46        }
47    } else {
48        false
49    };
50
51    // Extract config type if it exists
52    let config_type = if has_config {
53        if let Data::Struct(data) = &input.data {
54            if let Fields::Named(fields) = &data.fields {
55                fields
56                    .named
57                    .iter()
58                    .find(|f| f.ident.as_ref().map(|i| i == "config").unwrap_or(false))
59                    .map(|f| &f.ty)
60            } else {
61                None
62            }
63        } else {
64            None
65        }
66    } else {
67        None
68    };
69
70    let config_field = if let Some(config_ty) = config_type {
71        quote! {
72            config: Option<#config_ty>,
73        }
74    } else {
75        quote! {}
76    };
77
78    let config_init = if has_config {
79        quote! { config: None, }
80    } else {
81        quote! {}
82    };
83
84    let config_method = if let Some(config_ty) = config_type {
85        quote! {
86            /// Set the configuration
87            pub fn config(mut self, config: #config_ty) -> Self {
88                self.config = Some(config);
89                self
90            }
91        }
92    } else {
93        quote! {}
94    };
95
96    let config_build = if has_config {
97        quote! {
98            config: self.config.expect("Config must be set before calling build()"),
99        }
100    } else {
101        quote! {}
102    };
103
104    let expanded = quote! {
105        pub struct #builder_name {
106            name: String,
107            inputs: Vec<crate::ir::Argument>,
108            outputs: Vec<crate::ir::Argument>,
109            #config_field
110        }
111
112        impl #builder_name {
113            /// Create a new builder
114            pub fn new(name: impl Into<String>) -> Self {
115                Self {
116                    name: name.into(),
117                    inputs: vec![],
118                    outputs: vec![],
119                    #config_init
120                }
121            }
122
123            /// Add a tensor input (dynamic, no static shape)
124            pub fn input_tensor(
125                mut self,
126                name: &str,
127                rank: usize,
128                dtype: burn_tensor::DType,
129            ) -> Self {
130                use crate::ir::{Argument, ArgType, TensorType};
131                self.inputs.push(Argument::new(
132                    name,
133                    ArgType::Tensor(TensorType {
134                        dtype,
135                        rank,
136                        static_shape: None,
137                    }),
138                ));
139                self
140            }
141
142            /// Add a tensor input with static shape
143            pub fn input_tensor_shape(
144                mut self,
145                name: &str,
146                shape: Vec<usize>,
147                dtype: burn_tensor::DType,
148            ) -> Self {
149                use crate::ir::{Argument, ArgType, TensorType};
150                self.inputs.push(Argument::new(
151                    name,
152                    ArgType::Tensor(TensorType {
153                        dtype,
154                        rank: shape.len(),
155                        static_shape: Some(shape),
156                    }),
157                ));
158                self
159            }
160
161            /// Add a scalar input
162            pub fn input_scalar(mut self, name: &str, dtype: burn_tensor::DType) -> Self {
163                use crate::ir::{Argument, ArgType};
164                self.inputs.push(Argument::new(name, ArgType::Scalar(dtype)));
165                self
166            }
167
168            /// Add a shape input (rank 1 by default, since shapes are 1D arrays)
169            pub fn input_shape(mut self, name: &str) -> Self {
170                use crate::ir::{Argument, ArgType};
171                self.inputs.push(Argument::new(name, ArgType::Shape(1)));
172                self
173            }
174
175            /// Add an output tensor
176            pub fn output_tensor(
177                mut self,
178                name: &str,
179                rank: usize,
180                dtype: burn_tensor::DType,
181            ) -> Self {
182                use crate::ir::{Argument, ArgType, TensorType};
183                self.outputs.push(Argument::new(
184                    name,
185                    ArgType::Tensor(TensorType {
186                        dtype,
187                        rank,
188                        static_shape: None,
189                    }),
190                ));
191                self
192            }
193
194            /// Add a scalar output
195            pub fn output_scalar(mut self, name: &str, dtype: burn_tensor::DType) -> Self {
196                use crate::ir::{Argument, ArgType};
197                self.outputs.push(Argument::new(name, ArgType::Scalar(dtype)));
198                self
199            }
200
201            /// Add a shape output (size 1 by default, since shapes are 1D arrays of length 1)
202            pub fn output_shape(mut self, name: &str) -> Self {
203                use crate::ir::{Argument, ArgType};
204                self.outputs.push(Argument::new(name, ArgType::Shape(1)));
205                self
206            }
207
208            /// Add a shape output with a specific size (number of elements in the shape array)
209            pub fn output_shape_with_size(mut self, name: &str, size: usize) -> Self {
210                use crate::ir::{Argument, ArgType};
211                self.outputs.push(Argument::new(name, ArgType::Shape(size)));
212                self
213            }
214
215            /// Add a constant i64 scalar input with a known value
216            pub fn input_const_i64(mut self, name: &str, value: i64) -> Self {
217                use crate::ir::Argument;
218                let arg = Argument::from_const_i64(name, value);
219                self.inputs.push(arg);
220                self
221            }
222
223            #config_method
224
225            /// Build the node
226            pub fn build(self) -> #node_name {
227                #node_name {
228                    name: self.name,
229                    inputs: self.inputs,
230                    outputs: self.outputs,
231                    #config_build
232                }
233            }
234        }
235    };
236
237    TokenStream::from(expanded)
238}