crate::ix!();
pub fn collect_variant_probs(
variants: &Punctuated<Variant, Comma>,
) -> (Vec<Ident>, Vec<f64>, Vec<Fields>) {
let mut variant_idents = Vec::new();
let mut probs = Vec::new();
let mut variant_fields = Vec::new();
for variant in variants {
let variant_ident = variant.ident.clone();
let prob = extract_probability_from_attributes(&variant.attrs).unwrap_or(1.0);
let fields = variant.fields.clone();
variant_idents.push(variant_ident);
probs.push(prob);
variant_fields.push(fields);
}
(variant_idents, probs, variant_fields)
}
#[cfg(test)]
mod tests {
use super::*;
use syn::{parse_quote, punctuated::Punctuated, token::Comma, Variant};
#[test]
fn test_collect_variant_probs() {
let variants: Punctuated<Variant, Comma> = parse_quote! {
#[rand_construct(p = 0.8)]
UnitVariant,
UnnamedVariant(i32, String),
#[rand_construct(p = 0.3)]
NamedVariant { x: f64, y: bool }
};
let (variant_idents, probs, variant_fields) = collect_variant_probs(&variants);
let expected_idents: Vec<String> = vec![
"UnitVariant".to_string(),
"UnnamedVariant".to_string(),
"NamedVariant".to_string(),
];
let expected_probs: Vec<f64> = vec![0.8, 1.0, 0.3];
let ident_strings: Vec<String> = variant_idents.iter().map(|id| id.to_string()).collect();
assert_eq!(ident_strings, expected_idents);
assert_eq!(probs, expected_probs);
assert!(matches!(variant_fields[0], Fields::Unit));
assert!(matches!(variant_fields[1], Fields::Unnamed(_)));
assert!(matches!(variant_fields[2], Fields::Named(_)));
}
}