inc_complete_derive/
lib.rs

1//! Derive macro for automatically implementing Storage trait
2//! This generates a call to the impl_storage! macro based on the struct fields.
3
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::quote;
7use syn::{
8    Attribute, Data, DeriveInput, Expr, ExprLit, Fields, GenericArgument, Lit, Meta, PathArguments,
9    Type, TypePath, parse_macro_input, spanned::Spanned,
10};
11
12// =============================================================================
13// Helper Functions for Attribute Parsing
14// =============================================================================
15
16/// Parse a type attribute from a Meta::NameValue
17fn parse_type_attribute(
18    name_value: &syn::MetaNameValue,
19    expected_name: &str,
20) -> syn::Result<proc_macro2::TokenStream> {
21    let name = name_value
22        .path
23        .get_ident()
24        .ok_or_else(|| syn::Error::new_spanned(&name_value.path, "expected identifier"))?
25        .to_string();
26
27    if name != expected_name {
28        return Err(syn::Error::new_spanned(
29            &name_value.path,
30            format!("expected '{}', found '{}'", expected_name, name),
31        ));
32    }
33
34    let expr = &name_value.value;
35    Ok(quote::quote! { #expr })
36}
37
38/// Parse an integer attribute from a Meta::NameValue
39fn parse_int_attribute(
40    name_value: &syn::MetaNameValue,
41    expected_name: &str,
42) -> syn::Result<proc_macro2::TokenStream> {
43    let name = name_value
44        .path
45        .get_ident()
46        .ok_or_else(|| syn::Error::new_spanned(&name_value.path, "expected identifier"))?
47        .to_string();
48
49    if name != expected_name {
50        return Err(syn::Error::new_spanned(
51            &name_value.path,
52            format!("expected '{}', found '{}'", expected_name, name),
53        ));
54    }
55
56    if let Expr::Lit(ExprLit {
57        lit: Lit::Int(lit_int),
58        ..
59    }) = &name_value.value
60    {
61        Ok(quote! { #lit_int })
62    } else {
63        Err(syn::Error::new_spanned(
64            &name_value.value,
65            format!("{} must be an integer", expected_name),
66        ))
67    }
68}
69
70/// Parse a boolean flag attribute from a Meta::Path
71fn parse_flag_attribute(path: &syn::Path, expected_name: &str) -> syn::Result<bool> {
72    if path.is_ident(expected_name) {
73        Ok(true)
74    } else {
75        Err(syn::Error::new_spanned(
76            path,
77            format!("unknown flag attribute, expected '{}'", expected_name),
78        ))
79    }
80}
81
82/// Extract generic type T from Container<T>
83fn extract_generic_type(ty: &Type) -> Option<Type> {
84    if let Type::Path(TypePath { path, .. }) = ty {
85        if let Some(segment) = path.segments.last() {
86            if let PathArguments::AngleBracketed(args) = &segment.arguments {
87                if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
88                    return Some(inner_type.clone());
89                }
90            }
91        }
92    }
93    None
94}
95
96/// Parser specifically for Input derive attributes
97struct InputAttributeParser {
98    id: Option<proc_macro2::TokenStream>,
99    output_type: Option<proc_macro2::TokenStream>,
100    storage_type: Option<proc_macro2::TokenStream>,
101    assume_changed: bool,
102}
103
104impl InputAttributeParser {
105    fn new() -> Self {
106        Self {
107            id: None,
108            output_type: None,
109            storage_type: None,
110            assume_changed: false,
111        }
112    }
113
114    fn parse_attribute_list(&mut self, attr: &Attribute) -> syn::Result<()> {
115        let meta = attr.meta.clone();
116        if let Meta::List(meta_list) = meta {
117            let parsed: syn::punctuated::Punctuated<Meta, syn::Token![,]> =
118                meta_list.parse_args_with(syn::punctuated::Punctuated::parse_terminated)?;
119
120            for meta in parsed {
121                match meta {
122                    Meta::Path(path) => {
123                        if parse_flag_attribute(&path, "assume_changed")? {
124                            self.assume_changed = true;
125                        }
126                    }
127                    Meta::NameValue(name_value) => {
128                        let name = name_value
129                            .path
130                            .get_ident()
131                            .ok_or_else(|| {
132                                syn::Error::new_spanned(&name_value.path, "expected identifier")
133                            })?
134                            .to_string();
135
136                        match name.as_str() {
137                            "id" => {
138                                self.id = Some(parse_int_attribute(&name_value, "id")?);
139                            }
140                            "output" => {
141                                self.output_type =
142                                    Some(parse_type_attribute(&name_value, "output")?);
143                            }
144                            "storage" => {
145                                self.storage_type =
146                                    Some(parse_type_attribute(&name_value, "storage")?);
147                            }
148                            _ => {
149                                return Err(syn::Error::new_spanned(
150                                    &name_value.path,
151                                    format!(
152                                        "unknown attribute '{}' for Input derive. Valid attributes are: id, output, storage, assume_changed",
153                                        name
154                                    ),
155                                ));
156                            }
157                        }
158                    }
159                    _ => {
160                        return Err(syn::Error::new_spanned(
161                            &meta,
162                            "unsupported attribute format",
163                        ));
164                    }
165                }
166            }
167        }
168        Ok(())
169    }
170
171    fn validate(
172        &self,
173        attr_span: Span,
174    ) -> syn::Result<(
175        proc_macro2::TokenStream,
176        proc_macro2::TokenStream,
177        proc_macro2::TokenStream,
178        bool,
179    )> {
180        let id = self
181            .id
182            .clone()
183            .ok_or_else(|| syn::Error::new(attr_span, "missing required 'id' attribute"))?;
184        let output_type = self
185            .output_type
186            .clone()
187            .ok_or_else(|| syn::Error::new(attr_span, "missing required 'output' attribute"))?;
188        let storage_type = self
189            .storage_type
190            .clone()
191            .ok_or_else(|| syn::Error::new(attr_span, "missing required 'storage' attribute"))?;
192
193        Ok((id, output_type, storage_type, self.assume_changed))
194    }
195}
196
197/// Parser specifically for Intermediate macro attributes
198struct IntermediateAttributeParser {
199    id: Option<proc_macro2::TokenStream>,
200    assume_changed: bool,
201}
202
203impl IntermediateAttributeParser {
204    fn new() -> Self {
205        Self {
206            id: None,
207            assume_changed: false,
208        }
209    }
210
211    fn parse_meta(&mut self, meta: &Meta) -> syn::Result<()> {
212        match meta {
213            Meta::Path(path) => {
214                if parse_flag_attribute(path, "assume_changed")? {
215                    self.assume_changed = true;
216                }
217            }
218            Meta::NameValue(name_value) => {
219                let name = name_value
220                    .path
221                    .get_ident()
222                    .ok_or_else(|| {
223                        syn::Error::new_spanned(&name_value.path, "expected identifier")
224                    })?
225                    .to_string();
226
227                match name.as_str() {
228                    "id" => {
229                        self.id = Some(parse_int_attribute(name_value, "id")?);
230                    }
231                    _ => {
232                        return Err(syn::Error::new_spanned(
233                            &name_value.path,
234                            format!(
235                                "unknown attribute '{}' for intermediate macro. Valid attributes are: id, assume_changed",
236                                name
237                            ),
238                        ));
239                    }
240                }
241            }
242            _ => {
243                return Err(syn::Error::new_spanned(
244                    meta,
245                    "unsupported attribute format",
246                ));
247            }
248        }
249        Ok(())
250    }
251
252    fn validate(&self, attr_span: Span) -> syn::Result<(proc_macro2::TokenStream, bool)> {
253        let id = self
254            .id
255            .clone()
256            .ok_or_else(|| syn::Error::new(attr_span, "missing required 'id' attribute"))?;
257
258        Ok((id, self.assume_changed))
259    }
260}
261
262// =============================================================================
263// Storage Derive Macro
264// =============================================================================
265
266/// Derive macro for Storage trait
267///
268/// Usage:
269/// ```rust
270/// #[derive(Storage)]
271/// struct MyStorage {
272///     numbers: SingletonStorage<Number>,
273///     strings: HashMapStorage<StringComputation>,
274/// }
275/// ```
276///
277/// This generates a call to `impl_storage!(MyStorage, numbers: Number, strings: StringComputation)`
278/// which provides the actual implementation.
279///
280/// ## Skip Attribute
281///
282/// Fields can be excluded from the Storage implementation using `#[inc_complete(skip)]`:
283///
284/// ```rust
285/// #[derive(Storage)]
286/// struct MyStorage {
287///     numbers: SingletonStorage<Number>,
288///
289///     #[inc_complete(skip)]
290///     metadata: String,  // This field won't be included in Storage implementation
291/// }
292///
293/// ## Computation Attribute
294///
295/// The computation type for a field can be manually specified using `#[inc_complete(computation = Type)]`.
296/// This is required if the field type is not already generic over the computation type. For
297/// example, `HashMapStorage<MyComputationType>` is generic over `MyComputationType` but
298/// `MyStringStorage` is not:
299///
300/// ```rust
301/// #[derive(Storage)]
302/// struct MyStorage {
303///     numbers: SingletonStorage<Number>,
304///
305///     #[inc_complete(computation = MyStringInput)]
306///     strings: MyStringStorage,
307/// }
308/// ```
309#[proc_macro_derive(Storage, attributes(inc_complete))]
310pub fn derive_storage(input: TokenStream) -> TokenStream {
311    let input = parse_macro_input!(input as DeriveInput);
312
313    let struct_name = &input.ident;
314
315    // Extract field information
316    let fields = match &input.data {
317        Data::Struct(data) => match &data.fields {
318            Fields::Named(fields) => &fields.named,
319            _ => {
320                return syn::Error::new(
321                    Span::call_site(),
322                    "Storage derive only works on structs with named fields",
323                )
324                .to_compile_error()
325                .into();
326            }
327        },
328        _ => {
329            return syn::Error::new(Span::call_site(), "Storage derive only works on structs")
330                .to_compile_error()
331                .into();
332        }
333    };
334
335    // Extract computation types from storage fields
336    let mut field_mappings = Vec::new();
337    let mut accumulated = Vec::new();
338
339    for field in fields {
340        let field_name = field.ident.as_ref().unwrap();
341        let field_type = &field.ty;
342
343        // Parse attributes
344        let mut skip_field = false;
345        let mut is_accumulated = false;
346        let mut manual_computation_type: Option<Type> = None;
347
348        for attr in &field.attrs {
349            if attr.path().is_ident("inc_complete") {
350                match attr.meta {
351                    Meta::List(ref list) => {
352                        // Parse nested attributes like #[inc_complete(skip)] or #[inc_complete(computation = Type)]
353                        let nested_result =
354                            list.parse_args_with(|parser: syn::parse::ParseStream| {
355                                while !parser.is_empty() {
356                                    let lookahead = parser.lookahead1();
357                                    if lookahead.peek(syn::Ident) {
358                                        let ident: syn::Ident = parser.parse()?;
359                                        if ident == "skip" {
360                                            skip_field = true;
361                                        } else if ident == "computation" {
362                                            parser.parse::<syn::Token![=]>()?;
363                                            manual_computation_type = Some(parser.parse()?);
364                                        } else if ident == "accumulate" {
365                                            is_accumulated = true;
366                                        } else {
367                                            return Err(syn::Error::new_spanned(
368                                                ident,
369                                                "expected 'skip' or 'computation'",
370                                            ));
371                                        }
372                                    } else {
373                                        return Err(lookahead.error());
374                                    }
375
376                                    if !parser.is_empty() {
377                                        parser.parse::<syn::Token![,]>()?;
378                                    }
379                                }
380
381                                Ok(())
382                            });
383
384                        if let Err(e) = nested_result {
385                            return e.to_compile_error().into();
386                        }
387                    }
388                    _ => {
389                        return syn::Error::new_spanned(
390                            attr,
391                            "expected #[inc_complete(skip)] or #[inc_complete(computation = Type)]",
392                        )
393                        .to_compile_error()
394                        .into();
395                    }
396                }
397            }
398        }
399
400        if skip_field {
401            // Skip this field - don't include it in the impl_storage! call
402            continue;
403        }
404
405        // Determine the computation type
406        let computation_type = if let Some(manual_type) = manual_computation_type {
407            // Use manually specified type
408            manual_type
409        } else if let Some(extracted_type) = extract_generic_type(field_type) {
410            // Try to extract the generic type from SingletonStorage<T>, HashMapStorage<T>, etc.
411            extracted_type
412        } else {
413            return syn::Error::new(
414                field.span(),
415                "Field must be a storage type like SingletonStorage<T>, HashMapStorage<T>, or use #[inc_complete(computation = Type)] to specify the type manually, or use #[inc_complete(skip)] to exclude it",
416            )
417            .to_compile_error()
418            .into();
419        };
420
421        let item = quote! { #field_name: #computation_type, };
422        if is_accumulated {
423            accumulated.push(item);
424        } else {
425            field_mappings.push(item);
426        }
427    }
428
429    // Generate a call to impl_storage! macro
430    let expanded = quote! {
431        inc_complete::impl_storage!(#struct_name, #(#field_mappings)* @accumulators { #(#accumulated)* });
432    };
433
434    TokenStream::from(expanded)
435}
436
437// =============================================================================
438// Input Derive Macro
439// =============================================================================
440
441/// Derive macro for Input computation
442///
443/// Usage:
444/// ```rust
445/// #[derive(Input)]
446/// #[inc_complete(id = 0, output = i32, storage = MyStorage)]
447/// struct MyInput;
448/// ```
449///
450/// This generates a call to `define_input!(0, MyInput -> i32, MyStorage)`
451/// which provides the actual implementation.
452///
453/// Required attributes:
454/// - `id`: The unique computation ID (integer)
455/// - `output`: The output type
456/// - `storage`: The storage type
457///
458/// Optional attribute:
459/// - `assume_changed`: If present, indicates that the input is assumed to have changed
460#[proc_macro_derive(Input, attributes(inc_complete))]
461pub fn derive_input(input: TokenStream) -> TokenStream {
462    let input = parse_macro_input!(input as DeriveInput);
463
464    let struct_name = &input.ident;
465
466    // Find the inc_complete attribute
467    let inc_complete_attr = input
468        .attrs
469        .iter()
470        .find(|attr| attr.path().is_ident("inc_complete"));
471
472    let attr = match inc_complete_attr {
473        Some(attr) => attr,
474        None => {
475            return syn::Error::new(
476                Span::call_site(),
477                "Input derive requires #[inc_complete(...)] attribute",
478            )
479            .to_compile_error()
480            .into();
481        }
482    };
483
484    // Parse the attribute arguments using the Input-specific parser
485    let mut parser = InputAttributeParser::new();
486    if let Err(err) = parser.parse_attribute_list(attr) {
487        return err.to_compile_error().into();
488    }
489
490    let (id, output_type, storage_type, assume_changed) = match parser.validate(attr.span()) {
491        Ok(values) => values,
492        Err(err) => return err.to_compile_error().into(),
493    };
494
495    // Generate the define_input! call
496    let expanded = if assume_changed {
497        quote! {
498            inc_complete::define_input!(#id, assume_changed #struct_name -> #output_type, #storage_type);
499        }
500    } else {
501        quote! {
502            inc_complete::define_input!(#id, #struct_name -> #output_type, #storage_type);
503        }
504    };
505
506    TokenStream::from(expanded)
507}
508
509// =============================================================================
510// Intermediate Macros
511// =============================================================================
512
513/// Procedural macro for defining intermediate computations directly on functions
514///
515/// Usage:
516/// ```rust
517/// // Assumes ComputeDouble struct is already defined
518/// #[intermediate(id = 1)]
519/// fn compute_double(_context: &ComputeDouble, db: &DbHandle<MyStorage>) -> i32 {
520///     db.get(InputValue) * 2
521/// }
522/// ```
523///
524/// This generates a call to `define_intermediate!` using the extracted information.
525/// Assumes the computation struct is already defined.
526///
527/// Required attributes:
528/// - `id`: The computation ID (integer)
529///
530/// Optional attribute:
531/// - `assume_changed`: If present, indicates that the input is assumed to have changed
532///
533/// The macro automatically extracts:
534/// - **Output type** from the function return type (`i32`)
535/// - **Computation type** from the first parameter (`ComputeDouble` from `&ComputeDouble`)
536/// - **Storage type** from the second parameter (`MyStorage` from `&DbHandle<MyStorage>`)
537/// - **Function name** automatically (`compute_double`)
538///
539#[proc_macro_attribute]
540pub fn intermediate(args: TokenStream, input: TokenStream) -> TokenStream {
541    let args = parse_macro_input!(args with syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated);
542    let input_fn = parse_macro_input!(input as syn::ItemFn);
543
544    match process_intermediate_function(args, input_fn) {
545        Ok(tokens) => tokens.into(),
546        Err(err) => err.to_compile_error().into(),
547    }
548}
549
550// =============================================================================
551// Helper Functions for Intermediate Macros
552// =============================================================================
553
554/// Process the intermediate function and generate the appropriate code
555fn process_intermediate_function(
556    args: syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>,
557    input_fn: syn::ItemFn,
558) -> syn::Result<proc_macro2::TokenStream> {
559    // Parse the attributes using the Intermediate-specific parser
560    let mut parser = IntermediateAttributeParser::new();
561    for arg in args {
562        parser.parse_meta(&arg)?;
563    }
564    let (id, assume_changed) = parser.validate(Span::call_site())?;
565
566    // Extract information from the function signature
567    let fn_name = &input_fn.sig.ident;
568    let output_type = extract_return_type(&input_fn.sig.output)?;
569    let (computation_type, storage_type) = extract_types_from_params(&input_fn.sig.inputs)?;
570
571    // Generate the define_intermediate call, and optionally the struct
572    let expanded = if assume_changed {
573        quote! {
574            #input_fn
575
576            inc_complete::define_intermediate!(#id, assume_changed #computation_type -> #output_type, #storage_type, #fn_name);
577        }
578    } else {
579        quote! {
580            #input_fn
581
582            inc_complete::define_intermediate!(#id, #computation_type -> #output_type, #storage_type, #fn_name);
583        }
584    };
585
586    Ok(expanded)
587}
588
589/// Extract the return type from a function signature
590fn extract_return_type(output: &syn::ReturnType) -> syn::Result<proc_macro2::TokenStream> {
591    match output {
592        syn::ReturnType::Type(_, ty) => Ok(quote! { #ty }),
593        syn::ReturnType::Default => Err(syn::Error::new(
594            Span::call_site(),
595            "function must have an explicit return type",
596        )),
597    }
598}
599
600/// Extract both computation type and storage type from function parameters
601fn extract_types_from_params(
602    inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>,
603) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
604    let mut iter = inputs.iter();
605
606    // Extract computation type from first parameter
607    let first_arg = iter.next().ok_or_else(|| {
608        syn::Error::new(
609            Span::call_site(),
610            "function must have at least two parameters: (&ComputationType, &DbHandle<StorageType>)",
611        )
612    })?;
613
614    let computation_type = extract_reference_inner_type(
615        first_arg,
616        "first parameter must be a reference to the computation type (e.g., &ComputationType)",
617    )?;
618
619    // Extract storage type from second parameter
620    let second_arg = iter.next().ok_or_else(|| {
621        syn::Error::new(
622            Span::call_site(),
623            "function must have a second parameter of type &DbHandle<StorageType>",
624        )
625    })?;
626
627    let storage_type = extract_dbhandle_storage_type(second_arg)?;
628
629    Ok((computation_type, storage_type))
630}
631
632/// Extract inner type from a reference parameter
633fn extract_reference_inner_type(
634    arg: &syn::FnArg,
635    error_msg: &str,
636) -> syn::Result<proc_macro2::TokenStream> {
637    if let syn::FnArg::Typed(pat_type) = arg {
638        if let syn::Type::Reference(type_ref) = pat_type.ty.as_ref() {
639            let inner_type = &type_ref.elem;
640            Ok(quote! { #inner_type })
641        } else {
642            Err(syn::Error::new_spanned(arg, error_msg))
643        }
644    } else {
645        Err(syn::Error::new_spanned(arg, error_msg))
646    }
647}
648
649/// Extract storage type from &DbHandle<StorageType> parameter
650fn extract_dbhandle_storage_type(arg: &syn::FnArg) -> syn::Result<proc_macro2::TokenStream> {
651    let error_msg = "second parameter must be &DbHandle<StorageType>";
652
653    if let syn::FnArg::Typed(pat_type) = arg {
654        if let syn::Type::Reference(type_ref) = pat_type.ty.as_ref() {
655            if let syn::Type::Path(type_path) = type_ref.elem.as_ref() {
656                // Look for DbHandle<StorageType>
657                if let Some(segment) = type_path.path.segments.last() {
658                    if segment.ident == "DbHandle" {
659                        if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
660                            if let Some(syn::GenericArgument::Type(storage_ty)) = args.args.first()
661                            {
662                                return Ok(quote! { #storage_ty });
663                            }
664                        }
665                        return Err(syn::Error::new_spanned(
666                            arg,
667                            "DbHandle must have a generic type parameter for the storage type",
668                        ));
669                    }
670                }
671            }
672        }
673    }
674
675    Err(syn::Error::new_spanned(arg, error_msg))
676}