1#![doc = include_str!("../README.md")]
16
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote, ToTokens};
19use syn::{
20 parse_macro_input, punctuated::Punctuated, spanned::Spanned, token::Comma, DataEnum,
21 DataStruct, DeriveInput, Error, Fields, GenericParam, Generics, Ident, Index, Lifetime,
22 LifetimeDef, Result, Variant, Visibility,
23};
24
25#[derive(Default, Debug)]
26struct GenericFragments {
27 impl_generics: TokenStream,
29 ty_generics: TokenStream,
30 where_clause: TokenStream,
31
32 ref_name: TokenStream,
34
35 ref_impl_generics: TokenStream,
37 ref_ty_generics: TokenStream,
38 ref_where_clause: TokenStream,
40
41 phantom_types: TokenStream,
42}
43
44fn to_tokens<T: ToTokens>(t: &T) -> TokenStream {
45 let mut tokens = TokenStream::new();
46 t.to_tokens(&mut tokens);
47 tokens
48}
49
50fn warning_inhibit() -> TokenStream {
51 quote! {
52 #[allow(unused_parens, non_camel_case_types, unused_variables, dead_code, missing_docs)]
53 }
54}
55
56fn make_generics(ref_name: &Ident, generics: &Generics) -> GenericFragments {
57 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
58
59 let mut ref_generics = generics.clone();
62 let zeroio_lifetime = GenericParam::Lifetime(LifetimeDef::new(Lifetime::new(
63 "'zeroio_deserialize",
64 generics.span().unwrap().into(),
65 )));
66 ref_generics.params.push(zeroio_lifetime);
67 let (ref_impl_generics, ref_ty_generics, ref_where_clause) = ref_generics.split_for_impl();
68 let mut phantom_types = Punctuated::<TokenStream, Comma>::from_iter(
73 generics.params.iter().filter_map(|p| match p {
74 GenericParam::Type(p) => Some(to_tokens(&p.ident)),
75 _ => None,
76 }),
77 );
78 phantom_types.push(quote! {&'zeroio_deserialize ()});
79
80 GenericFragments {
81 impl_generics: to_tokens(&impl_generics),
82 ty_generics: to_tokens(&ty_generics),
83 where_clause: to_tokens(&where_clause),
84 ref_name: to_tokens(&ref_name),
85 ref_impl_generics: to_tokens(&ref_impl_generics),
86 ref_ty_generics: to_tokens(&ref_ty_generics),
87 ref_where_clause: to_tokens(&ref_where_clause),
89 phantom_types: to_tokens(&phantom_types),
90 }
91}
92
93#[derive(Debug)]
95struct StructFragments {
96 field_name: Vec<Ident>,
97
98 fixed_words: TokenStream,
100 tot_len: TokenStream,
101
102 fill: TokenStream,
104
105 accessors: TokenStream,
107
108 deserialize_impl: TokenStream,
110
111 tuple_id: Vec<TokenStream>,
113
114 declare_ref: TokenStream,
116
117 from_ref: TokenStream,
119}
120
121fn make_struct_impls(
125 selfname: &TokenStream,
126 vis: &Visibility,
127 fields: &Fields,
128 generics: &GenericFragments,
129 is_enum: bool,
130) -> StructFragments {
131 let field_name: Vec<Ident> = fields.iter().flat_map(|f| &f.ident).cloned().collect();
133 let field_ty: Vec<TokenStream> = fields.iter().map(|f| to_tokens(&f.ty)).collect();
134
135 let tuple_id: Vec<_> = (0..fields.len())
137 .map(|idx| match is_enum {
138 true => to_tokens(&format_ident!("elem{idx}")),
139 false => {
140 let idx = Index::from(idx);
142 quote! { self . #idx }
143 }
144 })
145 .collect();
146
147 let field_offsets: Vec<_> = (0..field_ty.len())
150 .map(|idx| {
151 if idx == 0 {
152 quote! { 0 }
153 } else {
154 let part_ty = &field_ty[0..idx];
155 quote! {
156 #( <#part_ty as risc0_zeroio::Deserialize<'_>> :: FIXED_WORDS)+*
157 }
158 }
159 })
160 .collect();
161
162 let fixed_words = match &fields {
163 Fields::Unit => quote! {0},
164 Fields::Named(_) | Fields::Unnamed(_) => quote! {
165 #(<#field_ty as risc0_zeroio::Serialize>::FIXED_WORDS)+*
166 },
167 };
168 let selfref = if is_enum {
169 quote! {}
170 } else {
171 quote! {self.}
172 };
173 let tot_len = match &fields {
174 Fields::Unit => quote! {0},
175 Fields::Named(_) => quote! {
176 #(#selfref #field_name . tot_len())+*
177 },
178 Fields::Unnamed(_) => quote! {
179 #(#tuple_id . tot_len())+*
180 },
181 };
182
183 let fill = match &fields {
184 Fields::Unit => {
185 quote! { Ok(()) }
186 }
187 Fields::Named(_) => {
188 quote! {
189 let pos: usize = 0;
190 #(
191 #selfref #field_name . fill(&mut _buf.descend(
192 pos, <#field_ty as risc0_zeroio::Serialize>::FIXED_WORDS)?, _a)?;
193 let pos = pos + <#field_ty as risc0_zeroio::Serialize>::FIXED_WORDS;
194 )*
195
196 Ok(())
197 }
198 }
199 Fields::Unnamed(_) => {
200 quote! {
201 let pos: usize = 0;
202 #(
203 #tuple_id . fill(&mut _buf.descend(
204 pos, <#field_ty as risc0_zeroio::Serialize>::FIXED_WORDS)?, _a)?;
205 let pos = pos + <#field_ty as risc0_zeroio::Serialize>::FIXED_WORDS;
206 )*
207
208 Ok(())
209 }
210 }
211 };
212
213 let accessors = match &fields {
215 Fields::Unit => TokenStream::new(),
216 Fields::Named(_) => quote! {
217 #(
219 #vis fn #field_name(&self) ->
220 <#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::RefType {
221 <#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::deserialize_from(
222 &self.buf[#field_offsets ..]
223 )
224 }
225 )*
226 },
227 Fields::Unnamed(_) => {
228 let tuple_method = (0..fields.len()).map(|idx| format_ident!("elem{}", idx));
229
230 quote! {#(
231 #vis fn #tuple_method(&self) -> <#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::RefType {
232 <#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::deserialize_from(
233 &self.buf[#field_offsets ..]
234 )
235 }
236 )*}
237 }
238 };
239
240 let from_ref = match &fields {
241 Fields::Unit => {
242 quote! { #selfname }
243 }
244 Fields::Named(_) => quote! {
245 #selfname{#(
246 #field_name:
247 <#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::from_ref(
248 &<#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::deserialize_from(&_val.buf[#field_offsets ..]))
249 ),*}
250 },
251 Fields::Unnamed(_) => quote! {
252 #selfname(#(
253 <#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::from_ref(
254 &<#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::deserialize_from(&_val.buf[#field_offsets ..]))
255 ),*)},
256 };
257
258 let GenericFragments {
259 ref_name,
260 ref_ty_generics,
261 phantom_types,
262 ..
263 } = generics;
264
265 let inhibit_warns = warning_inhibit();
266 let declare_ref = quote! {
267 #inhibit_warns
268 #vis struct #ref_name #ref_ty_generics {
269 buf: &'zeroio_deserialize [u32],
270 phantom: core::marker::PhantomData <(#phantom_types)>,
271 }
272 };
273
274 let deserialize_impl = match &fields {
275 Fields::Unit => quote! {
276
277 type RefType = #ref_name #ref_ty_generics;
278
279 const FIXED_WORDS : usize = 0;
280
281 fn deserialize_from(_buf: &'zeroio_deserialize [u32]) -> Self::RefType {
282 Self::RefType{phantom: core::marker::PhantomData}
283 }
284
285 fn from_ref(_val: &Self::RefType) -> Self {
286 #from_ref
287 }
288 },
289 Fields::Named(_) | Fields::Unnamed(_) => quote! {
290 type RefType = #ref_name #ref_ty_generics;
291
292 const FIXED_WORDS : usize =
293 #(<#field_ty as risc0_zeroio::Deserialize<'_>>::FIXED_WORDS)+* ;
294
295 fn deserialize_from(_buf: &'zeroio_deserialize [u32]) -> Self::RefType {
296 Self::RefType { buf: _buf, phantom: core::marker::PhantomData }
297 }
298
299 fn from_ref(_val: &Self::RefType) -> Self {
300 #from_ref
301 }
302 },
303 };
304
305 StructFragments {
306 tot_len,
307 fill,
308 declare_ref,
309 deserialize_impl,
310 tuple_id,
311 accessors,
312 fixed_words,
313 field_name,
314 from_ref,
315 }
316}
317
318fn emit_serialize_struct(input: &DeriveInput, st: &DataStruct) -> Result<TokenStream> {
319 let name = &input.ident;
321
322 let genfrags @ GenericFragments {
323 impl_generics,
324 ty_generics,
325 where_clause,
326 ..
327 } = &make_generics(&format_ident!("{}Ref", name), &input.generics);
328
329 let StructFragments {
330 fixed_words,
331 tot_len,
332 fill,
333 ..
334 } = make_struct_impls("e! {#name}, &input.vis, &st.fields, &genfrags, false);
335
336 let inhibit_warns = warning_inhibit();
338 let expanded = quote! {
339 #inhibit_warns
340 impl #impl_generics risc0_zeroio::Serialize for #name #ty_generics #where_clause {
341 const FIXED_WORDS : usize = #fixed_words;
342
343 fn tot_len(&self) -> usize { #tot_len }
344
345 fn fill(&self, _buf: & mut risc0_zeroio::AllocBuf,
346 _a: &mut risc0_zeroio::Alloc) -> risc0_zeroio::Result<()> {
347 #fill
348 }
349 }
350 };
351
352 Ok(expanded.into())
353}
354
355fn emit_deserialize_struct(input: &DeriveInput, fields: &Fields) -> Result<TokenStream> {
356 let name = &input.ident;
358
359 let genfrags @ GenericFragments {
360 ty_generics,
361 where_clause,
362 ref_name,
363 ref_impl_generics,
364 ref_ty_generics,
365 ref_where_clause,
366 ..
367 } = &make_generics(&format_ident!("{}Ref", name), &input.generics);
368
369 let StructFragments {
370 accessors,
371 declare_ref,
372 deserialize_impl,
373 ..
374 } = make_struct_impls("e! {#name}, &input.vis, fields, &genfrags, false);
375
376 let inhibit_warns = warning_inhibit();
377 let vis = &input.vis;
378 let expanded = quote! {
379 #declare_ref
380
381 #inhibit_warns
382 impl #ref_impl_generics #ref_name #ref_ty_generics #ref_where_clause {
383 #accessors
384
385 #vis fn into_orig(&self) -> #name #ty_generics {
386 <#name #ty_generics as risc0_zeroio::Deserialize<'zeroio_deserialize>>::from_ref(&self)
387 }
388 }
389
390 #inhibit_warns
392 impl #ref_impl_generics risc0_zeroio::Deserialize<'zeroio_deserialize> for #name #ty_generics #where_clause {
393 #deserialize_impl
394 }
395 };
396
397 Ok(expanded.into())
398}
399
400fn make_var_generated_type(name: &Ident, var_name: &Ident) -> Ident {
401 format_ident!("{}Ref", format!("{}::{}", name, var_name).replace(":", "_"))
402}
403
404struct VarFragments {
406 var: Variant,
407
408 var_id: usize,
409 var_name: Ident,
410 var_ref_ty: Ident,
412
413 st: StructFragments,
414}
415
416fn make_var_frags(name: &Ident, vis: &Visibility, var: &Variant, var_id: usize) -> VarFragments {
417 let var_name = var.ident.clone();
418 let var_ref_ty = make_var_generated_type(name, &var_name);
419 let generics = make_generics(&var_ref_ty, &Generics::default());
420 let st = make_struct_impls(
421 "e! {#name :: #var_name},
422 vis,
423 &var.fields,
424 &generics,
425 true,
426 );
427 VarFragments {
428 var: var.clone(),
429 var_id,
430 var_name,
431 var_ref_ty,
432 st,
433 }
434}
435
436fn make_enum_frags(name: &Ident, vis: &Visibility, en: &DataEnum) -> Vec<VarFragments> {
437 en.variants
438 .iter()
439 .enumerate()
440 .map(|(var_id, var)| make_var_frags(name, vis, var, var_id))
441 .collect()
442}
443
444fn emit_serialize_enum(input: &DeriveInput, en: &DataEnum) -> Result<TokenStream> {
445 let name = &input.ident;
447
448 let GenericFragments {
449 impl_generics,
450 ty_generics,
451 where_clause,
452 ..
453 } = make_generics(&format_ident!("{}Ref", name), &input.generics);
454
455 let vars = make_enum_frags(name, &input.vis, en);
456
457 let match_tot_len = vars.iter().map(
458 |VarFragments {
459 var,
460 var_name,
461 st:
462 StructFragments {
463 tot_len,
464 field_name,
465 tuple_id,
466 ..
467 },
468 ..
469 }| {
470 match var.fields {
471 Fields::Unit => quote! {
472 #name :: #var_name => #tot_len
473 },
474 Fields::Named(_) => quote! {
475 #name :: #var_name{ #(#field_name),* } => #tot_len
476 },
477 Fields::Unnamed(_) => quote! {
478 #name :: #var_name( #(#tuple_id),* ) => #tot_len
479 },
480 }
481 },
482 );
483
484 let match_and_fill = vars.iter().map(
485 |VarFragments {
486 var,
487 var_id,
488 var_name,
489 st:
490 StructFragments {
491 fill,
492 field_name,
493 tuple_id,
494 fixed_words,
495 ..
496 },
497 ..
498 }| {
499 let var_id = *var_id as u32;
500 let enumfill = quote! {
501 let mut vardata = _a.alloc(#fixed_words)?;
502 {
503 let _buf = &mut vardata;
504 #fill
505 }?;
506 _enumbuf.fill_from([#var_id, vardata.rel_ptr_from(_enumbuf)])
507 };
508 match var.fields {
509 Fields::Unit => quote! {
510 #name :: #var_name => { #enumfill }
511 },
512 Fields::Named(_) => quote! {
513 #name :: #var_name{ #(#field_name),* } => { #enumfill }
514 },
515 Fields::Unnamed(_) => quote! {
516 #name :: #var_name( #(#tuple_id).* ) => { #enumfill }
517 },
518 }
519 },
520 );
521
522 let inhibit_warns = warning_inhibit();
524 let expanded = quote! {
525 #inhibit_warns
526 impl #impl_generics risc0_zeroio::Serialize for #name #ty_generics #where_clause {
527 const FIXED_WORDS: usize = 2;
529
530 fn tot_len(&self) -> usize {
531 <Self as risc0_zeroio::Serialize>::FIXED_WORDS + match self {
532 #(#match_tot_len,)*
533 }
534 }
535
536 fn fill(&self, _enumbuf: &mut risc0_zeroio::AllocBuf, _a: &mut risc0_zeroio::Alloc)
537 -> risc0_zeroio::Result<()> {
538 match self {
539 #(#match_and_fill,)*
540 }
541 }
542 }
543 };
544 Ok(expanded.into())
545}
546
547fn emit_deserialize_enum(input: &DeriveInput, en: &DataEnum) -> Result<TokenStream> {
548 let name = &input.ident;
550 let vis = &input.vis;
551
552 let GenericFragments {
553 ty_generics,
554 where_clause,
555 ref_name,
556 ref_impl_generics,
557 ref_ty_generics,
558 ..
559 } = &make_generics(&format_ident!("{}Ref", name), &input.generics);
560
561 let vars = &make_enum_frags(name, &input.vis, en);
562
563 let var_name: &Vec<_> = &vars.iter().map(|var| &var.var_name).collect();
564 let var_ref_ty = vars.iter().map(|var| &var.var_ref_ty);
565
566 let declare_ref = vars.iter().map(|var| &var.st.declare_ref);
567
568 let match_and_deser = vars.iter().map(
569 |VarFragments {
570 var_id,
571 var_name,
572 var_ref_ty,
573 ..
574 }| {
575 let var_id = *var_id as u32;
576 quote! {
577 #var_id => Self::RefType::#var_name(
578 #var_ref_ty::<'zeroio_deserialize>::deserialize_from(&_buf[ptr as usize..]))
579 }
580 },
581 );
582
583 let ref_impl = vars.iter().map(
584 |VarFragments {
585 var_ref_ty,
586 st:
587 StructFragments {
588 accessors,
589 from_ref,
590 ..
591 },
592 ..
593 }| {
594 quote! {
595 impl<'zeroio_deserialize> #var_ref_ty <'zeroio_deserialize> {
596 #accessors
597
598 fn deserialize_from(_buf: &'zeroio_deserialize [u32]) -> Self {
599 Self{buf: _buf, phantom: core::marker::PhantomData}
600 }
601
602 pub fn into_orig(&self) -> #name {
603 let _val = self;
604 #from_ref
605 }
606 }
607 }
608 },
609 );
610
611 let inhibit_warns = warning_inhibit();
612 let expanded = quote! {
613 #inhibit_warns
614 #vis enum #ref_name #ref_ty_generics {#(
615 #var_name(#var_ref_ty<'zeroio_deserialize>),
616 )*}
617
618 #(#declare_ref)*
620 #(#ref_impl)*
621
622 #inhibit_warns
624 impl #ref_impl_generics risc0_zeroio::Deserialize<'zeroio_deserialize> for #name #ty_generics #where_clause {
625 type RefType = #ref_name #ref_ty_generics;
626
627 const FIXED_WORDS : usize = 2;
628
629 fn deserialize_from(_buf: &'zeroio_deserialize [u32]) -> Self::RefType {
630 let (id, ptr) = (_buf[0], _buf[1]);
631 match id {
632 #(#match_and_deser,)*
633 _ => panic!("Unknown variant id {}", id)
634 }
635 }
636
637 fn from_ref(_val: &Self::RefType) -> Self {
638 match _val {#(
639 Self::RefType::#var_name(ref var) => var.into_orig()
640 ,)*}
641 }
642 }
643 };
644
645 Ok(expanded.into())
646}
647
648fn debug_dump(ty: &str, ident: &Ident, res: &mut TokenStream) {
652 if cfg!(feature = "debug-derive") {
653 let filename = format!("{ty}-{ident}.rs");
654 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
655 let path = std::path::Path::new(&manifest_dir)
656 .join("src")
657 .join(&filename);
658 std::fs::write(&path, format!("{}", res)).unwrap();
659
660 let pathname = path.display().to_string();
661 *res = quote! {
662 include!(#pathname);
663 };
664 }
665}
666
667#[proc_macro_derive(Serialize)]
668pub fn derive_serialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
669 let input = parse_macro_input!(input as DeriveInput);
670
671 let mut res = match &input.data {
672 syn::Data::Struct(ref st) => emit_serialize_struct(&input, st),
673 syn::Data::Enum(en) => emit_serialize_enum(&input, &en),
674 _ => Err(Error::new(
675 input.span().unwrap().into(),
676 "Zeroio derive only supports structs and enums",
677 )),
678 }
679 .unwrap_or_else(|err| Error::to_compile_error(&err).into());
680 debug_dump("ser", &input.ident, &mut res);
681 res.into()
682}
683
684#[proc_macro_derive(Deserialize)]
685pub fn derive_deserialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
686 let input = parse_macro_input!(input as DeriveInput);
687
688 let mut res = match &input.data {
689 syn::Data::Struct(st) => emit_deserialize_struct(&input, &st.fields),
690 syn::Data::Enum(en) => emit_deserialize_enum(&input, &en),
691 _ => Err(Error::new(
692 input.span().unwrap().into(),
693 "Zeroio derive only supports structs and enums",
694 )),
695 }
696 .unwrap_or_else(|err| Error::to_compile_error(&err).into());
697 debug_dump("deser", &input.ident, &mut res);
698 res.into()
699}