Skip to main content

linera_views_derive/
lib.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4//! The procedural macros for the crate `linera-views`.
5
6use proc_macro::TokenStream;
7use proc_macro2::{Span, TokenStream as TokenStream2};
8use quote::{format_ident, quote};
9use syn::{parse_macro_input, parse_quote, ItemStruct, Type};
10
11#[derive(Debug, deluxe::ParseAttributes)]
12#[deluxe(attributes(view))]
13struct StructAttrs {
14    context: Option<syn::Type>,
15}
16
17struct Constraints<'a> {
18    input_constraints: Vec<&'a syn::WherePredicate>,
19    impl_generics: syn::ImplGenerics<'a>,
20    type_generics: syn::TypeGenerics<'a>,
21}
22
23impl<'a> Constraints<'a> {
24    fn get(item: &'a syn::ItemStruct) -> Self {
25        let (impl_generics, type_generics, maybe_where_clause) = item.generics.split_for_impl();
26        let input_constraints = maybe_where_clause
27            .map(|w| w.predicates.iter())
28            .into_iter()
29            .flatten()
30            .collect();
31
32        Self {
33            input_constraints,
34            impl_generics,
35            type_generics,
36        }
37    }
38}
39
40fn get_extended_entry(e: Type) -> TokenStream2 {
41    let syn::Type::Path(typepath) = e else {
42        panic!("The type should be a path");
43    };
44    let path_segment = typepath.path.segments.into_iter().next().unwrap();
45    let ident = path_segment.ident;
46    let arguments = path_segment.arguments;
47    quote! { #ident :: #arguments }
48}
49
50fn generate_view_code(input: ItemStruct, root: bool) -> TokenStream2 {
51    let Constraints {
52        input_constraints,
53        impl_generics,
54        type_generics,
55    } = Constraints::get(&input);
56
57    let attrs: StructAttrs = deluxe::parse_attributes(&input).unwrap();
58    let context = attrs.context.unwrap_or_else(|| {
59        let ident = &input
60            .generics
61            .type_params()
62            .next()
63            .expect("no `context` given and no type parameters")
64            .ident;
65        parse_quote! { #ident }
66    });
67
68    let struct_name = &input.ident;
69    let field_types: Vec<_> = input.fields.iter().map(|field| &field.ty).collect();
70
71    let mut name_quotes = Vec::new();
72    let mut rollback_quotes = Vec::new();
73    let mut pre_save_quotes = Vec::new();
74    let mut delete_view_quotes = Vec::new();
75    let mut clear_quotes = Vec::new();
76    let mut has_pending_changes_quotes = Vec::new();
77    let mut num_init_keys_quotes = Vec::new();
78    let mut pre_load_keys_quotes = Vec::new();
79    let mut post_load_keys_quotes = Vec::new();
80    for (idx, e) in input.fields.iter().enumerate() {
81        let name = e.ident.clone().unwrap();
82        let delete_view_ident = format_ident!("deleted{}", idx);
83        let idx_lit = syn::LitInt::new(&idx.to_string(), Span::call_site());
84        let g = get_extended_entry(e.ty.clone());
85        name_quotes.push(quote! { #name });
86        rollback_quotes.push(quote! { self.#name.rollback(); });
87        pre_save_quotes.push(quote! { let #delete_view_ident = self.#name.pre_save(batch)?; });
88        delete_view_quotes.push(quote! { #delete_view_ident });
89        clear_quotes.push(quote! { self.#name.clear(); });
90        has_pending_changes_quotes.push(quote! {
91            if self.#name.has_pending_changes().await {
92                return true;
93            }
94        });
95        num_init_keys_quotes.push(quote! { #g :: NUM_INIT_KEYS });
96        pre_load_keys_quotes.push(quote! {
97            let index = #idx_lit;
98            let base_key = context.base_key().derive_tag_key(linera_views::views::MIN_VIEW_TAG, &index)?;
99            keys.extend(#g :: pre_load(&context.clone_with_base_key(base_key))?);
100        });
101        post_load_keys_quotes.push(quote! {
102            let index = #idx_lit;
103            let pos_next = pos + #g :: NUM_INIT_KEYS;
104            let base_key = context.base_key().derive_tag_key(linera_views::views::MIN_VIEW_TAG, &index)?;
105            let #name = #g :: post_load(context.clone_with_base_key(base_key), &values[pos..pos_next])?;
106            pos = pos_next;
107        });
108    }
109
110    let first_name_quote = name_quotes
111        .first()
112        .expect("list of names should be non-empty");
113
114    let load_metrics = if root && cfg!(feature = "metrics") {
115        quote! {
116            #[cfg(not(target_arch = "wasm32"))]
117            linera_views::metrics::increment_counter(
118                &linera_views::metrics::LOAD_VIEW_COUNTER,
119                stringify!(#struct_name),
120                &context.base_key().bytes,
121            );
122            #[cfg(not(target_arch = "wasm32"))]
123            use linera_views::metrics::prometheus_util::MeasureLatency as _;
124            let _latency = linera_views::metrics::LOAD_VIEW_LATENCY.measure_latency();
125        }
126    } else {
127        quote! {}
128    };
129
130    quote! {
131        impl #impl_generics linera_views::views::View for #struct_name #type_generics
132        where
133            #context: linera_views::context::Context,
134            #(#input_constraints,)*
135            #(#field_types: linera_views::views::View<Context = #context>,)*
136        {
137            const NUM_INIT_KEYS: usize = #(<#field_types as linera_views::views::View>::NUM_INIT_KEYS)+*;
138
139            type Context = #context;
140
141            fn context(&self) -> &#context {
142                use linera_views::views::View;
143                self.#first_name_quote.context()
144            }
145
146            fn pre_load(context: &#context) -> Result<Vec<Vec<u8>>, linera_views::ViewError> {
147                use linera_views::context::Context as _;
148                let mut keys = Vec::new();
149                #(#pre_load_keys_quotes)*
150                Ok(keys)
151            }
152
153            fn post_load(context: #context, values: &[Option<Vec<u8>>]) -> Result<Self, linera_views::ViewError> {
154                use linera_views::context::Context as _;
155                let mut pos = 0;
156                #(#post_load_keys_quotes)*
157                Ok(Self {#(#name_quotes),*})
158            }
159
160            async fn load(context: #context) -> Result<Self, linera_views::ViewError> {
161                use linera_views::{context::Context as _, store::ReadableKeyValueStore as _};
162                #load_metrics
163                if Self::NUM_INIT_KEYS == 0 {
164                    Self::post_load(context, &[])
165                } else {
166                    let keys = Self::pre_load(&context)?;
167                    let values = context.store().read_multi_values_bytes(&keys).await?;
168                    Self::post_load(context, &values)
169                }
170            }
171
172
173            fn rollback(&mut self) {
174                #(#rollback_quotes)*
175            }
176
177            async fn has_pending_changes(&self) -> bool {
178                #(#has_pending_changes_quotes)*
179                false
180            }
181
182            fn pre_save(&self, batch: &mut linera_views::batch::Batch) -> Result<bool, linera_views::ViewError> {
183                #(#pre_save_quotes)*
184                Ok( #(#delete_view_quotes)&&* )
185            }
186
187            fn post_save(&mut self) {
188                #(self.#name_quotes.post_save();)*
189            }
190
191            fn clear(&mut self) {
192                #(#clear_quotes)*
193            }
194        }
195    }
196}
197
198fn generate_root_view_code(input: ItemStruct) -> TokenStream2 {
199    let Constraints {
200        input_constraints,
201        impl_generics,
202        type_generics,
203    } = Constraints::get(&input);
204    let struct_name = &input.ident;
205
206    let metrics_code = if cfg!(feature = "metrics") {
207        quote! {
208            #[cfg(not(target_arch = "wasm32"))]
209            linera_views::metrics::increment_counter(
210                &linera_views::metrics::SAVE_VIEW_COUNTER,
211                stringify!(#struct_name),
212                &self.context().base_key().bytes,
213            );
214        }
215    } else {
216        quote! {}
217    };
218
219    let write_batch_with_metrics = if cfg!(feature = "metrics") {
220        quote! {
221            if !batch.is_empty() {
222                #[cfg(not(target_arch = "wasm32"))]
223                let start = std::time::Instant::now();
224                self.context().store().write_batch(batch).await?;
225                #[cfg(not(target_arch = "wasm32"))]
226                {
227                    let latency_ms = start.elapsed().as_secs_f64() * 1000.0;
228                    linera_views::metrics::SAVE_VIEW_LATENCY
229                        .with_label_values(&[stringify!(#struct_name)])
230                        .observe(latency_ms);
231                }
232            }
233        }
234    } else {
235        quote! {
236            if !batch.is_empty() {
237                self.context().store().write_batch(batch).await?;
238            }
239        }
240    };
241
242    quote! {
243        impl #impl_generics linera_views::views::RootView for #struct_name #type_generics
244        where
245            #(#input_constraints,)*
246            Self: linera_views::views::View,
247        {
248            async fn save(&mut self) -> Result<(), linera_views::ViewError> {
249                use linera_views::{context::Context as _, batch::Batch, store::WritableKeyValueStore as _, views::View as _};
250                #metrics_code
251                let mut batch = Batch::new();
252                self.pre_save(&mut batch)?;
253                #write_batch_with_metrics
254                self.post_save();
255                Ok(())
256            }
257        }
258    }
259}
260
261fn generate_hash_view_code(input: ItemStruct) -> TokenStream2 {
262    let Constraints {
263        input_constraints,
264        impl_generics,
265        type_generics,
266    } = Constraints::get(&input);
267    let struct_name = &input.ident;
268
269    let field_types = input.fields.iter().map(|field| &field.ty);
270    let mut field_hashes_mut = Vec::new();
271    let mut field_hashes = Vec::new();
272    for e in &input.fields {
273        let name = e.ident.as_ref().unwrap();
274        field_hashes_mut.push(quote! { hasher.write_all(self.#name.hash_mut().await?.as_ref())?; });
275        field_hashes.push(quote! { hasher.write_all(self.#name.hash().await?.as_ref())?; });
276    }
277
278    quote! {
279        impl #impl_generics linera_views::views::HashableView for #struct_name #type_generics
280        where
281            #(#field_types: linera_views::views::HashableView,)*
282            #(#input_constraints,)*
283            Self: linera_views::views::View,
284        {
285            type Hasher = linera_views::sha3::Sha3_256;
286
287            async fn hash_mut(&mut self) -> Result<<Self::Hasher as linera_views::views::Hasher>::Output, linera_views::ViewError> {
288                use linera_views::views::{Hasher, HashableView};
289                use std::io::Write;
290                let mut hasher = Self::Hasher::default();
291                #(#field_hashes_mut)*
292                Ok(hasher.finalize())
293            }
294
295            async fn hash(&self) -> Result<<Self::Hasher as linera_views::views::Hasher>::Output, linera_views::ViewError> {
296                use linera_views::views::{Hasher, HashableView};
297                use std::io::Write;
298                let mut hasher = Self::Hasher::default();
299                #(#field_hashes)*
300                Ok(hasher.finalize())
301            }
302        }
303    }
304}
305
306fn generate_crypto_hash_code(input: ItemStruct) -> TokenStream2 {
307    let Constraints {
308        input_constraints,
309        impl_generics,
310        type_generics,
311    } = Constraints::get(&input);
312    let field_types = input.fields.iter().map(|field| &field.ty);
313    let struct_name = &input.ident;
314    let hash_type = syn::Ident::new(&format!("{struct_name}Hash"), Span::call_site());
315    quote! {
316        impl #impl_generics linera_views::views::CryptoHashView
317        for #struct_name #type_generics
318        where
319            #(#field_types: linera_views::views::HashableView,)*
320            #(#input_constraints,)*
321            Self: linera_views::views::View,
322        {
323            async fn crypto_hash(&self) -> Result<linera_base::crypto::CryptoHash, linera_views::ViewError> {
324                use linera_base::crypto::{BcsHashable, CryptoHash};
325                use linera_views::{
326                    batch::Batch,
327                    generic_array::GenericArray,
328                    sha3::{digest::OutputSizeUser, Sha3_256},
329                    views::HashableView,
330                };
331                use serde::{Serialize, Deserialize};
332                #[derive(Serialize, Deserialize)]
333                struct #hash_type(GenericArray<u8, <Sha3_256 as OutputSizeUser>::OutputSize>);
334                impl<'de> BcsHashable<'de> for #hash_type {}
335                let hash = self.hash().await?;
336                Ok(CryptoHash::new(&#hash_type(hash)))
337            }
338
339            async fn crypto_hash_mut(&mut self) -> Result<linera_base::crypto::CryptoHash, linera_views::ViewError> {
340                use linera_base::crypto::{BcsHashable, CryptoHash};
341                use linera_views::{
342                    batch::Batch,
343                    generic_array::GenericArray,
344                    sha3::{digest::OutputSizeUser, Sha3_256},
345                    views::HashableView,
346                };
347                use serde::{Serialize, Deserialize};
348                #[derive(Serialize, Deserialize)]
349                struct #hash_type(GenericArray<u8, <Sha3_256 as OutputSizeUser>::OutputSize>);
350                impl<'de> BcsHashable<'de> for #hash_type {}
351                let hash = self.hash_mut().await?;
352                Ok(CryptoHash::new(&#hash_type(hash)))
353            }
354        }
355    }
356}
357
358fn generate_clonable_view_code(input: ItemStruct) -> TokenStream2 {
359    let Constraints {
360        input_constraints,
361        impl_generics,
362        type_generics,
363    } = Constraints::get(&input);
364    let struct_name = &input.ident;
365
366    let mut clone_constraints = vec![];
367    let mut clone_fields = vec![];
368
369    for field in &input.fields {
370        let name = &field.ident;
371        let ty = &field.ty;
372        clone_constraints.push(quote! { #ty: ClonableView });
373        clone_fields.push(quote! { #name: self.#name.clone_unchecked()? });
374    }
375
376    quote! {
377        impl #impl_generics linera_views::views::ClonableView for #struct_name #type_generics
378        where
379            #(#input_constraints,)*
380            #(#clone_constraints,)*
381            Self: linera_views::views::View,
382        {
383            fn clone_unchecked(&mut self) -> Result<Self, linera_views::ViewError> {
384                Ok(Self {
385                    #(#clone_fields,)*
386                })
387            }
388        }
389    }
390}
391
392#[proc_macro_derive(View, attributes(view))]
393pub fn derive_view(input: TokenStream) -> TokenStream {
394    let input = parse_macro_input!(input as ItemStruct);
395    generate_view_code(input, false).into()
396}
397
398#[proc_macro_derive(HashableView, attributes(view))]
399pub fn derive_hash_view(input: TokenStream) -> TokenStream {
400    let input = parse_macro_input!(input as ItemStruct);
401    let mut stream = generate_view_code(input.clone(), false);
402    stream.extend(generate_hash_view_code(input));
403    stream.into()
404}
405
406#[proc_macro_derive(RootView, attributes(view))]
407pub fn derive_root_view(input: TokenStream) -> TokenStream {
408    let input = parse_macro_input!(input as ItemStruct);
409    let mut stream = generate_view_code(input.clone(), true);
410    stream.extend(generate_root_view_code(input));
411    stream.into()
412}
413
414#[proc_macro_derive(CryptoHashView, attributes(view))]
415pub fn derive_crypto_hash_view(input: TokenStream) -> TokenStream {
416    let input = parse_macro_input!(input as ItemStruct);
417    let mut stream = generate_view_code(input.clone(), false);
418    stream.extend(generate_hash_view_code(input.clone()));
419    stream.extend(generate_crypto_hash_code(input));
420    stream.into()
421}
422
423#[proc_macro_derive(CryptoHashRootView, attributes(view))]
424pub fn derive_crypto_hash_root_view(input: TokenStream) -> TokenStream {
425    let input = parse_macro_input!(input as ItemStruct);
426    let mut stream = generate_view_code(input.clone(), true);
427    stream.extend(generate_root_view_code(input.clone()));
428    stream.extend(generate_hash_view_code(input.clone()));
429    stream.extend(generate_crypto_hash_code(input));
430    stream.into()
431}
432
433#[proc_macro_derive(HashableRootView, attributes(view))]
434#[cfg(test)]
435pub fn derive_hashable_root_view(input: TokenStream) -> TokenStream {
436    let input = parse_macro_input!(input as ItemStruct);
437    let mut stream = generate_view_code(input.clone(), true);
438    stream.extend(generate_root_view_code(input.clone()));
439    stream.extend(generate_hash_view_code(input));
440    stream.into()
441}
442
443#[proc_macro_derive(ClonableView, attributes(view))]
444pub fn derive_clonable_view(input: TokenStream) -> TokenStream {
445    let input = parse_macro_input!(input as ItemStruct);
446    generate_clonable_view_code(input).into()
447}
448
449#[cfg(test)]
450pub mod tests {
451
452    use quote::quote;
453    use syn::{parse_quote, AngleBracketedGenericArguments};
454
455    use crate::*;
456
457    fn pretty(tokens: TokenStream2) -> String {
458        prettyplease::unparse(
459            &syn::parse2::<syn::File>(tokens).expect("failed to parse test output"),
460        )
461    }
462
463    #[test]
464    fn test_generate_view_code() {
465        for context in SpecificContextInfo::test_cases() {
466            let input = context.test_view_input();
467            insta::assert_snapshot!(
468                format!(
469                    "test_generate_view_code{}_{}",
470                    if cfg!(feature = "metrics") {
471                        "_metrics"
472                    } else {
473                        ""
474                    },
475                    context.name,
476                ),
477                pretty(generate_view_code(input, true))
478            );
479        }
480    }
481
482    #[test]
483    fn test_generate_hash_view_code() {
484        for context in SpecificContextInfo::test_cases() {
485            let input = context.test_view_input();
486            insta::assert_snapshot!(
487                format!("test_generate_hash_view_code_{}", context.name),
488                pretty(generate_hash_view_code(input))
489            );
490        }
491    }
492
493    #[test]
494    fn test_generate_root_view_code() {
495        for context in SpecificContextInfo::test_cases() {
496            let input = context.test_view_input();
497            insta::assert_snapshot!(
498                format!(
499                    "test_generate_root_view_code{}_{}",
500                    if cfg!(feature = "metrics") {
501                        "_metrics"
502                    } else {
503                        ""
504                    },
505                    context.name,
506                ),
507                pretty(generate_root_view_code(input))
508            );
509        }
510    }
511
512    #[test]
513    fn test_generate_crypto_hash_code() {
514        for context in SpecificContextInfo::test_cases() {
515            let input = context.test_view_input();
516            insta::assert_snapshot!(pretty(generate_crypto_hash_code(input)));
517        }
518    }
519
520    #[test]
521    fn test_generate_clonable_view_code() {
522        for context in SpecificContextInfo::test_cases() {
523            let input = context.test_view_input();
524            insta::assert_snapshot!(pretty(generate_clonable_view_code(input)));
525        }
526    }
527
528    #[derive(Clone)]
529    pub struct SpecificContextInfo {
530        name: String,
531        attribute: Option<TokenStream2>,
532        context: Type,
533        generics: AngleBracketedGenericArguments,
534        where_clause: Option<TokenStream2>,
535    }
536
537    impl SpecificContextInfo {
538        pub fn empty() -> Self {
539            SpecificContextInfo {
540                name: "C".to_string(),
541                attribute: None,
542                context: syn::parse_quote! { C },
543                generics: syn::parse_quote! { <C> },
544                where_clause: None,
545            }
546        }
547
548        pub fn new(context: syn::Type) -> Self {
549            let name = quote! { #context };
550            SpecificContextInfo {
551                name: format!("{name}")
552                    .replace(' ', "")
553                    .replace([':', '<', '>'], "_"),
554                attribute: Some(quote! { #[view(context = #context)] }),
555                context,
556                generics: parse_quote! { <> },
557                where_clause: None,
558            }
559        }
560
561        /// Sets the `where_clause` to a dummy value for test cases with a where clause.
562        ///
563        /// Also adds a `MyParam` generic type parameter to the `generics` field, which is the type
564        /// constrained by the dummy predicate in the `where_clause`.
565        pub fn with_dummy_where_clause(mut self) -> Self {
566            self.generics.args.push(parse_quote! { MyParam });
567            self.where_clause = Some(quote! {
568                where MyParam: Send + Sync + 'static,
569            });
570            self.name.push_str("_with_where");
571
572            self
573        }
574
575        pub fn test_cases() -> impl Iterator<Item = Self> {
576            Some(Self::empty())
577                .into_iter()
578                .chain(
579                    [
580                        syn::parse_quote! { CustomContext },
581                        syn::parse_quote! { custom::path::to::ContextType },
582                        syn::parse_quote! { custom::GenericContext<T> },
583                    ]
584                    .into_iter()
585                    .map(Self::new),
586                )
587                .flat_map(|case| [case.clone(), case.with_dummy_where_clause()])
588        }
589
590        pub fn test_view_input(&self) -> ItemStruct {
591            let SpecificContextInfo {
592                attribute,
593                context,
594                generics,
595                where_clause,
596                ..
597            } = self;
598
599            parse_quote! {
600                #attribute
601                struct TestView #generics
602                #where_clause
603                {
604                    register: RegisterView<#context, usize>,
605                    collection: CollectionView<#context, usize, RegisterView<#context, usize>>,
606                }
607            }
608        }
609    }
610}