use std::cmp;
use std::collections::HashSet;
use std::time::Duration;
use async_stream::stream;
use rand::{thread_rng, Rng};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
},
sync::Mutex,
time::sleep,
};
use tokio_stream::Stream;
use crate::{parser, Command, Message};
#[derive(Debug)]
pub struct RedisSub {
addr: String,
channels: Mutex<HashSet<String>>,
pattern_channels: Mutex<HashSet<String>>,
writer: Mutex<Option<OwnedWriteHalf>>,
}
impl RedisSub {
#[must_use]
pub fn new(addr: &str) -> Self {
Self {
addr: addr.to_string(),
channels: Mutex::new(HashSet::new()),
pattern_channels: Mutex::new(HashSet::new()),
writer: Mutex::new(None),
}
}
pub async fn subscribe(&self, channel: String) -> crate::Result<()> {
self.channels.lock().await.insert(channel.clone());
self.send_cmd(Command::Subscribe(channel)).await
}
pub async fn unsubscribe(&self, channel: String) -> crate::Result<()> {
if !self.channels.lock().await.remove(&channel) {
return Err(crate::Error::NotSubscribed);
}
self.send_cmd(Command::Unsubscribe(channel)).await
}
pub async fn psubscribe(&self, channel: String) -> crate::Result<()> {
self.pattern_channels.lock().await.insert(channel.clone());
self.send_cmd(Command::PatternSubscribe(channel)).await
}
pub async fn punsubscribe(&self, channel: String) -> crate::Result<()> {
if !self.pattern_channels.lock().await.remove(&channel) {
return Err(crate::Error::NotSubscribed);
}
self.send_cmd(Command::PatternUnsubscribe(channel)).await
}
pub(crate) async fn connect(
&self,
fail_fast: bool,
) -> crate::Result<(OwnedReadHalf, OwnedWriteHalf)> {
let mut retry_count = 0;
loop {
let jitter = thread_rng().gen_range(0..1000);
match TcpStream::connect(self.addr.as_str()).await {
Ok(stream) => return Ok(stream.into_split()),
Err(e) if fail_fast => return Err(crate::Error::IoError(e)),
Err(e) if retry_count <= 7 => {
warn!(
"failed to connect to redis (attempt {}/8) {:?}",
retry_count, e
);
retry_count += 1;
let timeout = cmp::min(retry_count ^ 2, 64) * 1000 + jitter;
sleep(Duration::from_millis(timeout)).await;
continue;
}
Err(e) => {
return Err(crate::Error::IoError(e));
}
};
}
}
async fn subscribe_stored(&self) -> crate::Result<()> {
for channel in self.channels.lock().await.iter() {
self.send_cmd(Command::Subscribe(channel.to_string()))
.await?;
}
for channel in self.pattern_channels.lock().await.iter() {
self.send_cmd(Command::PatternSubscribe(channel.to_string()))
.await?;
}
Ok(())
}
pub async fn listen(&self) -> crate::Result<impl Stream<Item = Message> + '_> {
self.connect(true).await?;
Ok(Box::pin(stream! {
loop {
let (mut read, write) = match self.connect(false).await {
Ok(t) => t,
Err(e) => {
warn!("failed to connect to server: {:?}", e);
continue;
}
};
{
debug!("updating stored Redis TCP writer");
let mut stored_writer = self.writer.lock().await;
*stored_writer = Some(write);
}
debug!("subscribing to stored channels after connect");
if let Err(e) = self.subscribe_stored().await {
warn!("failed to subscribe to stored channels on connection, trying connection again... (err {:?})", e);
continue;
}
yield Message::Connected;
let mut buf = [0; 64 * 1024];
let mut unread_buf = String::new();
'inner: loop {
debug!("reading incoming data");
let res = match read.read(&mut buf).await {
Ok(0) => Err(crate::Error::ZeroBytesRead),
Ok(n) => Ok(n),
Err(e) => Err(crate::Error::from(e)),
};
let n = match res {
Ok(n) => n,
Err(e) => {
*self.writer.lock().await = None;
yield Message::Disconnected(e);
break 'inner;
}
};
let buf_data = match std::str::from_utf8(&buf[..n]) {
Ok(d) => d,
Err(e) => {
yield Message::Error(e.into());
continue;
}
};
unread_buf.push_str(buf_data);
let parsed = parser::parse(&mut unread_buf);
for res in parsed {
debug!("new message");
match Message::from_response(res) {
Ok(msg) => yield msg,
Err(e) => {
warn!("failed to parse message: {:?}", e);
continue;
},
};
}
}
}
}))
}
async fn send_cmd(&self, command: Command) -> crate::Result<()> {
if let Some(writer) = &mut *self.writer.lock().await {
writer.writable().await?;
debug!("sending command {:?} to redis", &command);
writer.write_all(command.to_string().as_bytes()).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use redis::AsyncCommands;
use tokio_stream::StreamExt;
async fn get_redis_connections() -> (redis::Client, redis::aio::Connection, RedisSub) {
let client =
redis::Client::open("redis://127.0.0.1/").expect("failed to create Redis client");
let connection = client
.get_tokio_connection()
.await
.expect("failed to open Redis connection");
let redis_sub = RedisSub::new("127.0.0.1:6379");
(client, connection, redis_sub)
}
#[tokio::test]
async fn test_redis_sub() {
let (_client, mut connection, redis_sub) = get_redis_connections().await;
redis_sub
.subscribe("1234".to_string())
.await
.expect("failed to subscribe to new Redis channel");
let f = tokio::spawn(async move {
{
let mut stream = redis_sub
.listen()
.await
.expect("failed to connect to redis");
let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
.await
.expect("timeout duration of 500 milliseconds was exceeded")
.expect("expected a Message");
assert!(
msg.is_connected(),
"message after opening stream was not `Connected`: {:?}",
msg
);
let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
.await
.expect("timeout duration of 500 milliseconds was exceeded")
.expect("expected a Message");
assert!(
msg.is_subscription(),
"message after connection was not `Subscription`: {:?}",
msg
);
let msg = tokio::time::timeout(Duration::from_secs(2), stream.next())
.await
.expect("timeout duration of 2 seconds was exceeded")
.expect("expected a Message");
assert!(
msg.is_message(),
"message after subscription was not `Message`: {:?}",
msg
);
match msg {
Message::Message { channel, message } => {
assert_eq!(channel, "1234".to_string());
assert_eq!(message, "1234".to_string());
}
_ => unreachable!("already checked this is message"),
}
}
redis_sub
});
tokio::time::sleep(Duration::from_millis(1100)).await;
connection
.publish::<&str, &str, u32>("1234", "1234")
.await
.expect("failed to send publish command to Redis");
let redis_sub = f.await.expect("background future failed");
let mut stream = redis_sub
.listen()
.await
.expect("failed to connect to redis");
let _ = stream.next().await;
let _ = stream.next().await;
redis_sub
.unsubscribe("1234".to_string())
.await
.expect("failed to unsubscribe from Redis channel");
let msg = stream.next().await.expect("expected a Message");
assert!(
msg.is_unsubscription(),
"message after unsubscription was not `Unsubscription`: {:?}",
msg
)
}
#[tokio::test]
pub async fn test_redis_pattern_sub() {
let (_client, mut connection, redis_sub) = get_redis_connections().await;
redis_sub
.psubscribe("*420*".to_string())
.await
.expect("failed to subscribe to new Redis channel");
let f = tokio::spawn(async move {
{
let mut stream = redis_sub
.listen()
.await
.expect("failed to connect to redis");
let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
.await
.expect("timeout duration of 500 milliseconds was exceeded")
.expect("expected a Message");
assert!(
msg.is_connected(),
"message after opening stream was not `Connected`: {:?}",
msg
);
let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
.await
.expect("timeout duration of 500 milliseconds was exceeded")
.expect("expected a Message");
assert!(
msg.is_pattern_subscription(),
"message after connection was not `PatternSubscription`: {:?}",
msg
);
let msg = tokio::time::timeout(Duration::from_secs(2), stream.next())
.await
.expect("timeout duration of 2 seconds was exceeded")
.expect("expected a Message");
assert!(
msg.is_pattern_message(),
"message after subscription was not `PatternMessage`: {:?}",
msg
);
match msg {
Message::PatternMessage {
pattern,
channel,
message,
} => {
assert_eq!(pattern, "*420*".to_string());
assert_eq!(channel, "64209".to_string());
assert_eq!(message, "123456".to_string());
}
_ => unreachable!("already checked this is message"),
}
}
redis_sub
});
tokio::time::sleep(Duration::from_millis(1100)).await;
connection
.publish::<&str, &str, u32>("64209", "123456")
.await
.expect("failed to send publish command to Redis");
let redis_sub = f.await.expect("background future failed");
let mut stream = redis_sub
.listen()
.await
.expect("failed to connect to redis");
let _ = stream.next().await;
let _ = stream.next().await;
redis_sub
.punsubscribe("*420*".to_string())
.await
.expect("failed to unsubscribe from Redis channel");
let msg = stream.next().await.expect("expected a Message");
assert!(
msg.is_pattern_unsubscription(),
"message after unsubscription was not `Unsubscription`: {:?}",
msg
)
}
}