inc_complete_derive/
lib.rs

1//! Derive macro for automatically implementing Storage trait
2//! This generates a call to the impl_storage! macro based on the struct fields.
3
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::quote;
7use syn::{
8    Attribute, Data, DeriveInput, Expr, ExprLit, Fields, GenericArgument, Lit, Meta, PathArguments,
9    Type, TypePath, parse_macro_input, spanned::Spanned,
10};
11
12// =============================================================================
13// Helper Functions for Attribute Parsing
14// =============================================================================
15
16/// Parse a type attribute from a Meta::NameValue
17fn parse_type_attribute(
18    name_value: &syn::MetaNameValue,
19    expected_name: &str,
20) -> syn::Result<proc_macro2::TokenStream> {
21    let name = name_value
22        .path
23        .get_ident()
24        .ok_or_else(|| syn::Error::new_spanned(&name_value.path, "expected identifier"))?
25        .to_string();
26
27    if name != expected_name {
28        return Err(syn::Error::new_spanned(
29            &name_value.path,
30            format!("expected '{}', found '{}'", expected_name, name),
31        ));
32    }
33
34    let expr = &name_value.value;
35    Ok(quote::quote! { #expr })
36}
37
38/// Parse an integer attribute from a Meta::NameValue
39fn parse_int_attribute(
40    name_value: &syn::MetaNameValue,
41    expected_name: &str,
42) -> syn::Result<proc_macro2::TokenStream> {
43    let name = name_value
44        .path
45        .get_ident()
46        .ok_or_else(|| syn::Error::new_spanned(&name_value.path, "expected identifier"))?
47        .to_string();
48
49    if name != expected_name {
50        return Err(syn::Error::new_spanned(
51            &name_value.path,
52            format!("expected '{}', found '{}'", expected_name, name),
53        ));
54    }
55
56    if let Expr::Lit(ExprLit {
57        lit: Lit::Int(lit_int),
58        ..
59    }) = &name_value.value
60    {
61        Ok(quote! { #lit_int })
62    } else {
63        Err(syn::Error::new_spanned(
64            &name_value.value,
65            format!("{} must be an integer", expected_name),
66        ))
67    }
68}
69
70/// Parse a boolean flag attribute from a Meta::Path
71fn parse_flag_attribute(path: &syn::Path, expected_name: &str) -> syn::Result<bool> {
72    if path.is_ident(expected_name) {
73        Ok(true)
74    } else {
75        Err(syn::Error::new_spanned(
76            path,
77            format!("unknown flag attribute, expected '{}'", expected_name),
78        ))
79    }
80}
81
82/// Extract generic type T from Container<T>
83fn extract_generic_type(ty: &Type, container_suffix: &str) -> Option<Type> {
84    if let Type::Path(TypePath { path, .. }) = ty {
85        if let Some(segment) = path.segments.last() {
86            if segment.ident.to_string().ends_with(container_suffix) {
87                if let PathArguments::AngleBracketed(args) = &segment.arguments {
88                    if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
89                        return Some(inner_type.clone());
90                    }
91                }
92            }
93        }
94    }
95    None
96}
97
98/// Parser specifically for Input derive attributes
99struct InputAttributeParser {
100    id: Option<proc_macro2::TokenStream>,
101    output_type: Option<proc_macro2::TokenStream>,
102    storage_type: Option<proc_macro2::TokenStream>,
103    assume_changed: bool,
104}
105
106impl InputAttributeParser {
107    fn new() -> Self {
108        Self {
109            id: None,
110            output_type: None,
111            storage_type: None,
112            assume_changed: false,
113        }
114    }
115
116    fn parse_attribute_list(&mut self, attr: &Attribute) -> syn::Result<()> {
117        let meta = attr.meta.clone();
118        if let Meta::List(meta_list) = meta {
119            let parsed: syn::punctuated::Punctuated<Meta, syn::Token![,]> =
120                meta_list.parse_args_with(syn::punctuated::Punctuated::parse_terminated)?;
121
122            for meta in parsed {
123                match meta {
124                    Meta::Path(path) => {
125                        if parse_flag_attribute(&path, "assume_changed")? {
126                            self.assume_changed = true;
127                        }
128                    }
129                    Meta::NameValue(name_value) => {
130                        let name = name_value
131                            .path
132                            .get_ident()
133                            .ok_or_else(|| {
134                                syn::Error::new_spanned(&name_value.path, "expected identifier")
135                            })?
136                            .to_string();
137
138                        match name.as_str() {
139                            "id" => {
140                                self.id = Some(parse_int_attribute(&name_value, "id")?);
141                            }
142                            "output" => {
143                                self.output_type =
144                                    Some(parse_type_attribute(&name_value, "output")?);
145                            }
146                            "storage" => {
147                                self.storage_type =
148                                    Some(parse_type_attribute(&name_value, "storage")?);
149                            }
150                            _ => {
151                                return Err(syn::Error::new_spanned(
152                                    &name_value.path,
153                                    format!(
154                                        "unknown attribute '{}' for Input derive. Valid attributes are: id, output, storage, assume_changed",
155                                        name
156                                    ),
157                                ));
158                            }
159                        }
160                    }
161                    _ => {
162                        return Err(syn::Error::new_spanned(
163                            &meta,
164                            "unsupported attribute format",
165                        ));
166                    }
167                }
168            }
169        }
170        Ok(())
171    }
172
173    fn validate(
174        &self,
175        attr_span: Span,
176    ) -> syn::Result<(
177        proc_macro2::TokenStream,
178        proc_macro2::TokenStream,
179        proc_macro2::TokenStream,
180        bool,
181    )> {
182        let id = self
183            .id
184            .clone()
185            .ok_or_else(|| syn::Error::new(attr_span, "missing required 'id' attribute"))?;
186        let output_type = self
187            .output_type
188            .clone()
189            .ok_or_else(|| syn::Error::new(attr_span, "missing required 'output' attribute"))?;
190        let storage_type = self
191            .storage_type
192            .clone()
193            .ok_or_else(|| syn::Error::new(attr_span, "missing required 'storage' attribute"))?;
194
195        Ok((id, output_type, storage_type, self.assume_changed))
196    }
197}
198
199/// Parser specifically for Intermediate macro attributes
200struct IntermediateAttributeParser {
201    id: Option<proc_macro2::TokenStream>,
202    assume_changed: bool,
203}
204
205impl IntermediateAttributeParser {
206    fn new() -> Self {
207        Self {
208            id: None,
209            assume_changed: false,
210        }
211    }
212
213    fn parse_meta(&mut self, meta: &Meta) -> syn::Result<()> {
214        match meta {
215            Meta::Path(path) => {
216                if parse_flag_attribute(path, "assume_changed")? {
217                    self.assume_changed = true;
218                }
219            }
220            Meta::NameValue(name_value) => {
221                let name = name_value
222                    .path
223                    .get_ident()
224                    .ok_or_else(|| {
225                        syn::Error::new_spanned(&name_value.path, "expected identifier")
226                    })?
227                    .to_string();
228
229                match name.as_str() {
230                    "id" => {
231                        self.id = Some(parse_int_attribute(name_value, "id")?);
232                    }
233                    _ => {
234                        return Err(syn::Error::new_spanned(
235                            &name_value.path,
236                            format!(
237                                "unknown attribute '{}' for intermediate macro. Valid attributes are: id, assume_changed",
238                                name
239                            ),
240                        ));
241                    }
242                }
243            }
244            _ => {
245                return Err(syn::Error::new_spanned(
246                    meta,
247                    "unsupported attribute format",
248                ));
249            }
250        }
251        Ok(())
252    }
253
254    fn validate(&self, attr_span: Span) -> syn::Result<(proc_macro2::TokenStream, bool)> {
255        let id = self
256            .id
257            .clone()
258            .ok_or_else(|| syn::Error::new(attr_span, "missing required 'id' attribute"))?;
259
260        Ok((id, self.assume_changed))
261    }
262}
263
264// =============================================================================
265// Storage Derive Macro
266// =============================================================================
267
268/// Derive macro for Storage trait
269///
270/// Usage:
271/// ```rust
272/// #[derive(Storage)]
273/// struct MyStorage {
274///     numbers: SingletonStorage<Number>,
275///     strings: HashMapStorage<StringComputation>,
276/// }
277/// ```
278///
279/// This generates a call to `impl_storage!(MyStorage, numbers: Number, strings: StringComputation)`
280/// which provides the actual implementation.
281///
282/// ## Skip Attribute
283///
284/// Fields can be excluded from the Storage implementation using `#[inc_complete(skip)]`:
285///
286/// ```rust
287/// #[derive(Storage)]
288/// struct MyStorage {
289///     numbers: SingletonStorage<Number>,
290///
291///     #[inc_complete(skip)]
292///     metadata: String,  // This field won't be included in Storage implementation
293/// }
294///
295/// ## Computation Attribute
296///
297/// The computation type for a field can be manually specified using `#[inc_complete(computation = Type)]`.
298/// This is required if the field type is not already generic over the computation type. For
299/// example, `HashMapStorage<MyComputationType>` is generic over `MyComputationType` but
300/// `MyStringStorage` is not:
301///
302/// ```rust
303/// #[derive(Storage)]
304/// struct MyStorage {
305///     numbers: SingletonStorage<Number>,
306///
307///     #[inc_complete(computation = MyStringInput)]
308///     strings: MyStringStorage,
309/// }
310/// ```
311#[proc_macro_derive(Storage, attributes(inc_complete))]
312pub fn derive_storage(input: TokenStream) -> TokenStream {
313    let input = parse_macro_input!(input as DeriveInput);
314
315    let struct_name = &input.ident;
316
317    // Extract field information
318    let fields = match &input.data {
319        Data::Struct(data) => match &data.fields {
320            Fields::Named(fields) => &fields.named,
321            _ => {
322                return syn::Error::new(
323                    Span::call_site(),
324                    "Storage derive only works on structs with named fields",
325                )
326                .to_compile_error()
327                .into();
328            }
329        },
330        _ => {
331            return syn::Error::new(Span::call_site(), "Storage derive only works on structs")
332                .to_compile_error()
333                .into();
334        }
335    };
336
337    // Extract computation types from storage fields
338    let mut field_mappings = Vec::new();
339    let mut accumulated = Vec::new();
340
341    for field in fields {
342        let field_name = field.ident.as_ref().unwrap();
343        let field_type = &field.ty;
344
345        // Parse attributes
346        let mut skip_field = false;
347        let mut is_accumulated = false;
348        let mut manual_computation_type: Option<Type> = None;
349
350        for attr in &field.attrs {
351            if attr.path().is_ident("inc_complete") {
352                match attr.meta {
353                    Meta::List(ref list) => {
354                        // Parse nested attributes like #[inc_complete(skip)] or #[inc_complete(computation = Type)]
355                        let nested_result =
356                            list.parse_args_with(|parser: syn::parse::ParseStream| {
357                                while !parser.is_empty() {
358                                    let lookahead = parser.lookahead1();
359                                    if lookahead.peek(syn::Ident) {
360                                        let ident: syn::Ident = parser.parse()?;
361                                        if ident == "skip" {
362                                            skip_field = true;
363                                        } else if ident == "computation" {
364                                            parser.parse::<syn::Token![=]>()?;
365                                            manual_computation_type = Some(parser.parse()?);
366                                        } else if ident == "accumulate" {
367                                            is_accumulated = true;
368                                        } else {
369                                            return Err(syn::Error::new_spanned(
370                                                ident,
371                                                "expected 'skip' or 'computation'",
372                                            ));
373                                        }
374                                    } else {
375                                        return Err(lookahead.error());
376                                    }
377
378                                    if !parser.is_empty() {
379                                        parser.parse::<syn::Token![,]>()?;
380                                    }
381                                }
382
383                                Ok(())
384                            });
385
386                        if let Err(e) = nested_result {
387                            return e.to_compile_error().into();
388                        }
389                    }
390                    _ => {
391                        return syn::Error::new_spanned(
392                            attr,
393                            "expected #[inc_complete(skip)] or #[inc_complete(computation = Type)]",
394                        )
395                        .to_compile_error()
396                        .into();
397                    }
398                }
399            }
400        }
401
402        if skip_field {
403            // Skip this field - don't include it in the impl_storage! call
404            continue;
405        }
406
407        // Determine the computation type
408        let computation_type = if let Some(manual_type) = manual_computation_type {
409            // Use manually specified type
410            manual_type
411        } else if let Some(extracted_type) = extract_generic_type(field_type, "Storage") {
412            // Try to extract the generic type from SingletonStorage<T>, HashMapStorage<T>, etc.
413            extracted_type
414        } else {
415            return syn::Error::new(
416                field.span(),
417                "Field must be a storage type like SingletonStorage<T>, HashMapStorage<T>, or use #[inc_complete(computation = Type)] to specify the type manually, or use #[inc_complete(skip)] to exclude it",
418            )
419            .to_compile_error()
420            .into();
421        };
422
423        let item = quote! { #field_name: #computation_type, };
424        if is_accumulated {
425            accumulated.push(item);
426        } else {
427            field_mappings.push(item);
428        }
429    }
430
431    // Generate a call to impl_storage! macro
432    let expanded = quote! {
433        inc_complete::impl_storage!(#struct_name, #(#field_mappings)* @accumulators { #(#accumulated)* });
434    };
435
436    TokenStream::from(expanded)
437}
438
439// =============================================================================
440// Input Derive Macro
441// =============================================================================
442
443/// Derive macro for Input computation
444///
445/// Usage:
446/// ```rust
447/// #[derive(Input)]
448/// #[inc_complete(id = 0, output = i32, storage = MyStorage)]
449/// struct MyInput;
450/// ```
451///
452/// This generates a call to `define_input!(0, MyInput -> i32, MyStorage)`
453/// which provides the actual implementation.
454///
455/// Required attributes:
456/// - `id`: The unique computation ID (integer)
457/// - `output`: The output type
458/// - `storage`: The storage type
459///
460/// Optional attribute:
461/// - `assume_changed`: If present, indicates that the input is assumed to have changed
462#[proc_macro_derive(Input, attributes(inc_complete))]
463pub fn derive_input(input: TokenStream) -> TokenStream {
464    let input = parse_macro_input!(input as DeriveInput);
465
466    let struct_name = &input.ident;
467
468    // Find the inc_complete attribute
469    let inc_complete_attr = input
470        .attrs
471        .iter()
472        .find(|attr| attr.path().is_ident("inc_complete"));
473
474    let attr = match inc_complete_attr {
475        Some(attr) => attr,
476        None => {
477            return syn::Error::new(
478                Span::call_site(),
479                "Input derive requires #[inc_complete(...)] attribute",
480            )
481            .to_compile_error()
482            .into();
483        }
484    };
485
486    // Parse the attribute arguments using the Input-specific parser
487    let mut parser = InputAttributeParser::new();
488    if let Err(err) = parser.parse_attribute_list(attr) {
489        return err.to_compile_error().into();
490    }
491
492    let (id, output_type, storage_type, assume_changed) = match parser.validate(attr.span()) {
493        Ok(values) => values,
494        Err(err) => return err.to_compile_error().into(),
495    };
496
497    // Generate the define_input! call
498    let expanded = if assume_changed {
499        quote! {
500            inc_complete::define_input!(#id, assume_changed #struct_name -> #output_type, #storage_type);
501        }
502    } else {
503        quote! {
504            inc_complete::define_input!(#id, #struct_name -> #output_type, #storage_type);
505        }
506    };
507
508    TokenStream::from(expanded)
509}
510
511// =============================================================================
512// Intermediate Macros
513// =============================================================================
514
515/// Procedural macro for defining intermediate computations directly on functions
516///
517/// Usage:
518/// ```rust
519/// // Assumes ComputeDouble struct is already defined
520/// #[intermediate(id = 1)]
521/// fn compute_double(_context: &ComputeDouble, db: &DbHandle<MyStorage>) -> i32 {
522///     db.get(InputValue) * 2
523/// }
524/// ```
525///
526/// This generates a call to `define_intermediate!` using the extracted information.
527/// Assumes the computation struct is already defined.
528///
529/// Required attributes:
530/// - `id`: The computation ID (integer)
531///
532/// Optional attribute:
533/// - `assume_changed`: If present, indicates that the input is assumed to have changed
534///
535/// The macro automatically extracts:
536/// - **Output type** from the function return type (`i32`)
537/// - **Computation type** from the first parameter (`ComputeDouble` from `&ComputeDouble`)
538/// - **Storage type** from the second parameter (`MyStorage` from `&DbHandle<MyStorage>`)
539/// - **Function name** automatically (`compute_double`)
540///
541#[proc_macro_attribute]
542pub fn intermediate(args: TokenStream, input: TokenStream) -> TokenStream {
543    let args = parse_macro_input!(args with syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated);
544    let input_fn = parse_macro_input!(input as syn::ItemFn);
545
546    match process_intermediate_function(args, input_fn) {
547        Ok(tokens) => tokens.into(),
548        Err(err) => err.to_compile_error().into(),
549    }
550}
551
552// =============================================================================
553// Helper Functions for Intermediate Macros
554// =============================================================================
555
556/// Process the intermediate function and generate the appropriate code
557fn process_intermediate_function(
558    args: syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>,
559    input_fn: syn::ItemFn,
560) -> syn::Result<proc_macro2::TokenStream> {
561    // Parse the attributes using the Intermediate-specific parser
562    let mut parser = IntermediateAttributeParser::new();
563    for arg in args {
564        parser.parse_meta(&arg)?;
565    }
566    let (id, assume_changed) = parser.validate(Span::call_site())?;
567
568    // Extract information from the function signature
569    let fn_name = &input_fn.sig.ident;
570    let output_type = extract_return_type(&input_fn.sig.output)?;
571    let (computation_type, storage_type) = extract_types_from_params(&input_fn.sig.inputs)?;
572
573    // Generate the define_intermediate call, and optionally the struct
574    let expanded = if assume_changed {
575        quote! {
576            #input_fn
577
578            inc_complete::define_intermediate!(#id, assume_changed #computation_type -> #output_type, #storage_type, #fn_name);
579        }
580    } else {
581        quote! {
582            #input_fn
583
584            inc_complete::define_intermediate!(#id, #computation_type -> #output_type, #storage_type, #fn_name);
585        }
586    };
587
588    Ok(expanded)
589}
590
591/// Extract the return type from a function signature
592fn extract_return_type(output: &syn::ReturnType) -> syn::Result<proc_macro2::TokenStream> {
593    match output {
594        syn::ReturnType::Type(_, ty) => Ok(quote! { #ty }),
595        syn::ReturnType::Default => Err(syn::Error::new(
596            Span::call_site(),
597            "function must have an explicit return type",
598        )),
599    }
600}
601
602/// Extract both computation type and storage type from function parameters
603fn extract_types_from_params(
604    inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>,
605) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
606    let mut iter = inputs.iter();
607
608    // Extract computation type from first parameter
609    let first_arg = iter.next().ok_or_else(|| {
610        syn::Error::new(
611            Span::call_site(),
612            "function must have at least two parameters: (&ComputationType, &DbHandle<StorageType>)",
613        )
614    })?;
615
616    let computation_type = extract_reference_inner_type(
617        first_arg,
618        "first parameter must be a reference to the computation type (e.g., &ComputationType)",
619    )?;
620
621    // Extract storage type from second parameter
622    let second_arg = iter.next().ok_or_else(|| {
623        syn::Error::new(
624            Span::call_site(),
625            "function must have a second parameter of type &DbHandle<StorageType>",
626        )
627    })?;
628
629    let storage_type = extract_dbhandle_storage_type(second_arg)?;
630
631    Ok((computation_type, storage_type))
632}
633
634/// Extract inner type from a reference parameter
635fn extract_reference_inner_type(
636    arg: &syn::FnArg,
637    error_msg: &str,
638) -> syn::Result<proc_macro2::TokenStream> {
639    if let syn::FnArg::Typed(pat_type) = arg {
640        if let syn::Type::Reference(type_ref) = pat_type.ty.as_ref() {
641            let inner_type = &type_ref.elem;
642            Ok(quote! { #inner_type })
643        } else {
644            Err(syn::Error::new_spanned(arg, error_msg))
645        }
646    } else {
647        Err(syn::Error::new_spanned(arg, error_msg))
648    }
649}
650
651/// Extract storage type from &DbHandle<StorageType> parameter
652fn extract_dbhandle_storage_type(arg: &syn::FnArg) -> syn::Result<proc_macro2::TokenStream> {
653    let error_msg = "second parameter must be &DbHandle<StorageType>";
654
655    if let syn::FnArg::Typed(pat_type) = arg {
656        if let syn::Type::Reference(type_ref) = pat_type.ty.as_ref() {
657            if let syn::Type::Path(type_path) = type_ref.elem.as_ref() {
658                // Look for DbHandle<StorageType>
659                if let Some(segment) = type_path.path.segments.last() {
660                    if segment.ident == "DbHandle" {
661                        if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
662                            if let Some(syn::GenericArgument::Type(storage_ty)) = args.args.first()
663                            {
664                                return Ok(quote! { #storage_ty });
665                            }
666                        }
667                        return Err(syn::Error::new_spanned(
668                            arg,
669                            "DbHandle must have a generic type parameter for the storage type",
670                        ));
671                    }
672                }
673            }
674        }
675    }
676
677    Err(syn::Error::new_spanned(arg, error_msg))
678}