use super::error::SignalError;
use super::model_signals::{post_delete, post_save, pre_delete, pre_save};
use super::signal::Signal;
use async_trait::async_trait;
use std::any::Any;
use std::marker::PhantomData;
use std::sync::Arc;
#[async_trait]
pub trait OrmEventListener: Send + Sync {
async fn before_insert(&self, instance: Arc<dyn Any + Send + Sync>) -> Result<(), SignalError>;
async fn after_insert(&self, instance: Arc<dyn Any + Send + Sync>) -> Result<(), SignalError>;
async fn before_update(&self, instance: Arc<dyn Any + Send + Sync>) -> Result<(), SignalError>;
async fn after_update(&self, instance: Arc<dyn Any + Send + Sync>) -> Result<(), SignalError>;
async fn before_delete(&self, instance: Arc<dyn Any + Send + Sync>) -> Result<(), SignalError>;
async fn after_delete(&self, instance: Arc<dyn Any + Send + Sync>) -> Result<(), SignalError>;
}
pub struct OrmSignalAdapter<T: Send + Sync + 'static> {
_phantom: PhantomData<T>,
}
impl<T: Send + Sync + 'static> OrmSignalAdapter<T> {
pub fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
pub fn signal_count(&self) -> usize {
pre_save::<T>().receiver_count()
+ post_save::<T>().receiver_count()
+ pre_delete::<T>().receiver_count()
+ post_delete::<T>().receiver_count()
}
}
impl<T: Send + Sync + 'static> Default for OrmSignalAdapter<T> {
fn default() -> Self {
Self::new()
}
}
pub async fn dispatch_pre_save<T: Send + Sync + Clone + 'static>(
instance: T,
) -> Result<(), SignalError> {
pre_save::<T>().send(instance).await
}
pub async fn dispatch_post_save<T: Send + Sync + Clone + 'static>(
instance: T,
_created: bool,
) -> Result<(), SignalError> {
post_save::<T>().send(instance).await
}
pub async fn dispatch_pre_delete<T: Send + Sync + Clone + 'static>(
instance: T,
) -> Result<(), SignalError> {
pre_delete::<T>().send(instance).await
}
pub async fn dispatch_post_delete<T: Send + Sync + Clone + 'static>(
instance: T,
) -> Result<(), SignalError> {
post_delete::<T>().send(instance).await
}
pub async fn connect_orm_signals<T: Send + Sync + 'static>() {
let _ = pre_save::<T>();
let _ = post_save::<T>();
let _ = pre_delete::<T>();
let _ = post_delete::<T>();
}
pub fn get_orm_signals<T: Send + Sync + 'static>() -> Vec<Signal<T>> {
vec![
pre_save::<T>(),
post_save::<T>(),
pre_delete::<T>(),
post_delete::<T>(),
]
}
#[cfg(test)]
mod tests {
use super::*;
use parking_lot::Mutex;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct TestModel {
id: i32,
name: String,
}
#[tokio::test]
async fn test_orm_signal_adapter_creation() {
let _adapter = OrmSignalAdapter::<TestModel>::new();
}
#[tokio::test]
#[serial_test::serial]
async fn test_dispatch_pre_save() {
pre_save::<TestModel>().disconnect_all();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
pre_save::<TestModel>().connect(move |_instance| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
});
let model = TestModel {
id: 1,
name: "Test".to_string(),
};
dispatch_pre_save(model).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 1);
pre_save::<TestModel>().disconnect_all();
}
#[tokio::test]
async fn test_dispatch_post_save() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
post_save::<TestModel>().connect(move |_instance| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
});
let model = TestModel {
id: 1,
name: "Test".to_string(),
};
dispatch_post_save(model, true).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
#[serial_test::serial]
async fn test_dispatch_pre_delete() {
pre_delete::<TestModel>().disconnect_all();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
pre_delete::<TestModel>().connect(move |_instance| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
});
let model = TestModel {
id: 1,
name: "Test".to_string(),
};
dispatch_pre_delete(model).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 1);
pre_delete::<TestModel>().disconnect_all();
}
#[tokio::test]
async fn test_dispatch_post_delete() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
post_delete::<TestModel>().connect(move |_instance| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
});
let model = TestModel {
id: 1,
name: "Test".to_string(),
};
dispatch_post_delete(model).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_connect_orm_signals() {
connect_orm_signals::<TestModel>().await;
let adapter = OrmSignalAdapter::<TestModel>::new();
let initial_count = adapter.signal_count();
pre_save::<TestModel>().connect(|_| async { Ok(()) });
assert!(adapter.signal_count() > initial_count);
}
#[tokio::test]
async fn test_get_orm_signals() {
let signals = get_orm_signals::<TestModel>();
assert_eq!(signals.len(), 4);
let names: Vec<_> = signals.iter().map(|s| format!("{:?}", s)).collect();
assert_eq!(names.len(), 4);
}
#[tokio::test]
#[serial_test::serial]
async fn test_orm_signal_flow() {
pre_save::<TestModel>().disconnect_all();
post_save::<TestModel>().disconnect_all();
pre_delete::<TestModel>().disconnect_all();
post_delete::<TestModel>().disconnect_all();
let events = Arc::new(Mutex::new(Vec::new()));
let e1 = events.clone();
pre_save::<TestModel>().connect(move |_| {
let e = e1.clone();
async move {
e.lock().push("pre_save");
Ok(())
}
});
let e2 = events.clone();
post_save::<TestModel>().connect(move |_| {
let e = e2.clone();
async move {
e.lock().push("post_save");
Ok(())
}
});
let e3 = events.clone();
pre_delete::<TestModel>().connect(move |_| {
let e = e3.clone();
async move {
e.lock().push("pre_delete");
Ok(())
}
});
let e4 = events.clone();
post_delete::<TestModel>().connect(move |_| {
let e = e4.clone();
async move {
e.lock().push("post_delete");
Ok(())
}
});
let model = TestModel {
id: 1,
name: "Test".to_string(),
};
dispatch_pre_save(model.clone()).await.unwrap();
dispatch_post_save(model.clone(), true).await.unwrap();
dispatch_pre_delete(model.clone()).await.unwrap();
dispatch_post_delete(model).await.unwrap();
let event_log = events.lock();
assert_eq!(event_log.len(), 4);
assert_eq!(event_log[0], "pre_save");
assert_eq!(event_log[1], "post_save");
assert_eq!(event_log[2], "pre_delete");
assert_eq!(event_log[3], "post_delete");
pre_save::<TestModel>().disconnect_all();
post_save::<TestModel>().disconnect_all();
pre_delete::<TestModel>().disconnect_all();
post_delete::<TestModel>().disconnect_all();
}
}