inc_complete_derive/
lib.rs

1//! Derive macro for automatically implementing Storage trait
2//! This generates a call to the impl_storage! macro based on the struct fields.
3
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::quote;
7use syn::{
8    Attribute, Data, DeriveInput, Expr, ExprLit, Fields, GenericArgument, Lit, Meta, PathArguments,
9    Type, TypePath, parse_macro_input, spanned::Spanned,
10};
11
12// =============================================================================
13// Helper Functions for Attribute Parsing
14// =============================================================================
15
16/// Parse a type attribute from a Meta::NameValue
17fn parse_type_attribute(
18    name_value: &syn::MetaNameValue,
19    expected_name: &str,
20) -> syn::Result<proc_macro2::TokenStream> {
21    let name = name_value
22        .path
23        .get_ident()
24        .ok_or_else(|| syn::Error::new_spanned(&name_value.path, "expected identifier"))?
25        .to_string();
26
27    if name != expected_name {
28        return Err(syn::Error::new_spanned(
29            &name_value.path,
30            format!("expected '{}', found '{}'", expected_name, name),
31        ));
32    }
33
34    let expr = &name_value.value;
35    Ok(quote::quote! { #expr })
36}
37
38/// Parse an integer attribute from a Meta::NameValue
39fn parse_int_attribute(
40    name_value: &syn::MetaNameValue,
41    expected_name: &str,
42) -> syn::Result<proc_macro2::TokenStream> {
43    let name = name_value
44        .path
45        .get_ident()
46        .ok_or_else(|| syn::Error::new_spanned(&name_value.path, "expected identifier"))?
47        .to_string();
48
49    if name != expected_name {
50        return Err(syn::Error::new_spanned(
51            &name_value.path,
52            format!("expected '{}', found '{}'", expected_name, name),
53        ));
54    }
55
56    if let Expr::Lit(ExprLit {
57        lit: Lit::Int(lit_int),
58        ..
59    }) = &name_value.value
60    {
61        Ok(quote! { #lit_int })
62    } else {
63        Err(syn::Error::new_spanned(
64            &name_value.value,
65            format!("{} must be an integer", expected_name),
66        ))
67    }
68}
69
70/// Parse a boolean flag attribute from a Meta::Path
71fn parse_flag_attribute(path: &syn::Path, expected_name: &str) -> syn::Result<bool> {
72    if path.is_ident(expected_name) {
73        Ok(true)
74    } else {
75        Err(syn::Error::new_spanned(
76            path,
77            format!("unknown flag attribute, expected '{}'", expected_name),
78        ))
79    }
80}
81
82/// Extract generic type T from Container<T>
83fn extract_generic_type(ty: &Type, container_suffix: &str) -> Option<Type> {
84    if let Type::Path(TypePath { path, .. }) = ty {
85        if let Some(segment) = path.segments.last() {
86            if segment.ident.to_string().ends_with(container_suffix) {
87                if let PathArguments::AngleBracketed(args) = &segment.arguments {
88                    if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
89                        return Some(inner_type.clone());
90                    }
91                }
92            }
93        }
94    }
95    None
96}
97
98/// Parser specifically for Input derive attributes
99struct InputAttributeParser {
100    id: Option<proc_macro2::TokenStream>,
101    output_type: Option<proc_macro2::TokenStream>,
102    storage_type: Option<proc_macro2::TokenStream>,
103    assume_changed: bool,
104}
105
106impl InputAttributeParser {
107    fn new() -> Self {
108        Self {
109            id: None,
110            output_type: None,
111            storage_type: None,
112            assume_changed: false,
113        }
114    }
115
116    fn parse_attribute_list(&mut self, attr: &Attribute) -> syn::Result<()> {
117        let meta = attr.meta.clone();
118        if let Meta::List(meta_list) = meta {
119            let parsed: syn::punctuated::Punctuated<Meta, syn::Token![,]> =
120                meta_list.parse_args_with(syn::punctuated::Punctuated::parse_terminated)?;
121
122            for meta in parsed {
123                match meta {
124                    Meta::Path(path) => {
125                        if parse_flag_attribute(&path, "assume_changed")? {
126                            self.assume_changed = true;
127                        }
128                    }
129                    Meta::NameValue(name_value) => {
130                        let name = name_value
131                            .path
132                            .get_ident()
133                            .ok_or_else(|| {
134                                syn::Error::new_spanned(&name_value.path, "expected identifier")
135                            })?
136                            .to_string();
137
138                        match name.as_str() {
139                            "id" => {
140                                self.id = Some(parse_int_attribute(&name_value, "id")?);
141                            }
142                            "output" => {
143                                self.output_type =
144                                    Some(parse_type_attribute(&name_value, "output")?);
145                            }
146                            "storage" => {
147                                self.storage_type =
148                                    Some(parse_type_attribute(&name_value, "storage")?);
149                            }
150                            _ => {
151                                return Err(syn::Error::new_spanned(
152                                    &name_value.path,
153                                    format!(
154                                        "unknown attribute '{}' for Input derive. Valid attributes are: id, output, storage, assume_changed",
155                                        name
156                                    ),
157                                ));
158                            }
159                        }
160                    }
161                    _ => {
162                        return Err(syn::Error::new_spanned(
163                            &meta,
164                            "unsupported attribute format",
165                        ));
166                    }
167                }
168            }
169        }
170        Ok(())
171    }
172
173    fn validate(
174        &self,
175        attr_span: Span,
176    ) -> syn::Result<(
177        proc_macro2::TokenStream,
178        proc_macro2::TokenStream,
179        proc_macro2::TokenStream,
180        bool,
181    )> {
182        let id = self
183            .id
184            .clone()
185            .ok_or_else(|| syn::Error::new(attr_span, "missing required 'id' attribute"))?;
186        let output_type = self
187            .output_type
188            .clone()
189            .ok_or_else(|| syn::Error::new(attr_span, "missing required 'output' attribute"))?;
190        let storage_type = self
191            .storage_type
192            .clone()
193            .ok_or_else(|| syn::Error::new(attr_span, "missing required 'storage' attribute"))?;
194
195        Ok((id, output_type, storage_type, self.assume_changed))
196    }
197}
198
199/// Parser specifically for Intermediate macro attributes
200struct IntermediateAttributeParser {
201    id: Option<proc_macro2::TokenStream>,
202    assume_changed: bool,
203}
204
205impl IntermediateAttributeParser {
206    fn new() -> Self {
207        Self {
208            id: None,
209            assume_changed: false,
210        }
211    }
212
213    fn parse_meta(&mut self, meta: &Meta) -> syn::Result<()> {
214        match meta {
215            Meta::Path(path) => {
216                if parse_flag_attribute(path, "assume_changed")? {
217                    self.assume_changed = true;
218                }
219            }
220            Meta::NameValue(name_value) => {
221                let name = name_value
222                    .path
223                    .get_ident()
224                    .ok_or_else(|| {
225                        syn::Error::new_spanned(&name_value.path, "expected identifier")
226                    })?
227                    .to_string();
228
229                match name.as_str() {
230                    "id" => {
231                        self.id = Some(parse_int_attribute(name_value, "id")?);
232                    }
233                    _ => {
234                        return Err(syn::Error::new_spanned(
235                            &name_value.path,
236                            format!(
237                                "unknown attribute '{}' for intermediate macro. Valid attributes are: id, assume_changed",
238                                name
239                            ),
240                        ));
241                    }
242                }
243            }
244            _ => {
245                return Err(syn::Error::new_spanned(
246                    meta,
247                    "unsupported attribute format",
248                ));
249            }
250        }
251        Ok(())
252    }
253
254    fn validate(&self, attr_span: Span) -> syn::Result<(proc_macro2::TokenStream, bool)> {
255        let id = self
256            .id
257            .clone()
258            .ok_or_else(|| syn::Error::new(attr_span, "missing required 'id' attribute"))?;
259
260        Ok((id, self.assume_changed))
261    }
262}
263
264// =============================================================================
265// Storage Derive Macro
266// =============================================================================
267
268/// Derive macro for Storage trait
269///
270/// Usage:
271/// ```rust
272/// #[derive(Storage)]
273/// struct MyStorage {
274///     numbers: SingletonStorage<Number>,
275///     strings: HashMapStorage<StringComputation>,
276/// }
277/// ```
278///
279/// This generates a call to `impl_storage!(MyStorage, numbers: Number, strings: StringComputation)`
280/// which provides the actual implementation.
281///
282/// ## Skip Attribute
283///
284/// Fields can be excluded from the Storage implementation using `#[inc_complete(skip)]`:
285///
286/// ```rust
287/// #[derive(Storage)]
288/// struct MyStorage {
289///     numbers: SingletonStorage<Number>,
290///
291///     #[inc_complete(skip)]
292///     metadata: String,  // This field won't be included in Storage implementation
293/// }
294///
295/// ## Computation Attribute
296///
297/// The computation type for a field can be manually specified using `#[inc_complete(computation = Type)]`.
298/// This is required if the field type is not already generic over the computation type. For
299/// example, `HashMapStorage<MyComputationType>` is generic over `MyComputationType` but
300/// `MyStringStorage` is not:
301///
302/// ```rust
303/// #[derive(Storage)]
304/// struct MyStorage {
305///     numbers: SingletonStorage<Number>,
306///
307///     #[inc_complete(computation = MyStringInput)]
308///     strings: MyStringStorage,
309/// }
310/// ```
311#[proc_macro_derive(Storage, attributes(inc_complete))]
312pub fn derive_storage(input: TokenStream) -> TokenStream {
313    let input = parse_macro_input!(input as DeriveInput);
314
315    let struct_name = &input.ident;
316
317    // Extract field information
318    let fields = match &input.data {
319        Data::Struct(data) => match &data.fields {
320            Fields::Named(fields) => &fields.named,
321            _ => {
322                return syn::Error::new(
323                    Span::call_site(),
324                    "Storage derive only works on structs with named fields",
325                )
326                .to_compile_error()
327                .into();
328            }
329        },
330        _ => {
331            return syn::Error::new(Span::call_site(), "Storage derive only works on structs")
332                .to_compile_error()
333                .into();
334        }
335    };
336
337    // Extract computation types from storage fields
338    let mut field_mappings = Vec::new();
339
340    for field in fields {
341        let field_name = field.ident.as_ref().unwrap();
342        let field_type = &field.ty;
343
344        // Parse attributes
345        let mut skip_field = false;
346        let mut manual_computation_type: Option<Type> = None;
347
348        for attr in &field.attrs {
349            if attr.path().is_ident("inc_complete") {
350                match attr.meta {
351                    Meta::List(ref list) => {
352                        // Parse nested attributes like #[inc_complete(skip)] or #[inc_complete(computation = Type)]
353                        let nested_result =
354                            list.parse_args_with(|parser: syn::parse::ParseStream| {
355                                let mut found_skip = false;
356                                let mut found_computation_type: Option<Type> = None;
357
358                                while !parser.is_empty() {
359                                    let lookahead = parser.lookahead1();
360                                    if lookahead.peek(syn::Ident) {
361                                        let ident: syn::Ident = parser.parse()?;
362                                        if ident == "skip" {
363                                            found_skip = true;
364                                        } else if ident == "computation" {
365                                            parser.parse::<syn::Token![=]>()?;
366                                            found_computation_type = Some(parser.parse()?);
367                                        } else {
368                                            return Err(syn::Error::new_spanned(
369                                                ident,
370                                                "expected 'skip' or 'computation'",
371                                            ));
372                                        }
373                                    } else {
374                                        return Err(lookahead.error());
375                                    }
376
377                                    if !parser.is_empty() {
378                                        parser.parse::<syn::Token![,]>()?;
379                                    }
380                                }
381
382                                Ok((found_skip, found_computation_type))
383                            });
384
385                        match nested_result {
386                            Ok((found_skip, found_computation_type)) => {
387                                if found_skip {
388                                    skip_field = true;
389                                }
390                                if let Some(computation_type) = found_computation_type {
391                                    manual_computation_type = Some(computation_type);
392                                }
393                            }
394                            Err(e) => {
395                                return e.to_compile_error().into();
396                            }
397                        }
398                    }
399                    _ => {
400                        return syn::Error::new_spanned(
401                            attr,
402                            "expected #[inc_complete(skip)] or #[inc_complete(computation = Type)]",
403                        )
404                        .to_compile_error()
405                        .into();
406                    }
407                }
408            }
409        }
410
411        if skip_field {
412            // Skip this field - don't include it in the impl_storage! call
413            continue;
414        }
415
416        // Determine the computation type
417        let computation_type = if let Some(manual_type) = manual_computation_type {
418            // Use manually specified type
419            manual_type
420        } else if let Some(extracted_type) = extract_generic_type(field_type, "Storage") {
421            // Try to extract the generic type from SingletonStorage<T>, HashMapStorage<T>, etc.
422            extracted_type
423        } else {
424            return syn::Error::new(
425                field.span(),
426                "Field must be a storage type like SingletonStorage<T>, HashMapStorage<T>, or use #[inc_complete(computation = Type)] to specify the type manually, or use #[inc_complete(skip)] to exclude it",
427            )
428            .to_compile_error()
429            .into();
430        };
431
432        field_mappings.push(quote! { #field_name: #computation_type });
433    }
434
435    // Generate a call to impl_storage! macro
436    let expanded = quote! {
437        inc_complete::impl_storage!(#struct_name, #(#field_mappings),*);
438    };
439
440    TokenStream::from(expanded)
441}
442
443// =============================================================================
444// Input Derive Macro
445// =============================================================================
446
447/// Derive macro for Input computation
448///
449/// Usage:
450/// ```rust
451/// #[derive(Input)]
452/// #[inc_complete(id = 0, output = i32, storage = MyStorage)]
453/// struct MyInput;
454/// ```
455///
456/// This generates a call to `define_input!(0, MyInput -> i32, MyStorage)`
457/// which provides the actual implementation.
458///
459/// Required attributes:
460/// - `id`: The unique computation ID (integer)
461/// - `output`: The output type
462/// - `storage`: The storage type
463///
464/// Optional attribute:
465/// - `assume_changed`: If present, indicates that the input is assumed to have changed
466#[proc_macro_derive(Input, attributes(inc_complete))]
467pub fn derive_input(input: TokenStream) -> TokenStream {
468    let input = parse_macro_input!(input as DeriveInput);
469
470    let struct_name = &input.ident;
471
472    // Find the inc_complete attribute
473    let inc_complete_attr = input
474        .attrs
475        .iter()
476        .find(|attr| attr.path().is_ident("inc_complete"));
477
478    let attr = match inc_complete_attr {
479        Some(attr) => attr,
480        None => {
481            return syn::Error::new(
482                Span::call_site(),
483                "Input derive requires #[inc_complete(...)] attribute",
484            )
485            .to_compile_error()
486            .into();
487        }
488    };
489
490    // Parse the attribute arguments using the Input-specific parser
491    let mut parser = InputAttributeParser::new();
492    if let Err(err) = parser.parse_attribute_list(attr) {
493        return err.to_compile_error().into();
494    }
495
496    let (id, output_type, storage_type, assume_changed) = match parser.validate(attr.span()) {
497        Ok(values) => values,
498        Err(err) => return err.to_compile_error().into(),
499    };
500
501    // Generate the define_input! call
502    let expanded = if assume_changed {
503        quote! {
504            inc_complete::define_input!(#id, assume_changed #struct_name -> #output_type, #storage_type);
505        }
506    } else {
507        quote! {
508            inc_complete::define_input!(#id, #struct_name -> #output_type, #storage_type);
509        }
510    };
511
512    TokenStream::from(expanded)
513}
514
515// =============================================================================
516// Intermediate Macros
517// =============================================================================
518
519/// Procedural macro for defining intermediate computations directly on functions
520///
521/// Usage:
522/// ```rust
523/// // Assumes ComputeDouble struct is already defined
524/// #[intermediate(id = 1)]
525/// fn compute_double(_context: &ComputeDouble, db: &DbHandle<MyStorage>) -> i32 {
526///     db.get(InputValue) * 2
527/// }
528/// ```
529///
530/// This generates a call to `define_intermediate!` using the extracted information.
531/// Assumes the computation struct is already defined.
532///
533/// Required attributes:
534/// - `id`: The computation ID (integer)
535///
536/// Optional attribute:
537/// - `assume_changed`: If present, indicates that the input is assumed to have changed
538///
539/// The macro automatically extracts:
540/// - **Output type** from the function return type (`i32`)
541/// - **Computation type** from the first parameter (`ComputeDouble` from `&ComputeDouble`)
542/// - **Storage type** from the second parameter (`MyStorage` from `&DbHandle<MyStorage>`)
543/// - **Function name** automatically (`compute_double`)
544///
545#[proc_macro_attribute]
546pub fn intermediate(args: TokenStream, input: TokenStream) -> TokenStream {
547    let args = parse_macro_input!(args with syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated);
548    let input_fn = parse_macro_input!(input as syn::ItemFn);
549
550    match process_intermediate_function(args, input_fn) {
551        Ok(tokens) => tokens.into(),
552        Err(err) => err.to_compile_error().into(),
553    }
554}
555
556// =============================================================================
557// Helper Functions for Intermediate Macros
558// =============================================================================
559
560/// Process the intermediate function and generate the appropriate code
561fn process_intermediate_function(
562    args: syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>,
563    input_fn: syn::ItemFn,
564) -> syn::Result<proc_macro2::TokenStream> {
565    // Parse the attributes using the Intermediate-specific parser
566    let mut parser = IntermediateAttributeParser::new();
567    for arg in args {
568        parser.parse_meta(&arg)?;
569    }
570    let (id, assume_changed) = parser.validate(Span::call_site())?;
571
572    // Extract information from the function signature
573    let fn_name = &input_fn.sig.ident;
574    let output_type = extract_return_type(&input_fn.sig.output)?;
575    let (computation_type, storage_type) = extract_types_from_params(&input_fn.sig.inputs)?;
576
577    // Generate the define_intermediate call, and optionally the struct
578    let expanded = if assume_changed {
579        quote! {
580            #input_fn
581
582            inc_complete::define_intermediate!(#id, assume_changed #computation_type -> #output_type, #storage_type, #fn_name);
583        }
584    } else {
585        quote! {
586            #input_fn
587
588            inc_complete::define_intermediate!(#id, #computation_type -> #output_type, #storage_type, #fn_name);
589        }
590    };
591
592    Ok(expanded)
593}
594
595/// Extract the return type from a function signature
596fn extract_return_type(output: &syn::ReturnType) -> syn::Result<proc_macro2::TokenStream> {
597    match output {
598        syn::ReturnType::Type(_, ty) => Ok(quote! { #ty }),
599        syn::ReturnType::Default => Err(syn::Error::new(
600            Span::call_site(),
601            "function must have an explicit return type",
602        )),
603    }
604}
605
606/// Extract both computation type and storage type from function parameters
607fn extract_types_from_params(
608    inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>,
609) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
610    let mut iter = inputs.iter();
611
612    // Extract computation type from first parameter
613    let first_arg = iter.next().ok_or_else(|| {
614        syn::Error::new(
615            Span::call_site(),
616            "function must have at least two parameters: (&ComputationType, &DbHandle<StorageType>)",
617        )
618    })?;
619
620    let computation_type = extract_reference_inner_type(
621        first_arg,
622        "first parameter must be a reference to the computation type (e.g., &ComputationType)",
623    )?;
624
625    // Extract storage type from second parameter
626    let second_arg = iter.next().ok_or_else(|| {
627        syn::Error::new(
628            Span::call_site(),
629            "function must have a second parameter of type &DbHandle<StorageType>",
630        )
631    })?;
632
633    let storage_type = extract_dbhandle_storage_type(second_arg)?;
634
635    Ok((computation_type, storage_type))
636}
637
638/// Extract inner type from a reference parameter
639fn extract_reference_inner_type(
640    arg: &syn::FnArg,
641    error_msg: &str,
642) -> syn::Result<proc_macro2::TokenStream> {
643    if let syn::FnArg::Typed(pat_type) = arg {
644        if let syn::Type::Reference(type_ref) = pat_type.ty.as_ref() {
645            let inner_type = &type_ref.elem;
646            Ok(quote! { #inner_type })
647        } else {
648            Err(syn::Error::new_spanned(arg, error_msg))
649        }
650    } else {
651        Err(syn::Error::new_spanned(arg, error_msg))
652    }
653}
654
655/// Extract storage type from &DbHandle<StorageType> parameter
656fn extract_dbhandle_storage_type(arg: &syn::FnArg) -> syn::Result<proc_macro2::TokenStream> {
657    let error_msg = "second parameter must be &DbHandle<StorageType>";
658
659    if let syn::FnArg::Typed(pat_type) = arg {
660        if let syn::Type::Reference(type_ref) = pat_type.ty.as_ref() {
661            if let syn::Type::Path(type_path) = type_ref.elem.as_ref() {
662                // Look for DbHandle<StorageType>
663                if let Some(segment) = type_path.path.segments.last() {
664                    if segment.ident == "DbHandle" {
665                        if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
666                            if let Some(syn::GenericArgument::Type(storage_ty)) = args.args.first()
667                            {
668                                return Ok(quote! { #storage_ty });
669                            }
670                        }
671                        return Err(syn::Error::new_spanned(
672                            arg,
673                            "DbHandle must have a generic type parameter for the storage type",
674                        ));
675                    }
676                }
677            }
678        }
679    }
680
681    Err(syn::Error::new_spanned(arg, error_msg))
682}