use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::{debug, trace, warn};
use uuid::Uuid;
use crate::bus::EventBus;
use crate::event::AstridEvent;
pub type EventFilter = Box<dyn Fn(&AstridEvent) -> bool + Send + Sync>;
pub trait EventSubscriber: Send + Sync {
fn on_event(&self, event: &AstridEvent, bus: &EventBus);
fn accepts(&self, event: &AstridEvent) -> bool {
let _ = event;
true
}
#[allow(clippy::unnecessary_literal_bound)]
fn name(&self) -> &str {
"anonymous"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SubscriberId(Uuid);
impl SubscriberId {
#[must_use]
fn new() -> Self {
Self(Uuid::new_v4())
}
}
#[derive(Default)]
pub struct SubscriberRegistry {
subscribers: RwLock<Arc<HashMap<SubscriberId, Arc<dyn EventSubscriber>>>>,
}
impl std::fmt::Debug for SubscriberRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let count = self.subscribers.read().map(|s| s.len()).unwrap_or_default();
f.debug_struct("SubscriberRegistry")
.field("subscriber_count", &count)
.finish()
}
}
impl SubscriberRegistry {
#[must_use]
pub fn new() -> Self {
Self {
subscribers: RwLock::new(Arc::new(HashMap::new())),
}
}
fn update_registry<F>(&self, update_fn: F) -> bool
where
F: FnOnce(&mut HashMap<SubscriberId, Arc<dyn EventSubscriber>>) -> bool,
{
let (changed, _old_map) = {
let mut subs = self.subscribers.write().expect("lock poisoned");
let mut new_map = HashMap::clone(&subs);
if update_fn(&mut new_map) {
let old = std::mem::replace(&mut *subs, Arc::new(new_map));
(true, Some(old))
} else {
(false, None)
}
};
changed
}
pub fn register(&self, subscriber: Arc<dyn EventSubscriber>) -> SubscriberId {
let id = SubscriberId::new();
let name = subscriber.name().to_string();
self.update_registry(|map| {
map.insert(id, subscriber);
true
});
debug!(subscriber_name = %name, "Subscriber registered");
id
}
pub fn unregister(&self, id: SubscriberId) -> bool {
let removed = self.update_registry(|map| map.remove(&id).is_some());
if removed {
debug!("Subscriber unregistered");
}
removed
}
pub fn notify(&self, event: &AstridEvent, bus: &EventBus) {
let subs = {
let guard = self.subscribers.read().expect("lock poisoned");
Arc::clone(&*guard)
};
for (id, subscriber) in subs.iter() {
if subscriber.accepts(event) {
trace!(
subscriber_name = %subscriber.name(),
event_type = %event.event_type(),
"Notifying subscriber"
);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
subscriber.on_event(event, bus);
}));
if let Err(e) = result {
let panic_msg = if let Some(s) = e.downcast_ref::<&str>() {
Some(s.to_string())
} else {
e.downcast_ref::<String>().cloned()
};
warn!(
subscriber_id = ?id,
subscriber_name = %subscriber.name(),
panic_msg = ?panic_msg,
"Subscriber panicked"
);
}
}
}
}
#[must_use]
pub fn len(&self) -> usize {
self.subscribers.read().expect("lock poisoned").len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.subscribers.read().expect("lock poisoned").is_empty()
}
pub fn clear(&self) {
self.update_registry(|map| {
if map.is_empty() {
false
} else {
map.clear();
true
}
});
debug!("All subscribers cleared");
}
}
pub struct FilterSubscriber<F>
where
F: Fn(&AstridEvent) + Send + Sync,
{
name: String,
filter: Option<EventFilter>,
handler: F,
}
impl<F> FilterSubscriber<F>
where
F: Fn(&AstridEvent) + Send + Sync,
{
pub fn new(name: impl Into<String>, handler: F) -> Self {
Self {
name: name.into(),
filter: None,
handler,
}
}
#[must_use]
pub fn with_filter<P>(mut self, predicate: P) -> Self
where
P: Fn(&AstridEvent) -> bool + Send + Sync + 'static,
{
self.filter = Some(Box::new(predicate));
self
}
}
impl<F> EventSubscriber for FilterSubscriber<F>
where
F: Fn(&AstridEvent) + Send + Sync,
{
fn on_event(&self, event: &AstridEvent, _bus: &EventBus) {
(self.handler)(event);
}
fn accepts(&self, event: &AstridEvent) -> bool {
match &self.filter {
Some(f) => f(event),
None => true,
}
}
fn name(&self) -> &str {
&self.name
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::event::EventMetadata;
use std::sync::atomic::{AtomicUsize, Ordering};
struct CountingSubscriber {
name: String,
count: AtomicUsize,
}
impl CountingSubscriber {
fn new(name: &str) -> Self {
Self {
name: name.to_string(),
count: AtomicUsize::new(0),
}
}
fn count(&self) -> usize {
self.count.load(Ordering::SeqCst)
}
}
impl EventSubscriber for CountingSubscriber {
fn on_event(&self, _event: &AstridEvent, _bus: &EventBus) {
self.count.fetch_add(1, Ordering::SeqCst);
}
fn name(&self) -> &str {
&self.name
}
}
#[test]
fn test_registry_register_unregister() {
let registry = SubscriberRegistry::new();
assert!(registry.is_empty());
let subscriber = Arc::new(CountingSubscriber::new("test"));
let id = registry.register(subscriber);
assert_eq!(registry.len(), 1);
assert!(!registry.is_empty());
let removed = registry.unregister(id);
assert!(removed);
assert!(registry.is_empty());
}
#[test]
fn test_registry_notify() {
let bus = EventBus::new();
let registry = bus.registry();
let subscriber = Arc::new(CountingSubscriber::new("test"));
registry.register(Arc::clone(&subscriber) as Arc<dyn EventSubscriber>);
let event = AstridEvent::RuntimeStarted {
metadata: EventMetadata::new("test"),
version: "0.1.0".to_string(),
};
registry.notify(&event, &bus);
assert_eq!(subscriber.count(), 1);
registry.notify(&event, &bus);
assert_eq!(subscriber.count(), 2);
}
#[test]
fn test_registry_multiple_subscribers() {
let bus = EventBus::new();
let registry = bus.registry();
let sub1 = Arc::new(CountingSubscriber::new("sub1"));
let sub2 = Arc::new(CountingSubscriber::new("sub2"));
registry.register(Arc::clone(&sub1) as Arc<dyn EventSubscriber>);
registry.register(Arc::clone(&sub2) as Arc<dyn EventSubscriber>);
let event = AstridEvent::RuntimeStarted {
metadata: EventMetadata::new("test"),
version: "0.1.0".to_string(),
};
registry.notify(&event, &bus);
assert_eq!(sub1.count(), 1);
assert_eq!(sub2.count(), 1);
}
#[test]
fn test_filter_subscriber() {
let received = Arc::new(AtomicUsize::new(0));
let received_clone = Arc::clone(&received);
let subscriber = FilterSubscriber::new("security_only", move |_event| {
received_clone.fetch_add(1, Ordering::SeqCst);
})
.with_filter(super::super::event::AstridEvent::is_security_event);
let bus = EventBus::new();
let registry = bus.registry();
registry.register(Arc::new(subscriber));
let event1 = AstridEvent::RuntimeStarted {
metadata: EventMetadata::new("test"),
version: "0.1.0".to_string(),
};
registry.notify(&event1, &bus);
assert_eq!(received.load(Ordering::SeqCst), 0);
let event2 = AstridEvent::CapabilityGranted {
metadata: EventMetadata::new("test"),
capability_id: Uuid::new_v4(),
resource: "test".to_string(),
action: "execute".to_string(),
};
registry.notify(&event2, &bus);
assert_eq!(received.load(Ordering::SeqCst), 1);
}
#[test]
fn test_registry_clear() {
let registry = SubscriberRegistry::new();
let sub1 = Arc::new(CountingSubscriber::new("sub1"));
let sub2 = Arc::new(CountingSubscriber::new("sub2"));
registry.register(sub1);
registry.register(sub2);
assert_eq!(registry.len(), 2);
registry.clear();
assert!(registry.is_empty());
}
#[test]
fn test_unregister_nonexistent() {
let registry = SubscriberRegistry::new();
let fake_id = SubscriberId::new();
let removed = registry.unregister(fake_id);
assert!(!removed);
}
}