use super::receiver_context::ReceiverContext;
use super::signal::Signal;
use async_trait::async_trait;
use std::future::Future;
use std::sync::Arc;
use super::error::SignalError;
#[async_trait]
pub trait InjectableSignal<T: Send + Sync + 'static> {
fn connect_with_context<F, Fut>(&self, receiver: F)
where
F: Fn(Arc<T>, ReceiverContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), SignalError>> + Send + 'static;
}
#[async_trait]
impl<T: Send + Sync + 'static> InjectableSignal<T> for Signal<T> {
fn connect_with_context<F, Fut>(&self, receiver: F)
where
F: Fn(Arc<T>, ReceiverContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), SignalError>> + Send + 'static,
{
self.connect(move |instance: Arc<T>| {
let ctx = ReceiverContext::new();
receiver(instance, ctx)
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::signals::SignalName;
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct TestModel {
id: i32,
name: String,
}
#[tokio::test]
async fn test_connect_with_context() {
let signal = Signal::new(SignalName::custom("test_context"));
let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let counter_clone = counter.clone();
signal.connect_with_context(move |_instance: Arc<TestModel>, ctx: ReceiverContext| {
let counter = counter_clone.clone();
async move {
assert!(!ctx.has_di_context());
counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
});
let model = TestModel {
id: 1,
name: "Test".to_string(),
};
signal.send(model).await.unwrap();
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
}
}