1use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::quote;
7use syn::{
8 Attribute, Data, DeriveInput, Expr, ExprLit, Fields, GenericArgument, Lit, Meta, PathArguments,
9 Type, TypePath, parse_macro_input, spanned::Spanned,
10};
11
12fn parse_type_attribute(
18 name_value: &syn::MetaNameValue,
19 expected_name: &str,
20) -> syn::Result<proc_macro2::TokenStream> {
21 let name = name_value
22 .path
23 .get_ident()
24 .ok_or_else(|| syn::Error::new_spanned(&name_value.path, "expected identifier"))?
25 .to_string();
26
27 if name != expected_name {
28 return Err(syn::Error::new_spanned(
29 &name_value.path,
30 format!("expected '{}', found '{}'", expected_name, name),
31 ));
32 }
33
34 let expr = &name_value.value;
35 Ok(quote::quote! { #expr })
36}
37
38fn parse_int_attribute(
40 name_value: &syn::MetaNameValue,
41 expected_name: &str,
42) -> syn::Result<proc_macro2::TokenStream> {
43 let name = name_value
44 .path
45 .get_ident()
46 .ok_or_else(|| syn::Error::new_spanned(&name_value.path, "expected identifier"))?
47 .to_string();
48
49 if name != expected_name {
50 return Err(syn::Error::new_spanned(
51 &name_value.path,
52 format!("expected '{}', found '{}'", expected_name, name),
53 ));
54 }
55
56 if let Expr::Lit(ExprLit {
57 lit: Lit::Int(lit_int),
58 ..
59 }) = &name_value.value
60 {
61 Ok(quote! { #lit_int })
62 } else {
63 Err(syn::Error::new_spanned(
64 &name_value.value,
65 format!("{} must be an integer", expected_name),
66 ))
67 }
68}
69
70fn parse_flag_attribute(path: &syn::Path, expected_name: &str) -> syn::Result<bool> {
72 if path.is_ident(expected_name) {
73 Ok(true)
74 } else {
75 Err(syn::Error::new_spanned(
76 path,
77 format!("unknown flag attribute, expected '{}'", expected_name),
78 ))
79 }
80}
81
82fn extract_generic_type(ty: &Type, container_suffix: &str) -> Option<Type> {
84 if let Type::Path(TypePath { path, .. }) = ty {
85 if let Some(segment) = path.segments.last() {
86 if segment.ident.to_string().ends_with(container_suffix) {
87 if let PathArguments::AngleBracketed(args) = &segment.arguments {
88 if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
89 return Some(inner_type.clone());
90 }
91 }
92 }
93 }
94 }
95 None
96}
97
98struct InputAttributeParser {
100 id: Option<proc_macro2::TokenStream>,
101 output_type: Option<proc_macro2::TokenStream>,
102 storage_type: Option<proc_macro2::TokenStream>,
103 assume_changed: bool,
104}
105
106impl InputAttributeParser {
107 fn new() -> Self {
108 Self {
109 id: None,
110 output_type: None,
111 storage_type: None,
112 assume_changed: false,
113 }
114 }
115
116 fn parse_attribute_list(&mut self, attr: &Attribute) -> syn::Result<()> {
117 let meta = attr.meta.clone();
118 if let Meta::List(meta_list) = meta {
119 let parsed: syn::punctuated::Punctuated<Meta, syn::Token![,]> =
120 meta_list.parse_args_with(syn::punctuated::Punctuated::parse_terminated)?;
121
122 for meta in parsed {
123 match meta {
124 Meta::Path(path) => {
125 if parse_flag_attribute(&path, "assume_changed")? {
126 self.assume_changed = true;
127 }
128 }
129 Meta::NameValue(name_value) => {
130 let name = name_value
131 .path
132 .get_ident()
133 .ok_or_else(|| {
134 syn::Error::new_spanned(&name_value.path, "expected identifier")
135 })?
136 .to_string();
137
138 match name.as_str() {
139 "id" => {
140 self.id = Some(parse_int_attribute(&name_value, "id")?);
141 }
142 "output" => {
143 self.output_type =
144 Some(parse_type_attribute(&name_value, "output")?);
145 }
146 "storage" => {
147 self.storage_type =
148 Some(parse_type_attribute(&name_value, "storage")?);
149 }
150 _ => {
151 return Err(syn::Error::new_spanned(
152 &name_value.path,
153 format!(
154 "unknown attribute '{}' for Input derive. Valid attributes are: id, output, storage, assume_changed",
155 name
156 ),
157 ));
158 }
159 }
160 }
161 _ => {
162 return Err(syn::Error::new_spanned(
163 &meta,
164 "unsupported attribute format",
165 ));
166 }
167 }
168 }
169 }
170 Ok(())
171 }
172
173 fn validate(
174 &self,
175 attr_span: Span,
176 ) -> syn::Result<(
177 proc_macro2::TokenStream,
178 proc_macro2::TokenStream,
179 proc_macro2::TokenStream,
180 bool,
181 )> {
182 let id = self
183 .id
184 .clone()
185 .ok_or_else(|| syn::Error::new(attr_span, "missing required 'id' attribute"))?;
186 let output_type = self
187 .output_type
188 .clone()
189 .ok_or_else(|| syn::Error::new(attr_span, "missing required 'output' attribute"))?;
190 let storage_type = self
191 .storage_type
192 .clone()
193 .ok_or_else(|| syn::Error::new(attr_span, "missing required 'storage' attribute"))?;
194
195 Ok((id, output_type, storage_type, self.assume_changed))
196 }
197}
198
199struct IntermediateAttributeParser {
201 id: Option<proc_macro2::TokenStream>,
202 assume_changed: bool,
203}
204
205impl IntermediateAttributeParser {
206 fn new() -> Self {
207 Self {
208 id: None,
209 assume_changed: false,
210 }
211 }
212
213 fn parse_meta(&mut self, meta: &Meta) -> syn::Result<()> {
214 match meta {
215 Meta::Path(path) => {
216 if parse_flag_attribute(path, "assume_changed")? {
217 self.assume_changed = true;
218 }
219 }
220 Meta::NameValue(name_value) => {
221 let name = name_value
222 .path
223 .get_ident()
224 .ok_or_else(|| {
225 syn::Error::new_spanned(&name_value.path, "expected identifier")
226 })?
227 .to_string();
228
229 match name.as_str() {
230 "id" => {
231 self.id = Some(parse_int_attribute(name_value, "id")?);
232 }
233 _ => {
234 return Err(syn::Error::new_spanned(
235 &name_value.path,
236 format!(
237 "unknown attribute '{}' for intermediate macro. Valid attributes are: id, assume_changed",
238 name
239 ),
240 ));
241 }
242 }
243 }
244 _ => {
245 return Err(syn::Error::new_spanned(
246 meta,
247 "unsupported attribute format",
248 ));
249 }
250 }
251 Ok(())
252 }
253
254 fn validate(&self, attr_span: Span) -> syn::Result<(proc_macro2::TokenStream, bool)> {
255 let id = self
256 .id
257 .clone()
258 .ok_or_else(|| syn::Error::new(attr_span, "missing required 'id' attribute"))?;
259
260 Ok((id, self.assume_changed))
261 }
262}
263
264#[proc_macro_derive(Storage, attributes(inc_complete))]
312pub fn derive_storage(input: TokenStream) -> TokenStream {
313 let input = parse_macro_input!(input as DeriveInput);
314
315 let struct_name = &input.ident;
316
317 let fields = match &input.data {
319 Data::Struct(data) => match &data.fields {
320 Fields::Named(fields) => &fields.named,
321 _ => {
322 return syn::Error::new(
323 Span::call_site(),
324 "Storage derive only works on structs with named fields",
325 )
326 .to_compile_error()
327 .into();
328 }
329 },
330 _ => {
331 return syn::Error::new(Span::call_site(), "Storage derive only works on structs")
332 .to_compile_error()
333 .into();
334 }
335 };
336
337 let mut field_mappings = Vec::new();
339
340 for field in fields {
341 let field_name = field.ident.as_ref().unwrap();
342 let field_type = &field.ty;
343
344 let mut skip_field = false;
346 let mut manual_computation_type: Option<Type> = None;
347
348 for attr in &field.attrs {
349 if attr.path().is_ident("inc_complete") {
350 match attr.meta {
351 Meta::List(ref list) => {
352 let nested_result =
354 list.parse_args_with(|parser: syn::parse::ParseStream| {
355 let mut found_skip = false;
356 let mut found_computation_type: Option<Type> = None;
357
358 while !parser.is_empty() {
359 let lookahead = parser.lookahead1();
360 if lookahead.peek(syn::Ident) {
361 let ident: syn::Ident = parser.parse()?;
362 if ident == "skip" {
363 found_skip = true;
364 } else if ident == "computation" {
365 parser.parse::<syn::Token![=]>()?;
366 found_computation_type = Some(parser.parse()?);
367 } else {
368 return Err(syn::Error::new_spanned(
369 ident,
370 "expected 'skip' or 'computation'",
371 ));
372 }
373 } else {
374 return Err(lookahead.error());
375 }
376
377 if !parser.is_empty() {
378 parser.parse::<syn::Token![,]>()?;
379 }
380 }
381
382 Ok((found_skip, found_computation_type))
383 });
384
385 match nested_result {
386 Ok((found_skip, found_computation_type)) => {
387 if found_skip {
388 skip_field = true;
389 }
390 if let Some(computation_type) = found_computation_type {
391 manual_computation_type = Some(computation_type);
392 }
393 }
394 Err(e) => {
395 return e.to_compile_error().into();
396 }
397 }
398 }
399 _ => {
400 return syn::Error::new_spanned(
401 attr,
402 "expected #[inc_complete(skip)] or #[inc_complete(computation = Type)]",
403 )
404 .to_compile_error()
405 .into();
406 }
407 }
408 }
409 }
410
411 if skip_field {
412 continue;
414 }
415
416 let computation_type = if let Some(manual_type) = manual_computation_type {
418 manual_type
420 } else if let Some(extracted_type) = extract_generic_type(field_type, "Storage") {
421 extracted_type
423 } else {
424 return syn::Error::new(
425 field.span(),
426 "Field must be a storage type like SingletonStorage<T>, HashMapStorage<T>, or use #[inc_complete(computation = Type)] to specify the type manually, or use #[inc_complete(skip)] to exclude it",
427 )
428 .to_compile_error()
429 .into();
430 };
431
432 field_mappings.push(quote! { #field_name: #computation_type });
433 }
434
435 let expanded = quote! {
437 inc_complete::impl_storage!(#struct_name, #(#field_mappings),*);
438 };
439
440 TokenStream::from(expanded)
441}
442
443#[proc_macro_derive(Input, attributes(inc_complete))]
467pub fn derive_input(input: TokenStream) -> TokenStream {
468 let input = parse_macro_input!(input as DeriveInput);
469
470 let struct_name = &input.ident;
471
472 let inc_complete_attr = input
474 .attrs
475 .iter()
476 .find(|attr| attr.path().is_ident("inc_complete"));
477
478 let attr = match inc_complete_attr {
479 Some(attr) => attr,
480 None => {
481 return syn::Error::new(
482 Span::call_site(),
483 "Input derive requires #[inc_complete(...)] attribute",
484 )
485 .to_compile_error()
486 .into();
487 }
488 };
489
490 let mut parser = InputAttributeParser::new();
492 if let Err(err) = parser.parse_attribute_list(attr) {
493 return err.to_compile_error().into();
494 }
495
496 let (id, output_type, storage_type, assume_changed) = match parser.validate(attr.span()) {
497 Ok(values) => values,
498 Err(err) => return err.to_compile_error().into(),
499 };
500
501 let expanded = if assume_changed {
503 quote! {
504 inc_complete::define_input!(#id, assume_changed #struct_name -> #output_type, #storage_type);
505 }
506 } else {
507 quote! {
508 inc_complete::define_input!(#id, #struct_name -> #output_type, #storage_type);
509 }
510 };
511
512 TokenStream::from(expanded)
513}
514
515#[proc_macro_attribute]
546pub fn intermediate(args: TokenStream, input: TokenStream) -> TokenStream {
547 let args = parse_macro_input!(args with syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated);
548 let input_fn = parse_macro_input!(input as syn::ItemFn);
549
550 match process_intermediate_function(args, input_fn) {
551 Ok(tokens) => tokens.into(),
552 Err(err) => err.to_compile_error().into(),
553 }
554}
555
556fn process_intermediate_function(
562 args: syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>,
563 input_fn: syn::ItemFn,
564) -> syn::Result<proc_macro2::TokenStream> {
565 let mut parser = IntermediateAttributeParser::new();
567 for arg in args {
568 parser.parse_meta(&arg)?;
569 }
570 let (id, assume_changed) = parser.validate(Span::call_site())?;
571
572 let fn_name = &input_fn.sig.ident;
574 let output_type = extract_return_type(&input_fn.sig.output)?;
575 let (computation_type, storage_type) = extract_types_from_params(&input_fn.sig.inputs)?;
576
577 let expanded = if assume_changed {
579 quote! {
580 #input_fn
581
582 inc_complete::define_intermediate!(#id, assume_changed #computation_type -> #output_type, #storage_type, #fn_name);
583 }
584 } else {
585 quote! {
586 #input_fn
587
588 inc_complete::define_intermediate!(#id, #computation_type -> #output_type, #storage_type, #fn_name);
589 }
590 };
591
592 Ok(expanded)
593}
594
595fn extract_return_type(output: &syn::ReturnType) -> syn::Result<proc_macro2::TokenStream> {
597 match output {
598 syn::ReturnType::Type(_, ty) => Ok(quote! { #ty }),
599 syn::ReturnType::Default => Err(syn::Error::new(
600 Span::call_site(),
601 "function must have an explicit return type",
602 )),
603 }
604}
605
606fn extract_types_from_params(
608 inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>,
609) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
610 let mut iter = inputs.iter();
611
612 let first_arg = iter.next().ok_or_else(|| {
614 syn::Error::new(
615 Span::call_site(),
616 "function must have at least two parameters: (&ComputationType, &DbHandle<StorageType>)",
617 )
618 })?;
619
620 let computation_type = extract_reference_inner_type(
621 first_arg,
622 "first parameter must be a reference to the computation type (e.g., &ComputationType)",
623 )?;
624
625 let second_arg = iter.next().ok_or_else(|| {
627 syn::Error::new(
628 Span::call_site(),
629 "function must have a second parameter of type &DbHandle<StorageType>",
630 )
631 })?;
632
633 let storage_type = extract_dbhandle_storage_type(second_arg)?;
634
635 Ok((computation_type, storage_type))
636}
637
638fn extract_reference_inner_type(
640 arg: &syn::FnArg,
641 error_msg: &str,
642) -> syn::Result<proc_macro2::TokenStream> {
643 if let syn::FnArg::Typed(pat_type) = arg {
644 if let syn::Type::Reference(type_ref) = pat_type.ty.as_ref() {
645 let inner_type = &type_ref.elem;
646 Ok(quote! { #inner_type })
647 } else {
648 Err(syn::Error::new_spanned(arg, error_msg))
649 }
650 } else {
651 Err(syn::Error::new_spanned(arg, error_msg))
652 }
653}
654
655fn extract_dbhandle_storage_type(arg: &syn::FnArg) -> syn::Result<proc_macro2::TokenStream> {
657 let error_msg = "second parameter must be &DbHandle<StorageType>";
658
659 if let syn::FnArg::Typed(pat_type) = arg {
660 if let syn::Type::Reference(type_ref) = pat_type.ty.as_ref() {
661 if let syn::Type::Path(type_path) = type_ref.elem.as_ref() {
662 if let Some(segment) = type_path.path.segments.last() {
664 if segment.ident == "DbHandle" {
665 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
666 if let Some(syn::GenericArgument::Type(storage_ty)) = args.args.first()
667 {
668 return Ok(quote! { #storage_ty });
669 }
670 }
671 return Err(syn::Error::new_spanned(
672 arg,
673 "DbHandle must have a generic type parameter for the storage type",
674 ));
675 }
676 }
677 }
678 }
679 }
680
681 Err(syn::Error::new_spanned(arg, error_msg))
682}