use once_cell::sync::Lazy;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use super::Trigger;
type TriggerConstructor = Box<dyn Fn() -> Arc<dyn Trigger> + Send + Sync>;
type GlobalTriggerRegistry = Arc<RwLock<HashMap<String, TriggerConstructor>>>;
static GLOBAL_TRIGGER_REGISTRY: Lazy<GlobalTriggerRegistry> =
Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
pub fn register_trigger_constructor<F>(name: impl Into<String>, constructor: F)
where
F: Fn() -> Arc<dyn Trigger> + Send + Sync + 'static,
{
let name = name.into();
let mut registry = GLOBAL_TRIGGER_REGISTRY.write();
registry.insert(name.clone(), Box::new(constructor));
tracing::debug!("Registered trigger constructor: {}", name);
}
pub fn register_trigger<T: Trigger + Clone + 'static>(trigger: T) {
let name = trigger.name().to_string();
register_trigger_constructor(name, move || Arc::new(trigger.clone()));
}
pub fn get_trigger(name: &str) -> Option<Arc<dyn Trigger>> {
let registry = GLOBAL_TRIGGER_REGISTRY.read();
registry.get(name).map(|constructor| constructor())
}
pub fn global_trigger_registry() -> GlobalTriggerRegistry {
GLOBAL_TRIGGER_REGISTRY.clone()
}
pub fn list_triggers() -> Vec<String> {
let registry = GLOBAL_TRIGGER_REGISTRY.read();
registry.keys().cloned().collect()
}
pub fn get_all_triggers() -> Vec<Arc<dyn Trigger>> {
let registry = GLOBAL_TRIGGER_REGISTRY.read();
registry.values().map(|constructor| constructor()).collect()
}
pub fn deregister_trigger(name: &str) -> bool {
let mut registry = GLOBAL_TRIGGER_REGISTRY.write();
registry.remove(name).is_some()
}
pub fn is_trigger_registered(name: &str) -> bool {
let registry = GLOBAL_TRIGGER_REGISTRY.read();
registry.contains_key(name)
}
#[cfg(test)]
pub fn clear_triggers() {
let mut registry = GLOBAL_TRIGGER_REGISTRY.write();
registry.clear();
}
#[cfg(test)]
mod tests {
use super::*;
use crate::trigger::{TriggerError, TriggerResult};
use async_trait::async_trait;
use serial_test::serial;
use std::time::Duration;
#[derive(Debug, Clone)]
struct TestTrigger {
name: String,
}
impl TestTrigger {
fn new(name: &str) -> Self {
Self {
name: name.to_string(),
}
}
}
#[async_trait]
impl Trigger for TestTrigger {
fn name(&self) -> &str {
&self.name
}
fn poll_interval(&self) -> Duration {
Duration::from_secs(1)
}
fn allow_concurrent(&self) -> bool {
false
}
async fn poll(&self) -> Result<TriggerResult, TriggerError> {
Ok(TriggerResult::Skip)
}
}
#[test]
#[serial]
fn test_register_and_get_trigger() {
let name = "test_register_and_get_trigger_unique_12345";
let trigger = TestTrigger::new(name);
register_trigger(trigger);
assert!(is_trigger_registered(name));
assert!(!is_trigger_registered("definitely_nonexistent_trigger_xyz"));
let retrieved = get_trigger(name);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().name(), name);
}
#[test]
#[serial]
fn test_register_constructor() {
let name = "test_register_constructor_unique_12345";
register_trigger_constructor(name, move || Arc::new(TestTrigger::new(name)));
let trigger = get_trigger(name);
assert!(trigger.is_some());
assert_eq!(trigger.unwrap().name(), name);
}
#[test]
#[serial]
fn test_list_triggers() {
let name_a = "test_list_triggers_a_unique_12345";
let name_b = "test_list_triggers_b_unique_12345";
register_trigger(TestTrigger::new(name_a));
register_trigger(TestTrigger::new(name_b));
let names = list_triggers();
assert!(names.contains(&name_a.to_string()));
assert!(names.contains(&name_b.to_string()));
}
#[test]
#[serial]
fn test_get_all_triggers() {
let name_1 = "test_get_all_triggers_1_unique_12345";
let name_2 = "test_get_all_triggers_2_unique_12345";
register_trigger(TestTrigger::new(name_1));
register_trigger(TestTrigger::new(name_2));
let triggers = get_all_triggers();
let trigger_names: Vec<_> = triggers.iter().map(|t| t.name()).collect();
assert!(trigger_names.contains(&name_1));
assert!(trigger_names.contains(&name_2));
}
#[test]
#[serial]
fn test_deregister_trigger() {
let name = "test_deregister_trigger_unique_12345";
register_trigger(TestTrigger::new(name));
assert!(is_trigger_registered(name));
assert!(deregister_trigger(name));
assert!(!is_trigger_registered(name));
assert!(!deregister_trigger(name));
}
#[test]
#[serial]
fn test_register_deregister_roundtrip() {
let name = "test_roundtrip_unique_12345";
register_trigger_constructor(name, move || Arc::new(TestTrigger::new(name)));
assert!(is_trigger_registered(name));
let trigger = get_trigger(name).unwrap();
assert_eq!(trigger.name(), name);
assert!(deregister_trigger(name));
assert!(!is_trigger_registered(name));
assert!(get_trigger(name).is_none());
let names = list_triggers();
assert!(!names.contains(&name.to_string()));
}
#[test]
#[serial]
fn test_clear_triggers() {
let name = "test_clear_triggers_unique_12345";
register_trigger(TestTrigger::new(name));
assert!(
is_trigger_registered(name),
"Trigger should be registered after register_trigger"
);
clear_triggers();
assert!(
!is_trigger_registered(name),
"Trigger should not be registered after clear_triggers"
);
}
}