aster_bench/eval_suites/
factory.rs1pub use super::Evaluation;
2use regex::Regex;
3use std::borrow::Cow;
4use std::collections::HashMap;
5use std::sync::{OnceLock, RwLock};
6
7type EvaluationConstructor = fn() -> Box<dyn Evaluation>;
8type Registry = &'static RwLock<HashMap<&'static str, EvaluationConstructor>>;
9
10static EVAL_REGISTRY: OnceLock<RwLock<HashMap<&'static str, EvaluationConstructor>>> =
12 OnceLock::new();
13
14fn eval_registry() -> Registry {
16 EVAL_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
17}
18
19pub fn register_eval(selector: &'static str, constructor: fn() -> Box<dyn Evaluation>) {
21 let registry = eval_registry();
22 if let Ok(mut map) = registry.write() {
23 map.insert(selector, constructor);
24 }
25}
26
27pub struct EvaluationSuite;
28
29impl EvaluationSuite {
30 pub fn from(selector: &str) -> Option<Box<dyn Evaluation>> {
31 let registry = eval_registry();
32 let map = registry
33 .read()
34 .expect("Failed to read the benchmark evaluation registry.");
35
36 let constructor = map.get(selector)?;
37 let instance = constructor();
38
39 Some(instance)
40 }
41
42 pub fn registered_evals() -> Vec<&'static str> {
43 let registry = eval_registry();
44 let map = registry
45 .read()
46 .expect("Failed to read the benchmark evaluation registry.");
47
48 let evals: Vec<_> = map.keys().copied().collect();
49 evals
50 }
51 pub fn select(selectors: Vec<String>) -> HashMap<String, Vec<&'static str>> {
52 let eval_name_pattern = Regex::new(r":\w+$").unwrap();
53 let grouped_by_suite: HashMap<String, Vec<&'static str>> =
54 EvaluationSuite::registered_evals()
55 .into_iter()
56 .filter(|&eval| selectors.is_empty() || matches_any_selectors(eval, &selectors))
57 .fold(HashMap::new(), |mut suites, eval| {
58 let suite = match eval_name_pattern.replace(eval, "") {
59 Cow::Borrowed(s) => s.to_string(),
60 Cow::Owned(s) => s,
61 };
62 suites.entry(suite).or_default().push(eval);
63 suites
64 });
65
66 grouped_by_suite
67 }
68
69 pub fn available_selectors() -> HashMap<String, usize> {
70 let mut counts: HashMap<String, usize> = HashMap::new();
71 for selector in EvaluationSuite::registered_evals() {
72 let parts = selector.split(":").collect::<Vec<_>>();
73 for i in 0..parts.len() {
74 let sel = parts[..i + 1].join(":");
75 *counts.entry(sel).or_insert(0) += 1;
76 }
77 }
78 counts
79 }
80}
81
82fn matches_any_selectors(eval: &str, selectors: &Vec<String>) -> bool {
83 let nesting_pattern = Regex::new(r":\w+$").unwrap();
86 for selector in selectors {
87 let mut level_up = eval.to_string();
88 while !level_up.is_empty() {
89 if level_up == *selector {
90 return true;
91 }
92 if !level_up.contains(":") {
93 break;
94 };
95 level_up = match nesting_pattern.replace(&level_up, "") {
96 Cow::Borrowed(s) => s.to_string(),
97 Cow::Owned(s) => s,
98 };
99 }
100 }
101 false
102}
103
104#[macro_export]
105macro_rules! register_evaluation {
106 ($evaluation_type:ty) => {
107 paste::paste! {
108 #[ctor::ctor]
109 #[allow(non_snake_case)]
110 fn [<__register_evaluation_ $evaluation_type>]() {
111 let mut path = std::path::PathBuf::from(file!());
112 path.set_extension("");
113 let eval_suites_dir = "eval_suites";
114 let eval_selector = {
115 let s = path.components()
116 .skip_while(|comp| comp.as_os_str() != eval_suites_dir)
117 .skip(1)
118 .map(|comp| comp.as_os_str().to_string_lossy().to_string())
119 .collect::<Vec<_>>()
120 .join(":");
121 Box::leak(s.into_boxed_str())
122 };
123
124 $crate::eval_suites::factory::register_eval(eval_selector, || {
125 Box::new(<$evaluation_type>::new())
126 });
127 }
128 }
129 };
130}