use std::collections::HashMap;
use std::sync::{OnceLock, RwLock};
use crate::error::IndicatorError;
use crate::indicator::Indicator;
pub type IndicatorFactory =
fn(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError>;
pub struct IndicatorRegistry {
entries: RwLock<HashMap<String, IndicatorFactory>>,
}
impl IndicatorRegistry {
pub fn new_uninit() -> Self {
Self {
entries: RwLock::new(HashMap::new()),
}
}
pub fn register(&self, name: &str, factory: IndicatorFactory) {
let mut map = self.entries.write().expect("registry write lock poisoned");
map.insert(name.to_ascii_lowercase(), factory);
}
pub fn list(&self) -> Vec<String> {
let map = self.entries.read().expect("registry read lock poisoned");
map.keys().cloned().collect()
}
pub fn get(&self, name: &str) -> Option<IndicatorFactory> {
let map = self.entries.read().expect("registry read lock poisoned");
map.get(&name.to_ascii_lowercase()).copied()
}
pub fn create(
&self,
name: &str,
params: &HashMap<String, String>,
) -> Result<Box<dyn Indicator>, IndicatorError> {
let factory = self
.get(name)
.ok_or_else(|| IndicatorError::UnknownIndicator {
name: name.to_string(),
})?;
factory(params)
}
pub fn contains(&self, name: &str) -> bool {
self.get(name).is_some()
}
}
pub static REGISTRY: OnceLock<IndicatorRegistry> = OnceLock::new();
pub fn registry() -> &'static IndicatorRegistry {
REGISTRY.get_or_init(|| {
let reg = IndicatorRegistry {
entries: RwLock::new(HashMap::new()),
};
crate::trend::register_all(®);
crate::momentum::register_all(®);
crate::volatility::register_all(®);
crate::volume::register_all(®);
crate::signal::register_all(®);
crate::regime::register_all(®);
reg
})
}
pub fn param_usize<S: ::std::hash::BuildHasher>(
params: &HashMap<String, String, S>,
key: &str,
default: usize,
) -> Result<usize, IndicatorError> {
match params.get(key) {
None => Ok(default),
Some(s) => s
.parse::<usize>()
.map_err(|_| IndicatorError::InvalidParameter {
name: key.to_string(),
value: s.parse::<f64>().unwrap_or(f64::NAN),
}),
}
}
pub fn param_f64<S: ::std::hash::BuildHasher>(
params: &HashMap<String, String, S>,
key: &str,
default: f64,
) -> Result<f64, IndicatorError> {
match params.get(key) {
None => Ok(default),
Some(s) => s
.parse::<f64>()
.map_err(|_| IndicatorError::InvalidParameter {
name: key.to_string(),
value: f64::NAN,
}),
}
}
pub fn param_str<'a, S: ::std::hash::BuildHasher>(params: &'a HashMap<String, String, S>, key: &str, default: &'a str) -> &'a str {
params.get(key).map_or(default, String::as_str)
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_factory(_p: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
Err(IndicatorError::UnknownIndicator {
name: "dummy".into(),
})
}
#[test]
fn registry_register_and_list() {
let reg = IndicatorRegistry {
entries: RwLock::new(HashMap::new()),
};
reg.register("sma", dummy_factory);
reg.register("ema", dummy_factory);
let mut names = reg.list();
names.sort();
assert_eq!(names, vec!["ema", "sma"]);
}
#[test]
fn registry_unknown_returns_error() {
let reg = IndicatorRegistry {
entries: RwLock::new(HashMap::new()),
};
let err = reg
.create("no_such_indicator", &HashMap::new())
.unwrap_err();
assert!(matches!(err, IndicatorError::UnknownIndicator { .. }));
}
#[test]
fn param_usize_default() {
let params = HashMap::new();
assert_eq!(param_usize(¶ms, "period", 14).unwrap(), 14);
}
#[test]
fn param_usize_override() {
let params = [("period".to_string(), "20".to_string())].into();
assert_eq!(param_usize(¶ms, "period", 14).unwrap(), 20);
}
#[test]
fn param_usize_bad_value() {
let params = [("period".to_string(), "abc".to_string())].into();
assert!(param_usize(¶ms, "period", 14).is_err());
}
}