aptos_crypto_derive_link/
lib.rs1#![forbid(unsafe_code)]
5
6#![forbid(unsafe_code)]
99
100extern crate proc_macro;
101
102mod hasher;
103mod unions;
104
105use hasher::camel_to_snake;
106use proc_macro::TokenStream;
107use proc_macro2::Span;
108use quote::quote;
109use std::iter::FromIterator;
110use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Ident};
111use unions::*;
112
113#[proc_macro_derive(SilentDisplay)]
114pub fn silent_display(source: TokenStream) -> TokenStream {
115 let ast: DeriveInput = syn::parse(source).expect("Incorrect macro input");
116 let name = &ast.ident;
117 let gen = quote! {
118 impl ::std::fmt::Display for #name {
120 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
121 write!(f, "<elided secret for {}>", stringify!(#name))
122 }
123 }
124 };
125 gen.into()
126}
127
128#[proc_macro_derive(SilentDebug)]
129pub fn silent_debug(source: TokenStream) -> TokenStream {
130 let ast: DeriveInput = syn::parse(source).expect("Incorrect macro input");
131 let name = &ast.ident;
132 let gen = quote! {
133 impl ::std::fmt::Debug for #name {
135 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
136 write!(f, "<elided secret for {}>", stringify!(#name))
137 }
138 }
139 };
140 gen.into()
141}
142
143#[proc_macro_derive(DeserializeKey)]
145pub fn deserialize_key(source: TokenStream) -> TokenStream {
146 let ast: DeriveInput = syn::parse(source).expect("Incorrect macro input");
147 let name = &ast.ident;
148 let name_string = name.to_string();
149 let gen = quote! {
150 impl<'de> ::serde::Deserialize<'de> for #name {
151 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
152 where
153 D: ::serde::Deserializer<'de>,
154 {
155 if deserializer.is_human_readable() {
156 let encoded_key = <String>::deserialize(deserializer)?;
157 ValidCryptoMaterialStringExt::from_encoded_string(encoded_key.as_str())
158 .map_err(<D::Error as ::serde::de::Error>::custom)
159 } else {
160 #[derive(::serde::Deserialize)]
164 #[serde(rename = #name_string)]
165 struct Value<'a>(&'a [u8]);
166
167 let value = Value::deserialize(deserializer)?;
168 #name::try_from(value.0).map_err(|s| {
169 <D::Error as ::serde::de::Error>::custom(format!("{} with {}", s, #name_string))
170 })
171 }
172 }
173 }
174 };
175 gen.into()
176}
177
178#[proc_macro_derive(SerializeKey)]
180pub fn serialize_key(source: TokenStream) -> TokenStream {
181 let ast: DeriveInput = syn::parse(source).expect("Incorrect macro input");
182 let name = &ast.ident;
183 let name_string = name.to_string();
184 let gen = quote! {
185 impl ::serde::Serialize for #name {
186 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
187 where
188 S: ::serde::Serializer,
189 {
190 if serializer.is_human_readable() {
191 self.to_encoded_string()
192 .map_err(<S::Error as ::serde::ser::Error>::custom)
193 .and_then(|str| serializer.serialize_str(&str[..]))
194 } else {
195 serializer.serialize_newtype_struct(
197 #name_string,
198 serde_bytes::Bytes::new(&ValidCryptoMaterial::to_bytes(self).as_slice()),
199 )
200 }
201 }
202 }
203 };
204 gen.into()
205}
206
207#[proc_macro_derive(Deref)]
208pub fn derive_deref(input: TokenStream) -> TokenStream {
209 let item = syn::parse(input).expect("Incorrect macro input");
210 let (field_ty, field_access) = parse_newtype_fields(&item);
211
212 let name = &item.ident;
213 let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();
214
215 quote!(
216 impl #impl_generics ::std::ops::Deref for #name #ty_generics
217 #where_clause
218 {
219 type Target = #field_ty;
220
221 fn deref(&self) -> &Self::Target {
222 #field_access
223 }
224 }
225 )
226 .into()
227}
228
229#[proc_macro_derive(ValidCryptoMaterial)]
230pub fn derive_enum_valid_crypto_material(input: TokenStream) -> TokenStream {
231 let ast = parse_macro_input!(input as DeriveInput);
232
233 let name = &ast.ident;
234 match ast.data {
235 Data::Enum(ref variants) => impl_enum_valid_crypto_material(name, variants),
236 Data::Struct(_) | Data::Union(_) => {
237 panic!("#[derive(ValidCryptoMaterial)] is only defined for enums")
238 }
239 }
240}
241
242#[proc_macro_derive(PublicKey, attributes(PrivateKeyType))]
243pub fn derive_enum_publickey(input: TokenStream) -> TokenStream {
244 let ast = parse_macro_input!(input as DeriveInput);
245
246 let name = &ast.ident;
247 let private_key_type = get_type_from_attrs(&ast.attrs, "PrivateKeyType").unwrap();
248 match ast.data {
249 Data::Enum(ref variants) => impl_enum_publickey(name, private_key_type, variants),
250 Data::Struct(_) | Data::Union(_) => {
251 panic!("#[derive(PublicKey)] is only defined for enums")
252 }
253 }
254}
255
256#[proc_macro_derive(PrivateKey, attributes(PublicKeyType))]
257pub fn derive_enum_privatekey(input: TokenStream) -> TokenStream {
258 let ast = parse_macro_input!(input as DeriveInput);
259
260 let name = &ast.ident;
261 let public_key_type = get_type_from_attrs(&ast.attrs, "PublicKeyType").unwrap();
262 match ast.data {
263 Data::Enum(ref variants) => impl_enum_privatekey(name, public_key_type, variants),
264 Data::Struct(_) | Data::Union(_) => {
265 panic!("#[derive(PrivateKey)] is only defined for enums")
266 }
267 }
268}
269
270#[proc_macro_derive(VerifyingKey, attributes(PrivateKeyType, SignatureType))]
271pub fn derive_enum_verifyingkey(input: TokenStream) -> TokenStream {
272 let ast = parse_macro_input!(input as DeriveInput);
273
274 let name = &ast.ident;
275 let private_key_type = get_type_from_attrs(&ast.attrs, "PrivateKeyType").unwrap();
276 let signature_type = get_type_from_attrs(&ast.attrs, "SignatureType").unwrap();
277 match ast.data {
278 Data::Enum(ref variants) => {
279 impl_enum_verifyingkey(name, private_key_type, signature_type, variants)
280 }
281 Data::Struct(_) | Data::Union(_) => {
282 panic!("#[derive(PrivateKey)] is only defined for enums")
283 }
284 }
285}
286
287#[proc_macro_derive(SigningKey, attributes(PublicKeyType, SignatureType))]
288pub fn derive_enum_signingkey(input: TokenStream) -> TokenStream {
289 let ast = parse_macro_input!(input as DeriveInput);
290
291 let name = &ast.ident;
292 let public_key_type = get_type_from_attrs(&ast.attrs, "PublicKeyType").unwrap();
293 let signature_type = get_type_from_attrs(&ast.attrs, "SignatureType").unwrap();
294 match ast.data {
295 Data::Enum(ref variants) => {
296 impl_enum_signingkey(name, public_key_type, signature_type, variants)
297 }
298 Data::Struct(_) | Data::Union(_) => {
299 panic!("#[derive(PrivateKey)] is only defined for enums")
300 }
301 }
302}
303
304#[proc_macro_derive(Signature, attributes(PublicKeyType, PrivateKeyType))]
305pub fn derive_enum_signature(input: TokenStream) -> TokenStream {
306 let ast = parse_macro_input!(input as DeriveInput);
307
308 let name = &ast.ident;
309 let public_key_type = get_type_from_attrs(&ast.attrs, "PublicKeyType").unwrap();
310 let private_key_type = get_type_from_attrs(&ast.attrs, "PrivateKeyType").unwrap();
311 match ast.data {
312 Data::Enum(ref variants) => {
313 impl_enum_signature(name, public_key_type, private_key_type, variants)
314 }
315 Data::Struct(_) | Data::Union(_) => {
316 panic!("#[derive(PrivateKey)] is only defined for enums")
317 }
318 }
319}
320
321#[proc_macro_derive(CryptoHasher)]
325pub fn hasher_dispatch(input: TokenStream) -> TokenStream {
326 let item = parse_macro_input!(input as DeriveInput);
327 let hasher_name = Ident::new(
328 &format!("{}Hasher", &item.ident.to_string()),
329 Span::call_site(),
330 );
331 let snake_name = camel_to_snake(&item.ident.to_string());
332 let static_seed_name = Ident::new(
333 &format!("{}_SEED", snake_name.to_uppercase()),
334 Span::call_site(),
335 );
336
337 let static_hasher_name = Ident::new(
338 &format!("{}_HASHER", snake_name.to_uppercase()),
339 Span::call_site(),
340 );
341 let type_name = &item.ident;
342 let param = if item.generics.params.is_empty() {
343 quote!()
344 } else {
345 let args = proc_macro2::TokenStream::from_iter(
346 std::iter::repeat(quote!(())).take(item.generics.params.len()),
347 );
348 quote!(<#args>)
349 };
350
351 let out = quote!(
352 #[derive(Clone)]
354 pub struct #hasher_name(aptos_crypto::hash::DefaultHasher);
355
356 static #static_seed_name: aptos_crypto::_once_cell::sync::OnceCell<[u8; 32]> = aptos_crypto::_once_cell::sync::OnceCell::new();
357
358 impl #hasher_name {
359 fn new() -> Self {
360 let name = aptos_crypto::_serde_name::trace_name::<#type_name #param>()
361 .expect("The `CryptoHasher` macro only applies to structs and enums");
362 #hasher_name(
363 aptos_crypto::hash::DefaultHasher::new(&name.as_bytes()))
364 }
365 }
366
367 static #static_hasher_name: aptos_crypto::_once_cell::sync::Lazy<#hasher_name> =
368 aptos_crypto::_once_cell::sync::Lazy::new(|| #hasher_name::new());
369
370
371 impl std::default::Default for #hasher_name
372 {
373 fn default() -> Self {
374 #static_hasher_name.clone()
375 }
376 }
377
378 impl aptos_crypto::hash::CryptoHasher for #hasher_name {
379 fn seed() -> &'static [u8; 32] {
380 #static_seed_name.get_or_init(|| {
381 let name = aptos_crypto::_serde_name::trace_name::<#type_name #param>()
382 .expect("The `CryptoHasher` macro only applies to structs and enums.").as_bytes();
383 aptos_crypto::hash::DefaultHasher::prefixed_hash(&name)
384 })
385 }
386
387 fn update(&mut self, bytes: &[u8]) {
388 self.0.update(bytes);
389 }
390
391 fn finish(self) -> aptos_crypto::hash::HashValue {
392 self.0.finish()
393 }
394 }
395
396 impl std::io::Write for #hasher_name {
397 fn write(&mut self, bytes: &[u8]) -> std::io::Result<usize> {
398 use aptos_crypto::hash::CryptoHasher;
399
400 self.0.update(bytes);
401 Ok(bytes.len())
402 }
403 fn flush(&mut self) -> std::io::Result<()> {
404 Ok(())
405 }
406 }
407
408 );
409 out.into()
410}
411
412#[proc_macro_derive(BCSCryptoHash)]
413pub fn bcs_crypto_hash_dispatch(input: TokenStream) -> TokenStream {
414 let ast = parse_macro_input!(input as DeriveInput);
415 let name = &ast.ident;
416 let hasher_name = Ident::new(&format!("{}Hasher", &name.to_string()), Span::call_site());
417 let error_msg = syn::LitStr::new(
418 &format!("BCS serialization of {} should not fail", name),
419 Span::call_site(),
420 );
421 let generics = add_trait_bounds(ast.generics);
422 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
423 let out = quote!(
424 impl #impl_generics aptos_crypto::hash::CryptoHash for #name #ty_generics #where_clause {
425 type Hasher = #hasher_name;
426
427 fn hash(&self) -> aptos_crypto::hash::HashValue {
428 use aptos_crypto::hash::CryptoHasher;
429
430 let mut state = Self::Hasher::default();
431 bcs::serialize_into(&mut state, &self).expect(#error_msg);
432 state.finish()
433 }
434 }
435 );
436 out.into()
437}
438
439fn add_trait_bounds(mut generics: syn::Generics) -> syn::Generics {
440 for param in generics.params.iter_mut() {
441 if let syn::GenericParam::Type(type_param) = param {
442 type_param.bounds.push(parse_quote!(Serialize));
443 }
444 }
445 generics
446}