1use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::quote;
7use syn::{
8 Attribute, Data, DeriveInput, Expr, ExprLit, Fields, GenericArgument, Lit, Meta, PathArguments,
9 Type, TypePath, parse_macro_input, spanned::Spanned,
10};
11
12fn parse_type_attribute(
18 name_value: &syn::MetaNameValue,
19 expected_name: &str,
20) -> syn::Result<proc_macro2::TokenStream> {
21 let name = name_value
22 .path
23 .get_ident()
24 .ok_or_else(|| syn::Error::new_spanned(&name_value.path, "expected identifier"))?
25 .to_string();
26
27 if name != expected_name {
28 return Err(syn::Error::new_spanned(
29 &name_value.path,
30 format!("expected '{}', found '{}'", expected_name, name),
31 ));
32 }
33
34 let expr = &name_value.value;
35 Ok(quote::quote! { #expr })
36}
37
38fn parse_int_attribute(
40 name_value: &syn::MetaNameValue,
41 expected_name: &str,
42) -> syn::Result<proc_macro2::TokenStream> {
43 let name = name_value
44 .path
45 .get_ident()
46 .ok_or_else(|| syn::Error::new_spanned(&name_value.path, "expected identifier"))?
47 .to_string();
48
49 if name != expected_name {
50 return Err(syn::Error::new_spanned(
51 &name_value.path,
52 format!("expected '{}', found '{}'", expected_name, name),
53 ));
54 }
55
56 if let Expr::Lit(ExprLit {
57 lit: Lit::Int(lit_int),
58 ..
59 }) = &name_value.value
60 {
61 Ok(quote! { #lit_int })
62 } else {
63 Err(syn::Error::new_spanned(
64 &name_value.value,
65 format!("{} must be an integer", expected_name),
66 ))
67 }
68}
69
70fn parse_flag_attribute(path: &syn::Path, expected_name: &str) -> syn::Result<bool> {
72 if path.is_ident(expected_name) {
73 Ok(true)
74 } else {
75 Err(syn::Error::new_spanned(
76 path,
77 format!("unknown flag attribute, expected '{}'", expected_name),
78 ))
79 }
80}
81
82fn extract_generic_type(ty: &Type, container_suffix: &str) -> Option<Type> {
84 if let Type::Path(TypePath { path, .. }) = ty {
85 if let Some(segment) = path.segments.last() {
86 if segment.ident.to_string().ends_with(container_suffix) {
87 if let PathArguments::AngleBracketed(args) = &segment.arguments {
88 if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
89 return Some(inner_type.clone());
90 }
91 }
92 }
93 }
94 }
95 None
96}
97
98struct InputAttributeParser {
100 id: Option<proc_macro2::TokenStream>,
101 output_type: Option<proc_macro2::TokenStream>,
102 storage_type: Option<proc_macro2::TokenStream>,
103 assume_changed: bool,
104}
105
106impl InputAttributeParser {
107 fn new() -> Self {
108 Self {
109 id: None,
110 output_type: None,
111 storage_type: None,
112 assume_changed: false,
113 }
114 }
115
116 fn parse_attribute_list(&mut self, attr: &Attribute) -> syn::Result<()> {
117 let meta = attr.meta.clone();
118 if let Meta::List(meta_list) = meta {
119 let parsed: syn::punctuated::Punctuated<Meta, syn::Token![,]> =
120 meta_list.parse_args_with(syn::punctuated::Punctuated::parse_terminated)?;
121
122 for meta in parsed {
123 match meta {
124 Meta::Path(path) => {
125 if parse_flag_attribute(&path, "assume_changed")? {
126 self.assume_changed = true;
127 }
128 }
129 Meta::NameValue(name_value) => {
130 let name = name_value
131 .path
132 .get_ident()
133 .ok_or_else(|| {
134 syn::Error::new_spanned(&name_value.path, "expected identifier")
135 })?
136 .to_string();
137
138 match name.as_str() {
139 "id" => {
140 self.id = Some(parse_int_attribute(&name_value, "id")?);
141 }
142 "output" => {
143 self.output_type =
144 Some(parse_type_attribute(&name_value, "output")?);
145 }
146 "storage" => {
147 self.storage_type =
148 Some(parse_type_attribute(&name_value, "storage")?);
149 }
150 _ => {
151 return Err(syn::Error::new_spanned(
152 &name_value.path,
153 format!(
154 "unknown attribute '{}' for Input derive. Valid attributes are: id, output, storage, assume_changed",
155 name
156 ),
157 ));
158 }
159 }
160 }
161 _ => {
162 return Err(syn::Error::new_spanned(
163 &meta,
164 "unsupported attribute format",
165 ));
166 }
167 }
168 }
169 }
170 Ok(())
171 }
172
173 fn validate(
174 &self,
175 attr_span: Span,
176 ) -> syn::Result<(
177 proc_macro2::TokenStream,
178 proc_macro2::TokenStream,
179 proc_macro2::TokenStream,
180 bool,
181 )> {
182 let id = self
183 .id
184 .clone()
185 .ok_or_else(|| syn::Error::new(attr_span, "missing required 'id' attribute"))?;
186 let output_type = self
187 .output_type
188 .clone()
189 .ok_or_else(|| syn::Error::new(attr_span, "missing required 'output' attribute"))?;
190 let storage_type = self
191 .storage_type
192 .clone()
193 .ok_or_else(|| syn::Error::new(attr_span, "missing required 'storage' attribute"))?;
194
195 Ok((id, output_type, storage_type, self.assume_changed))
196 }
197}
198
199struct IntermediateAttributeParser {
201 id: Option<proc_macro2::TokenStream>,
202 assume_changed: bool,
203}
204
205impl IntermediateAttributeParser {
206 fn new() -> Self {
207 Self {
208 id: None,
209 assume_changed: false,
210 }
211 }
212
213 fn parse_meta(&mut self, meta: &Meta) -> syn::Result<()> {
214 match meta {
215 Meta::Path(path) => {
216 if parse_flag_attribute(path, "assume_changed")? {
217 self.assume_changed = true;
218 }
219 }
220 Meta::NameValue(name_value) => {
221 let name = name_value
222 .path
223 .get_ident()
224 .ok_or_else(|| {
225 syn::Error::new_spanned(&name_value.path, "expected identifier")
226 })?
227 .to_string();
228
229 match name.as_str() {
230 "id" => {
231 self.id = Some(parse_int_attribute(name_value, "id")?);
232 }
233 _ => {
234 return Err(syn::Error::new_spanned(
235 &name_value.path,
236 format!(
237 "unknown attribute '{}' for intermediate macro. Valid attributes are: id, assume_changed",
238 name
239 ),
240 ));
241 }
242 }
243 }
244 _ => {
245 return Err(syn::Error::new_spanned(
246 meta,
247 "unsupported attribute format",
248 ));
249 }
250 }
251 Ok(())
252 }
253
254 fn validate(&self, attr_span: Span) -> syn::Result<(proc_macro2::TokenStream, bool)> {
255 let id = self
256 .id
257 .clone()
258 .ok_or_else(|| syn::Error::new(attr_span, "missing required 'id' attribute"))?;
259
260 Ok((id, self.assume_changed))
261 }
262}
263
264#[proc_macro_derive(Storage, attributes(inc_complete))]
312pub fn derive_storage(input: TokenStream) -> TokenStream {
313 let input = parse_macro_input!(input as DeriveInput);
314
315 let struct_name = &input.ident;
316
317 let fields = match &input.data {
319 Data::Struct(data) => match &data.fields {
320 Fields::Named(fields) => &fields.named,
321 _ => {
322 return syn::Error::new(
323 Span::call_site(),
324 "Storage derive only works on structs with named fields",
325 )
326 .to_compile_error()
327 .into();
328 }
329 },
330 _ => {
331 return syn::Error::new(Span::call_site(), "Storage derive only works on structs")
332 .to_compile_error()
333 .into();
334 }
335 };
336
337 let mut field_mappings = Vec::new();
339 let mut accumulated = Vec::new();
340
341 for field in fields {
342 let field_name = field.ident.as_ref().unwrap();
343 let field_type = &field.ty;
344
345 let mut skip_field = false;
347 let mut is_accumulated = false;
348 let mut manual_computation_type: Option<Type> = None;
349
350 for attr in &field.attrs {
351 if attr.path().is_ident("inc_complete") {
352 match attr.meta {
353 Meta::List(ref list) => {
354 let nested_result =
356 list.parse_args_with(|parser: syn::parse::ParseStream| {
357 while !parser.is_empty() {
358 let lookahead = parser.lookahead1();
359 if lookahead.peek(syn::Ident) {
360 let ident: syn::Ident = parser.parse()?;
361 if ident == "skip" {
362 skip_field = true;
363 } else if ident == "computation" {
364 parser.parse::<syn::Token![=]>()?;
365 manual_computation_type = Some(parser.parse()?);
366 } else if ident == "accumulate" {
367 is_accumulated = true;
368 } else {
369 return Err(syn::Error::new_spanned(
370 ident,
371 "expected 'skip' or 'computation'",
372 ));
373 }
374 } else {
375 return Err(lookahead.error());
376 }
377
378 if !parser.is_empty() {
379 parser.parse::<syn::Token![,]>()?;
380 }
381 }
382
383 Ok(())
384 });
385
386 if let Err(e) = nested_result {
387 return e.to_compile_error().into();
388 }
389 }
390 _ => {
391 return syn::Error::new_spanned(
392 attr,
393 "expected #[inc_complete(skip)] or #[inc_complete(computation = Type)]",
394 )
395 .to_compile_error()
396 .into();
397 }
398 }
399 }
400 }
401
402 if skip_field {
403 continue;
405 }
406
407 let computation_type = if let Some(manual_type) = manual_computation_type {
409 manual_type
411 } else if let Some(extracted_type) = extract_generic_type(field_type, "Storage") {
412 extracted_type
414 } else {
415 return syn::Error::new(
416 field.span(),
417 "Field must be a storage type like SingletonStorage<T>, HashMapStorage<T>, or use #[inc_complete(computation = Type)] to specify the type manually, or use #[inc_complete(skip)] to exclude it",
418 )
419 .to_compile_error()
420 .into();
421 };
422
423 let item = quote! { #field_name: #computation_type, };
424 if is_accumulated {
425 accumulated.push(item);
426 } else {
427 field_mappings.push(item);
428 }
429 }
430
431 let expanded = quote! {
433 inc_complete::impl_storage!(#struct_name, #(#field_mappings)* @accumulators { #(#accumulated)* });
434 };
435
436 TokenStream::from(expanded)
437}
438
439#[proc_macro_derive(Input, attributes(inc_complete))]
463pub fn derive_input(input: TokenStream) -> TokenStream {
464 let input = parse_macro_input!(input as DeriveInput);
465
466 let struct_name = &input.ident;
467
468 let inc_complete_attr = input
470 .attrs
471 .iter()
472 .find(|attr| attr.path().is_ident("inc_complete"));
473
474 let attr = match inc_complete_attr {
475 Some(attr) => attr,
476 None => {
477 return syn::Error::new(
478 Span::call_site(),
479 "Input derive requires #[inc_complete(...)] attribute",
480 )
481 .to_compile_error()
482 .into();
483 }
484 };
485
486 let mut parser = InputAttributeParser::new();
488 if let Err(err) = parser.parse_attribute_list(attr) {
489 return err.to_compile_error().into();
490 }
491
492 let (id, output_type, storage_type, assume_changed) = match parser.validate(attr.span()) {
493 Ok(values) => values,
494 Err(err) => return err.to_compile_error().into(),
495 };
496
497 let expanded = if assume_changed {
499 quote! {
500 inc_complete::define_input!(#id, assume_changed #struct_name -> #output_type, #storage_type);
501 }
502 } else {
503 quote! {
504 inc_complete::define_input!(#id, #struct_name -> #output_type, #storage_type);
505 }
506 };
507
508 TokenStream::from(expanded)
509}
510
511#[proc_macro_attribute]
542pub fn intermediate(args: TokenStream, input: TokenStream) -> TokenStream {
543 let args = parse_macro_input!(args with syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated);
544 let input_fn = parse_macro_input!(input as syn::ItemFn);
545
546 match process_intermediate_function(args, input_fn) {
547 Ok(tokens) => tokens.into(),
548 Err(err) => err.to_compile_error().into(),
549 }
550}
551
552fn process_intermediate_function(
558 args: syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>,
559 input_fn: syn::ItemFn,
560) -> syn::Result<proc_macro2::TokenStream> {
561 let mut parser = IntermediateAttributeParser::new();
563 for arg in args {
564 parser.parse_meta(&arg)?;
565 }
566 let (id, assume_changed) = parser.validate(Span::call_site())?;
567
568 let fn_name = &input_fn.sig.ident;
570 let output_type = extract_return_type(&input_fn.sig.output)?;
571 let (computation_type, storage_type) = extract_types_from_params(&input_fn.sig.inputs)?;
572
573 let expanded = if assume_changed {
575 quote! {
576 #input_fn
577
578 inc_complete::define_intermediate!(#id, assume_changed #computation_type -> #output_type, #storage_type, #fn_name);
579 }
580 } else {
581 quote! {
582 #input_fn
583
584 inc_complete::define_intermediate!(#id, #computation_type -> #output_type, #storage_type, #fn_name);
585 }
586 };
587
588 Ok(expanded)
589}
590
591fn extract_return_type(output: &syn::ReturnType) -> syn::Result<proc_macro2::TokenStream> {
593 match output {
594 syn::ReturnType::Type(_, ty) => Ok(quote! { #ty }),
595 syn::ReturnType::Default => Err(syn::Error::new(
596 Span::call_site(),
597 "function must have an explicit return type",
598 )),
599 }
600}
601
602fn extract_types_from_params(
604 inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>,
605) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
606 let mut iter = inputs.iter();
607
608 let first_arg = iter.next().ok_or_else(|| {
610 syn::Error::new(
611 Span::call_site(),
612 "function must have at least two parameters: (&ComputationType, &DbHandle<StorageType>)",
613 )
614 })?;
615
616 let computation_type = extract_reference_inner_type(
617 first_arg,
618 "first parameter must be a reference to the computation type (e.g., &ComputationType)",
619 )?;
620
621 let second_arg = iter.next().ok_or_else(|| {
623 syn::Error::new(
624 Span::call_site(),
625 "function must have a second parameter of type &DbHandle<StorageType>",
626 )
627 })?;
628
629 let storage_type = extract_dbhandle_storage_type(second_arg)?;
630
631 Ok((computation_type, storage_type))
632}
633
634fn extract_reference_inner_type(
636 arg: &syn::FnArg,
637 error_msg: &str,
638) -> syn::Result<proc_macro2::TokenStream> {
639 if let syn::FnArg::Typed(pat_type) = arg {
640 if let syn::Type::Reference(type_ref) = pat_type.ty.as_ref() {
641 let inner_type = &type_ref.elem;
642 Ok(quote! { #inner_type })
643 } else {
644 Err(syn::Error::new_spanned(arg, error_msg))
645 }
646 } else {
647 Err(syn::Error::new_spanned(arg, error_msg))
648 }
649}
650
651fn extract_dbhandle_storage_type(arg: &syn::FnArg) -> syn::Result<proc_macro2::TokenStream> {
653 let error_msg = "second parameter must be &DbHandle<StorageType>";
654
655 if let syn::FnArg::Typed(pat_type) = arg {
656 if let syn::Type::Reference(type_ref) = pat_type.ty.as_ref() {
657 if let syn::Type::Path(type_path) = type_ref.elem.as_ref() {
658 if let Some(segment) = type_path.path.segments.last() {
660 if segment.ident == "DbHandle" {
661 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
662 if let Some(syn::GenericArgument::Type(storage_ty)) = args.args.first()
663 {
664 return Ok(quote! { #storage_ty });
665 }
666 }
667 return Err(syn::Error::new_spanned(
668 arg,
669 "DbHandle must have a generic type parameter for the storage type",
670 ));
671 }
672 }
673 }
674 }
675 }
676
677 Err(syn::Error::new_spanned(arg, error_msg))
678}