1use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::quote;
7use syn::{
8 Attribute, Data, DeriveInput, Expr, ExprLit, Fields, GenericArgument, Lit, Meta, PathArguments,
9 Type, TypePath, parse_macro_input, spanned::Spanned,
10};
11
12fn parse_type_attribute(
18 name_value: &syn::MetaNameValue,
19 expected_name: &str,
20) -> syn::Result<proc_macro2::TokenStream> {
21 let name = name_value
22 .path
23 .get_ident()
24 .ok_or_else(|| syn::Error::new_spanned(&name_value.path, "expected identifier"))?
25 .to_string();
26
27 if name != expected_name {
28 return Err(syn::Error::new_spanned(
29 &name_value.path,
30 format!("expected '{}', found '{}'", expected_name, name),
31 ));
32 }
33
34 let expr = &name_value.value;
35 Ok(quote::quote! { #expr })
36}
37
38fn parse_int_attribute(
40 name_value: &syn::MetaNameValue,
41 expected_name: &str,
42) -> syn::Result<proc_macro2::TokenStream> {
43 let name = name_value
44 .path
45 .get_ident()
46 .ok_or_else(|| syn::Error::new_spanned(&name_value.path, "expected identifier"))?
47 .to_string();
48
49 if name != expected_name {
50 return Err(syn::Error::new_spanned(
51 &name_value.path,
52 format!("expected '{}', found '{}'", expected_name, name),
53 ));
54 }
55
56 if let Expr::Lit(ExprLit {
57 lit: Lit::Int(lit_int),
58 ..
59 }) = &name_value.value
60 {
61 Ok(quote! { #lit_int })
62 } else {
63 Err(syn::Error::new_spanned(
64 &name_value.value,
65 format!("{} must be an integer", expected_name),
66 ))
67 }
68}
69
70fn parse_flag_attribute(path: &syn::Path, expected_name: &str) -> syn::Result<bool> {
72 if path.is_ident(expected_name) {
73 Ok(true)
74 } else {
75 Err(syn::Error::new_spanned(
76 path,
77 format!("unknown flag attribute, expected '{}'", expected_name),
78 ))
79 }
80}
81
82fn extract_generic_type(ty: &Type) -> Option<Type> {
84 if let Type::Path(TypePath { path, .. }) = ty {
85 if let Some(segment) = path.segments.last() {
86 if let PathArguments::AngleBracketed(args) = &segment.arguments {
87 if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
88 return Some(inner_type.clone());
89 }
90 }
91 }
92 }
93 None
94}
95
96struct InputAttributeParser {
98 id: Option<proc_macro2::TokenStream>,
99 output_type: Option<proc_macro2::TokenStream>,
100 storage_type: Option<proc_macro2::TokenStream>,
101 assume_changed: bool,
102}
103
104impl InputAttributeParser {
105 fn new() -> Self {
106 Self {
107 id: None,
108 output_type: None,
109 storage_type: None,
110 assume_changed: false,
111 }
112 }
113
114 fn parse_attribute_list(&mut self, attr: &Attribute) -> syn::Result<()> {
115 let meta = attr.meta.clone();
116 if let Meta::List(meta_list) = meta {
117 let parsed: syn::punctuated::Punctuated<Meta, syn::Token![,]> =
118 meta_list.parse_args_with(syn::punctuated::Punctuated::parse_terminated)?;
119
120 for meta in parsed {
121 match meta {
122 Meta::Path(path) => {
123 if parse_flag_attribute(&path, "assume_changed")? {
124 self.assume_changed = true;
125 }
126 }
127 Meta::NameValue(name_value) => {
128 let name = name_value
129 .path
130 .get_ident()
131 .ok_or_else(|| {
132 syn::Error::new_spanned(&name_value.path, "expected identifier")
133 })?
134 .to_string();
135
136 match name.as_str() {
137 "id" => {
138 self.id = Some(parse_int_attribute(&name_value, "id")?);
139 }
140 "output" => {
141 self.output_type =
142 Some(parse_type_attribute(&name_value, "output")?);
143 }
144 "storage" => {
145 self.storage_type =
146 Some(parse_type_attribute(&name_value, "storage")?);
147 }
148 _ => {
149 return Err(syn::Error::new_spanned(
150 &name_value.path,
151 format!(
152 "unknown attribute '{}' for Input derive. Valid attributes are: id, output, storage, assume_changed",
153 name
154 ),
155 ));
156 }
157 }
158 }
159 _ => {
160 return Err(syn::Error::new_spanned(
161 &meta,
162 "unsupported attribute format",
163 ));
164 }
165 }
166 }
167 }
168 Ok(())
169 }
170
171 fn validate(
172 &self,
173 attr_span: Span,
174 ) -> syn::Result<(
175 proc_macro2::TokenStream,
176 proc_macro2::TokenStream,
177 proc_macro2::TokenStream,
178 bool,
179 )> {
180 let id = self
181 .id
182 .clone()
183 .ok_or_else(|| syn::Error::new(attr_span, "missing required 'id' attribute"))?;
184 let output_type = self
185 .output_type
186 .clone()
187 .ok_or_else(|| syn::Error::new(attr_span, "missing required 'output' attribute"))?;
188 let storage_type = self
189 .storage_type
190 .clone()
191 .ok_or_else(|| syn::Error::new(attr_span, "missing required 'storage' attribute"))?;
192
193 Ok((id, output_type, storage_type, self.assume_changed))
194 }
195}
196
197struct IntermediateAttributeParser {
199 id: Option<proc_macro2::TokenStream>,
200 assume_changed: bool,
201}
202
203impl IntermediateAttributeParser {
204 fn new() -> Self {
205 Self {
206 id: None,
207 assume_changed: false,
208 }
209 }
210
211 fn parse_meta(&mut self, meta: &Meta) -> syn::Result<()> {
212 match meta {
213 Meta::Path(path) => {
214 if parse_flag_attribute(path, "assume_changed")? {
215 self.assume_changed = true;
216 }
217 }
218 Meta::NameValue(name_value) => {
219 let name = name_value
220 .path
221 .get_ident()
222 .ok_or_else(|| {
223 syn::Error::new_spanned(&name_value.path, "expected identifier")
224 })?
225 .to_string();
226
227 match name.as_str() {
228 "id" => {
229 self.id = Some(parse_int_attribute(name_value, "id")?);
230 }
231 _ => {
232 return Err(syn::Error::new_spanned(
233 &name_value.path,
234 format!(
235 "unknown attribute '{}' for intermediate macro. Valid attributes are: id, assume_changed",
236 name
237 ),
238 ));
239 }
240 }
241 }
242 _ => {
243 return Err(syn::Error::new_spanned(
244 meta,
245 "unsupported attribute format",
246 ));
247 }
248 }
249 Ok(())
250 }
251
252 fn validate(&self, attr_span: Span) -> syn::Result<(proc_macro2::TokenStream, bool)> {
253 let id = self
254 .id
255 .clone()
256 .ok_or_else(|| syn::Error::new(attr_span, "missing required 'id' attribute"))?;
257
258 Ok((id, self.assume_changed))
259 }
260}
261
262#[proc_macro_derive(Storage, attributes(inc_complete))]
310pub fn derive_storage(input: TokenStream) -> TokenStream {
311 let input = parse_macro_input!(input as DeriveInput);
312
313 let struct_name = &input.ident;
314
315 let fields = match &input.data {
317 Data::Struct(data) => match &data.fields {
318 Fields::Named(fields) => &fields.named,
319 _ => {
320 return syn::Error::new(
321 Span::call_site(),
322 "Storage derive only works on structs with named fields",
323 )
324 .to_compile_error()
325 .into();
326 }
327 },
328 _ => {
329 return syn::Error::new(Span::call_site(), "Storage derive only works on structs")
330 .to_compile_error()
331 .into();
332 }
333 };
334
335 let mut field_mappings = Vec::new();
337 let mut accumulated = Vec::new();
338
339 for field in fields {
340 let field_name = field.ident.as_ref().unwrap();
341 let field_type = &field.ty;
342
343 let mut skip_field = false;
345 let mut is_accumulated = false;
346 let mut manual_computation_type: Option<Type> = None;
347
348 for attr in &field.attrs {
349 if attr.path().is_ident("inc_complete") {
350 match attr.meta {
351 Meta::List(ref list) => {
352 let nested_result =
354 list.parse_args_with(|parser: syn::parse::ParseStream| {
355 while !parser.is_empty() {
356 let lookahead = parser.lookahead1();
357 if lookahead.peek(syn::Ident) {
358 let ident: syn::Ident = parser.parse()?;
359 if ident == "skip" {
360 skip_field = true;
361 } else if ident == "computation" {
362 parser.parse::<syn::Token![=]>()?;
363 manual_computation_type = Some(parser.parse()?);
364 } else if ident == "accumulate" {
365 is_accumulated = true;
366 } else {
367 return Err(syn::Error::new_spanned(
368 ident,
369 "expected 'skip' or 'computation'",
370 ));
371 }
372 } else {
373 return Err(lookahead.error());
374 }
375
376 if !parser.is_empty() {
377 parser.parse::<syn::Token![,]>()?;
378 }
379 }
380
381 Ok(())
382 });
383
384 if let Err(e) = nested_result {
385 return e.to_compile_error().into();
386 }
387 }
388 _ => {
389 return syn::Error::new_spanned(
390 attr,
391 "expected #[inc_complete(skip)] or #[inc_complete(computation = Type)]",
392 )
393 .to_compile_error()
394 .into();
395 }
396 }
397 }
398 }
399
400 if skip_field {
401 continue;
403 }
404
405 let computation_type = if let Some(manual_type) = manual_computation_type {
407 manual_type
409 } else if let Some(extracted_type) = extract_generic_type(field_type) {
410 extracted_type
412 } else {
413 return syn::Error::new(
414 field.span(),
415 "Field must be a storage type like SingletonStorage<T>, HashMapStorage<T>, or use #[inc_complete(computation = Type)] to specify the type manually, or use #[inc_complete(skip)] to exclude it",
416 )
417 .to_compile_error()
418 .into();
419 };
420
421 let item = quote! { #field_name: #computation_type, };
422 if is_accumulated {
423 accumulated.push(item);
424 } else {
425 field_mappings.push(item);
426 }
427 }
428
429 let expanded = quote! {
431 inc_complete::impl_storage!(#struct_name, #(#field_mappings)* @accumulators { #(#accumulated)* });
432 };
433
434 TokenStream::from(expanded)
435}
436
437#[proc_macro_derive(Input, attributes(inc_complete))]
461pub fn derive_input(input: TokenStream) -> TokenStream {
462 let input = parse_macro_input!(input as DeriveInput);
463
464 let struct_name = &input.ident;
465
466 let inc_complete_attr = input
468 .attrs
469 .iter()
470 .find(|attr| attr.path().is_ident("inc_complete"));
471
472 let attr = match inc_complete_attr {
473 Some(attr) => attr,
474 None => {
475 return syn::Error::new(
476 Span::call_site(),
477 "Input derive requires #[inc_complete(...)] attribute",
478 )
479 .to_compile_error()
480 .into();
481 }
482 };
483
484 let mut parser = InputAttributeParser::new();
486 if let Err(err) = parser.parse_attribute_list(attr) {
487 return err.to_compile_error().into();
488 }
489
490 let (id, output_type, storage_type, assume_changed) = match parser.validate(attr.span()) {
491 Ok(values) => values,
492 Err(err) => return err.to_compile_error().into(),
493 };
494
495 let expanded = if assume_changed {
497 quote! {
498 inc_complete::define_input!(#id, assume_changed #struct_name -> #output_type, #storage_type);
499 }
500 } else {
501 quote! {
502 inc_complete::define_input!(#id, #struct_name -> #output_type, #storage_type);
503 }
504 };
505
506 TokenStream::from(expanded)
507}
508
509#[proc_macro_attribute]
540pub fn intermediate(args: TokenStream, input: TokenStream) -> TokenStream {
541 let args = parse_macro_input!(args with syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated);
542 let input_fn = parse_macro_input!(input as syn::ItemFn);
543
544 match process_intermediate_function(args, input_fn) {
545 Ok(tokens) => tokens.into(),
546 Err(err) => err.to_compile_error().into(),
547 }
548}
549
550fn process_intermediate_function(
556 args: syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>,
557 input_fn: syn::ItemFn,
558) -> syn::Result<proc_macro2::TokenStream> {
559 let mut parser = IntermediateAttributeParser::new();
561 for arg in args {
562 parser.parse_meta(&arg)?;
563 }
564 let (id, assume_changed) = parser.validate(Span::call_site())?;
565
566 let fn_name = &input_fn.sig.ident;
568 let output_type = extract_return_type(&input_fn.sig.output)?;
569 let (computation_type, storage_type) = extract_types_from_params(&input_fn.sig.inputs)?;
570
571 let expanded = if assume_changed {
573 quote! {
574 #input_fn
575
576 inc_complete::define_intermediate!(#id, assume_changed #computation_type -> #output_type, #storage_type, #fn_name);
577 }
578 } else {
579 quote! {
580 #input_fn
581
582 inc_complete::define_intermediate!(#id, #computation_type -> #output_type, #storage_type, #fn_name);
583 }
584 };
585
586 Ok(expanded)
587}
588
589fn extract_return_type(output: &syn::ReturnType) -> syn::Result<proc_macro2::TokenStream> {
591 match output {
592 syn::ReturnType::Type(_, ty) => Ok(quote! { #ty }),
593 syn::ReturnType::Default => Err(syn::Error::new(
594 Span::call_site(),
595 "function must have an explicit return type",
596 )),
597 }
598}
599
600fn extract_types_from_params(
602 inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>,
603) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
604 let mut iter = inputs.iter();
605
606 let first_arg = iter.next().ok_or_else(|| {
608 syn::Error::new(
609 Span::call_site(),
610 "function must have at least two parameters: (&ComputationType, &DbHandle<StorageType>)",
611 )
612 })?;
613
614 let computation_type = extract_reference_inner_type(
615 first_arg,
616 "first parameter must be a reference to the computation type (e.g., &ComputationType)",
617 )?;
618
619 let second_arg = iter.next().ok_or_else(|| {
621 syn::Error::new(
622 Span::call_site(),
623 "function must have a second parameter of type &DbHandle<StorageType>",
624 )
625 })?;
626
627 let storage_type = extract_dbhandle_storage_type(second_arg)?;
628
629 Ok((computation_type, storage_type))
630}
631
632fn extract_reference_inner_type(
634 arg: &syn::FnArg,
635 error_msg: &str,
636) -> syn::Result<proc_macro2::TokenStream> {
637 if let syn::FnArg::Typed(pat_type) = arg {
638 if let syn::Type::Reference(type_ref) = pat_type.ty.as_ref() {
639 let inner_type = &type_ref.elem;
640 Ok(quote! { #inner_type })
641 } else {
642 Err(syn::Error::new_spanned(arg, error_msg))
643 }
644 } else {
645 Err(syn::Error::new_spanned(arg, error_msg))
646 }
647}
648
649fn extract_dbhandle_storage_type(arg: &syn::FnArg) -> syn::Result<proc_macro2::TokenStream> {
651 let error_msg = "second parameter must be &DbHandle<StorageType>";
652
653 if let syn::FnArg::Typed(pat_type) = arg {
654 if let syn::Type::Reference(type_ref) = pat_type.ty.as_ref() {
655 if let syn::Type::Path(type_path) = type_ref.elem.as_ref() {
656 if let Some(segment) = type_path.path.segments.last() {
658 if segment.ident == "DbHandle" {
659 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
660 if let Some(syn::GenericArgument::Type(storage_ty)) = args.args.first()
661 {
662 return Ok(quote! { #storage_ty });
663 }
664 }
665 return Err(syn::Error::new_spanned(
666 arg,
667 "DbHandle must have a generic type parameter for the storage type",
668 ));
669 }
670 }
671 }
672 }
673 }
674
675 Err(syn::Error::new_spanned(arg, error_msg))
676}