1use proc_macro::TokenStream;
7use proc_macro2::{Span, TokenStream as TokenStream2};
8use quote::{format_ident, quote};
9use syn::{parse_macro_input, parse_quote, ItemStruct, Type};
10
11#[derive(Debug, deluxe::ParseAttributes)]
12#[deluxe(attributes(view))]
13struct StructAttrs {
14 context: Option<syn::Type>,
15}
16
17struct Constraints<'a> {
18 input_constraints: Vec<&'a syn::WherePredicate>,
19 impl_generics: syn::ImplGenerics<'a>,
20 type_generics: syn::TypeGenerics<'a>,
21}
22
23impl<'a> Constraints<'a> {
24 fn get(item: &'a syn::ItemStruct) -> Self {
25 let (impl_generics, type_generics, maybe_where_clause) = item.generics.split_for_impl();
26 let input_constraints = maybe_where_clause
27 .map(|w| w.predicates.iter())
28 .into_iter()
29 .flatten()
30 .collect();
31
32 Self {
33 input_constraints,
34 impl_generics,
35 type_generics,
36 }
37 }
38}
39
40fn get_extended_entry(e: Type) -> TokenStream2 {
41 let syn::Type::Path(typepath) = e else {
42 panic!("The type should be a path");
43 };
44 let path_segment = typepath.path.segments.into_iter().next().unwrap();
45 let ident = path_segment.ident;
46 let arguments = path_segment.arguments;
47 quote! { #ident :: #arguments }
48}
49
50fn generate_view_code(input: ItemStruct, root: bool) -> TokenStream2 {
51 let Constraints {
52 input_constraints,
53 impl_generics,
54 type_generics,
55 } = Constraints::get(&input);
56
57 let attrs: StructAttrs = deluxe::parse_attributes(&input).unwrap();
58 let context = attrs.context.unwrap_or_else(|| {
59 let ident = &input
60 .generics
61 .type_params()
62 .next()
63 .expect("no `context` given and no type parameters")
64 .ident;
65 parse_quote! { #ident }
66 });
67
68 let struct_name = &input.ident;
69 let field_types: Vec<_> = input.fields.iter().map(|field| &field.ty).collect();
70
71 let mut name_quotes = Vec::new();
72 let mut rollback_quotes = Vec::new();
73 let mut pre_save_quotes = Vec::new();
74 let mut delete_view_quotes = Vec::new();
75 let mut clear_quotes = Vec::new();
76 let mut has_pending_changes_quotes = Vec::new();
77 let mut num_init_keys_quotes = Vec::new();
78 let mut pre_load_keys_quotes = Vec::new();
79 let mut post_load_keys_quotes = Vec::new();
80 for (idx, e) in input.fields.iter().enumerate() {
81 let name = e.ident.clone().unwrap();
82 let delete_view_ident = format_ident!("deleted{}", idx);
83 let idx_lit = syn::LitInt::new(&idx.to_string(), Span::call_site());
84 let g = get_extended_entry(e.ty.clone());
85 name_quotes.push(quote! { #name });
86 rollback_quotes.push(quote! { self.#name.rollback(); });
87 pre_save_quotes.push(quote! { let #delete_view_ident = self.#name.pre_save(batch)?; });
88 delete_view_quotes.push(quote! { #delete_view_ident });
89 clear_quotes.push(quote! { self.#name.clear(); });
90 has_pending_changes_quotes.push(quote! {
91 if self.#name.has_pending_changes().await {
92 return true;
93 }
94 });
95 num_init_keys_quotes.push(quote! { #g :: NUM_INIT_KEYS });
96 pre_load_keys_quotes.push(quote! {
97 let index = #idx_lit;
98 let base_key = context.base_key().derive_tag_key(linera_views::views::MIN_VIEW_TAG, &index)?;
99 keys.extend(#g :: pre_load(&context.clone_with_base_key(base_key))?);
100 });
101 post_load_keys_quotes.push(quote! {
102 let index = #idx_lit;
103 let pos_next = pos + #g :: NUM_INIT_KEYS;
104 let base_key = context.base_key().derive_tag_key(linera_views::views::MIN_VIEW_TAG, &index)?;
105 let #name = #g :: post_load(context.clone_with_base_key(base_key), &values[pos..pos_next])?;
106 pos = pos_next;
107 });
108 }
109
110 let first_name_quote = name_quotes
111 .first()
112 .expect("list of names should be non-empty");
113
114 let load_metrics = if root && cfg!(feature = "metrics") {
115 quote! {
116 #[cfg(not(target_arch = "wasm32"))]
117 linera_views::metrics::increment_counter(
118 &linera_views::metrics::LOAD_VIEW_COUNTER,
119 stringify!(#struct_name),
120 &context.base_key().bytes,
121 );
122 #[cfg(not(target_arch = "wasm32"))]
123 use linera_views::metrics::prometheus_util::MeasureLatency as _;
124 let _latency = linera_views::metrics::LOAD_VIEW_LATENCY.measure_latency();
125 }
126 } else {
127 quote! {}
128 };
129
130 quote! {
131 impl #impl_generics linera_views::views::View for #struct_name #type_generics
132 where
133 #context: linera_views::context::Context,
134 #(#input_constraints,)*
135 #(#field_types: linera_views::views::View<Context = #context>,)*
136 {
137 const NUM_INIT_KEYS: usize = #(<#field_types as linera_views::views::View>::NUM_INIT_KEYS)+*;
138
139 type Context = #context;
140
141 fn context(&self) -> &#context {
142 use linera_views::views::View;
143 self.#first_name_quote.context()
144 }
145
146 fn pre_load(context: &#context) -> Result<Vec<Vec<u8>>, linera_views::ViewError> {
147 use linera_views::context::Context as _;
148 let mut keys = Vec::new();
149 #(#pre_load_keys_quotes)*
150 Ok(keys)
151 }
152
153 fn post_load(context: #context, values: &[Option<Vec<u8>>]) -> Result<Self, linera_views::ViewError> {
154 use linera_views::context::Context as _;
155 let mut pos = 0;
156 #(#post_load_keys_quotes)*
157 Ok(Self {#(#name_quotes),*})
158 }
159
160 async fn load(context: #context) -> Result<Self, linera_views::ViewError> {
161 use linera_views::{context::Context as _, store::ReadableKeyValueStore as _};
162 #load_metrics
163 if Self::NUM_INIT_KEYS == 0 {
164 Self::post_load(context, &[])
165 } else {
166 let keys = Self::pre_load(&context)?;
167 let values = context.store().read_multi_values_bytes(&keys).await?;
168 Self::post_load(context, &values)
169 }
170 }
171
172
173 fn rollback(&mut self) {
174 #(#rollback_quotes)*
175 }
176
177 async fn has_pending_changes(&self) -> bool {
178 #(#has_pending_changes_quotes)*
179 false
180 }
181
182 fn pre_save(&self, batch: &mut linera_views::batch::Batch) -> Result<bool, linera_views::ViewError> {
183 #(#pre_save_quotes)*
184 Ok( #(#delete_view_quotes)&&* )
185 }
186
187 fn post_save(&mut self) {
188 #(self.#name_quotes.post_save();)*
189 }
190
191 fn clear(&mut self) {
192 #(#clear_quotes)*
193 }
194 }
195 }
196}
197
198fn generate_root_view_code(input: ItemStruct) -> TokenStream2 {
199 let Constraints {
200 input_constraints,
201 impl_generics,
202 type_generics,
203 } = Constraints::get(&input);
204 let struct_name = &input.ident;
205
206 let metrics_code = if cfg!(feature = "metrics") {
207 quote! {
208 #[cfg(not(target_arch = "wasm32"))]
209 linera_views::metrics::increment_counter(
210 &linera_views::metrics::SAVE_VIEW_COUNTER,
211 stringify!(#struct_name),
212 &self.context().base_key().bytes,
213 );
214 }
215 } else {
216 quote! {}
217 };
218
219 let write_batch_with_metrics = if cfg!(feature = "metrics") {
220 quote! {
221 if !batch.is_empty() {
222 #[cfg(not(target_arch = "wasm32"))]
223 let start = std::time::Instant::now();
224 self.context().store().write_batch(batch).await?;
225 #[cfg(not(target_arch = "wasm32"))]
226 {
227 let latency_ms = start.elapsed().as_secs_f64() * 1000.0;
228 linera_views::metrics::SAVE_VIEW_LATENCY
229 .with_label_values(&[stringify!(#struct_name)])
230 .observe(latency_ms);
231 }
232 }
233 }
234 } else {
235 quote! {
236 if !batch.is_empty() {
237 self.context().store().write_batch(batch).await?;
238 }
239 }
240 };
241
242 quote! {
243 impl #impl_generics linera_views::views::RootView for #struct_name #type_generics
244 where
245 #(#input_constraints,)*
246 Self: linera_views::views::View,
247 {
248 async fn save(&mut self) -> Result<(), linera_views::ViewError> {
249 use linera_views::{context::Context as _, batch::Batch, store::WritableKeyValueStore as _, views::View as _};
250 #metrics_code
251 let mut batch = Batch::new();
252 self.pre_save(&mut batch)?;
253 #write_batch_with_metrics
254 self.post_save();
255 Ok(())
256 }
257 }
258 }
259}
260
261fn generate_hash_view_code(input: ItemStruct) -> TokenStream2 {
262 let Constraints {
263 input_constraints,
264 impl_generics,
265 type_generics,
266 } = Constraints::get(&input);
267 let struct_name = &input.ident;
268
269 let field_types = input.fields.iter().map(|field| &field.ty);
270 let mut field_hashes_mut = Vec::new();
271 let mut field_hashes = Vec::new();
272 for e in &input.fields {
273 let name = e.ident.as_ref().unwrap();
274 field_hashes_mut.push(quote! { hasher.write_all(self.#name.hash_mut().await?.as_ref())?; });
275 field_hashes.push(quote! { hasher.write_all(self.#name.hash().await?.as_ref())?; });
276 }
277
278 quote! {
279 impl #impl_generics linera_views::views::HashableView for #struct_name #type_generics
280 where
281 #(#field_types: linera_views::views::HashableView,)*
282 #(#input_constraints,)*
283 Self: linera_views::views::View,
284 {
285 type Hasher = linera_views::sha3::Sha3_256;
286
287 async fn hash_mut(&mut self) -> Result<<Self::Hasher as linera_views::views::Hasher>::Output, linera_views::ViewError> {
288 use linera_views::views::{Hasher, HashableView};
289 use std::io::Write;
290 let mut hasher = Self::Hasher::default();
291 #(#field_hashes_mut)*
292 Ok(hasher.finalize())
293 }
294
295 async fn hash(&self) -> Result<<Self::Hasher as linera_views::views::Hasher>::Output, linera_views::ViewError> {
296 use linera_views::views::{Hasher, HashableView};
297 use std::io::Write;
298 let mut hasher = Self::Hasher::default();
299 #(#field_hashes)*
300 Ok(hasher.finalize())
301 }
302 }
303 }
304}
305
306fn generate_crypto_hash_code(input: ItemStruct) -> TokenStream2 {
307 let Constraints {
308 input_constraints,
309 impl_generics,
310 type_generics,
311 } = Constraints::get(&input);
312 let field_types = input.fields.iter().map(|field| &field.ty);
313 let struct_name = &input.ident;
314 let hash_type = syn::Ident::new(&format!("{struct_name}Hash"), Span::call_site());
315 quote! {
316 impl #impl_generics linera_views::views::CryptoHashView
317 for #struct_name #type_generics
318 where
319 #(#field_types: linera_views::views::HashableView,)*
320 #(#input_constraints,)*
321 Self: linera_views::views::View,
322 {
323 async fn crypto_hash(&self) -> Result<linera_base::crypto::CryptoHash, linera_views::ViewError> {
324 use linera_base::crypto::{BcsHashable, CryptoHash};
325 use linera_views::{
326 batch::Batch,
327 generic_array::GenericArray,
328 sha3::{digest::OutputSizeUser, Sha3_256},
329 views::HashableView,
330 };
331 use serde::{Serialize, Deserialize};
332 #[derive(Serialize, Deserialize)]
333 struct #hash_type(GenericArray<u8, <Sha3_256 as OutputSizeUser>::OutputSize>);
334 impl<'de> BcsHashable<'de> for #hash_type {}
335 let hash = self.hash().await?;
336 Ok(CryptoHash::new(&#hash_type(hash)))
337 }
338
339 async fn crypto_hash_mut(&mut self) -> Result<linera_base::crypto::CryptoHash, linera_views::ViewError> {
340 use linera_base::crypto::{BcsHashable, CryptoHash};
341 use linera_views::{
342 batch::Batch,
343 generic_array::GenericArray,
344 sha3::{digest::OutputSizeUser, Sha3_256},
345 views::HashableView,
346 };
347 use serde::{Serialize, Deserialize};
348 #[derive(Serialize, Deserialize)]
349 struct #hash_type(GenericArray<u8, <Sha3_256 as OutputSizeUser>::OutputSize>);
350 impl<'de> BcsHashable<'de> for #hash_type {}
351 let hash = self.hash_mut().await?;
352 Ok(CryptoHash::new(&#hash_type(hash)))
353 }
354 }
355 }
356}
357
358fn generate_clonable_view_code(input: ItemStruct) -> TokenStream2 {
359 let Constraints {
360 input_constraints,
361 impl_generics,
362 type_generics,
363 } = Constraints::get(&input);
364 let struct_name = &input.ident;
365
366 let mut clone_constraints = vec![];
367 let mut clone_fields = vec![];
368
369 for field in &input.fields {
370 let name = &field.ident;
371 let ty = &field.ty;
372 clone_constraints.push(quote! { #ty: ClonableView });
373 clone_fields.push(quote! { #name: self.#name.clone_unchecked()? });
374 }
375
376 quote! {
377 impl #impl_generics linera_views::views::ClonableView for #struct_name #type_generics
378 where
379 #(#input_constraints,)*
380 #(#clone_constraints,)*
381 Self: linera_views::views::View,
382 {
383 fn clone_unchecked(&mut self) -> Result<Self, linera_views::ViewError> {
384 Ok(Self {
385 #(#clone_fields,)*
386 })
387 }
388 }
389 }
390}
391
392#[proc_macro_derive(View, attributes(view))]
393pub fn derive_view(input: TokenStream) -> TokenStream {
394 let input = parse_macro_input!(input as ItemStruct);
395 generate_view_code(input, false).into()
396}
397
398#[proc_macro_derive(HashableView, attributes(view))]
399pub fn derive_hash_view(input: TokenStream) -> TokenStream {
400 let input = parse_macro_input!(input as ItemStruct);
401 let mut stream = generate_view_code(input.clone(), false);
402 stream.extend(generate_hash_view_code(input));
403 stream.into()
404}
405
406#[proc_macro_derive(RootView, attributes(view))]
407pub fn derive_root_view(input: TokenStream) -> TokenStream {
408 let input = parse_macro_input!(input as ItemStruct);
409 let mut stream = generate_view_code(input.clone(), true);
410 stream.extend(generate_root_view_code(input));
411 stream.into()
412}
413
414#[proc_macro_derive(CryptoHashView, attributes(view))]
415pub fn derive_crypto_hash_view(input: TokenStream) -> TokenStream {
416 let input = parse_macro_input!(input as ItemStruct);
417 let mut stream = generate_view_code(input.clone(), false);
418 stream.extend(generate_hash_view_code(input.clone()));
419 stream.extend(generate_crypto_hash_code(input));
420 stream.into()
421}
422
423#[proc_macro_derive(CryptoHashRootView, attributes(view))]
424pub fn derive_crypto_hash_root_view(input: TokenStream) -> TokenStream {
425 let input = parse_macro_input!(input as ItemStruct);
426 let mut stream = generate_view_code(input.clone(), true);
427 stream.extend(generate_root_view_code(input.clone()));
428 stream.extend(generate_hash_view_code(input.clone()));
429 stream.extend(generate_crypto_hash_code(input));
430 stream.into()
431}
432
433#[proc_macro_derive(HashableRootView, attributes(view))]
434#[cfg(test)]
435pub fn derive_hashable_root_view(input: TokenStream) -> TokenStream {
436 let input = parse_macro_input!(input as ItemStruct);
437 let mut stream = generate_view_code(input.clone(), true);
438 stream.extend(generate_root_view_code(input.clone()));
439 stream.extend(generate_hash_view_code(input));
440 stream.into()
441}
442
443#[proc_macro_derive(ClonableView, attributes(view))]
444pub fn derive_clonable_view(input: TokenStream) -> TokenStream {
445 let input = parse_macro_input!(input as ItemStruct);
446 generate_clonable_view_code(input).into()
447}
448
449#[cfg(test)]
450pub mod tests {
451
452 use quote::quote;
453 use syn::{parse_quote, AngleBracketedGenericArguments};
454
455 use crate::*;
456
457 fn pretty(tokens: TokenStream2) -> String {
458 prettyplease::unparse(
459 &syn::parse2::<syn::File>(tokens).expect("failed to parse test output"),
460 )
461 }
462
463 #[test]
464 fn test_generate_view_code() {
465 for context in SpecificContextInfo::test_cases() {
466 let input = context.test_view_input();
467 insta::assert_snapshot!(
468 format!(
469 "test_generate_view_code{}_{}",
470 if cfg!(feature = "metrics") {
471 "_metrics"
472 } else {
473 ""
474 },
475 context.name,
476 ),
477 pretty(generate_view_code(input, true))
478 );
479 }
480 }
481
482 #[test]
483 fn test_generate_hash_view_code() {
484 for context in SpecificContextInfo::test_cases() {
485 let input = context.test_view_input();
486 insta::assert_snapshot!(
487 format!("test_generate_hash_view_code_{}", context.name),
488 pretty(generate_hash_view_code(input))
489 );
490 }
491 }
492
493 #[test]
494 fn test_generate_root_view_code() {
495 for context in SpecificContextInfo::test_cases() {
496 let input = context.test_view_input();
497 insta::assert_snapshot!(
498 format!(
499 "test_generate_root_view_code{}_{}",
500 if cfg!(feature = "metrics") {
501 "_metrics"
502 } else {
503 ""
504 },
505 context.name,
506 ),
507 pretty(generate_root_view_code(input))
508 );
509 }
510 }
511
512 #[test]
513 fn test_generate_crypto_hash_code() {
514 for context in SpecificContextInfo::test_cases() {
515 let input = context.test_view_input();
516 insta::assert_snapshot!(pretty(generate_crypto_hash_code(input)));
517 }
518 }
519
520 #[test]
521 fn test_generate_clonable_view_code() {
522 for context in SpecificContextInfo::test_cases() {
523 let input = context.test_view_input();
524 insta::assert_snapshot!(pretty(generate_clonable_view_code(input)));
525 }
526 }
527
528 #[derive(Clone)]
529 pub struct SpecificContextInfo {
530 name: String,
531 attribute: Option<TokenStream2>,
532 context: Type,
533 generics: AngleBracketedGenericArguments,
534 where_clause: Option<TokenStream2>,
535 }
536
537 impl SpecificContextInfo {
538 pub fn empty() -> Self {
539 SpecificContextInfo {
540 name: "C".to_string(),
541 attribute: None,
542 context: syn::parse_quote! { C },
543 generics: syn::parse_quote! { <C> },
544 where_clause: None,
545 }
546 }
547
548 pub fn new(context: syn::Type) -> Self {
549 let name = quote! { #context };
550 SpecificContextInfo {
551 name: format!("{name}")
552 .replace(' ', "")
553 .replace([':', '<', '>'], "_"),
554 attribute: Some(quote! { #[view(context = #context)] }),
555 context,
556 generics: parse_quote! { <> },
557 where_clause: None,
558 }
559 }
560
561 pub fn with_dummy_where_clause(mut self) -> Self {
566 self.generics.args.push(parse_quote! { MyParam });
567 self.where_clause = Some(quote! {
568 where MyParam: Send + Sync + 'static,
569 });
570 self.name.push_str("_with_where");
571
572 self
573 }
574
575 pub fn test_cases() -> impl Iterator<Item = Self> {
576 Some(Self::empty())
577 .into_iter()
578 .chain(
579 [
580 syn::parse_quote! { CustomContext },
581 syn::parse_quote! { custom::path::to::ContextType },
582 syn::parse_quote! { custom::GenericContext<T> },
583 ]
584 .into_iter()
585 .map(Self::new),
586 )
587 .flat_map(|case| [case.clone(), case.with_dummy_where_clause()])
588 }
589
590 pub fn test_view_input(&self) -> ItemStruct {
591 let SpecificContextInfo {
592 attribute,
593 context,
594 generics,
595 where_clause,
596 ..
597 } = self;
598
599 parse_quote! {
600 #attribute
601 struct TestView #generics
602 #where_clause
603 {
604 register: RegisterView<#context, usize>,
605 collection: CollectionView<#context, usize, RegisterView<#context, usize>>,
606 }
607 }
608 }
609 }
610}