extern crate proc_macro;
use itertools::izip;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput, GenericArgument, PathArguments, Type};
use crate::party_traits::{
find_fields_in_struct,
hasher::HASHER_STR,
is_type,
scribe::TRANSCRIPT_STR,
};
const SESSION_ID_STR: &str = "session_id";
const PARTY_STR: &str = "party";
fn generate_refresh_for_iterable(ty: &Type) -> proc_macro2::TokenStream {
match ty {
Type::Path(type_path) => {
if let Some(last_segment) = type_path.path.segments.last() {
if last_segment.ident == "Vec" {
if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
let inner_refresh = generate_refresh_for_iterable(inner_type);
return quote! {
.iter_mut().for_each(|elem| #inner_refresh )
};
}
}
}
}
}
Type::Tuple(tuple) => {
if !tuple.elems.is_empty() {
let tuple_size = tuple.elems.len();
let tuple_vars: Vec<_> = (0..tuple_size)
.map(|i| quote::format_ident!("elem_{}", i))
.collect();
let refresh_calls: Vec<_> = izip!(&tuple.elems, &tuple_vars)
.map(|(elem_ty, var)| {
let inner_refresh = generate_refresh_for_iterable(elem_ty);
quote! { #var #inner_refresh; }
})
.collect();
return quote! {
{
let (#(#tuple_vars),*) = elem;
#(#refresh_calls)*
}
};
}
}
Type::Array(arr) => {
let inner_refresh = generate_refresh_for_iterable(&arr.elem);
return quote! {
.iter_mut().for_each(|elem| { #inner_refresh })
};
}
_ => {}
}
quote! { .refresh() }
}
pub fn derive_party(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let struct_name = &input.ident;
let generics = &input.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let field_results = match find_fields_in_struct(
&input,
[
SESSION_ID_STR.to_string(),
PARTY_STR.to_string(),
TRANSCRIPT_STR.to_string(),
HASHER_STR.to_string(),
],
) {
Ok(results) => results,
Err(e) => return e,
};
let session_id_field = match &field_results[SESSION_ID_STR].as_slice() {
[] => None,
[(ident, ty)] if is_type(ty, "SessionId") => Some(ident.clone()),
[(_, _)] => return quote! { compile_error!("Found a field marked with #[session_id] or named `session_id` but not of type `SessionId`.") }.into(),
_ => return quote! { compile_error!("Found multiple fields marked with #[session_id] or named `session_id`. Only one is allowed.") }.into(),
};
let transcript_matches = &field_results[TRANSCRIPT_STR];
if transcript_matches.len() > 1 {
return quote! { compile_error!("Party derive macro found multiple fields marked with #[transcript] or named `transcript`. Only one is allowed.") }.into();
}
let transcript_field = transcript_matches.first().map(|(ident, _)| ident);
let hasher_matches = &field_results[HASHER_STR];
if hasher_matches.len() > 1 {
return quote! { compile_error!("Party derive macro found multiple fields marked with #[hasher] or named `hasher`. Only one is allowed.") }.into();
}
let hasher_field = hasher_matches.first().map(|(ident, _)| ident.clone());
let party_annotated_fields = &field_results[PARTY_STR];
if session_id_field.is_none() && party_annotated_fields.is_empty() {
return quote! {
compile_error!("Party derive macro requires a field of type `SessionId` or a field annotated with #[party].");
}.into();
}
let session_id_impl = match (session_id_field.as_ref(), party_annotated_fields.first()) {
(Some(session_id), _) => quote! { &self.#session_id },
(None, Some((party, _))) => quote! { self.#party.session_id() },
(None, None) => quote! { unreachable!() },
};
let protocol_name_impl = if session_id_field.is_some() {
quote! { PROTOCOL_INFO.name().to_string() }
} else {
let (first_party_field, _) = party_annotated_fields.first().unwrap();
let type_name = struct_name.to_string();
quote! { format!("{} - {}", #type_name, self.#first_party_field.protocol_name()) }
};
let refresh_session_id = if let Some(session_id) = session_id_field.as_ref() {
if let Some(transcript) = transcript_field {
quote! { self.#session_id.refresh_from(&mut self.#transcript); }
} else {
quote! {
let protocol_name = crate::types::party::Party::protocol_name(&self);
self.#session_id.refresh_with(&protocol_name);
}
}
} else {
quote! {}
};
let refresh_party_fields = party_annotated_fields.iter().map(|(ident, field_ty)| {
let refresh_code = generate_refresh_for_iterable(field_ty);
quote! {
self.#ident #refresh_code;
}
});
let refresh_hasher_fields = if let Some(hasher) = hasher_field.as_ref() {
quote! { primitives::hashing::TweakableHasher::refresh(&mut self.#hasher, #session_id_impl); }
} else {
quote! {}
};
quote! {
impl #impl_generics crate::types::party::Party for #struct_name #ty_generics #where_clause {
fn session_id(&self) -> &primitives::types::SessionId {
#session_id_impl
}
fn protocol_name(&self) -> String {
#protocol_name_impl
}
fn refresh(&mut self) {
#refresh_session_id
#( #refresh_party_fields )*
#refresh_hasher_fields
}
}
}
.into()
}