1#![deny(warnings)]
116
117extern crate proc_macro;
118
119use proc_macro::TokenStream;
120use proc_macro2::TokenStream as TokenStream2;
121use quote::{format_ident, quote};
122use syn::{
123 Data, DeriveInput, Error, Field, Fields, Index, Variant, parse_macro_input,
124 punctuated::Punctuated, spanned::Spanned, token::Comma,
125};
126
127enum StructFields<'a> {
129 Named(&'a Punctuated<Field, Comma>),
130 Unnamed(&'a Punctuated<Field, Comma>),
131}
132
133fn extract_struct_fields<'a>(
135 input: &'a DeriveInput,
136 trait_name: &str,
137) -> Result<StructFields<'a>, Error> {
138 let name = &input.ident;
139 match &input.data {
140 Data::Struct(data) => match &data.fields {
141 Fields::Named(fields) => Ok(StructFields::Named(&fields.named)),
142 Fields::Unnamed(fields) => Ok(StructFields::Unnamed(&fields.unnamed)),
143 Fields::Unit => Err(Error::new(
144 input.span(),
145 format!("{trait_name} cannot be derived for unit struct `{name}`"),
146 )),
147 },
148 Data::Enum(_) => Err(Error::new(input.span(), enum_mismatch_msg(trait_name, name))),
149 Data::Union(_) => Err(Error::new(
150 input.span(),
151 format!("{trait_name} cannot be derived for union `{name}`"),
152 )),
153 }
154}
155
156fn extract_enum_variants<'a>(
158 input: &'a DeriveInput,
159 trait_name: &str,
160) -> Result<&'a Punctuated<Variant, Comma>, Error> {
161 let name = &input.ident;
162 match &input.data {
163 Data::Enum(data) => Ok(&data.variants),
164 Data::Struct(_) => Err(Error::new(input.span(), struct_mismatch_msg(trait_name, name))),
165 Data::Union(_) => Err(Error::new(
166 input.span(),
167 format!("{trait_name} cannot be derived for union `{name}`"),
168 )),
169 }
170}
171
172fn struct_mismatch_msg(trait_name: &str, name: &syn::Ident) -> String {
173 format!("{trait_name} cannot be derived for struct `{name}`")
174}
175
176fn enum_mismatch_msg(trait_name: &str, name: &syn::Ident) -> String {
177 format!("{trait_name} cannot be derived for enum `{name}`")
178}
179
180fn ensure_no_explicit_discriminants(
182 variants: &Punctuated<Variant, Comma>,
183 trait_name: &str,
184 enum_name: &syn::Ident,
185) -> Result<(), Error> {
186 for variant in variants {
187 if variant.discriminant.is_some() {
188 return Err(Error::new(
189 variant.span(),
190 format!(
191 "{trait_name} cannot be derived for enum `{enum_name}` with explicit \
192 discriminants"
193 ),
194 ));
195 }
196 }
197 Ok(())
198}
199
200#[proc_macro_derive(DeriveFromFeltRepr)]
219pub fn derive_from_felt_repr(input: TokenStream) -> TokenStream {
220 let input = parse_macro_input!(input as DeriveInput);
221
222 let expanded = derive_from_felt_repr_impl(
223 &input,
224 quote!(miden_field_repr),
225 quote!(miden_field_repr::Felt),
226 );
227 match expanded {
228 Ok(ts) => ts,
229 Err(err) => err.into_compile_error().into(),
230 }
231}
232
233fn derive_from_felt_repr_impl(
234 input: &DeriveInput,
235 felt_repr_crate: TokenStream2,
236 felt_ty: TokenStream2,
237) -> Result<TokenStream, Error> {
238 let name = &input.ident;
239 let generics = &input.generics;
240 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
241
242 let trait_name = "FromFeltRepr";
243 let expanded = match &input.data {
244 Data::Struct(_) => match extract_struct_fields(input, trait_name)? {
245 StructFields::Named(fields) => {
246 let field_names: Vec<_> =
247 fields.iter().map(|field| field.ident.as_ref().unwrap()).collect();
248 let field_types: Vec<_> = fields.iter().map(|field| &field.ty).collect();
249 quote! {
250 impl #impl_generics #felt_repr_crate::FromFeltRepr for #name #ty_generics #where_clause {
251 #[inline(always)]
252 fn from_felt_repr(reader: &mut #felt_repr_crate::FeltReader<'_>) -> Self {
253 Self {
254 #(#field_names: <#field_types as #felt_repr_crate::FromFeltRepr>::from_felt_repr(reader)),*
255 }
256 }
257 }
258 }
259 }
260 StructFields::Unnamed(fields) => {
261 let field_types: Vec<_> = fields.iter().map(|field| &field.ty).collect();
262 let reads = field_types.iter().map(|ty| {
263 quote! { <#ty as #felt_repr_crate::FromFeltRepr>::from_felt_repr(reader) }
264 });
265 quote! {
266 impl #impl_generics #felt_repr_crate::FromFeltRepr for #name #ty_generics #where_clause {
267 #[inline(always)]
268 fn from_felt_repr(reader: &mut #felt_repr_crate::FeltReader<'_>) -> Self {
269 Self(#(#reads),*)
270 }
271 }
272 }
273 }
274 },
275 Data::Enum(_) => {
276 let variants = extract_enum_variants(input, trait_name)?;
277 ensure_no_explicit_discriminants(variants, trait_name, name)?;
278
279 let arms = variants.iter().enumerate().map(|(variant_ordinal, variant)| {
280 let variant_ident = &variant.ident;
281 let tag = variant_ordinal as u32;
282 match &variant.fields {
283 Fields::Unit => quote! { #tag => Self::#variant_ident },
284 Fields::Unnamed(fields) => {
285 let field_types: Vec<_> = fields.unnamed.iter().map(|f| &f.ty).collect();
286 let reads = field_types.iter().map(|ty| {
287 quote! { <#ty as #felt_repr_crate::FromFeltRepr>::from_felt_repr(reader) }
288 });
289 quote! { #tag => Self::#variant_ident(#(#reads),*) }
290 }
291 Fields::Named(fields) => {
292 let field_idents: Vec<_> = fields
293 .named
294 .iter()
295 .map(|f| f.ident.as_ref().expect("named field"))
296 .collect();
297 let field_types: Vec<_> = fields.named.iter().map(|f| &f.ty).collect();
298 let reads = field_idents.iter().zip(field_types.iter()).map(|(ident, ty)| {
299 quote! { #ident: <#ty as #felt_repr_crate::FromFeltRepr>::from_felt_repr(reader) }
300 });
301 quote! { #tag => Self::#variant_ident { #(#reads),* } }
302 }
303 }
304 });
305
306 quote! {
307 impl #impl_generics #felt_repr_crate::FromFeltRepr for #name #ty_generics #where_clause {
308 #[inline(always)]
309 fn from_felt_repr(reader: &mut #felt_repr_crate::FeltReader<'_>) -> Self {
310 let tag: u32 = <u32 as #felt_repr_crate::FromFeltRepr>::from_felt_repr(reader);
311 match tag {
312 #(#arms,)*
313 other => panic!("Unknown `{}` enum variant tag: {}", stringify!(#name), other),
314 }
315 }
316 }
317 }
318 }
319 Data::Union(_) => {
320 return Err(Error::new(
321 input.span(),
322 format!("{trait_name} cannot be derived for union `{name}`"),
323 ));
324 }
325 };
326
327 let expanded = quote! {
328 #expanded
329
330 impl #impl_generics From<&[#felt_ty]> for #name #ty_generics #where_clause {
331 #[inline(always)]
332 fn from(felts: &[#felt_ty]) -> Self {
333 let mut reader = #felt_repr_crate::FeltReader::new(felts);
334 <Self as #felt_repr_crate::FromFeltRepr>::from_felt_repr(&mut reader)
335 }
336 }
337 };
338
339 Ok(expanded.into())
340}
341
342#[proc_macro_derive(DeriveToFeltRepr)]
361pub fn derive_to_felt_repr(input: TokenStream) -> TokenStream {
362 let input = parse_macro_input!(input as DeriveInput);
363
364 match derive_to_felt_repr_impl(&input, quote!(miden_field_repr)) {
365 Ok(ts) => ts,
366 Err(err) => err.into_compile_error().into(),
367 }
368}
369
370fn derive_to_felt_repr_impl(
371 input: &DeriveInput,
372 felt_repr_crate: TokenStream2,
373) -> Result<TokenStream, Error> {
374 let name = &input.ident;
375 let generics = &input.generics;
376 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
377
378 let trait_name = "ToFeltRepr";
379 let expanded = match &input.data {
380 Data::Struct(_) => match extract_struct_fields(input, trait_name)? {
381 StructFields::Named(fields) => {
382 let field_names: Vec<_> =
383 fields.iter().map(|field| field.ident.as_ref().unwrap()).collect();
384 quote! {
385 impl #impl_generics #felt_repr_crate::ToFeltRepr for #name #ty_generics #where_clause {
386 fn write_felt_repr(&self, writer: &mut #felt_repr_crate::FeltWriter<'_>) {
387 #(#felt_repr_crate::ToFeltRepr::write_felt_repr(&self.#field_names, writer);)*
388 }
389 }
390 }
391 }
392 StructFields::Unnamed(fields) => {
393 let field_indexes: Vec<Index> = (0..fields.len()).map(Index::from).collect();
394 quote! {
395 impl #impl_generics #felt_repr_crate::ToFeltRepr for #name #ty_generics #where_clause {
396 fn write_felt_repr(&self, writer: &mut #felt_repr_crate::FeltWriter<'_>) {
397 #(#felt_repr_crate::ToFeltRepr::write_felt_repr(&self.#field_indexes, writer);)*
398 }
399 }
400 }
401 }
402 },
403 Data::Enum(_) => {
404 let variants = extract_enum_variants(input, trait_name)?;
405 ensure_no_explicit_discriminants(variants, trait_name, name)?;
406
407 let arms = variants.iter().enumerate().map(|(variant_ordinal, variant)| {
408 let variant_ident = &variant.ident;
409 let tag = variant_ordinal as u32;
410
411 match &variant.fields {
412 Fields::Unit => quote! {
413 Self::#variant_ident => {
414 #felt_repr_crate::ToFeltRepr::write_felt_repr(&(#tag as u32), writer);
415 return;
416 }
417 },
418 Fields::Unnamed(fields) => {
419 let bindings: Vec<_> = (0..fields.unnamed.len())
420 .map(|i| format_ident!("__field{i}"))
421 .collect();
422 quote! {
423 Self::#variant_ident(#(#bindings),*) => {
424 #felt_repr_crate::ToFeltRepr::write_felt_repr(&(#tag as u32), writer);
425 #(#felt_repr_crate::ToFeltRepr::write_felt_repr(#bindings, writer);)*
426 return;
427 }
428 }
429 }
430 Fields::Named(fields) => {
431 let bindings: Vec<_> = fields
432 .named
433 .iter()
434 .map(|f| f.ident.as_ref().expect("named field"))
435 .collect();
436 quote! {
437 Self::#variant_ident { #(#bindings),* } => {
438 #felt_repr_crate::ToFeltRepr::write_felt_repr(&(#tag as u32), writer);
439 #(#felt_repr_crate::ToFeltRepr::write_felt_repr(#bindings, writer);)*
440 return;
441 }
442 }
443 }
444 }
445 });
446
447 quote! {
448 impl #impl_generics #felt_repr_crate::ToFeltRepr for #name #ty_generics #where_clause {
449 #[inline(always)]
450 fn write_felt_repr(&self, writer: &mut #felt_repr_crate::FeltWriter<'_>) {
451 match self {
452 #(#arms,)*
453 }
454 }
455 }
456 }
457 }
458 Data::Union(_) => {
459 return Err(Error::new(
460 input.span(),
461 format!("{trait_name} cannot be derived for union `{name}`"),
462 ));
463 }
464 };
465
466 Ok(expanded.into())
467}