use async_trait::async_trait;
use dashmap::DashMap;
use serde_json::Value as JsonValue;
use std::cell::RefCell;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EventResult {
Continue,
Veto,
}
#[async_trait]
pub trait MapperEvents: Send + Sync {
async fn before_insert(&self, _instance_id: &str, _values: &JsonValue) -> EventResult {
EventResult::Continue
}
async fn after_insert(&self, _instance_id: &str) -> EventResult {
EventResult::Continue
}
async fn before_update(&self, _instance_id: &str, _values: &JsonValue) -> EventResult {
EventResult::Continue
}
async fn after_update(&self, _instance_id: &str) -> EventResult {
EventResult::Continue
}
async fn before_delete(&self, _instance_id: &str) -> EventResult {
EventResult::Continue
}
async fn after_delete(&self, _instance_id: &str) -> EventResult {
EventResult::Continue
}
async fn load(&self, _instance_id: &str, _data: &JsonValue) -> EventResult {
EventResult::Continue
}
async fn refresh(&self, _instance_id: &str) -> EventResult {
EventResult::Continue
}
async fn expire(&self, _instance_id: &str, _attribute_names: &[String]) -> EventResult {
EventResult::Continue
}
}
#[async_trait]
pub trait SessionEvents: Send + Sync {
async fn before_flush(&self, _session_id: &str, _instances: &[String]) -> EventResult {
EventResult::Continue
}
async fn after_flush(&self, _session_id: &str) -> EventResult {
EventResult::Continue
}
async fn after_flush_postexec(&self, _session_id: &str) -> EventResult {
EventResult::Continue
}
async fn before_commit(&self, _session_id: &str) -> EventResult {
EventResult::Continue
}
async fn after_commit(&self, _session_id: &str) -> EventResult {
EventResult::Continue
}
async fn after_rollback(&self, _session_id: &str) -> EventResult {
EventResult::Continue
}
async fn after_begin(&self, _session_id: &str) -> EventResult {
EventResult::Continue
}
async fn after_soft_rollback(&self, _session_id: &str) -> EventResult {
EventResult::Continue
}
async fn before_bulk_insert(&self, _values: &[JsonValue]) -> EventResult {
EventResult::Continue
}
async fn after_bulk_insert(&self, _count: usize) -> EventResult {
EventResult::Continue
}
async fn before_bulk_update(&self, _filter: &JsonValue, _values: &JsonValue) -> EventResult {
EventResult::Continue
}
async fn after_bulk_update(&self, _count: usize) -> EventResult {
EventResult::Continue
}
async fn before_bulk_delete(&self, _filter: &JsonValue) -> EventResult {
EventResult::Continue
}
async fn after_bulk_delete(&self, _count: usize) -> EventResult {
EventResult::Continue
}
}
#[async_trait]
pub trait AttributeEvents: Send + Sync {
async fn set(
&self,
_instance_id: &str,
_attribute: &str,
_value: &JsonValue,
_old_value: Option<&JsonValue>,
) -> EventResult {
EventResult::Continue
}
async fn append(
&self,
_instance_id: &str,
_attribute: &str,
_value: &JsonValue,
) -> EventResult {
EventResult::Continue
}
async fn remove(
&self,
_instance_id: &str,
_attribute: &str,
_value: &JsonValue,
) -> EventResult {
EventResult::Continue
}
async fn init_scalar(
&self,
_instance_id: &str,
_attribute: &str,
_value: &JsonValue,
) -> EventResult {
EventResult::Continue
}
async fn init_collection(&self, _instance_id: &str, _attribute: &str) -> EventResult {
EventResult::Continue
}
}
#[async_trait]
pub trait InstanceEvents: Send + Sync {
async fn init(&self, _instance_id: &str) -> EventResult {
EventResult::Continue
}
async fn load(&self, _instance_id: &str) -> EventResult {
EventResult::Continue
}
async fn refresh(&self, _instance_id: &str) -> EventResult {
EventResult::Continue
}
async fn refresh_flush(&self, _instance_id: &str, _flush_context: &str) -> EventResult {
EventResult::Continue
}
async fn expire(&self, _instance_id: &str, _attrs: &[String]) -> EventResult {
EventResult::Continue
}
async fn pickle(&self, _instance_id: &str, _state_dict: &JsonValue) -> EventResult {
EventResult::Continue
}
async fn unpickle(&self, _instance_id: &str, _state_dict: &JsonValue) -> EventResult {
EventResult::Continue
}
}
#[derive(Clone)]
pub enum EventListener {
Mapper(Arc<dyn MapperEvents>),
Session(Arc<dyn SessionEvents>),
Attribute(Arc<dyn AttributeEvents>),
Instance(Arc<dyn InstanceEvents>),
}
pub struct EventRegistry {
mapper_listeners: DashMap<String, Vec<Arc<dyn MapperEvents>>>,
session_listeners: DashMap<String, Vec<Arc<dyn SessionEvents>>>,
attribute_listeners: DashMap<String, Vec<Arc<dyn AttributeEvents>>>,
instance_listeners: DashMap<String, Vec<Arc<dyn InstanceEvents>>>,
}
impl EventRegistry {
pub fn new() -> Self {
Self {
mapper_listeners: DashMap::new(),
session_listeners: DashMap::new(),
attribute_listeners: DashMap::new(),
instance_listeners: DashMap::new(),
}
}
pub fn register_mapper_listener(&self, model: String, listener: Arc<dyn MapperEvents>) {
self.mapper_listeners
.entry(model)
.or_default()
.push(listener);
}
pub fn register_session_listener(&self, session_id: String, listener: Arc<dyn SessionEvents>) {
self.session_listeners
.entry(session_id)
.or_default()
.push(listener);
}
pub fn register_attribute_listener(
&self,
model_attr: String,
listener: Arc<dyn AttributeEvents>,
) {
self.attribute_listeners
.entry(model_attr)
.or_default()
.push(listener);
}
pub fn register_instance_listener(
&self,
instance_id: String,
listener: Arc<dyn InstanceEvents>,
) {
self.instance_listeners
.entry(instance_id)
.or_default()
.push(listener);
}
pub async fn dispatch_before_insert(
&self,
model: &str,
instance_id: &str,
values: &JsonValue,
) -> EventResult {
if let Some(listeners) = self.mapper_listeners.get(model) {
for listener in listeners.value() {
match listener.before_insert(instance_id, values).await {
EventResult::Veto => return EventResult::Veto,
EventResult::Continue => continue,
}
}
}
EventResult::Continue
}
pub async fn dispatch_after_insert(&self, model: &str, instance_id: &str) -> EventResult {
if let Some(listeners) = self.mapper_listeners.get(model) {
for listener in listeners.value() {
listener.after_insert(instance_id).await;
}
}
EventResult::Continue
}
pub async fn dispatch_before_update(
&self,
model: &str,
instance_id: &str,
values: &JsonValue,
) -> EventResult {
if let Some(listeners) = self.mapper_listeners.get(model) {
for listener in listeners.value() {
match listener.before_update(instance_id, values).await {
EventResult::Veto => return EventResult::Veto,
EventResult::Continue => continue,
}
}
}
EventResult::Continue
}
pub async fn dispatch_after_update(&self, model: &str, instance_id: &str) -> EventResult {
if let Some(listeners) = self.mapper_listeners.get(model) {
for listener in listeners.value() {
listener.after_update(instance_id).await;
}
}
EventResult::Continue
}
pub async fn dispatch_before_delete(&self, model: &str, instance_id: &str) -> EventResult {
if let Some(listeners) = self.mapper_listeners.get(model) {
for listener in listeners.value() {
match listener.before_delete(instance_id).await {
EventResult::Veto => return EventResult::Veto,
EventResult::Continue => continue,
}
}
}
EventResult::Continue
}
pub async fn dispatch_after_delete(&self, model: &str, instance_id: &str) -> EventResult {
if let Some(listeners) = self.mapper_listeners.get(model) {
for listener in listeners.value() {
listener.after_delete(instance_id).await;
}
}
EventResult::Continue
}
pub async fn dispatch_before_flush(
&self,
session_id: &str,
instances: &[String],
) -> EventResult {
if let Some(listeners) = self.session_listeners.get(session_id) {
for listener in listeners.value() {
match listener.before_flush(session_id, instances).await {
EventResult::Veto => return EventResult::Veto,
EventResult::Continue => continue,
}
}
}
EventResult::Continue
}
pub async fn dispatch_after_flush(&self, session_id: &str) -> EventResult {
if let Some(listeners) = self.session_listeners.get(session_id) {
for listener in listeners.value() {
listener.after_flush(session_id).await;
}
}
EventResult::Continue
}
pub async fn dispatch_before_commit(&self, session_id: &str) -> EventResult {
if let Some(listeners) = self.session_listeners.get(session_id) {
for listener in listeners.value() {
match listener.before_commit(session_id).await {
EventResult::Veto => return EventResult::Veto,
EventResult::Continue => continue,
}
}
}
EventResult::Continue
}
pub async fn dispatch_after_commit(&self, session_id: &str) -> EventResult {
if let Some(listeners) = self.session_listeners.get(session_id) {
for listener in listeners.value() {
listener.after_commit(session_id).await;
}
}
EventResult::Continue
}
pub async fn dispatch_after_rollback(&self, session_id: &str) -> EventResult {
if let Some(listeners) = self.session_listeners.get(session_id) {
for listener in listeners.value() {
listener.after_rollback(session_id).await;
}
}
EventResult::Continue
}
pub async fn dispatch_attribute_set(
&self,
model_attr: &str,
instance_id: &str,
attribute: &str,
value: &JsonValue,
old_value: Option<&JsonValue>,
) -> EventResult {
if let Some(listeners) = self.attribute_listeners.get(model_attr) {
for listener in listeners.value() {
match listener.set(instance_id, attribute, value, old_value).await {
EventResult::Veto => return EventResult::Veto,
EventResult::Continue => continue,
}
}
}
EventResult::Continue
}
pub fn clear(&self) {
self.mapper_listeners.clear();
self.session_listeners.clear();
self.attribute_listeners.clear();
self.instance_listeners.clear();
}
pub fn mapper_listener_count(&self) -> usize {
self.mapper_listeners.len()
}
pub fn session_listener_count(&self) -> usize {
self.session_listeners.len()
}
}
impl Default for EventRegistry {
fn default() -> Self {
Self::new()
}
}
impl Clone for EventRegistry {
fn clone(&self) -> Self {
Self {
mapper_listeners: self.mapper_listeners.clone(),
session_listeners: self.session_listeners.clone(),
attribute_listeners: self.attribute_listeners.clone(),
instance_listeners: self.instance_listeners.clone(),
}
}
}
thread_local! {
static ACTIVE_REGISTRY: RefCell<Option<Arc<EventRegistry>>> = const { RefCell::new(None) };
}
pub fn with_event_registry<F, R>(registry: Arc<EventRegistry>, f: F) -> R
where
F: FnOnce() -> R,
{
ACTIVE_REGISTRY.with(|r| {
let prev = r.borrow_mut().replace(registry);
let result = f();
*r.borrow_mut() = prev;
result
})
}
pub fn get_active_registry() -> Option<Arc<EventRegistry>> {
ACTIVE_REGISTRY.with(|r| r.borrow().clone())
}
pub struct ActiveRegistryGuard {
prev: Option<Arc<EventRegistry>>,
}
impl Drop for ActiveRegistryGuard {
fn drop(&mut self) {
ACTIVE_REGISTRY.with(|r| {
*r.borrow_mut() = self.prev.take();
});
}
}
pub fn set_active_registry(registry: Arc<EventRegistry>) -> ActiveRegistryGuard {
let prev = ACTIVE_REGISTRY.with(|r| r.borrow_mut().replace(registry));
ActiveRegistryGuard { prev }
}
#[cfg(feature = "di")]
mod di_support {
use super::EventRegistry;
use async_trait::async_trait;
use reinhardt_di::{DiResult, Injectable, InjectionContext};
#[async_trait]
impl Injectable for EventRegistry {
async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
if let Some(registry) = ctx.get_singleton::<EventRegistry>() {
return Ok((*registry).clone());
}
Ok(EventRegistry::new())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
struct TestMapperListener {
before_insert_count: Arc<AtomicUsize>,
after_insert_count: Arc<AtomicUsize>,
before_update_count: Arc<AtomicUsize>,
should_veto: bool,
}
#[async_trait]
impl MapperEvents for TestMapperListener {
async fn before_insert(&self, _instance_id: &str, _values: &JsonValue) -> EventResult {
self.before_insert_count.fetch_add(1, Ordering::SeqCst);
if self.should_veto {
EventResult::Veto
} else {
EventResult::Continue
}
}
async fn after_insert(&self, _instance_id: &str) -> EventResult {
self.after_insert_count.fetch_add(1, Ordering::SeqCst);
EventResult::Continue
}
async fn before_update(&self, _instance_id: &str, _values: &JsonValue) -> EventResult {
self.before_update_count.fetch_add(1, Ordering::SeqCst);
EventResult::Continue
}
}
struct TestSessionListener {
before_flush_count: Arc<AtomicUsize>,
after_commit_count: Arc<AtomicUsize>,
}
#[async_trait]
impl SessionEvents for TestSessionListener {
async fn before_flush(&self, _session_id: &str, _instances: &[String]) -> EventResult {
self.before_flush_count.fetch_add(1, Ordering::SeqCst);
EventResult::Continue
}
async fn after_commit(&self, _session_id: &str) -> EventResult {
self.after_commit_count.fetch_add(1, Ordering::SeqCst);
EventResult::Continue
}
}
#[tokio::test]
async fn test_mapper_event_dispatch() {
let registry = EventRegistry::new();
let before_insert = Arc::new(AtomicUsize::new(0));
let after_insert = Arc::new(AtomicUsize::new(0));
let before_update = Arc::new(AtomicUsize::new(0));
let listener = Arc::new(TestMapperListener {
before_insert_count: before_insert.clone(),
after_insert_count: after_insert.clone(),
before_update_count: before_update.clone(),
should_veto: false,
});
registry.register_mapper_listener("User".to_string(), listener);
let values = serde_json::json!({"name": "John"});
let result = registry
.dispatch_before_insert("User", "user-1", &values)
.await;
assert_eq!(result, EventResult::Continue);
assert_eq!(before_insert.load(Ordering::SeqCst), 1);
registry.dispatch_after_insert("User", "user-1").await;
assert_eq!(after_insert.load(Ordering::SeqCst), 1);
let update_values = serde_json::json!({"name": "Jane"});
registry
.dispatch_before_update("User", "user-1", &update_values)
.await;
assert_eq!(before_update.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_mapper_event_veto() {
let registry = EventRegistry::new();
let before_insert = Arc::new(AtomicUsize::new(0));
let listener = Arc::new(TestMapperListener {
before_insert_count: before_insert.clone(),
after_insert_count: Arc::new(AtomicUsize::new(0)),
before_update_count: Arc::new(AtomicUsize::new(0)),
should_veto: true,
});
registry.register_mapper_listener("User".to_string(), listener);
let values = serde_json::json!({"name": "John"});
let result = registry
.dispatch_before_insert("User", "user-1", &values)
.await;
assert_eq!(result, EventResult::Veto);
assert_eq!(before_insert.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_session_event_dispatch() {
let registry = EventRegistry::new();
let before_flush = Arc::new(AtomicUsize::new(0));
let after_commit = Arc::new(AtomicUsize::new(0));
let listener = Arc::new(TestSessionListener {
before_flush_count: before_flush.clone(),
after_commit_count: after_commit.clone(),
});
registry.register_session_listener("session-1".to_string(), listener);
let instances = vec!["user-1".to_string(), "user-2".to_string()];
registry
.dispatch_before_flush("session-1", &instances)
.await;
assert_eq!(before_flush.load(Ordering::SeqCst), 1);
registry.dispatch_after_commit("session-1").await;
assert_eq!(after_commit.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_orm_events_multiple_listeners() {
let registry = EventRegistry::new();
let count1 = Arc::new(AtomicUsize::new(0));
let count2 = Arc::new(AtomicUsize::new(0));
let listener1 = Arc::new(TestMapperListener {
before_insert_count: count1.clone(),
after_insert_count: Arc::new(AtomicUsize::new(0)),
before_update_count: Arc::new(AtomicUsize::new(0)),
should_veto: false,
});
let listener2 = Arc::new(TestMapperListener {
before_insert_count: count2.clone(),
after_insert_count: Arc::new(AtomicUsize::new(0)),
before_update_count: Arc::new(AtomicUsize::new(0)),
should_veto: false,
});
registry.register_mapper_listener("User".to_string(), listener1);
registry.register_mapper_listener("User".to_string(), listener2);
let values = serde_json::json!({"name": "John"});
registry
.dispatch_before_insert("User", "user-1", &values)
.await;
assert_eq!(count1.load(Ordering::SeqCst), 1);
assert_eq!(count2.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_clear_listeners() {
let registry = EventRegistry::new();
let listener = Arc::new(TestMapperListener {
before_insert_count: Arc::new(AtomicUsize::new(0)),
after_insert_count: Arc::new(AtomicUsize::new(0)),
before_update_count: Arc::new(AtomicUsize::new(0)),
should_veto: false,
});
registry.register_mapper_listener("User".to_string(), listener);
assert_eq!(registry.mapper_listener_count(), 1);
registry.clear();
assert_eq!(registry.mapper_listener_count(), 0);
}
#[tokio::test]
async fn test_active_registry_scoping() {
assert!(get_active_registry().is_none());
let registry = Arc::new(EventRegistry::new());
let count = Arc::new(AtomicUsize::new(0));
let listener = Arc::new(TestMapperListener {
before_insert_count: count.clone(),
after_insert_count: Arc::new(AtomicUsize::new(0)),
before_update_count: Arc::new(AtomicUsize::new(0)),
should_veto: false,
});
registry.register_mapper_listener("ScopedTest".to_string(), listener);
with_event_registry(registry.clone(), || {
let active = get_active_registry();
assert!(active.is_some());
assert_eq!(active.unwrap().mapper_listener_count(), 1);
});
assert!(get_active_registry().is_none());
}
#[tokio::test]
async fn test_nested_registry_scoping() {
let outer_registry = Arc::new(EventRegistry::new());
let inner_registry = Arc::new(EventRegistry::new());
let outer_listener = Arc::new(TestMapperListener {
before_insert_count: Arc::new(AtomicUsize::new(0)),
after_insert_count: Arc::new(AtomicUsize::new(0)),
before_update_count: Arc::new(AtomicUsize::new(0)),
should_veto: false,
});
outer_registry.register_mapper_listener("OuterModel".to_string(), outer_listener);
let inner_listener = Arc::new(TestMapperListener {
before_insert_count: Arc::new(AtomicUsize::new(0)),
after_insert_count: Arc::new(AtomicUsize::new(0)),
before_update_count: Arc::new(AtomicUsize::new(0)),
should_veto: false,
});
inner_registry.register_mapper_listener("InnerModel".to_string(), inner_listener);
with_event_registry(outer_registry.clone(), || {
assert_eq!(get_active_registry().unwrap().mapper_listener_count(), 1);
with_event_registry(inner_registry.clone(), || {
assert_eq!(get_active_registry().unwrap().mapper_listener_count(), 1);
});
assert_eq!(get_active_registry().unwrap().mapper_listener_count(), 1);
});
}
}