shards_macro/
lib.rs

1extern crate proc_macro;
2use std::{boxed, collections::HashSet};
3
4use convert_case::Casing;
5use itertools::Itertools;
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use quote::quote;
9use syn::{
10  punctuated::Punctuated, token::Comma, Expr, Field, Ident, ImplItem, Lit, LitInt, LitStr, Meta,
11};
12
13// type Error = boxed::Box<dyn std::error::Error>;
14enum Error {
15  CompileError(proc_macro2::TokenStream),
16  Generic(boxed::Box<dyn std::error::Error>),
17}
18
19impl From<&str> for Error {
20  fn from(value: &str) -> Self {
21    Error::Generic(value.into())
22  }
23}
24
25impl From<String> for Error {
26  fn from(value: String) -> Self {
27    Error::Generic(value.into())
28  }
29}
30
31impl From<syn::Error> for Error {
32  fn from(value: syn::Error) -> Self {
33    Error::CompileError(value.into_compile_error())
34  }
35}
36
37impl Error {
38  fn to_compile_error2(self) -> proc_macro2::TokenStream {
39    match self {
40      Error::CompileError(stream) => stream,
41      Error::Generic(err) => syn::Error::new(Span::call_site(), err.to_string()).to_compile_error(),
42    }
43  }
44  fn to_compile_error(self) -> proc_macro::TokenStream {
45    self.to_compile_error2().into()
46  }
47  fn extended(self, e: Error) -> Self {
48    let mut stream = self.to_compile_error2();
49    stream.extend(e.to_compile_error2());
50    Error::CompileError(stream)
51  }
52}
53
54lazy_static::lazy_static! {
55  static ref IMPLS_TO_CHECK: Vec<&'static str> = vec![
56    "compose",
57    "warmup",
58    "mutate",
59    "crossover",
60    "get_state",
61    "set_state",
62    "reset_state",
63  ];
64  static ref IMPLS_TO_CHECK_SET : HashSet<&'static str> = HashSet::from_iter(IMPLS_TO_CHECK.iter().cloned());
65}
66
67struct ParamSingle {
68  name: String,
69  var_name: syn::Ident,
70  desc: String,
71  types: syn::Expr,
72}
73
74struct ParamSet {
75  type_name: syn::Type,
76  var_name: syn::Ident,
77  has_custom_interface: bool,
78}
79
80enum Param {
81  Single(ParamSingle),
82  Set(ParamSet),
83}
84
85fn get_field_name(fld: &Field) -> String {
86  if let Some(id) = &fld.ident {
87    id.to_string()
88  } else {
89    "".to_string()
90  }
91}
92
93fn get_expr_str_lit(expr: &Expr) -> Result<String, Error> {
94  if let syn::Expr::Lit(lit) = expr {
95    if let syn::Lit::Str(str) = &lit.lit {
96      Ok(str.value())
97    } else {
98      Err("Value must be a string literal".into())
99    }
100  } else {
101    Err("Value must be a string literal".into())
102  }
103}
104
105fn get_expr_bool_lit(expr: &Expr) -> Result<bool, Error> {
106  if let syn::Expr::Lit(lit) = expr {
107    if let syn::Lit::Bool(str) = &lit.lit {
108      Ok(str.value())
109    } else {
110      Err("Value must be a bool literal".into())
111    }
112  } else {
113    Err("Value must be a bool literal".into())
114  }
115}
116
117struct EnumInfoAttr {
118  id: Expr,
119  name: Expr,
120  desc: Expr,
121}
122
123fn read_enum_info_attr(attrs: &Vec<syn::Attribute>) -> Result<EnumInfoAttr, Error> {
124  for attr in attrs {
125    if attr.path().is_ident("enum_info") {
126      let args = attr.parse_args_with(Punctuated::<syn::Expr, Comma>::parse_terminated)?;
127      return if let Some((id, name, desc)) =
128        args.into_pairs().map(|x| x.into_value()).collect_tuple()
129      {
130        Ok(EnumInfoAttr { id, name, desc })
131      } else {
132        Err("shards_enum attribute must have 3 arguments: (Id, Name, Description)".into())
133      };
134    }
135  }
136  Err("Missing shards_enum attribute".into())
137}
138
139fn read_enum_value_attr(attrs: &Vec<syn::Attribute>) -> Result<Option<LitStr>, Error> {
140  for attr in attrs {
141    if attr.path().is_ident("enum_value") {
142      return Ok(Some(attr.parse_args()?));
143    }
144  }
145  Ok(None)
146}
147
148fn generate_enum_wrapper(enum_: syn::ItemEnum) -> Result<TokenStream, Error> {
149  let vis = enum_.vis;
150  let enum_id = enum_.ident;
151  let enum_name = enum_id.to_string();
152
153  let mut value_ids = Vec::new();
154  let mut value_str_ids = Vec::new();
155  let mut value_desc_lits = Vec::new();
156  let mut value_name_lits = Vec::new();
157
158  let shards_enum_attr = read_enum_info_attr(&enum_.attrs)?;
159
160  for var in &enum_.variants {
161    let var_name = var.ident.to_string();
162    let desc_lit = read_enum_value_attr(&var.attrs)?;
163
164    value_ids.push(var.ident.clone());
165    value_str_ids.push(Ident::new(
166      &format!("{}_str", var_name),
167      proc_macro2::Span::call_site(),
168    ));
169    value_name_lits.push(LitStr::new(&var_name, proc_macro2::Span::call_site()));
170
171    if let Some(lit) = desc_lit {
172      value_desc_lits.push(lit);
173    } else {
174      value_desc_lits.push(LitStr::new("", proc_macro2::Span::call_site()));
175    }
176  }
177
178  let enum_info_id = Ident::new(
179    &format!("{}EnumInfo", enum_name),
180    proc_macro2::Span::call_site(),
181  );
182
183  let enum_name_upper = enum_name.to_uppercase();
184
185  let enum_info_instance_id = Ident::new(
186    &format!("{}_ENUM_INFO", enum_name_upper),
187    proc_macro2::Span::call_site(),
188  );
189
190  let typedef_id = Ident::new(
191    &format!("{}_TYPE", enum_name_upper),
192    proc_macro2::Span::call_site(),
193  );
194
195  let typedef_vec_id = Ident::new(
196    &format!("{}_TYPES", enum_name_upper),
197    proc_macro2::Span::call_site(),
198  );
199
200  let enum_id_expr = shards_enum_attr.id;
201  let enum_name_expr = shards_enum_attr.name;
202  let enum_desc_expr = shards_enum_attr.desc;
203
204  Ok(
205    quote! {
206      #vis struct #enum_info_id {
207        name: &'static str,
208        help: shards::types::OptionalString,
209        enum_type: shards::types::Type,
210        labels: shards::types::Strings,
211        values: Vec<i32>,
212        descriptions: shards::types::OptionalStrings,
213      }
214
215      lazy_static::lazy_static! {
216        #vis static ref #enum_info_instance_id: #enum_info_id = #enum_info_id::new();
217        #vis static ref #typedef_id: shards::types::Type = #enum_info_instance_id.enum_type;
218        #vis static ref #typedef_vec_id: shards::types::Types = vec![*#typedef_id];
219      }
220
221      impl shards::core::EnumRegister for #enum_id {
222        fn register() {
223          let e = unsafe { &#enum_info_instance_id.enum_type.details.enumeration };
224          shards::core::register_enum_internal(e.vendorId, e.typeId, (&*#enum_info_instance_id).into());
225        }
226      }
227
228      #[allow(non_upper_case_globals)]
229      impl<'a> #enum_info_id {
230        #(
231          pub const #value_ids: shards::SHEnum = #enum_id::#value_ids as i32;
232        )*
233        #(
234          pub const #value_str_ids: &'static str = shards::cstr!(#value_name_lits);
235        )*
236
237        fn new() -> Self {
238          let mut labels = shards::types::Strings::new();
239          #(
240            labels.push(Self::#value_str_ids);
241          )*
242
243          let mut descriptions = shards::types::OptionalStrings::new();
244          #(
245            descriptions.push(shards::types::OptionalString(shards::shccstr!(#value_desc_lits)));
246          )*
247
248          Self {
249            name: shards::cstr!(#enum_name_expr),
250            help: shards::types::OptionalString(shards::shccstr!(#enum_desc_expr)),
251            enum_type: shards::types::Type::enumeration(shards::types::FRAG_CC, shards::fourCharacterCode(*#enum_id_expr)),
252            labels,
253            values: vec![#(Self::#value_ids,)*],
254            descriptions,
255          }
256        }
257      }
258
259      impl TryFrom<i32> for #enum_id {
260        type Error = &'static str;
261        fn try_from(value: i32) -> Result<Self, Self::Error> {
262          match value {
263            #(#enum_info_id::#value_ids => Ok(#enum_id::#value_ids),)*
264            _ => Err("Invalid enum value"),
265          }
266        }
267      }
268
269      impl From<#enum_id> for i32 {
270        fn from(value: #enum_id) -> Self {
271          match value {
272            #(#enum_id::#value_ids => #enum_info_id::#value_ids,)*
273          }
274        }
275      }
276
277      impl From<#enum_id> for shards::types::Var {
278        fn from(value: #enum_id) -> Self {
279          let e = unsafe { &#typedef_id.details.enumeration };
280          Self {
281            valueType: shards::SHType_Enum,
282            payload: shards::SHVarPayload {
283              __bindgen_anon_1: shards::SHVarPayload__bindgen_ty_1 {
284                __bindgen_anon_3: shards::SHVarPayload__bindgen_ty_1__bindgen_ty_3 {
285                  enumValue: value.into(),
286                  enumVendorId: e.vendorId,
287                  enumTypeId: e.typeId,
288                },
289              },
290            },
291            ..Default::default()
292          }
293        }
294      }
295
296      impl TryFrom<&shards::types::Var> for #enum_id {
297        type Error = &'static str;
298        fn try_from(value: &shards::types::Var) -> Result<Self, Self::Error> {
299          if value.valueType != shards::SHType_Enum {
300            return Err("Value is not an enum");
301          }
302
303          let e = unsafe { &value.payload.__bindgen_anon_1.__bindgen_anon_3 } ;
304          let e1 = unsafe { &#typedef_id.details.enumeration };
305          if e.enumVendorId != e1.vendorId {
306            return Err("Enum vendor id does not match");
307          }
308          if e.enumTypeId != e1.typeId {
309            return Err("Enum type id does not match");
310          }
311          e.enumValue.try_into()
312        }
313      }
314
315      impl From<&#enum_info_id> for shards::shardsc::SHEnumInfo {
316        fn from(info: &#enum_info_id) -> Self {
317          Self {
318            name: info.name.as_ptr() as *const std::os::raw::c_char,
319            help: info.help.0,
320            labels: info.labels.s,
321            values: shards::shardsc::SHEnums {
322              elements: (&info.values).as_ptr() as *mut i32,
323              len: info.values.len() as u32,
324              cap: 0
325            },
326            descriptions: (&info.descriptions).into(),
327          }
328        }
329      }
330    }
331    .into(),
332  )
333}
334
335#[proc_macro_derive(shards_enum, attributes(enum_info, enum_value))]
336pub fn derive_shards_enum(enum_def: TokenStream) -> TokenStream {
337  let enum_: syn::ItemEnum = syn::parse_macro_input!(enum_def);
338
339  match generate_enum_wrapper(enum_) {
340    Ok(result) => {
341      // eprintln!("derive_shards_enum:\n{}", result);
342      result
343    }
344    Err(err) => err.to_compile_error(),
345  }
346}
347
348fn parse_param_single(fld: &syn::Field, attr: &syn::Attribute) -> Result<ParamSingle, Error> {
349  let Meta::List(list) = &attr.meta else {
350    panic!("Param attribute must be a list");
351  };
352  let args = list
353    .parse_args_with(Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated)
354    .expect("Expected parsing");
355
356  if let Some((name, desc, types)) = args.into_pairs().map(|x| x.into_value()).collect_tuple() {
357    let name = get_expr_str_lit(&name)?;
358    let desc = get_expr_str_lit(&desc)?;
359    Ok(ParamSingle {
360      name,
361      var_name: fld.ident.clone().expect("Expected field name"),
362      desc,
363      types,
364    })
365  } else {
366    Err(
367      syn::Error::new(
368        attr.bracket_token.span.open(),
369        "Param attribute must have 3 arguments: (Name, Description, [Type1, Type2,...]/Types)",
370      )
371      .into(),
372    )
373  }
374}
375
376fn crc32(name: String) -> u32 {
377  let crc = crc::Crc::<u32>::new(&crc::CRC_32_BZIP2);
378  let checksum = crc.checksum(name.as_bytes());
379  checksum
380}
381
382struct Warmable {
383  warmup: proc_macro2::TokenStream,
384  cleanup: proc_macro2::TokenStream,
385}
386
387fn default_warmable(fld: &Field) -> Warmable {
388  let ident: &Ident = fld.ident.as_ref().expect("Expected field name");
389  Warmable {
390    warmup: quote! {self.#ident.warmup(context)?;},
391    cleanup: quote! {self.#ident.cleanup(context);},
392  }
393}
394
395fn to_warmable(
396  fld: &Field,
397  is_param_set: bool,
398  param_set_has_custom_interface: bool,
399) -> Option<Warmable> {
400  let rust_type = &fld.ty;
401  let ident: &Ident = fld.ident.as_ref().expect("Expected field name");
402  if let syn::Type::Path(p) = &rust_type {
403    let last_type_id = &p.path.segments.last().expect("Empty path").ident;
404    if last_type_id == "ParamVar" {
405      return Some(Warmable {
406        warmup: quote! {self.#ident.warmup(context);},
407        cleanup: quote! {self.#ident.cleanup(context);},
408      });
409    } else if is_param_set {
410      if param_set_has_custom_interface {
411        return Some(default_warmable(fld));
412      } else {
413        return Some(Warmable {
414          warmup: quote! {self.#ident.warmup_helper(context)?;},
415          cleanup: quote! {self.#ident.cleanup_helper(context)?;},
416        });
417      }
418    } else if last_type_id == "ShardsVar" {
419      return Some(default_warmable(fld));
420    }
421  }
422  return None;
423}
424
425#[derive(Default)]
426struct ShardFields {
427  params: Vec<Param>,
428  required: Option<syn::Ident>,
429  warmables: Vec<Warmable>,
430}
431
432fn parse_param_set_has_custom_interface(attr: &syn::Attribute) -> Result<bool, Error> {
433  let Meta::List(list) = &attr.meta else {
434    return Ok(false);
435  };
436
437  let args = list
438    .parse_args_with(Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated)
439    .expect("Expected parsing");
440
441  if let Some((b,)) = args.into_pairs().map(|x| x.into_value()).collect_tuple() {
442    let b = get_expr_bool_lit(&b)?;
443    return Ok(b);
444  }
445
446  Ok(false)
447}
448
449fn parse_shard_fields<'a>(
450  fields: impl IntoIterator<Item = &'a Field>,
451) -> Result<ShardFields, Error> {
452  let mut result = ShardFields::default();
453  for fld in fields {
454    let name: String = get_field_name(&fld);
455
456    for attr in &fld.attrs {
457      if attr.path().is_ident("shard_param") {
458        match parse_param_single(&fld, &attr) {
459          Ok(param) => {
460            result.params.push(Param::Single(param));
461            if let Some(warmable) = to_warmable(&fld, false, false) {
462              result.warmables.push(warmable);
463            }
464          }
465          Err(e) => {
466            return Err(e.extended(format!("Failed to parse param for field {}", name).into()))
467          }
468        }
469      } else if attr.path().is_ident("shard_param_set") {
470        let has_custom_interface = parse_param_set_has_custom_interface(&attr)?;
471
472        let param_set_ty = fld.ty.clone();
473        result.params.push(Param::Set(ParamSet {
474          type_name: param_set_ty,
475          var_name: fld.ident.clone().expect("Expected field name"),
476          has_custom_interface,
477        }));
478        result
479          .warmables
480          .push(to_warmable(fld, true, has_custom_interface).unwrap());
481      } else if attr.path().is_ident("shard_required") {
482        result.required = Some(fld.ident.as_ref().expect("Expected field name").clone());
483      } else if attr.path().is_ident("shard_warmup") {
484        if let Some(warmable) = to_warmable(&fld, false, false) {
485          result.warmables.push(warmable);
486        }
487      }
488    }
489  }
490  Ok(result)
491}
492
493struct ShardInfoAttr {
494  name: Expr,
495  desc: Expr,
496}
497
498fn read_shard_info_attr(
499  err_span: Span,
500  attrs: &Vec<syn::Attribute>,
501) -> Result<ShardInfoAttr, Error> {
502  for attr in attrs {
503    if attr.path().is_ident("shard_info") {
504      let args: Punctuated<Expr, Comma> =
505        attr.parse_args_with(Punctuated::<syn::Expr, Comma>::parse_terminated)?;
506      return if let Some((name, desc)) = args.into_pairs().map(|x| x.into_value()).collect_tuple() {
507        // Check if desc is a string literal and not empty
508        if let Expr::Lit(syn::ExprLit {
509          lit: Lit::Str(lit_str),
510          ..
511        }) = &desc
512        {
513          if lit_str.value().trim().is_empty() {
514            return Err("Description must not be empty".into());
515          }
516        } else {
517          return Err("Description must be a string literal".into());
518        }
519        Ok(ShardInfoAttr { name, desc })
520      } else {
521        Err("shard_info attribute must have 2 arguments: (Name, Description)".into())
522      };
523    }
524  }
525  Err(syn::Error::new(err_span, "Missing shard_info attribute").into())
526}
527
528struct ParameterAccessor {
529  get: proc_macro2::TokenStream,
530  set: proc_macro2::TokenStream,
531}
532
533fn generate_parameter_accessor(
534  offset_id: Ident,
535  in_id: Ident,
536  p: &Param,
537) -> Result<ParameterAccessor, Error> {
538  match p {
539    Param::Single(single) => {
540      let var_name = &single.var_name;
541      Ok(ParameterAccessor {
542        get: quote! {
543          if #in_id == #offset_id {
544            return (&self.#var_name).into();
545          }
546          #offset_id += 1;
547        },
548        set: quote! {
549          if(#in_id == #offset_id) {
550            return self.#var_name.set_param(value);
551          }
552          #offset_id += 1;
553        },
554      })
555    }
556    Param::Set(set) => {
557      let var_name = &set.var_name;
558      let set_type = &set.type_name;
559      Ok(ParameterAccessor {
560        get: quote! {
561          let local_id = #in_id - #offset_id;
562          if local_id >= 0 && local_id < (#set_type::num_params() as i32) {
563              return (&mut self.#var_name).get_param(local_id);
564          }
565          #offset_id += #set_type::num_params() as i32;
566        },
567        set: quote! {
568          let local_id = #in_id - #offset_id;
569          if local_id >= 0 && local_id < (#set_type::num_params() as i32) {
570              return (&mut self.#var_name).set_param(local_id, value);
571          }
572          #offset_id += #set_type::num_params() as i32;
573        },
574      })
575    }
576  }
577}
578
579fn generate_parameter_accessors(params: &Vec<Param>) -> Result<proc_macro2::TokenStream, Error> {
580  let static_params = params
581    .iter()
582    .filter_map(|p| match p {
583      Param::Single(single) => Some(single),
584      Param::Set(_) => None,
585    })
586    .collect::<Vec<_>>();
587  let is_complex = static_params.len() != params.len();
588
589  Ok(if is_complex {
590    let offset_id = Ident::new("offset", proc_macro2::Span::call_site());
591    let in_id = Ident::new("index", proc_macro2::Span::call_site());
592    let accessors = params
593      .iter()
594      .map(|p| generate_parameter_accessor(offset_id.clone(), in_id.clone(), p))
595      .collect::<Result<Vec<_>, _>>()?;
596    let getters = accessors.iter().map(|x| &x.get);
597    let setters = accessors.iter().map(|x| &x.set);
598
599    quote! {
600      fn set_param(&mut self, #in_id: i32, value: &shards::types::Var) -> std::result::Result<(), &'static str> {
601        let mut #offset_id: i32 = 0;
602        #(#setters)*
603        Err("Invalid parameter index")
604      }
605
606      fn get_param(&mut self, #in_id: i32) -> shards::types::Var {
607        let mut #offset_id: i32 = 0;
608        #(#getters)*
609        shards::types::Var::default()
610      }
611    }
612  } else {
613    let params_idents: Vec<_> = static_params.iter().map(|x| &x.var_name).collect();
614    let params_indices: Vec<_> = (0..static_params.len())
615      .map(|x| LitInt::new(&format!("{}", x), proc_macro2::Span::call_site()))
616      .collect();
617
618    quote! {
619      fn set_param(&mut self, index: i32, value: &shards::types::Var) -> std::result::Result<(), &'static str> {
620        match index {
621          #(
622            #params_indices => self.#params_idents.set_param(value),
623          )*
624          _ => Err("Invalid parameter index"),
625        }
626      }
627
628      fn get_param(&mut self, index: i32) -> shards::types::Var {
629        match index {
630          #(
631            #params_indices => (&self.#params_idents).into(),
632          )*
633          _ => shards::types::Var::default(),
634        }
635      }
636    }
637  })
638}
639
640struct ParamWrapperCode {
641  prelude: proc_macro2::TokenStream,
642  // warmup bodies
643  warmups: Vec<proc_macro2::TokenStream>,
644  // cleanup bodies, note that you should reverse these before using them
645  cleanups_rev: Vec<proc_macro2::TokenStream>,
646  // get_param & set_param
647  accessors: proc_macro2::TokenStream,
648  // Id for static parameters
649  params_static_id: Ident,
650  composes: Vec<proc_macro2::TokenStream>,
651  shard_fields: ShardFields,
652}
653
654fn generate_param_wrapper_code(struct_: &syn::ItemStruct) -> Result<ParamWrapperCode, Error> {
655  let struct_id = &struct_.ident;
656  let struct_name_upper = struct_id.to_string().to_uppercase();
657  let struct_name_lower = struct_id.to_string().to_case(convert_case::Case::Snake);
658
659  let shard_fields = parse_shard_fields(&struct_.fields)?;
660  let params = &shard_fields.params;
661
662  let static_params = params
663    .iter()
664    .filter_map(|p| match p {
665      Param::Single(single) => Some(single),
666      Param::Set(_) => None,
667    })
668    .collect::<Vec<_>>();
669
670  let mut array_initializers = Vec::new();
671  let param_names: Vec<_> = static_params
672    .iter()
673    .map(|x| LitStr::new(&x.name, proc_macro2::Span::call_site()))
674    .collect();
675  let param_descs: Vec<_> = static_params
676    .iter()
677    .map(|x| LitStr::new(&x.desc, proc_macro2::Span::call_site()))
678    .collect();
679  let param_types: Vec<_> = static_params
680    .iter()
681    .map(|x| {
682      if let Expr::Array(arr) = &x.types {
683        let tmp_id: Ident = Ident::new(
684          &format!("{}_{}_TYPES", struct_name_upper, x.name.to_uppercase()),
685          proc_macro2::Span::call_site(),
686        );
687        array_initializers.push(quote! { static ref #tmp_id: shards::types::Types = vec!#arr; });
688        syn::parse_quote! { #tmp_id }
689      } else {
690        x.types.clone()
691      }
692    })
693    .collect();
694
695  let params_static_id: Ident = Ident::new(
696    &format!("{}_PARAMETERS", struct_name_upper),
697    proc_macro2::Span::call_site(),
698  );
699
700  // Generate warmup/cleanup calls for supported types
701  let mut warmups = Vec::new();
702  let mut cleanups_rev = Vec::new();
703  for x in &shard_fields.warmables {
704    warmups.push(x.warmup.clone());
705    cleanups_rev.push(x.cleanup.clone());
706  }
707
708  let mut composes = Vec::new();
709  for param in params {
710    match param {
711      Param::Single(single) => {
712        let var_name = &single.var_name;
713        composes.push(quote! {
714          shards::util::collect_required_variables(&data.shared, out_required, (&self.#var_name).into())?;
715        });
716      }
717      Param::Set(set) => {
718        let var_name = &set.var_name;
719        composes.push(quote! {
720          (&mut self.#var_name).compose_helper(out_required, data)?;
721        });
722      }
723    }
724  }
725
726  let build_params_id = Ident::new(
727    &format!("build_params_{}", struct_name_lower),
728    proc_macro2::Span::call_site(),
729  );
730
731  let accessors = generate_parameter_accessors(params)?;
732  let append_params = params.iter().map(|p| match p {
733    Param::Single(_) => {
734      quote! {
735        params.push(static_params[static_idx].clone());
736        static_idx += 1;
737      }
738    }
739    Param::Set(set) => {
740      let set_type = &set.type_name;
741      quote! {
742        for param in #set_type::parameters() {
743          params.push(param.clone());
744        }
745      }
746    }
747  });
748
749  let prelude = quote! {
750    fn #build_params_id() -> Vec<shards::types::ParameterInfo> {
751      let static_params : Vec<shards::types::ParameterInfo> = vec![
752        #((
753          shards::cstr!(#param_names),
754          shards::shccstr!(#param_descs),
755          &#param_types[..]
756        ).into()),*
757      ];
758      let mut params = Vec::new();
759      let mut static_idx: usize = 0;
760      #(#append_params)*
761      params
762    }
763
764    lazy_static::lazy_static! {
765      #(#array_initializers)*
766      static ref #params_static_id: shards::types::Parameters = #build_params_id();
767    }
768  };
769
770  Ok(ParamWrapperCode {
771    prelude,
772    warmups,
773    params_static_id: params_static_id,
774    cleanups_rev,
775    accessors,
776    composes,
777    shard_fields,
778  })
779}
780
781fn process_param_set_impl(struct_: syn::ItemStruct) -> Result<TokenStream, Error> {
782  let struct_id = &struct_.ident;
783  let ParamWrapperCode {
784    prelude,
785    warmups,
786    cleanups_rev,
787    accessors,
788    params_static_id,
789    composes,
790    ..
791  } = generate_param_wrapper_code(&struct_)?;
792  let cleanups = cleanups_rev.iter().rev();
793
794  Ok(quote! {
795    #prelude
796
797    impl shards::shard::ParameterSet for #struct_id {
798      fn parameters() -> &'static shards::types::Parameters {
799          &#params_static_id
800      }
801
802      fn num_params() -> usize { #params_static_id.len() }
803
804      #accessors
805
806      fn warmup_helper(&mut self, context: &shards::types::Context) -> std::result::Result<(), &'static str> {
807        #( #warmups )*
808        Ok(())
809      }
810
811      fn cleanup_helper(&mut self, context: std::option::Option<&shards::types::Context>) -> std::result::Result<(), &'static str> {
812        #( #cleanups )*
813        Ok(())
814      }
815
816      fn compose_helper(&mut self, out_required: &mut shards::types::ExposedTypes, data: &shards::types::InstanceData) -> std::result::Result<(), &'static str> {
817        #( #composes )*
818        Ok(())
819      }
820    }
821  }.into())
822}
823
824fn process_shard_helper_impl(struct_: syn::ItemStruct) -> Result<TokenStream, Error> {
825  let struct_id = &struct_.ident;
826
827  let shard_info = read_shard_info_attr(struct_id.span(), &struct_.attrs)?;
828
829  let ParamWrapperCode {
830    prelude,
831    warmups,
832    cleanups_rev,
833    accessors,
834    params_static_id,
835    composes,
836    shard_fields,
837  } = generate_param_wrapper_code(&struct_)?;
838  let cleanups = cleanups_rev.iter().rev();
839
840  let shard_name_expr = shard_info.name;
841  let shard_name = get_expr_str_lit(&shard_name_expr)?;
842  let shard_desc_expr = shard_info.desc;
843
844  let crc = crc32(format!("{}-rust-0x20200101", shard_name));
845
846  let (required_variables_opt, compose_helper) = if let Some(required) = &shard_fields.required {
847    (
848      quote! { Some(&self.#required) },
849      quote! {
850        fn compose_helper(&mut self, data: &shards::types::InstanceData) -> std::result::Result<(), &'static str> {
851          self.#required.clear();
852          let out_required = &mut self.#required;
853          #(#composes)*
854          Ok(())
855        }
856      },
857    )
858  } else {
859    (quote! { None }, quote! {})
860  };
861
862  Ok(quote! {
863    #prelude
864
865    impl shards::shard::ShardGenerated for #struct_id {
866      fn register_name() -> &'static str {
867        shards::cstr!(#shard_name_expr)
868      }
869
870      fn name(&mut self) -> &str {
871        #shard_name_expr
872      }
873
874      fn hash() -> u32
875      where
876        Self: Sized,
877      {
878        #crc
879      }
880
881      fn help(&mut self) -> shards::types::OptionalString {
882        shards::types::OptionalString(shards::shccstr!(#shard_desc_expr))
883      }
884
885      fn parameters(&mut self) -> Option<&shards::types::Parameters> {
886          Some(&#params_static_id)
887      }
888
889      #accessors
890
891      fn required_variables(&mut self) -> Option<&shards::types::ExposedTypes> {
892        #required_variables_opt
893      }
894    }
895
896    impl #struct_id {
897      #compose_helper
898
899      fn warmup_helper(&mut self, context: &shards::types::Context) -> std::result::Result<(), &'static str> {
900        #( #warmups )*
901        Ok(())
902      }
903
904
905      fn cleanup_helper(&mut self, context: std::option::Option<&shards::types::Context>) -> std::result::Result<(), &'static str> {
906        #( #cleanups )*
907        Ok(())
908      }
909    }
910  }.into())
911}
912
913#[proc_macro_derive(
914  shard,
915  attributes(shard_info, shard_param, shard_param_set, shard_required, shard_warmup)
916)]
917pub fn derive_shard(struct_def: TokenStream) -> TokenStream {
918  let struct_: syn::ItemStruct = syn::parse_macro_input!(struct_def as syn::ItemStruct);
919
920  match process_shard_helper_impl(struct_) {
921    Ok(result) => {
922      // eprintln!("derive_shard:\n{}", result);
923      result
924    }
925    Err(err) => err.to_compile_error(),
926  }
927}
928
929#[proc_macro_derive(param_set, attributes(shard_param, shard_param_set, shard_warmup))]
930pub fn derive_param_set(struct_def: TokenStream) -> TokenStream {
931  let struct_: syn::ItemStruct = syn::parse_macro_input!(struct_def as syn::ItemStruct);
932
933  match process_param_set_impl(struct_) {
934    Ok(result) => {
935      // eprintln!("derive_param_set:\n{}", result);
936      result
937    }
938    Err(err) => err.to_compile_error(),
939  }
940}
941
942fn generate_impl_wrapper(impl_: syn::ItemImpl) -> Result<TokenStream, Error> {
943  let struct_ty = impl_.self_ty.as_ref();
944
945  let mut have_impls: HashSet<String> = HashSet::new();
946  for item in &impl_.items {
947    if let ImplItem::Fn(fn_item) = item {
948      let fn_name = fn_item.sig.ident.to_string();
949      if IMPLS_TO_CHECK_SET.contains(fn_name.as_str()) {
950        have_impls.insert(fn_name);
951      }
952    }
953  }
954
955  // Generate hasXXX() -> bool functions for all optional functions
956  let impls = IMPLS_TO_CHECK.iter().map(|x| {
957    let has_fn_id = Ident::new(&format!("has_{}", x), proc_macro2::Span::call_site());
958    let have_function = syn::LitBool::new(have_impls.contains(*x), proc_macro2::Span::call_site());
959    quote! { fn #has_fn_id() -> bool { #have_function } }
960  });
961
962  Ok(
963    quote! {
964      #[allow(non_snake_case)]
965      impl shards::shard::ShardGeneratedOverloads for #struct_ty {
966        #(#impls)*
967      }
968
969      #[allow(non_snake_case)]
970      #impl_
971    }
972    .into(),
973  )
974}
975
976#[proc_macro_attribute]
977pub fn shard_impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
978  let impl_: syn::ItemImpl = syn::parse_macro_input!(item);
979  match generate_impl_wrapper(impl_) {
980    Ok(result) => {
981      // eprintln!("shard_impl:\n{}", result);
982      result
983    }
984    Err(err) => err.to_compile_error(),
985  }
986}