1extern crate proc_macro;
2use std::{boxed, collections::HashSet};
3
4use convert_case::Casing;
5use itertools::Itertools;
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use quote::quote;
9use syn::{
10 punctuated::Punctuated, token::Comma, Expr, Field, Ident, ImplItem, Lit, LitInt, LitStr, Meta,
11};
12
13enum Error {
15 CompileError(proc_macro2::TokenStream),
16 Generic(boxed::Box<dyn std::error::Error>),
17}
18
19impl From<&str> for Error {
20 fn from(value: &str) -> Self {
21 Error::Generic(value.into())
22 }
23}
24
25impl From<String> for Error {
26 fn from(value: String) -> Self {
27 Error::Generic(value.into())
28 }
29}
30
31impl From<syn::Error> for Error {
32 fn from(value: syn::Error) -> Self {
33 Error::CompileError(value.into_compile_error())
34 }
35}
36
37impl Error {
38 fn to_compile_error2(self) -> proc_macro2::TokenStream {
39 match self {
40 Error::CompileError(stream) => stream,
41 Error::Generic(err) => syn::Error::new(Span::call_site(), err.to_string()).to_compile_error(),
42 }
43 }
44 fn to_compile_error(self) -> proc_macro::TokenStream {
45 self.to_compile_error2().into()
46 }
47 fn extended(self, e: Error) -> Self {
48 let mut stream = self.to_compile_error2();
49 stream.extend(e.to_compile_error2());
50 Error::CompileError(stream)
51 }
52}
53
54lazy_static::lazy_static! {
55 static ref IMPLS_TO_CHECK: Vec<&'static str> = vec![
56 "compose",
57 "warmup",
58 "mutate",
59 "crossover",
60 "get_state",
61 "set_state",
62 "reset_state",
63 ];
64 static ref IMPLS_TO_CHECK_SET : HashSet<&'static str> = HashSet::from_iter(IMPLS_TO_CHECK.iter().cloned());
65}
66
67struct ParamSingle {
68 name: String,
69 var_name: syn::Ident,
70 desc: String,
71 types: syn::Expr,
72}
73
74struct ParamSet {
75 type_name: syn::Type,
76 var_name: syn::Ident,
77 has_custom_interface: bool,
78}
79
80enum Param {
81 Single(ParamSingle),
82 Set(ParamSet),
83}
84
85fn get_field_name(fld: &Field) -> String {
86 if let Some(id) = &fld.ident {
87 id.to_string()
88 } else {
89 "".to_string()
90 }
91}
92
93fn get_expr_str_lit(expr: &Expr) -> Result<String, Error> {
94 if let syn::Expr::Lit(lit) = expr {
95 if let syn::Lit::Str(str) = &lit.lit {
96 Ok(str.value())
97 } else {
98 Err("Value must be a string literal".into())
99 }
100 } else {
101 Err("Value must be a string literal".into())
102 }
103}
104
105fn get_expr_bool_lit(expr: &Expr) -> Result<bool, Error> {
106 if let syn::Expr::Lit(lit) = expr {
107 if let syn::Lit::Bool(str) = &lit.lit {
108 Ok(str.value())
109 } else {
110 Err("Value must be a bool literal".into())
111 }
112 } else {
113 Err("Value must be a bool literal".into())
114 }
115}
116
117struct EnumInfoAttr {
118 id: Expr,
119 name: Expr,
120 desc: Expr,
121}
122
123fn read_enum_info_attr(attrs: &Vec<syn::Attribute>) -> Result<EnumInfoAttr, Error> {
124 for attr in attrs {
125 if attr.path().is_ident("enum_info") {
126 let args = attr.parse_args_with(Punctuated::<syn::Expr, Comma>::parse_terminated)?;
127 return if let Some((id, name, desc)) =
128 args.into_pairs().map(|x| x.into_value()).collect_tuple()
129 {
130 Ok(EnumInfoAttr { id, name, desc })
131 } else {
132 Err("shards_enum attribute must have 3 arguments: (Id, Name, Description)".into())
133 };
134 }
135 }
136 Err("Missing shards_enum attribute".into())
137}
138
139fn read_enum_value_attr(attrs: &Vec<syn::Attribute>) -> Result<Option<LitStr>, Error> {
140 for attr in attrs {
141 if attr.path().is_ident("enum_value") {
142 return Ok(Some(attr.parse_args()?));
143 }
144 }
145 Ok(None)
146}
147
148fn generate_enum_wrapper(enum_: syn::ItemEnum) -> Result<TokenStream, Error> {
149 let vis = enum_.vis;
150 let enum_id = enum_.ident;
151 let enum_name = enum_id.to_string();
152
153 let mut value_ids = Vec::new();
154 let mut value_str_ids = Vec::new();
155 let mut value_desc_lits = Vec::new();
156 let mut value_name_lits = Vec::new();
157
158 let shards_enum_attr = read_enum_info_attr(&enum_.attrs)?;
159
160 for var in &enum_.variants {
161 let var_name = var.ident.to_string();
162 let desc_lit = read_enum_value_attr(&var.attrs)?;
163
164 value_ids.push(var.ident.clone());
165 value_str_ids.push(Ident::new(
166 &format!("{}_str", var_name),
167 proc_macro2::Span::call_site(),
168 ));
169 value_name_lits.push(LitStr::new(&var_name, proc_macro2::Span::call_site()));
170
171 if let Some(lit) = desc_lit {
172 value_desc_lits.push(lit);
173 } else {
174 value_desc_lits.push(LitStr::new("", proc_macro2::Span::call_site()));
175 }
176 }
177
178 let enum_info_id = Ident::new(
179 &format!("{}EnumInfo", enum_name),
180 proc_macro2::Span::call_site(),
181 );
182
183 let enum_name_upper = enum_name.to_uppercase();
184
185 let enum_info_instance_id = Ident::new(
186 &format!("{}_ENUM_INFO", enum_name_upper),
187 proc_macro2::Span::call_site(),
188 );
189
190 let typedef_id = Ident::new(
191 &format!("{}_TYPE", enum_name_upper),
192 proc_macro2::Span::call_site(),
193 );
194
195 let typedef_vec_id = Ident::new(
196 &format!("{}_TYPES", enum_name_upper),
197 proc_macro2::Span::call_site(),
198 );
199
200 let enum_id_expr = shards_enum_attr.id;
201 let enum_name_expr = shards_enum_attr.name;
202 let enum_desc_expr = shards_enum_attr.desc;
203
204 Ok(
205 quote! {
206 #vis struct #enum_info_id {
207 name: &'static str,
208 help: shards::types::OptionalString,
209 enum_type: shards::types::Type,
210 labels: shards::types::Strings,
211 values: Vec<i32>,
212 descriptions: shards::types::OptionalStrings,
213 }
214
215 lazy_static::lazy_static! {
216 #vis static ref #enum_info_instance_id: #enum_info_id = #enum_info_id::new();
217 #vis static ref #typedef_id: shards::types::Type = #enum_info_instance_id.enum_type;
218 #vis static ref #typedef_vec_id: shards::types::Types = vec![*#typedef_id];
219 }
220
221 impl shards::core::EnumRegister for #enum_id {
222 fn register() {
223 let e = unsafe { &#enum_info_instance_id.enum_type.details.enumeration };
224 shards::core::register_enum_internal(e.vendorId, e.typeId, (&*#enum_info_instance_id).into());
225 }
226 }
227
228 #[allow(non_upper_case_globals)]
229 impl<'a> #enum_info_id {
230 #(
231 pub const #value_ids: shards::SHEnum = #enum_id::#value_ids as i32;
232 )*
233 #(
234 pub const #value_str_ids: &'static str = shards::cstr!(#value_name_lits);
235 )*
236
237 fn new() -> Self {
238 let mut labels = shards::types::Strings::new();
239 #(
240 labels.push(Self::#value_str_ids);
241 )*
242
243 let mut descriptions = shards::types::OptionalStrings::new();
244 #(
245 descriptions.push(shards::types::OptionalString(shards::shccstr!(#value_desc_lits)));
246 )*
247
248 Self {
249 name: shards::cstr!(#enum_name_expr),
250 help: shards::types::OptionalString(shards::shccstr!(#enum_desc_expr)),
251 enum_type: shards::types::Type::enumeration(shards::types::FRAG_CC, shards::fourCharacterCode(*#enum_id_expr)),
252 labels,
253 values: vec![#(Self::#value_ids,)*],
254 descriptions,
255 }
256 }
257 }
258
259 impl TryFrom<i32> for #enum_id {
260 type Error = &'static str;
261 fn try_from(value: i32) -> Result<Self, Self::Error> {
262 match value {
263 #(#enum_info_id::#value_ids => Ok(#enum_id::#value_ids),)*
264 _ => Err("Invalid enum value"),
265 }
266 }
267 }
268
269 impl From<#enum_id> for i32 {
270 fn from(value: #enum_id) -> Self {
271 match value {
272 #(#enum_id::#value_ids => #enum_info_id::#value_ids,)*
273 }
274 }
275 }
276
277 impl From<#enum_id> for shards::types::Var {
278 fn from(value: #enum_id) -> Self {
279 let e = unsafe { &#typedef_id.details.enumeration };
280 Self {
281 valueType: shards::SHType_Enum,
282 payload: shards::SHVarPayload {
283 __bindgen_anon_1: shards::SHVarPayload__bindgen_ty_1 {
284 __bindgen_anon_3: shards::SHVarPayload__bindgen_ty_1__bindgen_ty_3 {
285 enumValue: value.into(),
286 enumVendorId: e.vendorId,
287 enumTypeId: e.typeId,
288 },
289 },
290 },
291 ..Default::default()
292 }
293 }
294 }
295
296 impl TryFrom<&shards::types::Var> for #enum_id {
297 type Error = &'static str;
298 fn try_from(value: &shards::types::Var) -> Result<Self, Self::Error> {
299 if value.valueType != shards::SHType_Enum {
300 return Err("Value is not an enum");
301 }
302
303 let e = unsafe { &value.payload.__bindgen_anon_1.__bindgen_anon_3 } ;
304 let e1 = unsafe { &#typedef_id.details.enumeration };
305 if e.enumVendorId != e1.vendorId {
306 return Err("Enum vendor id does not match");
307 }
308 if e.enumTypeId != e1.typeId {
309 return Err("Enum type id does not match");
310 }
311 e.enumValue.try_into()
312 }
313 }
314
315 impl From<&#enum_info_id> for shards::shardsc::SHEnumInfo {
316 fn from(info: &#enum_info_id) -> Self {
317 Self {
318 name: info.name.as_ptr() as *const std::os::raw::c_char,
319 help: info.help.0,
320 labels: info.labels.s,
321 values: shards::shardsc::SHEnums {
322 elements: (&info.values).as_ptr() as *mut i32,
323 len: info.values.len() as u32,
324 cap: 0
325 },
326 descriptions: (&info.descriptions).into(),
327 }
328 }
329 }
330 }
331 .into(),
332 )
333}
334
335#[proc_macro_derive(shards_enum, attributes(enum_info, enum_value))]
336pub fn derive_shards_enum(enum_def: TokenStream) -> TokenStream {
337 let enum_: syn::ItemEnum = syn::parse_macro_input!(enum_def);
338
339 match generate_enum_wrapper(enum_) {
340 Ok(result) => {
341 result
343 }
344 Err(err) => err.to_compile_error(),
345 }
346}
347
348fn parse_param_single(fld: &syn::Field, attr: &syn::Attribute) -> Result<ParamSingle, Error> {
349 let Meta::List(list) = &attr.meta else {
350 panic!("Param attribute must be a list");
351 };
352 let args = list
353 .parse_args_with(Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated)
354 .expect("Expected parsing");
355
356 if let Some((name, desc, types)) = args.into_pairs().map(|x| x.into_value()).collect_tuple() {
357 let name = get_expr_str_lit(&name)?;
358 let desc = get_expr_str_lit(&desc)?;
359 Ok(ParamSingle {
360 name,
361 var_name: fld.ident.clone().expect("Expected field name"),
362 desc,
363 types,
364 })
365 } else {
366 Err(
367 syn::Error::new(
368 attr.bracket_token.span.open(),
369 "Param attribute must have 3 arguments: (Name, Description, [Type1, Type2,...]/Types)",
370 )
371 .into(),
372 )
373 }
374}
375
376fn crc32(name: String) -> u32 {
377 let crc = crc::Crc::<u32>::new(&crc::CRC_32_BZIP2);
378 let checksum = crc.checksum(name.as_bytes());
379 checksum
380}
381
382struct Warmable {
383 warmup: proc_macro2::TokenStream,
384 cleanup: proc_macro2::TokenStream,
385}
386
387fn default_warmable(fld: &Field) -> Warmable {
388 let ident: &Ident = fld.ident.as_ref().expect("Expected field name");
389 Warmable {
390 warmup: quote! {self.#ident.warmup(context)?;},
391 cleanup: quote! {self.#ident.cleanup(context);},
392 }
393}
394
395fn to_warmable(
396 fld: &Field,
397 is_param_set: bool,
398 param_set_has_custom_interface: bool,
399) -> Option<Warmable> {
400 let rust_type = &fld.ty;
401 let ident: &Ident = fld.ident.as_ref().expect("Expected field name");
402 if let syn::Type::Path(p) = &rust_type {
403 let last_type_id = &p.path.segments.last().expect("Empty path").ident;
404 if last_type_id == "ParamVar" {
405 return Some(Warmable {
406 warmup: quote! {self.#ident.warmup(context);},
407 cleanup: quote! {self.#ident.cleanup(context);},
408 });
409 } else if is_param_set {
410 if param_set_has_custom_interface {
411 return Some(default_warmable(fld));
412 } else {
413 return Some(Warmable {
414 warmup: quote! {self.#ident.warmup_helper(context)?;},
415 cleanup: quote! {self.#ident.cleanup_helper(context)?;},
416 });
417 }
418 } else if last_type_id == "ShardsVar" {
419 return Some(default_warmable(fld));
420 }
421 }
422 return None;
423}
424
425#[derive(Default)]
426struct ShardFields {
427 params: Vec<Param>,
428 required: Option<syn::Ident>,
429 warmables: Vec<Warmable>,
430}
431
432fn parse_param_set_has_custom_interface(attr: &syn::Attribute) -> Result<bool, Error> {
433 let Meta::List(list) = &attr.meta else {
434 return Ok(false);
435 };
436
437 let args = list
438 .parse_args_with(Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated)
439 .expect("Expected parsing");
440
441 if let Some((b,)) = args.into_pairs().map(|x| x.into_value()).collect_tuple() {
442 let b = get_expr_bool_lit(&b)?;
443 return Ok(b);
444 }
445
446 Ok(false)
447}
448
449fn parse_shard_fields<'a>(
450 fields: impl IntoIterator<Item = &'a Field>,
451) -> Result<ShardFields, Error> {
452 let mut result = ShardFields::default();
453 for fld in fields {
454 let name: String = get_field_name(&fld);
455
456 for attr in &fld.attrs {
457 if attr.path().is_ident("shard_param") {
458 match parse_param_single(&fld, &attr) {
459 Ok(param) => {
460 result.params.push(Param::Single(param));
461 if let Some(warmable) = to_warmable(&fld, false, false) {
462 result.warmables.push(warmable);
463 }
464 }
465 Err(e) => {
466 return Err(e.extended(format!("Failed to parse param for field {}", name).into()))
467 }
468 }
469 } else if attr.path().is_ident("shard_param_set") {
470 let has_custom_interface = parse_param_set_has_custom_interface(&attr)?;
471
472 let param_set_ty = fld.ty.clone();
473 result.params.push(Param::Set(ParamSet {
474 type_name: param_set_ty,
475 var_name: fld.ident.clone().expect("Expected field name"),
476 has_custom_interface,
477 }));
478 result
479 .warmables
480 .push(to_warmable(fld, true, has_custom_interface).unwrap());
481 } else if attr.path().is_ident("shard_required") {
482 result.required = Some(fld.ident.as_ref().expect("Expected field name").clone());
483 } else if attr.path().is_ident("shard_warmup") {
484 if let Some(warmable) = to_warmable(&fld, false, false) {
485 result.warmables.push(warmable);
486 }
487 }
488 }
489 }
490 Ok(result)
491}
492
493struct ShardInfoAttr {
494 name: Expr,
495 desc: Expr,
496}
497
498fn read_shard_info_attr(
499 err_span: Span,
500 attrs: &Vec<syn::Attribute>,
501) -> Result<ShardInfoAttr, Error> {
502 for attr in attrs {
503 if attr.path().is_ident("shard_info") {
504 let args: Punctuated<Expr, Comma> =
505 attr.parse_args_with(Punctuated::<syn::Expr, Comma>::parse_terminated)?;
506 return if let Some((name, desc)) = args.into_pairs().map(|x| x.into_value()).collect_tuple() {
507 if let Expr::Lit(syn::ExprLit {
509 lit: Lit::Str(lit_str),
510 ..
511 }) = &desc
512 {
513 if lit_str.value().trim().is_empty() {
514 return Err("Description must not be empty".into());
515 }
516 } else {
517 return Err("Description must be a string literal".into());
518 }
519 Ok(ShardInfoAttr { name, desc })
520 } else {
521 Err("shard_info attribute must have 2 arguments: (Name, Description)".into())
522 };
523 }
524 }
525 Err(syn::Error::new(err_span, "Missing shard_info attribute").into())
526}
527
528struct ParameterAccessor {
529 get: proc_macro2::TokenStream,
530 set: proc_macro2::TokenStream,
531}
532
533fn generate_parameter_accessor(
534 offset_id: Ident,
535 in_id: Ident,
536 p: &Param,
537) -> Result<ParameterAccessor, Error> {
538 match p {
539 Param::Single(single) => {
540 let var_name = &single.var_name;
541 Ok(ParameterAccessor {
542 get: quote! {
543 if #in_id == #offset_id {
544 return (&self.#var_name).into();
545 }
546 #offset_id += 1;
547 },
548 set: quote! {
549 if(#in_id == #offset_id) {
550 return self.#var_name.set_param(value);
551 }
552 #offset_id += 1;
553 },
554 })
555 }
556 Param::Set(set) => {
557 let var_name = &set.var_name;
558 let set_type = &set.type_name;
559 Ok(ParameterAccessor {
560 get: quote! {
561 let local_id = #in_id - #offset_id;
562 if local_id >= 0 && local_id < (#set_type::num_params() as i32) {
563 return (&mut self.#var_name).get_param(local_id);
564 }
565 #offset_id += #set_type::num_params() as i32;
566 },
567 set: quote! {
568 let local_id = #in_id - #offset_id;
569 if local_id >= 0 && local_id < (#set_type::num_params() as i32) {
570 return (&mut self.#var_name).set_param(local_id, value);
571 }
572 #offset_id += #set_type::num_params() as i32;
573 },
574 })
575 }
576 }
577}
578
579fn generate_parameter_accessors(params: &Vec<Param>) -> Result<proc_macro2::TokenStream, Error> {
580 let static_params = params
581 .iter()
582 .filter_map(|p| match p {
583 Param::Single(single) => Some(single),
584 Param::Set(_) => None,
585 })
586 .collect::<Vec<_>>();
587 let is_complex = static_params.len() != params.len();
588
589 Ok(if is_complex {
590 let offset_id = Ident::new("offset", proc_macro2::Span::call_site());
591 let in_id = Ident::new("index", proc_macro2::Span::call_site());
592 let accessors = params
593 .iter()
594 .map(|p| generate_parameter_accessor(offset_id.clone(), in_id.clone(), p))
595 .collect::<Result<Vec<_>, _>>()?;
596 let getters = accessors.iter().map(|x| &x.get);
597 let setters = accessors.iter().map(|x| &x.set);
598
599 quote! {
600 fn set_param(&mut self, #in_id: i32, value: &shards::types::Var) -> std::result::Result<(), &'static str> {
601 let mut #offset_id: i32 = 0;
602 #(#setters)*
603 Err("Invalid parameter index")
604 }
605
606 fn get_param(&mut self, #in_id: i32) -> shards::types::Var {
607 let mut #offset_id: i32 = 0;
608 #(#getters)*
609 shards::types::Var::default()
610 }
611 }
612 } else {
613 let params_idents: Vec<_> = static_params.iter().map(|x| &x.var_name).collect();
614 let params_indices: Vec<_> = (0..static_params.len())
615 .map(|x| LitInt::new(&format!("{}", x), proc_macro2::Span::call_site()))
616 .collect();
617
618 quote! {
619 fn set_param(&mut self, index: i32, value: &shards::types::Var) -> std::result::Result<(), &'static str> {
620 match index {
621 #(
622 #params_indices => self.#params_idents.set_param(value),
623 )*
624 _ => Err("Invalid parameter index"),
625 }
626 }
627
628 fn get_param(&mut self, index: i32) -> shards::types::Var {
629 match index {
630 #(
631 #params_indices => (&self.#params_idents).into(),
632 )*
633 _ => shards::types::Var::default(),
634 }
635 }
636 }
637 })
638}
639
640struct ParamWrapperCode {
641 prelude: proc_macro2::TokenStream,
642 warmups: Vec<proc_macro2::TokenStream>,
644 cleanups_rev: Vec<proc_macro2::TokenStream>,
646 accessors: proc_macro2::TokenStream,
648 params_static_id: Ident,
650 composes: Vec<proc_macro2::TokenStream>,
651 shard_fields: ShardFields,
652}
653
654fn generate_param_wrapper_code(struct_: &syn::ItemStruct) -> Result<ParamWrapperCode, Error> {
655 let struct_id = &struct_.ident;
656 let struct_name_upper = struct_id.to_string().to_uppercase();
657 let struct_name_lower = struct_id.to_string().to_case(convert_case::Case::Snake);
658
659 let shard_fields = parse_shard_fields(&struct_.fields)?;
660 let params = &shard_fields.params;
661
662 let static_params = params
663 .iter()
664 .filter_map(|p| match p {
665 Param::Single(single) => Some(single),
666 Param::Set(_) => None,
667 })
668 .collect::<Vec<_>>();
669
670 let mut array_initializers = Vec::new();
671 let param_names: Vec<_> = static_params
672 .iter()
673 .map(|x| LitStr::new(&x.name, proc_macro2::Span::call_site()))
674 .collect();
675 let param_descs: Vec<_> = static_params
676 .iter()
677 .map(|x| LitStr::new(&x.desc, proc_macro2::Span::call_site()))
678 .collect();
679 let param_types: Vec<_> = static_params
680 .iter()
681 .map(|x| {
682 if let Expr::Array(arr) = &x.types {
683 let tmp_id: Ident = Ident::new(
684 &format!("{}_{}_TYPES", struct_name_upper, x.name.to_uppercase()),
685 proc_macro2::Span::call_site(),
686 );
687 array_initializers.push(quote! { static ref #tmp_id: shards::types::Types = vec!#arr; });
688 syn::parse_quote! { #tmp_id }
689 } else {
690 x.types.clone()
691 }
692 })
693 .collect();
694
695 let params_static_id: Ident = Ident::new(
696 &format!("{}_PARAMETERS", struct_name_upper),
697 proc_macro2::Span::call_site(),
698 );
699
700 let mut warmups = Vec::new();
702 let mut cleanups_rev = Vec::new();
703 for x in &shard_fields.warmables {
704 warmups.push(x.warmup.clone());
705 cleanups_rev.push(x.cleanup.clone());
706 }
707
708 let mut composes = Vec::new();
709 for param in params {
710 match param {
711 Param::Single(single) => {
712 let var_name = &single.var_name;
713 composes.push(quote! {
714 shards::util::collect_required_variables(&data.shared, out_required, (&self.#var_name).into())?;
715 });
716 }
717 Param::Set(set) => {
718 let var_name = &set.var_name;
719 composes.push(quote! {
720 (&mut self.#var_name).compose_helper(out_required, data)?;
721 });
722 }
723 }
724 }
725
726 let build_params_id = Ident::new(
727 &format!("build_params_{}", struct_name_lower),
728 proc_macro2::Span::call_site(),
729 );
730
731 let accessors = generate_parameter_accessors(params)?;
732 let append_params = params.iter().map(|p| match p {
733 Param::Single(_) => {
734 quote! {
735 params.push(static_params[static_idx].clone());
736 static_idx += 1;
737 }
738 }
739 Param::Set(set) => {
740 let set_type = &set.type_name;
741 quote! {
742 for param in #set_type::parameters() {
743 params.push(param.clone());
744 }
745 }
746 }
747 });
748
749 let prelude = quote! {
750 fn #build_params_id() -> Vec<shards::types::ParameterInfo> {
751 let static_params : Vec<shards::types::ParameterInfo> = vec![
752 #((
753 shards::cstr!(#param_names),
754 shards::shccstr!(#param_descs),
755 &#param_types[..]
756 ).into()),*
757 ];
758 let mut params = Vec::new();
759 let mut static_idx: usize = 0;
760 #(#append_params)*
761 params
762 }
763
764 lazy_static::lazy_static! {
765 #(#array_initializers)*
766 static ref #params_static_id: shards::types::Parameters = #build_params_id();
767 }
768 };
769
770 Ok(ParamWrapperCode {
771 prelude,
772 warmups,
773 params_static_id: params_static_id,
774 cleanups_rev,
775 accessors,
776 composes,
777 shard_fields,
778 })
779}
780
781fn process_param_set_impl(struct_: syn::ItemStruct) -> Result<TokenStream, Error> {
782 let struct_id = &struct_.ident;
783 let ParamWrapperCode {
784 prelude,
785 warmups,
786 cleanups_rev,
787 accessors,
788 params_static_id,
789 composes,
790 ..
791 } = generate_param_wrapper_code(&struct_)?;
792 let cleanups = cleanups_rev.iter().rev();
793
794 Ok(quote! {
795 #prelude
796
797 impl shards::shard::ParameterSet for #struct_id {
798 fn parameters() -> &'static shards::types::Parameters {
799 &#params_static_id
800 }
801
802 fn num_params() -> usize { #params_static_id.len() }
803
804 #accessors
805
806 fn warmup_helper(&mut self, context: &shards::types::Context) -> std::result::Result<(), &'static str> {
807 #( #warmups )*
808 Ok(())
809 }
810
811 fn cleanup_helper(&mut self, context: std::option::Option<&shards::types::Context>) -> std::result::Result<(), &'static str> {
812 #( #cleanups )*
813 Ok(())
814 }
815
816 fn compose_helper(&mut self, out_required: &mut shards::types::ExposedTypes, data: &shards::types::InstanceData) -> std::result::Result<(), &'static str> {
817 #( #composes )*
818 Ok(())
819 }
820 }
821 }.into())
822}
823
824fn process_shard_helper_impl(struct_: syn::ItemStruct) -> Result<TokenStream, Error> {
825 let struct_id = &struct_.ident;
826
827 let shard_info = read_shard_info_attr(struct_id.span(), &struct_.attrs)?;
828
829 let ParamWrapperCode {
830 prelude,
831 warmups,
832 cleanups_rev,
833 accessors,
834 params_static_id,
835 composes,
836 shard_fields,
837 } = generate_param_wrapper_code(&struct_)?;
838 let cleanups = cleanups_rev.iter().rev();
839
840 let shard_name_expr = shard_info.name;
841 let shard_name = get_expr_str_lit(&shard_name_expr)?;
842 let shard_desc_expr = shard_info.desc;
843
844 let crc = crc32(format!("{}-rust-0x20200101", shard_name));
845
846 let (required_variables_opt, compose_helper) = if let Some(required) = &shard_fields.required {
847 (
848 quote! { Some(&self.#required) },
849 quote! {
850 fn compose_helper(&mut self, data: &shards::types::InstanceData) -> std::result::Result<(), &'static str> {
851 self.#required.clear();
852 let out_required = &mut self.#required;
853 #(#composes)*
854 Ok(())
855 }
856 },
857 )
858 } else {
859 (quote! { None }, quote! {})
860 };
861
862 Ok(quote! {
863 #prelude
864
865 impl shards::shard::ShardGenerated for #struct_id {
866 fn register_name() -> &'static str {
867 shards::cstr!(#shard_name_expr)
868 }
869
870 fn name(&mut self) -> &str {
871 #shard_name_expr
872 }
873
874 fn hash() -> u32
875 where
876 Self: Sized,
877 {
878 #crc
879 }
880
881 fn help(&mut self) -> shards::types::OptionalString {
882 shards::types::OptionalString(shards::shccstr!(#shard_desc_expr))
883 }
884
885 fn parameters(&mut self) -> Option<&shards::types::Parameters> {
886 Some(&#params_static_id)
887 }
888
889 #accessors
890
891 fn required_variables(&mut self) -> Option<&shards::types::ExposedTypes> {
892 #required_variables_opt
893 }
894 }
895
896 impl #struct_id {
897 #compose_helper
898
899 fn warmup_helper(&mut self, context: &shards::types::Context) -> std::result::Result<(), &'static str> {
900 #( #warmups )*
901 Ok(())
902 }
903
904
905 fn cleanup_helper(&mut self, context: std::option::Option<&shards::types::Context>) -> std::result::Result<(), &'static str> {
906 #( #cleanups )*
907 Ok(())
908 }
909 }
910 }.into())
911}
912
913#[proc_macro_derive(
914 shard,
915 attributes(shard_info, shard_param, shard_param_set, shard_required, shard_warmup)
916)]
917pub fn derive_shard(struct_def: TokenStream) -> TokenStream {
918 let struct_: syn::ItemStruct = syn::parse_macro_input!(struct_def as syn::ItemStruct);
919
920 match process_shard_helper_impl(struct_) {
921 Ok(result) => {
922 result
924 }
925 Err(err) => err.to_compile_error(),
926 }
927}
928
929#[proc_macro_derive(param_set, attributes(shard_param, shard_param_set, shard_warmup))]
930pub fn derive_param_set(struct_def: TokenStream) -> TokenStream {
931 let struct_: syn::ItemStruct = syn::parse_macro_input!(struct_def as syn::ItemStruct);
932
933 match process_param_set_impl(struct_) {
934 Ok(result) => {
935 result
937 }
938 Err(err) => err.to_compile_error(),
939 }
940}
941
942fn generate_impl_wrapper(impl_: syn::ItemImpl) -> Result<TokenStream, Error> {
943 let struct_ty = impl_.self_ty.as_ref();
944
945 let mut have_impls: HashSet<String> = HashSet::new();
946 for item in &impl_.items {
947 if let ImplItem::Fn(fn_item) = item {
948 let fn_name = fn_item.sig.ident.to_string();
949 if IMPLS_TO_CHECK_SET.contains(fn_name.as_str()) {
950 have_impls.insert(fn_name);
951 }
952 }
953 }
954
955 let impls = IMPLS_TO_CHECK.iter().map(|x| {
957 let has_fn_id = Ident::new(&format!("has_{}", x), proc_macro2::Span::call_site());
958 let have_function = syn::LitBool::new(have_impls.contains(*x), proc_macro2::Span::call_site());
959 quote! { fn #has_fn_id() -> bool { #have_function } }
960 });
961
962 Ok(
963 quote! {
964 #[allow(non_snake_case)]
965 impl shards::shard::ShardGeneratedOverloads for #struct_ty {
966 #(#impls)*
967 }
968
969 #[allow(non_snake_case)]
970 #impl_
971 }
972 .into(),
973 )
974}
975
976#[proc_macro_attribute]
977pub fn shard_impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
978 let impl_: syn::ItemImpl = syn::parse_macro_input!(item);
979 match generate_impl_wrapper(impl_) {
980 Ok(result) => {
981 result
983 }
984 Err(err) => err.to_compile_error(),
985 }
986}