leaf_protocol_macros/
lib.rs

1use std::{path::PathBuf, str::FromStr};
2
3use borsh::BorshSerialize;
4use iroh_base::hash::Hash;
5use leaf_protocol_types::*;
6use proc_macro::TokenStream;
7use quote::{format_ident, quote, quote_spanned, spanned::Spanned};
8use unsynn::{Parse, TokenTree};
9
10/// Helper macro to bail out of the macro with a compile error.
11macro_rules! throw {
12    ($hasSpan:expr, $err:literal) => {
13        let span = $hasSpan.__span();
14        return quote_spanned!(span =>
15            compile_error!($err);
16        ).into();
17    };
18}
19
20type KeyValueAttribute = unsynn::Cons<
21    unsynn::Ident,
22    Option<
23        unsynn::Cons<
24            unsynn::Assign,
25            unsynn::Either<unsynn::LiteralString, unsynn::PathSepDelimitedVec<unsynn::Ident>>,
26        >,
27    >,
28>;
29
30/// Derive macro for the `Component` trait.
31///
32/// ```ignore
33/// #[derive(BorshSerialize, BorshDeserialize, HasBorshSchema, Component)]
34/// #[component(
35///     specification = "examples/schemas/ExampleData",
36///     schema_id = "ehlbg4aesvav6x4wt4bcdocci323a5cb2jnhyhiizj3qmbaqkk4a"
37/// )]
38/// struct ExampleData {
39///     name: String,
40///     age: u8,
41///     tags: Vec<String>,
42/// }
43/// ```
44///
45/// The attribute options for the `#[component()]` attribute are:
46///
47/// - `name = "ComponentName"` - Allows you to set the component name in the schema.
48/// - `schema_id = "ehlbg4aesvav6x4wt4bcdocci323a5cb2jnhyhiizj3qmbaqkk4a"` - Lets you add an
49///   assertion that the resulting schema digest matches the expected value.
50/// - `specification = "path/to/schema"` - Lets you specify the path to a directory containing the
51///   specification components.
52///
53/// ## Specification
54///
55/// The specification is made up of a list of components. The specification directory must contain 1
56/// file for each component you want to add to the specification.
57///
58/// The name of each component must contain the base32 encoded schema ID of the component. It may
59/// optionally be prefixed with an identifier ending with an underscore ( `_` ) before the schema
60/// ID. It may optionally be suffixed with a `.` and a file extension that will be ignored.
61///
62/// The macro will ignore any file in the directory starting with a `.` or with `README`.
63///
64/// The contents of each component file must be in the Borsh format associated to the component's
65/// schema ID.
66#[proc_macro_derive(Component, attributes(component))]
67pub fn derive_component(input: TokenStream) -> TokenStream {
68    let input = venial::parse_item(input.into()).unwrap();
69
70    let mut attr_name: Option<String> = None;
71    let mut attr_schema_id: Option<String> = None;
72    let mut attr_no_check_schema_id = false;
73    let mut attr_no_compute_schema_id = false;
74    let mut attr_specification: Option<String> = None;
75
76    for attr in input.attributes() {
77        if attr.path.len() != 1 {
78            continue;
79        }
80        let TokenTree::Ident(name) = &attr.path[0] else {
81            continue;
82        };
83        if name != &format_ident!("component") {
84            continue;
85        }
86
87        let mut value =
88            unsynn::TokenStream::from_iter(attr.value.get_value_tokens().iter().cloned())
89                .into_iter();
90        let Ok(key_value_attributes) =
91            unsynn::CommaDelimitedVec::<KeyValueAttribute>::parse(&mut value)
92        else {
93            throw!(attr.value, "Cannot parse attribute");
94        };
95
96        let mut ids = Vec::new();
97        for key_value in key_value_attributes.0 {
98            let key_value = key_value.value;
99            let ident = key_value.first;
100            let eq_value = key_value.second;
101
102            ids.push(ident.clone());
103
104            if ident == format_ident!("name") {
105                if let Some(eq_value) = eq_value {
106                    if let unsynn::Either::First(n) = eq_value.second {
107                        attr_name = Some(n.as_str().into());
108                    } else {
109                        throw!(ident, "name should be a string.");
110                    }
111                } else {
112                    throw!(ident, "name requires a value");
113                }
114            } else if ident == format_ident!("specification") {
115                if let Some(eq_value) = eq_value {
116                    if let unsynn::Either::First(s) = eq_value.second {
117                        attr_specification = Some(s.as_str().into());
118                    } else {
119                        throw!(ident, "specification should be a string.");
120                    }
121                } else {
122                    throw!(ident, "specification needs a value.");
123                }
124            } else if ident == "schema_id" {
125                if let Some(eq_value) = eq_value {
126                    if let unsynn::Either::First(s) = eq_value.second {
127                        attr_schema_id = Some(s.as_str().into());
128                    } else {
129                        throw!(ident, "schema_id should be a string.");
130                    }
131                } else {
132                    throw!(ident, "schema_id needs a value.");
133                }
134            } else if ident == "no_check_schema_id" {
135                if eq_value.is_none() {
136                    attr_no_check_schema_id = true;
137                } else {
138                    throw!(ident, "no_check_schema_id takes no value");
139                }
140            } else if ident == "no_compute_schema_id" {
141                if eq_value.is_none() {
142                    attr_no_compute_schema_id = true;
143                } else {
144                    throw!(ident, "no_compute_schema_id takes no value");
145                }
146            } else {
147                throw!(ident, "unrecognized setting");
148            }
149        }
150    }
151
152    let name = input.name();
153    let component_name = if let Some(component_name) = attr_name {
154        component_name
155    } else {
156        name.clone().unwrap().to_string()
157    };
158
159    let mut spec_files = Vec::new();
160    if let Some(specification) = &attr_specification {
161        let cargo_workspace_dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap());
162        let specification_dir = cargo_workspace_dir.join(specification);
163        let specification_dir_read = std::fs::read_dir(specification_dir).unwrap();
164        for entry in specification_dir_read {
165            let entry = entry.unwrap();
166            let filename = entry.file_name().into_string().unwrap();
167            if entry.file_type().unwrap().is_file()
168                && !filename.starts_with('.')
169                && !filename.starts_with("README")
170            {
171                spec_files.push(entry.path());
172            }
173        }
174    }
175
176    let expected_schema_id = attr_schema_id.map(|x| Digest(Hash::from_str(&x).unwrap()));
177    let schema_id = if !attr_no_compute_schema_id {
178        let spec_hash: Digest = {
179            let components = spec_files
180                .into_iter()
181                .map(|path| {
182                    let schema_id_str = path.file_name().unwrap().to_str().unwrap();
183                    let schema_id_str = if let Some((_prefix, id)) = schema_id_str.rsplit_once('_')
184                    {
185                        if let Some((id, _suffix)) = id.split_once('.') {
186                            id
187                        } else {
188                            id
189                        }
190                    } else {
191                        schema_id_str
192                    };
193                    let schema_id = Digest(Hash::from_str(schema_id_str).unwrap());
194                    let mut buf = Vec::new();
195                    let data = std::fs::read(&path).unwrap();
196                    ComponentKind::Unencrypted(ComponentData {
197                        schema: schema_id,
198                        data,
199                    })
200                    .serialize(&mut buf)
201                    .unwrap();
202                    let component_id = Digest(Hash::from(iroh_blake3::hash(&buf)));
203
204                    ComponentEntry {
205                        schema_id: Some(schema_id),
206                        component_id,
207                    }
208                })
209                .collect::<Vec<_>>();
210            let mut entity = Entity { components };
211            entity.sort_components();
212            entity.compute_digest()
213        };
214
215        let mut schema_bytes = Vec::new();
216        (&component_name, spec_hash)
217            .serialize(&mut schema_bytes)
218            .unwrap();
219
220        Digest::new(&schema_bytes)
221    } else if let Some(expected) = expected_schema_id {
222        expected
223    } else {
224        throw!(
225            name,
226            "You must either provide a schema ID with a `no_compute_schema_id` flag,\
227            or add a `no_check_schema_id` and allow it to be computed"
228        );
229    };
230    let schema_id_bytes = *schema_id.0.as_bytes();
231
232    if !attr_no_check_schema_id && !attr_no_compute_schema_id {
233        let expected = expected_schema_id.unwrap();
234        if schema_id != expected {
235            panic!(
236                "Computed schema ID does not match expected:\
237                \ncomputed:{schema_id}\nexpected:{expected}"
238            )
239        }
240    }
241
242    quote! {
243        impl Component for #name {
244            fn schema_id() -> Digest {
245                Digest::from_bytes([#(#schema_id_bytes),*])
246            }
247        }
248    }
249    .into()
250}
251
252/// Derive macro fro the `HasBorshSchema` trait.
253///
254/// [`HasBorshSchema`] is required to implement [`Component`], and returns the borsh schema for the
255/// Rust type that can be used for the component specification.
256#[proc_macro_derive(HasBorshSchema)]
257pub fn derive_has_borsh_schema(input: TokenStream) -> TokenStream {
258    let input = venial::parse_item(input.into()).unwrap();
259
260    let Some(name) = input.name() else {
261        throw!(input, "Missing struct/enum name.");
262    };
263
264    fn fields_schema_expr(fields: &venial::Fields) -> proc_macro2::TokenStream {
265        match fields {
266            venial::Fields::Unit => {
267                quote! {
268                    BorshSchema::Null
269                }
270            }
271            venial::Fields::Tuple(fields) => {
272                if fields.fields.len() != 1 {
273                    throw!(
274                        fields,
275                        "Only tuples with one field may be used in BorshSchemas, \
276                        and the type of the field in the schema will \
277                        be that of the inner type in that case."
278                    );
279                }
280                let (field, _punct) = &fields.fields[0];
281                let ty = &field.ty;
282                quote! { <#ty>::borsh_schema() }
283            }
284            venial::Fields::Named(fields) => {
285                let mut field_exprs = Vec::new();
286                for field in fields.fields.items() {
287                    let name = &field.name;
288                    let ty = &field.ty;
289                    field_exprs.push(quote! {
290                       (stringify!(#name).to_string(), <#ty>::borsh_schema())
291                    });
292                }
293                quote! { BorshSchema::Struct { fields: vec![#(#field_exprs),*] } }
294            }
295        }
296    }
297
298    let schema_expr = match input {
299        venial::Item::Struct(s) => fields_schema_expr(&s.fields),
300        venial::Item::Enum(e) => {
301            let mut variant_exprs = Vec::new();
302            for variant in e.variants.items() {
303                let name = &variant.name;
304                let fields_schema = fields_schema_expr(&variant.fields);
305                variant_exprs.push(quote! { ( stringify!(#name).to_string(), #fields_schema) });
306            }
307            quote! { BorshSchema::Enum { variants: vec![#(#variant_exprs),*] } }
308        }
309        _ => {
310            throw!(
311                name,
312                "You may only derive HasBorshSchema on Structs, and Enums"
313            );
314        }
315    };
316
317    quote! {
318        impl HasBorshSchema for #name {
319            fn borsh_schema() -> BorshSchema {
320                #schema_expr
321            }
322        }
323    }
324    .into()
325}