use crate::base::Broker;
use crate::components::ComponentLifecycle;
use crate::error::Result;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
#[derive(Debug, Clone, PartialEq)]
pub enum SubscriptionEvent {
TaskEnqueued {
queue: String,
task_id: String,
task_type: String,
},
TaskStarted {
queue: String,
task_id: String,
task_type: String,
},
TaskCompleted {
queue: String,
task_id: String,
task_type: String,
},
TaskFailed {
queue: String,
task_id: String,
task_type: String,
error: String,
},
TaskRetried {
queue: String,
task_id: String,
task_type: String,
retry_count: i32,
},
TaskCancelled { task_id: String },
ServerStateChanged { server_id: String, status: String },
}
#[derive(Debug, Clone)]
pub struct SubscriberConfig {
pub buffer_size: usize,
}
impl Default for SubscriberConfig {
fn default() -> Self {
Self { buffer_size: 100 }
}
}
pub struct Subscriber {
broker: Arc<dyn Broker>,
#[allow(dead_code)] config: SubscriberConfig,
done: Arc<AtomicBool>,
event_tx: mpsc::Sender<SubscriptionEvent>,
event_rx: Option<mpsc::Receiver<SubscriptionEvent>>,
}
impl Subscriber {
pub fn new(broker: Arc<dyn Broker>, config: SubscriberConfig) -> Self {
let (event_tx, event_rx) = mpsc::channel(config.buffer_size);
Self {
broker,
config,
done: Arc::new(AtomicBool::new(false)),
event_tx,
event_rx: Some(event_rx),
}
}
pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<SubscriptionEvent>> {
self.event_rx.take()
}
pub fn start(self: Arc<Self>) -> JoinHandle<()> {
tracing::info!("starting subscriber");
tokio::spawn(async move {
match self.broker.cancellation_pub_sub().await {
Ok(mut stream) => {
use futures::StreamExt;
loop {
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
if self.done.load(Ordering::Relaxed) {
tracing::debug!("Subscriber: shutting down");
break;
}
}
Some(result) = stream.next() => {
tracing::debug!("Subscriber: received subscription event");
match result {
Ok(task_id) => {
tracing::debug!("Subscriber: received cancellation for task {}", task_id);
let event = SubscriptionEvent::TaskCancelled { task_id };
if let Err(e) = self.event_tx.send(event).await {
tracing::warn!("Subscriber: failed to forward cancellation event: {}", e);
}
}
Err(e) => {
tracing::warn!("Subscriber: error receiving cancellation event: {}", e);
}
}
}
}
}
}
Err(e) => {
tracing::error!(
"Subscriber: failed to subscribe to cancellation events: {}",
e
);
loop {
if self.done.load(Ordering::Relaxed) {
tracing::debug!("Subscriber: shutting down");
break;
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
}
})
}
pub async fn publish(&self, event: SubscriptionEvent) -> Result<()> {
if let Err(e) = self.event_tx.send(event).await {
tracing::warn!("Subscriber: failed to publish event: {}", e);
return Err(crate::error::Error::other(format!(
"Failed to publish event: {e}"
)));
}
Ok(())
}
pub fn shutdown(&self) {
self.done.store(true, Ordering::Relaxed);
}
pub fn is_done(&self) -> bool {
self.done.load(Ordering::Relaxed)
}
}
impl ComponentLifecycle for Subscriber {
fn start(self: Arc<Self>) -> JoinHandle<()> {
Subscriber::start(self)
}
fn shutdown(&self) {
Subscriber::shutdown(self)
}
fn is_done(&self) -> bool {
Subscriber::is_done(self)
}
}
#[cfg(feature = "default")]
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::RedisConnectionType;
#[test]
fn test_subscriber_config_default() {
let config = SubscriberConfig::default();
assert_eq!(config.buffer_size, 100);
}
#[tokio::test]
async fn test_subscriber_shutdown() {
use crate::backend::RedisBroker;
let redis_connection_config = RedisConnectionType::single("redis://localhost:6379").unwrap();
let broker = Arc::new(RedisBroker::new(redis_connection_config).await.unwrap());
let config = SubscriberConfig::default();
let subscriber = Subscriber::new(broker, config);
assert!(!subscriber.is_done());
subscriber.shutdown();
assert!(subscriber.is_done());
}
#[tokio::test]
async fn test_subscriber_publish_receive() {
use crate::backend::RedisBroker;
let redis_connection_config = RedisConnectionType::single("redis://localhost:6379").unwrap();
let broker = Arc::new(RedisBroker::new(redis_connection_config).await.unwrap());
let config = SubscriberConfig::default();
let mut subscriber = Subscriber::new(broker, config);
let event = SubscriptionEvent::TaskEnqueued {
queue: "default".to_string(),
task_id: "task123".to_string(),
task_type: "email:send".to_string(),
};
subscriber.publish(event.clone()).await.unwrap();
if let Some(mut rx) = subscriber.take_receiver() {
let received = rx.recv().await.unwrap();
assert_eq!(received, event);
}
}
#[tokio::test]
#[ignore] async fn test_subscriber_cancellation_pubsub() {
use crate::backend::RedisBroker;
let redis_connection_config = RedisConnectionType::single("redis://localhost:6379").unwrap();
let broker: Arc<dyn Broker> =
Arc::new(RedisBroker::new(redis_connection_config).await.unwrap());
let config = SubscriberConfig::default();
let mut subscriber = Subscriber::new(Arc::clone(&broker), config);
let mut rx = subscriber.take_receiver().unwrap();
let subscriber_arc = Arc::new(subscriber);
let handle = subscriber_arc.clone().start();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let task_id = "test_task_123";
broker.publish_cancellation(task_id).await.unwrap();
let received = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()).await;
assert!(received.is_ok());
if let Ok(Some(event)) = received {
match event {
SubscriptionEvent::TaskCancelled {
task_id: received_id,
} => {
assert_eq!(received_id, task_id);
}
_ => panic!("Expected TaskCancelled event"),
}
}
subscriber_arc.shutdown();
let _ = tokio::time::timeout(std::time::Duration::from_secs(1), handle).await;
}
}