use std::collections::HashMap;
use std::sync::Arc;
use allora_core::patterns::aggregator::{
AggregationStrategy, CompletionCondition, ConcatText, EmitSignal, GroupStore, JsonArray,
};
pub const STRATEGY_CONCAT_TEXT: &str = "allora.concat_text";
pub const STRATEGY_JSON_ARRAY: &str = "allora.json_array";
pub const STRATEGY_EMIT_SIGNAL: &str = "allora.emit_signal";
#[derive(Default, Clone)]
pub struct PatternRegistry {
completions: HashMap<String, Arc<dyn CompletionCondition>>,
strategies: HashMap<String, Arc<dyn AggregationStrategy>>,
stores: HashMap<String, Arc<dyn GroupStore>>,
}
impl PatternRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn with_defaults() -> Self {
let mut r = Self::default();
r.register_strategy(STRATEGY_CONCAT_TEXT, Arc::new(ConcatText));
r.register_strategy(STRATEGY_JSON_ARRAY, Arc::new(JsonArray));
r.register_strategy(STRATEGY_EMIT_SIGNAL, Arc::new(EmitSignal));
r
}
pub fn register_completion<N: Into<String>>(
&mut self,
name: N,
impl_: Arc<dyn CompletionCondition>,
) {
self.completions.insert(name.into(), impl_);
}
pub fn register_strategy<N: Into<String>>(
&mut self,
name: N,
impl_: Arc<dyn AggregationStrategy>,
) {
self.strategies.insert(name.into(), impl_);
}
pub fn register_store<N: Into<String>>(&mut self, name: N, impl_: Arc<dyn GroupStore>) {
self.stores.insert(name.into(), impl_);
}
pub fn completion(&self, name: &str) -> Option<Arc<dyn CompletionCondition>> {
self.completions.get(name).cloned()
}
pub fn strategy(&self, name: &str) -> Option<Arc<dyn AggregationStrategy>> {
self.strategies.get(name).cloned()
}
pub fn store(&self, name: &str) -> Option<Arc<dyn GroupStore>> {
self.stores.get(name).cloned()
}
pub fn completion_names(&self) -> Vec<&str> {
self.completions.keys().map(|s| s.as_str()).collect()
}
pub fn strategy_names(&self) -> Vec<&str> {
self.strategies.keys().map(|s| s.as_str()).collect()
}
pub fn store_names(&self) -> Vec<&str> {
self.stores.keys().map(|s| s.as_str()).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use allora_core::Message;
use std::time::Instant;
struct CountAtLeast(usize);
impl CompletionCondition for CountAtLeast {
fn is_complete(&self, group: &[Message], _: Instant) -> bool {
group.len() >= self.0
}
}
#[test]
fn empty_registry_returns_none_for_everything() {
let r = PatternRegistry::new();
assert!(r.completion("anything").is_none());
assert!(r.strategy("anything").is_none());
assert!(r.store("anything").is_none());
assert!(r.completion_names().is_empty());
}
#[test]
fn with_defaults_registers_all_three_built_in_strategies() {
let r = PatternRegistry::with_defaults();
assert!(r.strategy(STRATEGY_CONCAT_TEXT).is_some());
assert!(r.strategy(STRATEGY_JSON_ARRAY).is_some());
assert!(r.strategy(STRATEGY_EMIT_SIGNAL).is_some());
assert!(r.completion_names().is_empty());
assert!(r.store_names().is_empty());
}
#[test]
fn register_then_lookup_completion() {
let mut r = PatternRegistry::new();
r.register_completion("chain.test_quorum", Arc::new(CountAtLeast(3)));
let c = r.completion("chain.test_quorum").expect("registered");
let g: Vec<Message> = (0..3).map(|_| Message::default()).collect();
assert!(c.is_complete(&g, Instant::now()));
let g2: Vec<Message> = (0..2).map(|_| Message::default()).collect();
assert!(!c.is_complete(&g2, Instant::now()));
}
#[test]
fn registration_overwrites_existing_entry() {
let mut r = PatternRegistry::new();
r.register_completion("k", Arc::new(CountAtLeast(1)));
r.register_completion("k", Arc::new(CountAtLeast(99)));
let c = r.completion("k").unwrap();
let g: Vec<Message> = (0..5).map(|_| Message::default()).collect();
assert!(!c.is_complete(&g, Instant::now())); }
}