1#![forbid(unsafe_code)]
31#![forbid(rustdoc::broken_intra_doc_links)]
32#![forbid(rustdoc::private_intra_doc_links)]
33#![forbid(missing_docs)]
34#![forbid(rustdoc::missing_crate_level_docs)]
35#![forbid(rustdoc::private_doc_tests)]
37#![forbid(rustdoc::invalid_codeblock_attributes)]
38#![forbid(rustdoc::invalid_html_tags)]
39#![forbid(rustdoc::invalid_rust_codeblocks)]
40#![forbid(rustdoc::bare_urls)]
41#![forbid(rustdoc::unescaped_backticks)]
42#![forbid(rustdoc::redundant_explicit_links)]
43
44use proc_macro::TokenStream;
45use quote::quote;
46use syn::{parse_macro_input, Data, DataStruct, DeriveInput};
47
48#[proc_macro_derive(Module)]
54pub fn derive_module(input: TokenStream) -> TokenStream {
55 let input = parse_macro_input!(input as DeriveInput);
56 let struct_name = &input.ident;
57
58 let mut field_iterators = quote! {
59 trait __MarkerTraitRef: Sized {
60 fn __iterate_by_ref(self, res: &mut Vec<(String, &zyx::Tensor)>, label: &str) {}
61 }
62
63 struct __MarkerStructRef<T>(T);
64
65 impl<'a, T: zyx::Module> __MarkerStructRef<&'a T> {
66 fn __iterate_by_ref(self, res: &mut Vec<(String, &'a zyx::Tensor)>, label: &str) {
67 res.extend(self.0.iter_tensors().map(|(k, t)| (format!("{label}.{k}"), t)));
68 }
69 }
70
71 impl<'a, T> __MarkerTraitRef for __MarkerStructRef<&'a T>{}
72
73 let mut res = Vec::<(String, &zyx::Tensor)>::new();
74 };
75
76 if let Data::Struct(DataStruct { fields, .. }) = &input.data {
77 for field in fields.iter() {
78 let field_name = match &field.ident {
79 Some(ident) => ident,
80 None => panic!("Unnamed fields are not supported"),
81 };
82 let field_name_str = field_name.to_string();
83
84 let field_ty: &syn::Type = &field.ty;
85
86 use std::string::ToString;
87 if quote! { #field_ty }.to_string() == "Tensor" {
88 field_iterators = quote! {
89 #field_iterators
90 res.push((#field_name_str.to_string(), &self.#field_name));
91 }
92 } else if quote! { #field_ty }.to_string() == "Option < Tensor >" {
93 field_iterators = quote! {
94 #field_iterators
95 if let Some(tensor) = &self.#field_name {
96 res.push((#field_name_str.to_string(), tensor));
97 }
98 }
99 } else {
100 field_iterators = quote! {
101 #field_iterators
102 __MarkerStructRef::<&#field_ty>::__iterate_by_ref(__MarkerStructRef(&self.#field_name), &mut res, #field_name_str);
103 };
104 }
105 }
106 }
107
108 let mut mut_field_iterators = quote! {
109 trait __MarkerTraitRef: Sized {
110 fn __iterate_by_ref(mut self, res: &mut Vec<(String, &mut zyx::Tensor)>, label: &str) {}
111 }
112
113 struct __MarkerStructRef<T>(T);
114
115 impl<'a, T: zyx::Module> __MarkerStructRef<&'a mut T> {
116 fn __iterate_by_ref(mut self, res: &mut Vec<(String, &'a mut zyx::Tensor)>, label: &str) {
117 res.extend(self.0.iter_tensors_mut().map(|(k, t)| (format!("{label}.{k}"), t)));
118 }
119 }
120
121 impl<'a, T> __MarkerTraitRef for __MarkerStructRef<&'a mut T>{}
122
123 let mut res = Vec::<(String, &mut zyx::Tensor)>::new();
124 };
125
126 if let Data::Struct(DataStruct { fields, .. }) = &input.data {
127 for field in fields.iter() {
128 let field_name = match &field.ident {
129 Some(ident) => ident,
130 None => panic!("Unnamed fields are not supported"),
131 };
132 let field_name_str = field_name.to_string();
133
134 let field_ty: &syn::Type = &field.ty;
135
136 use std::string::ToString;
137 if quote! { #field_ty }.to_string() == "Tensor" {
138 mut_field_iterators = quote! {
139 #mut_field_iterators
140 res.push((#field_name_str.to_string(), &mut self.#field_name));
141 }
142 } else if quote! { #field_ty }.to_string() == "Option < Tensor >" {
143 mut_field_iterators = quote! {
144 #mut_field_iterators
145 if let Some(tensor) = &mut self.#field_name {
146 res.push((#field_name_str.to_string(), tensor));
147 }
148 }
149 } else {
150 mut_field_iterators = quote! {
151 #mut_field_iterators
152 __MarkerStructRef::<&mut #field_ty>::__iterate_by_ref(__MarkerStructRef(&mut self.#field_name), &mut res, #field_name_str);
153 };
154 }
155 }
156 }
157
158 let expanded = quote! {
159 impl zyx::Module for #struct_name {
160 fn iter<'a>(&'a self) -> impl Iterator<Item = &'a zyx::Tensor> {
161 self.into_iter()
162 }
163
164 fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut zyx::Tensor> {
165 self.into_iter()
166 }
167
168 fn iter_tensors<'a>(&'a self) -> impl Iterator<Item = (String, &'a zyx::Tensor)> {
169 #field_iterators
170 res.into_iter()
171 }
172
173 fn iter_tensors_mut<'a>(&'a mut self) -> impl Iterator<Item = (String, &'a mut zyx::Tensor)> {
174 #mut_field_iterators
175 res.into_iter()
176 }
177 }
178 };
179
180 let mut field_iterators = quote! {
182 trait __MarkerTraitRef<'a> {
183 fn __iterate_by_ref(&self, res: &mut Vec<&'a zyx::Tensor>) {}
184 }
185
186 struct __MarkerStructRef<T: Copy>(T);
187
188 impl<'a, T: IntoIterator<Item = &'a zyx::Tensor> + Copy> __MarkerStructRef<T> {
189 fn __iterate_by_ref(&self, res: &mut Vec<&'a zyx::Tensor>) {
190 res.extend(self.0.into_iter());
191 }
192 }
193
194 impl<'a, T: Copy> __MarkerTraitRef<'a> for __MarkerStructRef<T>{}
195
196 let mut res = Vec::<&zyx::Tensor>::new();
197 };
198
199 if let Data::Struct(DataStruct { fields, .. }) = &input.data {
200 for field in fields.iter() {
201 let field_name = match &field.ident {
202 Some(ident) => ident,
203 None => panic!("Unnamed fields are not supported"),
204 };
205 let field_ty: &syn::Type = &field.ty;
206 use std::string::ToString;
207 if quote! { #field_ty }.to_string() == "Tensor" {
208 field_iterators = quote! {
209 #field_iterators
210 res.push(&self.#field_name);
211 }
212 } else {
213 field_iterators = quote! {
214 #field_iterators
215 __MarkerStructRef::<&#field_ty>::__iterate_by_ref(&__MarkerStructRef(&self.#field_name), &mut res);
216 };
217 }
218 }
219 }
220
221 let expanded = quote! {
222 #expanded
223
224 impl<'a> IntoIterator for &'a #struct_name {
225 type Item = &'a zyx::Tensor;
226 type IntoIter = std::vec::IntoIter<&'a zyx::Tensor>;
227
228 fn into_iter(self) -> Self::IntoIter {
229 #field_iterators
230 res.into_iter()
231 }
232 }
233 };
234
235 let mut field_iterators = quote! {
237 trait MarkerTraitMut<'a>: Sized {
238 fn iterate_by_mut(mut self, res: &mut Vec<&'a mut zyx::Tensor>) {}
239 }
240
241 struct MarkerStructMut<T>(T);
242
243 impl<'a, T: IntoIterator<Item = &'a mut zyx::Tensor>> MarkerStructMut<T> {
244 fn iterate_by_mut(mut self, res: &mut Vec<&'a mut zyx::Tensor>) {
245 res.extend(self.0.into_iter());
246 }
247 }
248
249 impl<'a, T> MarkerTraitMut<'a> for MarkerStructMut<T>{}
250
251 let mut res = Vec::<&mut zyx::Tensor>::new();
252 };
253
254 if let Data::Struct(DataStruct { fields, .. }) = &input.data {
255 for field in fields.iter() {
256 let field_name = match &field.ident {
257 Some(ident) => ident,
258 None => panic!("Unnamed fields are not supported"),
259 };
260 let field_ty: &syn::Type = &field.ty;
261 use std::string::ToString;
262 if quote! { #field_ty }.to_string() == "Tensor" {
263 field_iterators = quote! {
264 #field_iterators
265 res.push(&mut self.#field_name);
266 }
267 } else {
268 field_iterators = quote! {
269 #field_iterators
270 MarkerStructMut::<&mut #field_ty>::iterate_by_mut(MarkerStructMut(&mut self.#field_name), &mut res);
271 };
272 }
273 }
274 }
275
276 let expanded = quote! {
277 #expanded
278
279 impl<'a> IntoIterator for &'a mut #struct_name {
280 type Item = &'a mut zyx::Tensor;
281 type IntoIter = std::vec::IntoIter<&'a mut zyx::Tensor>;
282
283 fn into_iter(self) -> Self::IntoIter {
284 #field_iterators
285 res.into_iter()
286 }
287 }
288 };
289
290 TokenStream::from(expanded)
291}