Skip to main content

aster_bench/eval_suites/
factory.rs

1pub use super::Evaluation;
2use regex::Regex;
3use std::borrow::Cow;
4use std::collections::HashMap;
5use std::sync::{OnceLock, RwLock};
6
7type EvaluationConstructor = fn() -> Box<dyn Evaluation>;
8type Registry = &'static RwLock<HashMap<&'static str, EvaluationConstructor>>;
9
10// Use std::sync::RwLock for interior mutability
11static EVAL_REGISTRY: OnceLock<RwLock<HashMap<&'static str, EvaluationConstructor>>> =
12    OnceLock::new();
13
14/// Initialize the registry if it hasn't been initialized
15fn eval_registry() -> Registry {
16    EVAL_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
17}
18
19/// Register a new evaluation version
20pub fn register_eval(selector: &'static str, constructor: fn() -> Box<dyn Evaluation>) {
21    let registry = eval_registry();
22    if let Ok(mut map) = registry.write() {
23        map.insert(selector, constructor);
24    }
25}
26
27pub struct EvaluationSuite;
28
29impl EvaluationSuite {
30    pub fn from(selector: &str) -> Option<Box<dyn Evaluation>> {
31        let registry = eval_registry();
32        let map = registry
33            .read()
34            .expect("Failed to read the benchmark evaluation registry.");
35
36        let constructor = map.get(selector)?;
37        let instance = constructor();
38
39        Some(instance)
40    }
41
42    pub fn registered_evals() -> Vec<&'static str> {
43        let registry = eval_registry();
44        let map = registry
45            .read()
46            .expect("Failed to read the benchmark evaluation registry.");
47
48        let evals: Vec<_> = map.keys().copied().collect();
49        evals
50    }
51    pub fn select(selectors: Vec<String>) -> HashMap<String, Vec<&'static str>> {
52        let eval_name_pattern = Regex::new(r":\w+$").unwrap();
53        let grouped_by_suite: HashMap<String, Vec<&'static str>> =
54            EvaluationSuite::registered_evals()
55                .into_iter()
56                .filter(|&eval| selectors.is_empty() || matches_any_selectors(eval, &selectors))
57                .fold(HashMap::new(), |mut suites, eval| {
58                    let suite = match eval_name_pattern.replace(eval, "") {
59                        Cow::Borrowed(s) => s.to_string(),
60                        Cow::Owned(s) => s,
61                    };
62                    suites.entry(suite).or_default().push(eval);
63                    suites
64                });
65
66        grouped_by_suite
67    }
68
69    pub fn available_selectors() -> HashMap<String, usize> {
70        let mut counts: HashMap<String, usize> = HashMap::new();
71        for selector in EvaluationSuite::registered_evals() {
72            let parts = selector.split(":").collect::<Vec<_>>();
73            for i in 0..parts.len() {
74                let sel = parts[..i + 1].join(":");
75                *counts.entry(sel).or_insert(0) += 1;
76            }
77        }
78        counts
79    }
80}
81
82fn matches_any_selectors(eval: &str, selectors: &Vec<String>) -> bool {
83    // selectors must prefix match exactly, no matching half-way in a word
84    // remove one level of nesting at a time and check exact match
85    let nesting_pattern = Regex::new(r":\w+$").unwrap();
86    for selector in selectors {
87        let mut level_up = eval.to_string();
88        while !level_up.is_empty() {
89            if level_up == *selector {
90                return true;
91            }
92            if !level_up.contains(":") {
93                break;
94            };
95            level_up = match nesting_pattern.replace(&level_up, "") {
96                Cow::Borrowed(s) => s.to_string(),
97                Cow::Owned(s) => s,
98            };
99        }
100    }
101    false
102}
103
104#[macro_export]
105macro_rules! register_evaluation {
106    ($evaluation_type:ty) => {
107        paste::paste! {
108            #[ctor::ctor]
109            #[allow(non_snake_case)]
110            fn [<__register_evaluation_ $evaluation_type>]() {
111                let mut path = std::path::PathBuf::from(file!());
112                path.set_extension("");
113                let eval_suites_dir = "eval_suites";
114                let eval_selector = {
115                    let s = path.components()
116                        .skip_while(|comp| comp.as_os_str() != eval_suites_dir)
117                        .skip(1)
118                        .map(|comp| comp.as_os_str().to_string_lossy().to_string())
119                        .collect::<Vec<_>>()
120                        .join(":");
121                    Box::leak(s.into_boxed_str())
122                };
123
124                $crate::eval_suites::factory::register_eval(eval_selector, || {
125                    Box::new(<$evaluation_type>::new())
126                });
127            }
128        }
129    };
130}