1use std::collections::HashMap;
23use std::sync::{OnceLock, RwLock};
24
25use crate::error::IndicatorError;
26use crate::indicator::Indicator;
27
28pub type IndicatorFactory =
35 fn(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError>;
36
37pub struct IndicatorRegistry {
44 entries: RwLock<HashMap<String, IndicatorFactory>>,
45}
46
47impl IndicatorRegistry {
48 pub fn new_uninit() -> Self {
49 Self {
52 entries: RwLock::new(HashMap::new()),
53 }
54 }
55
56 pub fn register(&self, name: &str, factory: IndicatorFactory) {
60 let mut map = self.entries.write().expect("registry write lock poisoned");
61 map.insert(name.to_ascii_lowercase(), factory);
62 }
63
64 pub fn list(&self) -> Vec<String> {
68 let map = self.entries.read().expect("registry read lock poisoned");
69 map.keys().cloned().collect()
70 }
71
72 pub fn get(&self, name: &str) -> Option<IndicatorFactory> {
76 let map = self.entries.read().expect("registry read lock poisoned");
77 map.get(&name.to_ascii_lowercase()).copied()
78 }
79
80 pub fn create(
89 &self,
90 name: &str,
91 params: &HashMap<String, String>,
92 ) -> Result<Box<dyn Indicator>, IndicatorError> {
93 let factory = self
94 .get(name)
95 .ok_or_else(|| IndicatorError::UnknownIndicator {
96 name: name.to_string(),
97 })?;
98 factory(params)
99 }
100
101 pub fn contains(&self, name: &str) -> bool {
105 self.get(name).is_some()
106 }
107}
108
109pub static REGISTRY: OnceLock<IndicatorRegistry> = OnceLock::new();
118
119pub fn registry() -> &'static IndicatorRegistry {
121 REGISTRY.get_or_init(|| {
122 let reg = IndicatorRegistry {
123 entries: RwLock::new(HashMap::new()),
124 };
125 crate::trend::register_all(®);
127 crate::momentum::register_all(®);
128 crate::volatility::register_all(®);
129 crate::volume::register_all(®);
130 crate::signal::register_all(®);
131 crate::regime::register_all(®);
132 reg
133 })
134}
135
136pub fn param_usize(
142 params: &HashMap<String, String>,
143 key: &str,
144 default: usize,
145) -> Result<usize, IndicatorError> {
146 match params.get(key) {
147 None => Ok(default),
148 Some(s) => s.parse::<usize>().map_err(|_| IndicatorError::InvalidParameter {
149 name: key.to_string(),
150 value: s.parse::<f64>().unwrap_or(f64::NAN),
151 }),
152 }
153}
154
155pub fn param_f64(
157 params: &HashMap<String, String>,
158 key: &str,
159 default: f64,
160) -> Result<f64, IndicatorError> {
161 match params.get(key) {
162 None => Ok(default),
163 Some(s) => s.parse::<f64>().map_err(|_| IndicatorError::InvalidParameter {
164 name: key.to_string(),
165 value: f64::NAN,
166 }),
167 }
168}
169
170pub fn param_str<'a>(params: &'a HashMap<String, String>, key: &str, default: &'a str) -> &'a str {
172 params.get(key).map(|s| s.as_str()).unwrap_or(default)
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 fn dummy_factory(_p: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
180 Err(IndicatorError::UnknownIndicator {
182 name: "dummy".into(),
183 })
184 }
185
186 #[test]
187 fn registry_register_and_list() {
188 let reg = IndicatorRegistry {
189 entries: RwLock::new(HashMap::new()),
190 };
191 reg.register("sma", dummy_factory);
192 reg.register("ema", dummy_factory);
193 let mut names = reg.list();
194 names.sort();
195 assert_eq!(names, vec!["ema", "sma"]);
196 }
197
198 #[test]
199 fn registry_unknown_returns_error() {
200 let reg = IndicatorRegistry {
201 entries: RwLock::new(HashMap::new()),
202 };
203 let err = reg
204 .create("no_such_indicator", &HashMap::new())
205 .unwrap_err();
206 assert!(matches!(err, IndicatorError::UnknownIndicator { .. }));
207 }
208
209 #[test]
210 fn param_usize_default() {
211 let params = HashMap::new();
212 assert_eq!(param_usize(¶ms, "period", 14).unwrap(), 14);
213 }
214
215 #[test]
216 fn param_usize_override() {
217 let params = [("period".to_string(), "20".to_string())].into();
218 assert_eq!(param_usize(¶ms, "period", 14).unwrap(), 20);
219 }
220
221 #[test]
222 fn param_usize_bad_value() {
223 let params = [("period".to_string(), "abc".to_string())].into();
224 assert!(param_usize(¶ms, "period", 14).is_err());
225 }
226}