use super::error::SignalError;
use super::signal::Signal;
use async_trait::async_trait;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::SystemTime;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredSignal<T> {
pub id: u64,
pub signal_name: String,
pub timestamp: SystemTime,
pub payload: T,
}
impl<T> StoredSignal<T> {
pub fn new(id: u64, signal_name: String, payload: T) -> Self {
Self {
id,
signal_name,
timestamp: SystemTime::now(),
payload,
}
}
}
#[async_trait]
pub trait SignalStore<T: Send + Sync + 'static>: Send + Sync {
async fn store(&self, signal: StoredSignal<T>) -> Result<(), SignalError>;
async fn retrieve(&self, id: u64) -> Result<Option<StoredSignal<T>>, SignalError>;
async fn list(&self, limit: usize, offset: usize) -> Result<Vec<StoredSignal<T>>, SignalError>;
async fn count(&self) -> Result<u64, SignalError>;
async fn clear(&self) -> Result<(), SignalError>;
}
pub struct MemoryStore<T> {
signals: Arc<RwLock<VecDeque<StoredSignal<T>>>>,
next_id: Arc<RwLock<u64>>,
max_size: usize,
}
impl<T> MemoryStore<T> {
pub fn new() -> Self {
Self {
signals: Arc::new(RwLock::new(VecDeque::new())),
next_id: Arc::new(RwLock::new(1)),
max_size: usize::MAX,
}
}
pub fn with_max_size(max_size: usize) -> Self {
Self {
signals: Arc::new(RwLock::new(VecDeque::new())),
next_id: Arc::new(RwLock::new(1)),
max_size,
}
}
pub fn max_size(&self) -> usize {
self.max_size
}
}
impl<T> Default for MemoryStore<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Clone for MemoryStore<T> {
fn clone(&self) -> Self {
Self {
signals: Arc::clone(&self.signals),
next_id: Arc::clone(&self.next_id),
max_size: self.max_size,
}
}
}
#[async_trait]
impl<T: Send + Sync + Clone + 'static> SignalStore<T> for MemoryStore<T> {
async fn store(&self, signal: StoredSignal<T>) -> Result<(), SignalError> {
let mut signals = self.signals.write();
if signals.len() >= self.max_size {
signals.pop_front();
}
signals.push_back(signal);
Ok(())
}
async fn retrieve(&self, id: u64) -> Result<Option<StoredSignal<T>>, SignalError> {
let signals = self.signals.read();
Ok(signals.iter().find(|s| s.id == id).cloned())
}
async fn list(&self, limit: usize, offset: usize) -> Result<Vec<StoredSignal<T>>, SignalError> {
let signals = self.signals.read();
Ok(signals.iter().skip(offset).take(limit).cloned().collect())
}
async fn count(&self) -> Result<u64, SignalError> {
Ok(self.signals.read().len() as u64)
}
async fn clear(&self) -> Result<(), SignalError> {
self.signals.write().clear();
Ok(())
}
}
pub struct PersistentSignal<T: Send + Sync + 'static> {
signal: Signal<T>,
store: Arc<dyn SignalStore<T>>,
signal_name: String,
next_id: Arc<RwLock<u64>>,
}
impl<T: Send + Sync + Clone + 'static> PersistentSignal<T> {
pub fn new<S>(signal: Signal<T>, store: S) -> Self
where
S: SignalStore<T> + 'static,
{
let signal_name = format!("persistent_{}", std::any::type_name::<T>());
Self {
signal,
store: Arc::new(store),
signal_name,
next_id: Arc::new(RwLock::new(1)),
}
}
pub async fn send(&self, instance: T) -> Result<(), SignalError> {
let stored_instance = instance.clone();
let id = {
let mut next_id = self.next_id.write();
let id = *next_id;
*next_id += 1;
id
};
let stored_signal = StoredSignal::new(id, self.signal_name.clone(), stored_instance);
self.store.store(stored_signal).await?;
self.signal.send(instance).await
}
pub async fn retrieve(&self, id: u64) -> Result<Option<StoredSignal<T>>, SignalError> {
self.store.retrieve(id).await
}
pub async fn list(
&self,
limit: usize,
offset: usize,
) -> Result<Vec<StoredSignal<T>>, SignalError> {
self.store.list(limit, offset).await
}
pub async fn count(&self) -> Result<u64, SignalError> {
self.store.count().await
}
pub async fn clear(&self) -> Result<(), SignalError> {
self.store.clear().await
}
pub fn signal(&self) -> &Signal<T> {
&self.signal
}
pub fn store(&self) -> Arc<dyn SignalStore<T>> {
Arc::clone(&self.store)
}
}
impl<T: Send + Sync + Clone + 'static> Clone for PersistentSignal<T> {
fn clone(&self) -> Self {
Self {
signal: self.signal.clone(),
store: Arc::clone(&self.store),
signal_name: self.signal_name.clone(),
next_id: Arc::clone(&self.next_id),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::signals::SignalName;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
struct TestEvent {
id: i32,
message: String,
}
#[tokio::test]
async fn test_memory_store_basic() {
let store = MemoryStore::new();
let event = StoredSignal::new(
1,
"test".to_string(),
TestEvent {
id: 1,
message: "Hello".to_string(),
},
);
store.store(event.clone()).await.unwrap();
let retrieved = store.retrieve(1).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().payload.id, 1);
assert_eq!(store.count().await.unwrap(), 1);
let list = store.list(10, 0).await.unwrap();
assert_eq!(list.len(), 1);
assert_eq!(list[0].payload.message, "Hello");
}
#[tokio::test]
async fn test_memory_store_max_size() {
let store = MemoryStore::with_max_size(3);
for i in 1..=5 {
let event = StoredSignal::new(
i,
"test".to_string(),
TestEvent {
id: i as i32,
message: format!("Event {}", i),
},
);
store.store(event).await.unwrap();
}
assert_eq!(store.count().await.unwrap(), 3);
let list = store.list(10, 0).await.unwrap();
assert_eq!(list[0].id, 3);
assert_eq!(list[1].id, 4);
assert_eq!(list[2].id, 5);
}
#[tokio::test]
async fn test_memory_store_clear() {
let store = MemoryStore::new();
for i in 1..=3 {
let event = StoredSignal::new(
i,
"test".to_string(),
TestEvent {
id: i as i32,
message: "test".to_string(),
},
);
store.store(event).await.unwrap();
}
assert_eq!(store.count().await.unwrap(), 3);
store.clear().await.unwrap();
assert_eq!(store.count().await.unwrap(), 0);
}
#[tokio::test]
async fn test_persistent_signal_send_and_store() {
let signal = Signal::<TestEvent>::new(SignalName::custom("test_persistent"));
let store = MemoryStore::new();
let persistent = PersistentSignal::new(signal.clone(), store.clone());
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
signal.connect(move |_event| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
});
let event = TestEvent {
id: 42,
message: "Test event".to_string(),
};
persistent.send(event.clone()).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert_eq!(store.count().await.unwrap(), 1);
let stored = store.retrieve(1).await.unwrap();
assert!(stored.is_some());
assert_eq!(stored.unwrap().payload.id, 42);
}
#[tokio::test]
async fn test_persistent_signal_list_pagination() {
let signal = Signal::<TestEvent>::new(SignalName::custom("test_pagination"));
let store = MemoryStore::new();
let persistent = PersistentSignal::new(signal, store);
for i in 1..=10 {
let event = TestEvent {
id: i,
message: format!("Event {}", i),
};
persistent.send(event).await.unwrap();
}
let page1 = persistent.list(5, 0).await.unwrap();
assert_eq!(page1.len(), 5);
assert_eq!(page1[0].payload.id, 1);
let page2 = persistent.list(5, 5).await.unwrap();
assert_eq!(page2.len(), 5);
assert_eq!(page2[0].payload.id, 6);
assert_eq!(persistent.count().await.unwrap(), 10);
}
#[tokio::test]
async fn test_persistent_signal_clear() {
let signal = Signal::<TestEvent>::new(SignalName::custom("test_clear"));
let store = MemoryStore::new();
let persistent = PersistentSignal::new(signal, store);
for i in 1..=5 {
persistent
.send(TestEvent {
id: i,
message: "test".to_string(),
})
.await
.unwrap();
}
assert_eq!(persistent.count().await.unwrap(), 5);
persistent.clear().await.unwrap();
assert_eq!(persistent.count().await.unwrap(), 0);
}
}