Skip to main content

rialo_s_spl_discriminator_syn/
lib.rs

1//! Token parsing and generating library for the `spl-discriminator` library
2
3#![deny(missing_docs)]
4#![cfg_attr(not(test), forbid(unsafe_code))]
5
6mod error;
7pub mod parser;
8
9use proc_macro2::{Span, TokenStream};
10use quote::{quote, ToTokens};
11use sha2::{Digest, Sha256};
12use syn::{parse::Parse, Generics, Ident, Item, ItemEnum, ItemStruct, LitByteStr, WhereClause};
13
14use crate::{error::SplDiscriminateError, parser::parse_hash_input};
15
16/// "Builder" struct to implement the `SplDiscriminate` trait
17/// on an enum or struct
18pub struct SplDiscriminateBuilder {
19    /// The struct/enum identifier
20    pub ident: Ident,
21    /// The item's generic arguments (if any)
22    pub generics: Generics,
23    /// The item's where clause for generics (if any)
24    pub where_clause: Option<WhereClause>,
25    /// The TLV `hash_input`
26    pub hash_input: String,
27}
28
29impl TryFrom<ItemEnum> for SplDiscriminateBuilder {
30    type Error = SplDiscriminateError;
31
32    fn try_from(item_enum: ItemEnum) -> Result<Self, Self::Error> {
33        let ident = item_enum.ident;
34        let where_clause = item_enum.generics.where_clause.clone();
35        let generics = item_enum.generics;
36        let hash_input = parse_hash_input(&item_enum.attrs)?;
37        Ok(Self {
38            ident,
39            generics,
40            where_clause,
41            hash_input,
42        })
43    }
44}
45
46impl TryFrom<ItemStruct> for SplDiscriminateBuilder {
47    type Error = SplDiscriminateError;
48
49    fn try_from(item_struct: ItemStruct) -> Result<Self, Self::Error> {
50        let ident = item_struct.ident;
51        let where_clause = item_struct.generics.where_clause.clone();
52        let generics = item_struct.generics;
53        let hash_input = parse_hash_input(&item_struct.attrs)?;
54        Ok(Self {
55            ident,
56            generics,
57            where_clause,
58            hash_input,
59        })
60    }
61}
62
63impl Parse for SplDiscriminateBuilder {
64    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
65        let item = Item::parse(input)?;
66        match item {
67            Item::Enum(item_enum) => item_enum.try_into(),
68            Item::Struct(item_struct) => item_struct.try_into(),
69            _ => {
70                return Err(syn::Error::new(
71                    Span::call_site(),
72                    "Only enums and structs are supported",
73                ))
74            }
75        }
76        .map_err(|e| syn::Error::new(input.span(), format!("Failed to parse item: {}", e)))
77    }
78}
79
80impl ToTokens for SplDiscriminateBuilder {
81    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
82        tokens.extend::<TokenStream>(self.into());
83    }
84}
85
86impl From<&SplDiscriminateBuilder> for TokenStream {
87    fn from(builder: &SplDiscriminateBuilder) -> Self {
88        let ident = &builder.ident;
89        let generics = &builder.generics;
90        let where_clause = &builder.where_clause;
91        let bytes = get_discriminator_bytes(&builder.hash_input);
92        quote! {
93            impl #generics rialo_s_spl_discriminator::discriminator::SplDiscriminate for #ident #generics #where_clause {
94                const SPL_DISCRIMINATOR: rialo_s_spl_discriminator::discriminator::ArrayDiscriminator
95                    = rialo_s_spl_discriminator::discriminator::ArrayDiscriminator::new(*#bytes);
96            }
97        }
98    }
99}
100
101/// Returns the bytes for the TLV `hash_input` discriminator
102fn get_discriminator_bytes(hash_input: &str) -> LitByteStr {
103    LitByteStr::new(
104        &Sha256::digest(hash_input.as_bytes())[..8],
105        Span::call_site(),
106    )
107}