use async_trait::async_trait;
use celers_core::event::{Event, EventEmitter};
use celers_core::{CelersError, Result};
use redis::{AsyncCommands, Client};
use std::sync::Arc;
use tokio::sync::RwLock;
const DEFAULT_CHANNEL: &str = "celeryev";
const TASK_CHANNEL: &str = "celeryev.task";
const WORKER_CHANNEL: &str = "celeryev.worker";
#[derive(Debug, Clone)]
pub struct RedisEventConfig {
pub channel: String,
pub task_channel: Option<String>,
pub worker_channel: Option<String>,
pub publish_to_type_channels: bool,
pub enabled: bool,
}
impl Default for RedisEventConfig {
fn default() -> Self {
Self {
channel: DEFAULT_CHANNEL.to_string(),
task_channel: Some(TASK_CHANNEL.to_string()),
worker_channel: Some(WORKER_CHANNEL.to_string()),
publish_to_type_channels: true,
enabled: true,
}
}
}
impl RedisEventConfig {
pub fn new() -> Self {
Self::default()
}
pub fn channel(mut self, channel: impl Into<String>) -> Self {
self.channel = channel.into();
self
}
pub fn task_channel(mut self, channel: Option<String>) -> Self {
self.task_channel = channel;
self
}
pub fn worker_channel(mut self, channel: Option<String>) -> Self {
self.worker_channel = channel;
self
}
pub fn publish_to_type_channels(mut self, enabled: bool) -> Self {
self.publish_to_type_channels = enabled;
self
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
}
pub struct RedisEventEmitter {
client: Client,
config: RedisEventConfig,
conn: Arc<RwLock<Option<redis::aio::MultiplexedConnection>>>,
}
impl RedisEventEmitter {
pub fn new(url: &str) -> std::result::Result<Self, crate::BackendError> {
let client = Client::open(url).map_err(|e| {
crate::BackendError::Connection(format!("Failed to create Redis client: {}", e))
})?;
Ok(Self {
client,
config: RedisEventConfig::default(),
conn: Arc::new(RwLock::new(None)),
})
}
pub fn with_config(
url: &str,
config: RedisEventConfig,
) -> std::result::Result<Self, crate::BackendError> {
let client = Client::open(url).map_err(|e| {
crate::BackendError::Connection(format!("Failed to create Redis client: {}", e))
})?;
Ok(Self {
client,
config,
conn: Arc::new(RwLock::new(None)),
})
}
async fn get_connection(
&self,
) -> std::result::Result<redis::aio::MultiplexedConnection, crate::BackendError> {
{
let conn_guard = self.conn.read().await;
if let Some(ref conn) = *conn_guard {
return Ok(conn.clone());
}
}
let conn = self.client.get_multiplexed_async_connection().await?;
{
let mut conn_guard = self.conn.write().await;
*conn_guard = Some(conn.clone());
}
Ok(conn)
}
async fn publish(&self, channel: &str, event_json: &str) -> Result<()> {
let mut conn = self
.get_connection()
.await
.map_err(|e| CelersError::Other(format!("Redis connection error: {}", e)))?;
conn.publish::<_, _, ()>(channel, event_json)
.await
.map_err(|e| CelersError::Other(format!("Redis publish error: {}", e)))?;
Ok(())
}
pub fn config(&self) -> &RedisEventConfig {
&self.config
}
pub fn is_active(&self) -> bool {
self.config.enabled
}
pub fn channel(&self) -> &str {
&self.config.channel
}
}
#[async_trait]
impl EventEmitter for RedisEventEmitter {
async fn emit(&self, event: Event) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let event_json = serde_json::to_string(&event)
.map_err(|e| CelersError::Other(format!("Event serialization error: {}", e)))?;
self.publish(&self.config.channel, &event_json).await?;
if self.config.publish_to_type_channels {
match &event {
Event::Task(_) => {
if let Some(ref task_channel) = self.config.task_channel {
self.publish(task_channel, &event_json).await?;
}
}
Event::Worker(_) => {
if let Some(ref worker_channel) = self.config.worker_channel {
self.publish(worker_channel, &event_json).await?;
}
}
}
}
Ok(())
}
async fn emit_batch(&self, events: Vec<Event>) -> Result<()> {
if !self.config.enabled || events.is_empty() {
return Ok(());
}
let mut conn = self
.get_connection()
.await
.map_err(|e| CelersError::Other(format!("Redis connection error: {}", e)))?;
let mut pipe = redis::pipe();
for event in &events {
let event_json = serde_json::to_string(event)
.map_err(|e| CelersError::Other(format!("Event serialization error: {}", e)))?;
pipe.publish(&self.config.channel, &event_json);
if self.config.publish_to_type_channels {
match event {
Event::Task(_) => {
if let Some(ref task_channel) = self.config.task_channel {
pipe.publish(task_channel, &event_json);
}
}
Event::Worker(_) => {
if let Some(ref worker_channel) = self.config.worker_channel {
pipe.publish(worker_channel, &event_json);
}
}
}
}
}
pipe.query_async::<()>(&mut conn)
.await
.map_err(|e| CelersError::Other(format!("Redis pipeline error: {}", e)))?;
Ok(())
}
fn is_enabled(&self) -> bool {
self.config.enabled
}
}
pub struct RedisEventReceiver {
client: Client,
channels: Vec<String>,
}
impl RedisEventReceiver {
pub fn new(url: &str) -> std::result::Result<Self, crate::BackendError> {
let client = Client::open(url).map_err(|e| {
crate::BackendError::Connection(format!("Failed to create Redis client: {}", e))
})?;
Ok(Self {
client,
channels: vec![DEFAULT_CHANNEL.to_string()],
})
}
pub fn with_channels(
url: &str,
channels: Vec<String>,
) -> std::result::Result<Self, crate::BackendError> {
let client = Client::open(url).map_err(|e| {
crate::BackendError::Connection(format!("Failed to create Redis client: {}", e))
})?;
Ok(Self { client, channels })
}
pub fn subscribe_all(url: &str) -> std::result::Result<Self, crate::BackendError> {
let client = Client::open(url).map_err(|e| {
crate::BackendError::Connection(format!("Failed to create Redis client: {}", e))
})?;
Ok(Self {
client,
channels: vec![
DEFAULT_CHANNEL.to_string(),
TASK_CHANNEL.to_string(),
WORKER_CHANNEL.to_string(),
],
})
}
pub async fn subscribe(&self) -> std::result::Result<redis::aio::PubSub, crate::BackendError> {
let conn = self.client.get_async_pubsub().await?;
Ok(conn)
}
pub fn channels(&self) -> &[String] {
&self.channels
}
pub async fn receive<F, Fut>(
&self,
mut handler: F,
) -> std::result::Result<(), crate::BackendError>
where
F: FnMut(Event) -> Fut,
Fut: std::future::Future<Output = std::result::Result<(), crate::BackendError>>,
{
use futures_util::StreamExt;
let mut pubsub = self.subscribe().await?;
for channel in &self.channels {
pubsub.subscribe(channel).await?;
}
let mut stream = pubsub.on_message();
while let Some(msg) = stream.next().await {
let payload: String = msg.get_payload()?;
match serde_json::from_str::<Event>(&payload) {
Ok(event) => {
handler(event).await?;
}
Err(e) => {
tracing::warn!("Failed to deserialize event: {}", e);
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_redis_event_config_default() {
let config = RedisEventConfig::default();
assert_eq!(config.channel, "celeryev");
assert_eq!(config.task_channel, Some("celeryev.task".to_string()));
assert_eq!(config.worker_channel, Some("celeryev.worker".to_string()));
assert!(config.publish_to_type_channels);
assert!(config.enabled);
}
#[test]
fn test_redis_event_config_builder() {
let config = RedisEventConfig::new()
.channel("my-events")
.task_channel(None)
.worker_channel(Some("my-worker-events".to_string()))
.publish_to_type_channels(false)
.enabled(true);
assert_eq!(config.channel, "my-events");
assert!(config.task_channel.is_none());
assert_eq!(config.worker_channel, Some("my-worker-events".to_string()));
assert!(!config.publish_to_type_channels);
assert!(config.enabled);
}
}