1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5#[proc_macro_derive(NodeBuilder)]
31pub fn node_builder_derive(input: TokenStream) -> TokenStream {
32 let input = parse_macro_input!(input as DeriveInput);
33
34 let node_name = &input.ident;
35 let builder_name = syn::Ident::new(&format!("{}Builder", node_name), node_name.span());
36
37 let has_config = if let Data::Struct(data) = &input.data {
39 if let Fields::Named(fields) = &data.fields {
40 fields
41 .named
42 .iter()
43 .any(|f| f.ident.as_ref().map(|i| i == "config").unwrap_or(false))
44 } else {
45 false
46 }
47 } else {
48 false
49 };
50
51 let config_type = if has_config {
53 if let Data::Struct(data) = &input.data {
54 if let Fields::Named(fields) = &data.fields {
55 fields
56 .named
57 .iter()
58 .find(|f| f.ident.as_ref().map(|i| i == "config").unwrap_or(false))
59 .map(|f| &f.ty)
60 } else {
61 None
62 }
63 } else {
64 None
65 }
66 } else {
67 None
68 };
69
70 let config_field = if let Some(config_ty) = config_type {
71 quote! {
72 config: Option<#config_ty>,
73 }
74 } else {
75 quote! {}
76 };
77
78 let config_init = if has_config {
79 quote! { config: None, }
80 } else {
81 quote! {}
82 };
83
84 let config_method = if let Some(config_ty) = config_type {
85 quote! {
86 pub fn config(mut self, config: #config_ty) -> Self {
88 self.config = Some(config);
89 self
90 }
91 }
92 } else {
93 quote! {}
94 };
95
96 let config_build = if has_config {
97 quote! {
98 config: self.config.expect("Config must be set before calling build()"),
99 }
100 } else {
101 quote! {}
102 };
103
104 let expanded = quote! {
105 pub struct #builder_name {
106 name: String,
107 inputs: Vec<crate::ir::Argument>,
108 outputs: Vec<crate::ir::Argument>,
109 #config_field
110 }
111
112 impl #builder_name {
113 pub fn new(name: impl Into<String>) -> Self {
115 Self {
116 name: name.into(),
117 inputs: vec![],
118 outputs: vec![],
119 #config_init
120 }
121 }
122
123 pub fn input_tensor(
125 mut self,
126 name: &str,
127 rank: usize,
128 dtype: burn_tensor::DType,
129 ) -> Self {
130 use crate::ir::{Argument, ArgType, TensorType};
131 self.inputs.push(Argument::new(
132 name,
133 ArgType::Tensor(TensorType {
134 dtype,
135 rank,
136 static_shape: None,
137 }),
138 ));
139 self
140 }
141
142 pub fn input_tensor_shape(
144 mut self,
145 name: &str,
146 shape: Vec<usize>,
147 dtype: burn_tensor::DType,
148 ) -> Self {
149 use crate::ir::{Argument, ArgType, TensorType};
150 self.inputs.push(Argument::new(
151 name,
152 ArgType::Tensor(TensorType {
153 dtype,
154 rank: shape.len(),
155 static_shape: Some(shape),
156 }),
157 ));
158 self
159 }
160
161 pub fn input_scalar(mut self, name: &str, dtype: burn_tensor::DType) -> Self {
163 use crate::ir::{Argument, ArgType};
164 self.inputs.push(Argument::new(name, ArgType::Scalar(dtype)));
165 self
166 }
167
168 pub fn input_shape(mut self, name: &str) -> Self {
170 use crate::ir::{Argument, ArgType};
171 self.inputs.push(Argument::new(name, ArgType::Shape(1)));
172 self
173 }
174
175 pub fn output_tensor(
177 mut self,
178 name: &str,
179 rank: usize,
180 dtype: burn_tensor::DType,
181 ) -> Self {
182 use crate::ir::{Argument, ArgType, TensorType};
183 self.outputs.push(Argument::new(
184 name,
185 ArgType::Tensor(TensorType {
186 dtype,
187 rank,
188 static_shape: None,
189 }),
190 ));
191 self
192 }
193
194 pub fn output_scalar(mut self, name: &str, dtype: burn_tensor::DType) -> Self {
196 use crate::ir::{Argument, ArgType};
197 self.outputs.push(Argument::new(name, ArgType::Scalar(dtype)));
198 self
199 }
200
201 pub fn output_shape(mut self, name: &str) -> Self {
203 use crate::ir::{Argument, ArgType};
204 self.outputs.push(Argument::new(name, ArgType::Shape(1)));
205 self
206 }
207
208 pub fn output_shape_with_size(mut self, name: &str, size: usize) -> Self {
210 use crate::ir::{Argument, ArgType};
211 self.outputs.push(Argument::new(name, ArgType::Shape(size)));
212 self
213 }
214
215 pub fn input_const_i64(mut self, name: &str, value: i64) -> Self {
217 use crate::ir::Argument;
218 let arg = Argument::from_const_i64(name, value);
219 self.inputs.push(arg);
220 self
221 }
222
223 #config_method
224
225 pub fn build(self) -> #node_name {
227 #node_name {
228 name: self.name,
229 inputs: self.inputs,
230 outputs: self.outputs,
231 #config_build
232 }
233 }
234 }
235 };
236
237 TokenStream::from(expanded)
238}