use std::{
collections::HashMap,
fmt,
sync::{Arc, RwLock},
};
use redis::{AsyncCommands, FromRedisValue, PushInfo, RedisResult, aio::MultiplexedConnection};
use tokio::sync::mpsc;
use super::{ChanItem, Driver, MessageStream};
pub use redis as redis_client;
#[derive(Debug)]
pub struct RedisError(redis::RedisError);
impl From<redis::RedisError> for RedisError {
fn from(e: redis::RedisError) -> Self {
Self(e)
}
}
impl fmt::Display for RedisError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl std::error::Error for RedisError {}
type HandlerMap = HashMap<String, mpsc::Sender<ChanItem>>;
#[derive(Clone)]
pub struct RedisDriver {
handlers: Arc<RwLock<HandlerMap>>,
conn: MultiplexedConnection,
}
#[cfg(feature = "redis-cluster")]
#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
#[derive(Clone)]
pub struct ClusterDriver {
handlers: Arc<RwLock<HandlerMap>>,
conn: redis::cluster_async::ClusterConnection,
}
fn read_msg(msg: redis::PushInfo) -> RedisResult<Option<(String, Vec<u8>)>> {
match msg.kind {
redis::PushKind::Message | redis::PushKind::SMessage => {
if msg.data.len() < 2 {
return Ok(None);
}
let mut iter = msg.data.into_iter();
let channel: String = FromRedisValue::from_redis_value(iter.next().unwrap())?;
let message = FromRedisValue::from_redis_value(iter.next().unwrap())?;
Ok(Some((channel, message)))
}
_ => Ok(None),
}
}
fn handle_msg(msg: PushInfo, handlers: Arc<RwLock<HandlerMap>>) {
match read_msg(msg) {
Ok(Some((chan, msg))) => {
if let Some(tx) = handlers.read().unwrap().get(&chan) {
if let Err(e) = tx.try_send((chan, msg)) {
tracing::warn!("redis pubsub channel full {e}");
}
} else {
tracing::warn!(chan, "no handler for channel");
}
}
Ok(_) => {}
Err(e) => {
tracing::error!("error reading message from redis: {e}");
}
}
}
impl RedisDriver {
pub async fn new(client: &redis::Client) -> Result<Self, redis::RedisError> {
let handlers = Arc::new(RwLock::new(HashMap::new()));
let handlers_clone = handlers.clone();
let config = redis::AsyncConnectionConfig::new().set_push_sender(move |msg| {
handle_msg(msg, handlers_clone.clone());
Ok::<(), std::convert::Infallible>(())
});
let conn = client
.get_multiplexed_async_connection_with_config(&config)
.await?;
Ok(Self { conn, handlers })
}
}
#[cfg(feature = "redis-cluster")]
impl ClusterDriver {
#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
pub async fn new(client: &redis::cluster::ClusterClient) -> Result<Self, redis::RedisError> {
let handlers = Arc::new(RwLock::new(HashMap::new()));
let handlers_clone = handlers.clone();
let config = redis::cluster::ClusterConfig::new().set_push_sender(move |msg| {
handle_msg(msg, handlers_clone.clone());
Ok::<(), std::convert::Infallible>(())
});
let conn = client.get_async_connection_with_config(config).await?;
Ok(Self { conn, handlers })
}
}
impl Driver for RedisDriver {
type Error = RedisError;
async fn publish(&self, chan: String, val: Vec<u8>) -> Result<(), Self::Error> {
self.conn
.clone()
.publish::<_, _, redis::Value>(chan, val)
.await?;
Ok(())
}
async fn subscribe(
&self,
chan: String,
size: usize,
) -> Result<MessageStream<ChanItem>, Self::Error> {
self.conn.clone().subscribe(chan.as_str()).await?;
let (tx, rx) = mpsc::channel(size);
self.handlers.write().unwrap().insert(chan, tx);
Ok(MessageStream::new(rx))
}
async fn unsubscribe(&self, chan: String) -> Result<(), Self::Error> {
self.handlers.write().unwrap().remove(&chan);
self.conn.clone().unsubscribe(chan).await?;
Ok(())
}
async fn num_serv(&self, chan: &str) -> Result<u16, Self::Error> {
let mut conn = self.conn.clone();
let (_, count): (String, u16) = redis::cmd("PUBSUB")
.arg("NUMSUB")
.arg(chan)
.query_async(&mut conn)
.await?;
Ok(count)
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
#[cfg(feature = "redis-cluster")]
impl Driver for ClusterDriver {
type Error = RedisError;
async fn publish(&self, chan: String, val: Vec<u8>) -> Result<(), Self::Error> {
self.conn
.clone()
.spublish::<_, _, redis::Value>(chan, val)
.await?;
Ok(())
}
async fn subscribe(
&self,
chan: String,
size: usize,
) -> Result<MessageStream<ChanItem>, Self::Error> {
self.conn.clone().ssubscribe(chan.as_str()).await?;
let (tx, rx) = mpsc::channel(size);
self.handlers.write().unwrap().insert(chan, tx);
Ok(MessageStream::new(rx))
}
async fn unsubscribe(&self, chan: String) -> Result<(), Self::Error> {
self.handlers.write().unwrap().remove(&chan);
self.conn.clone().sunsubscribe(chan).await?;
Ok(())
}
async fn num_serv(&self, chan: &str) -> Result<u16, Self::Error> {
let mut conn = self.conn.clone();
let (_, count): (String, u16) = redis::cmd("PUBSUB")
.arg("SHARDNUMSUB")
.arg(chan)
.query_async(&mut conn)
.await?;
Ok(count)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn watch_handle_message() {
let mut handlers = HashMap::new();
let (tx, mut rx) = mpsc::channel(1);
handlers.insert("test".to_string(), tx);
let msg = redis::PushInfo {
kind: redis::PushKind::Message,
data: vec![
redis::Value::BulkString("test".into()),
redis::Value::BulkString("foo".into()),
],
};
super::handle_msg(msg, Arc::new(RwLock::new(handlers)));
let (chan, data) = rx.try_recv().unwrap();
assert_eq!(chan, "test");
assert_eq!(data, "foo".as_bytes());
}
#[test]
fn watch_handler_pattern() {
let mut handlers = HashMap::new();
let (tx1, mut rx1) = mpsc::channel(1);
handlers.insert("test-response#namespace#uid#".to_string(), tx1);
let msg = redis::PushInfo {
kind: redis::PushKind::Message,
data: vec![
redis::Value::BulkString("test-response#namespace#uid#".into()),
redis::Value::BulkString("foo".into()),
],
};
super::handle_msg(msg, Arc::new(RwLock::new(handlers)));
let (chan, data) = rx1.try_recv().unwrap();
assert_eq!(chan, "test-response#namespace#uid#");
assert_eq!(data, "foo".as_bytes());
}
}