reductionml_core/
reduction_registry.rs

1use std::collections::{BTreeMap, HashMap};
2use std::sync::RwLock;
3
4use once_cell::sync::Lazy;
5
6use crate::{
7    reduction_factory::ReductionFactory,
8    reductions::{
9        BinaryReductionFactory, CBAdfReductionFactory, CBExploreAdfGreedyReductionFactory,
10        CBExploreAdfSquareCBReductionFactory, CoinRegressorFactory, DebugReductionFactory,
11    },
12};
13
14pub static REDUCTION_REGISTRY: Lazy<RwLock<ReductionRegistry>> = Lazy::new(|| {
15    let mut registry: ReductionRegistry = ReductionRegistry::default();
16    registry.register(Box::<CoinRegressorFactory>::default());
17    registry.register(Box::<BinaryReductionFactory>::default());
18    registry.register(Box::<CBAdfReductionFactory>::default());
19    registry.register(Box::<CBExploreAdfGreedyReductionFactory>::default());
20    registry.register(Box::<DebugReductionFactory>::default());
21    registry.register(Box::<CBExploreAdfSquareCBReductionFactory>::default());
22    RwLock::new(registry)
23});
24
25#[derive(Default)]
26pub struct ReductionRegistry {
27    registry: BTreeMap<String, Box<dyn ReductionFactory>>,
28}
29
30// impl Send for ReductionRegistry {}
31unsafe impl Sync for ReductionRegistry {}
32unsafe impl Send for ReductionRegistry {}
33
34impl ReductionRegistry {
35    pub fn register(&mut self, factory: Box<dyn ReductionFactory>) {
36        self.registry
37            .insert(factory.typename().as_ref().to_owned(), factory);
38    }
39
40    pub fn get(&self, typename: &str) -> Option<&dyn ReductionFactory> {
41        self.registry.get(typename).map(|x| x.as_ref())
42    }
43
44    pub fn iter(&self) -> impl Iterator<Item = &dyn ReductionFactory> {
45        self.registry.values().map(|x| x.as_ref())
46    }
47}