1use std::cmp;
2use std::collections::HashSet;
3use std::time::Duration;
4
5use async_stream::stream;
6use rand::{thread_rng, Rng};
7use tokio::{
8 io::{AsyncReadExt, AsyncWriteExt},
9 net::{
10 tcp::{OwnedReadHalf, OwnedWriteHalf},
11 TcpStream,
12 },
13 sync::Mutex,
14 time::sleep,
15};
16use tokio_stream::Stream;
17
18use crate::{parser, Command, Message};
19
20#[derive(Debug)]
23pub struct RedisSub {
24 addr: String,
26 channels: Mutex<HashSet<String>>,
28 pattern_channels: Mutex<HashSet<String>>,
30 writer: Mutex<Option<OwnedWriteHalf>>,
32}
33
34impl RedisSub {
35 #[must_use]
38 pub fn new(addr: &str) -> Self {
39 Self {
40 addr: addr.to_string(),
41 channels: Mutex::new(HashSet::new()),
42 pattern_channels: Mutex::new(HashSet::new()),
43 writer: Mutex::new(None),
44 }
45 }
46
47 pub async fn publish(&self, channel: String, message: String) -> crate::Result<()> {
52 self.send_cmd(Command::Publish(channel, message)).await
53 }
54
55 pub async fn subscribe(&self, channel: String) -> crate::Result<()> {
60 self.channels.lock().await.insert(channel.clone());
61
62 self.send_cmd(Command::Subscribe(channel)).await
63 }
64
65 pub async fn unsubscribe(&self, channel: String) -> crate::Result<()> {
70 if !self.channels.lock().await.remove(&channel) {
71 return Err(crate::Error::NotSubscribed);
72 }
73
74 self.send_cmd(Command::Unsubscribe(channel)).await
75 }
76
77 pub async fn psubscribe(&self, channel: String) -> crate::Result<()> {
82 self.pattern_channels.lock().await.insert(channel.clone());
83
84 self.send_cmd(Command::PatternSubscribe(channel)).await
85 }
86
87 pub async fn punsubscribe(&self, channel: String) -> crate::Result<()> {
92 if !self.pattern_channels.lock().await.remove(&channel) {
93 return Err(crate::Error::NotSubscribed);
94 }
95
96 self.send_cmd(Command::PatternUnsubscribe(channel)).await
97 }
98
99 pub(crate) async fn connect(
108 &self,
109 fail_fast: bool,
110 ) -> crate::Result<(OwnedReadHalf, OwnedWriteHalf)> {
111 let mut retry_count = 0;
112
113 loop {
114 let jitter = thread_rng().gen_range(0..1000);
116 match TcpStream::connect(self.addr.as_str()).await {
118 Ok(stream) => return Ok(stream.into_split()),
119 Err(e) if fail_fast => return Err(crate::Error::IoError(e)),
120 Err(e) if retry_count <= 7 => {
121 warn!(
123 "failed to connect to redis (attempt {}/8) {:?}",
124 retry_count, e
125 );
126 retry_count += 1;
127 let timeout = cmp::min(retry_count ^ 2, 64) * 1000 + jitter;
128 sleep(Duration::from_millis(timeout)).await;
129 continue;
130 }
131 Err(e) => {
132 return Err(crate::Error::IoError(e));
135 }
136 };
137 }
138 }
139
140 async fn subscribe_stored(&self) -> crate::Result<()> {
141 for channel in self.channels.lock().await.iter() {
142 self.send_cmd(Command::Subscribe(channel.to_string()))
143 .await?;
144 }
145
146 for channel in self.pattern_channels.lock().await.iter() {
147 self.send_cmd(Command::PatternSubscribe(channel.to_string()))
148 .await?;
149 }
150
151 Ok(())
152 }
153
154 pub async fn listen(&self) -> crate::Result<impl Stream<Item = Message> + '_> {
161 self.connect(true).await?;
162
163 Ok(Box::pin(stream! {
164 loop {
165 let (mut read, write) = match self.connect(false).await {
166 Ok(t) => t,
167 Err(e) => {
168 warn!("failed to connect to server: {:?}", e);
169 continue;
170 }
171 };
172
173 {
175 debug!("updating stored Redis TCP writer");
176 let mut stored_writer = self.writer.lock().await;
177 *stored_writer = Some(write);
178 }
179
180 debug!("subscribing to stored channels after connect");
182 if let Err(e) = self.subscribe_stored().await {
183 warn!("failed to subscribe to stored channels on connection, trying connection again... (err {:?})", e);
184 continue;
185 }
186
187 yield Message::Connected;
189
190 let mut buf = [0; 64 * 1024];
192 let mut unread_buf = String::new();
193
194 'inner: loop {
195 debug!("reading incoming data");
196 let res = match read.read(&mut buf).await {
198 Ok(0) => Err(crate::Error::ZeroBytesRead),
199 Ok(n) => Ok(n),
200 Err(e) => Err(crate::Error::from(e)),
201 };
202
203 let n = match res {
205 Ok(n) => n,
206 Err(e) => {
207 *self.writer.lock().await = None;
208 yield Message::Disconnected(e);
209 break 'inner;
210 }
211 };
212
213 let buf_data = match std::str::from_utf8(&buf[..n]) {
214 Ok(d) => d,
215 Err(e) => {
216 yield Message::Error(e.into());
217 continue;
218 }
219 };
220
221 unread_buf.push_str(buf_data);
223 let parsed = parser::parse(&mut unread_buf);
225
226 for res in parsed {
228 debug!("new message");
229 match Message::from_response(res) {
231 Ok(msg) => yield msg,
232 Err(e) => {
233 warn!("failed to parse message: {:?}", e);
234 continue;
235 },
236 };
237 }
238 }
239 }
240 }))
241 }
242
243 async fn send_cmd(&self, command: Command) -> crate::Result<()> {
245 if let Some(writer) = &mut *self.writer.lock().await {
246 writer.writable().await?;
247
248 debug!("sending command {:?} to redis", &command);
249 writer.write_all(command.to_string().as_bytes()).await?;
250 }
251
252 Ok(())
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use redis::AsyncCommands;
260 use tokio_stream::StreamExt;
261
262 async fn get_redis_connections() -> (redis::Client, redis::aio::Connection, RedisSub) {
263 let client =
264 redis::Client::open("redis://127.0.0.1/").expect("failed to create Redis client");
265 let connection = client
266 .get_tokio_connection()
267 .await
268 .expect("failed to open Redis connection");
269 let redis_sub = RedisSub::new("127.0.0.1:6379");
270 (client, connection, redis_sub)
271 }
272
273 #[tokio::test]
274 async fn test_redis_sub() {
275 let (_client, mut connection, redis_sub) = get_redis_connections().await;
276
277 redis_sub
278 .subscribe("1234".to_string())
279 .await
280 .expect("failed to subscribe to new Redis channel");
281 let f = tokio::spawn(async move {
282 {
283 let mut stream = redis_sub
284 .listen()
285 .await
286 .expect("failed to connect to redis");
287
288 let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
289 .await
290 .expect("timeout duration of 500 milliseconds was exceeded")
291 .expect("expected a Message");
292 assert!(
293 msg.is_connected(),
294 "message after opening stream was not `Connected`: {:?}",
295 msg
296 );
297
298 let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
299 .await
300 .expect("timeout duration of 500 milliseconds was exceeded")
301 .expect("expected a Message");
302 assert!(
303 msg.is_subscription(),
304 "message after connection was not `Subscription`: {:?}",
305 msg
306 );
307
308 let msg = tokio::time::timeout(Duration::from_secs(2), stream.next())
309 .await
310 .expect("timeout duration of 2 seconds was exceeded")
311 .expect("expected a Message");
312 assert!(
313 msg.is_message(),
314 "message after subscription was not `Message`: {:?}",
315 msg
316 );
317 match msg {
318 Message::Message { channel, message } => {
319 assert_eq!(channel, "1234".to_string());
320 assert_eq!(message, "1234".to_string());
321 }
322 _ => unreachable!("already checked this is message"),
323 }
324 }
325
326 redis_sub
327 });
328
329 tokio::time::sleep(Duration::from_millis(1100)).await;
331 connection
332 .publish::<&str, &str, u32>("1234", "1234")
333 .await
334 .expect("failed to send publish command to Redis");
335 let redis_sub = f.await.expect("background future failed");
336
337 let mut stream = redis_sub
338 .listen()
339 .await
340 .expect("failed to connect to redis");
341 let _ = stream.next().await;
342 let _ = stream.next().await;
343 redis_sub
344 .unsubscribe("1234".to_string())
345 .await
346 .expect("failed to unsubscribe from Redis channel");
347 let msg = stream.next().await.expect("expected a Message");
348 assert!(
349 msg.is_unsubscription(),
350 "message after unsubscription was not `Unsubscription`: {:?}",
351 msg
352 )
353 }
354
355 #[tokio::test]
356 pub async fn test_redis_pattern_sub() {
357 let (_client, mut connection, redis_sub) = get_redis_connections().await;
358
359 redis_sub
360 .psubscribe("*420*".to_string())
361 .await
362 .expect("failed to subscribe to new Redis channel");
363 let f = tokio::spawn(async move {
364 {
365 let mut stream = redis_sub
366 .listen()
367 .await
368 .expect("failed to connect to redis");
369
370 let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
371 .await
372 .expect("timeout duration of 500 milliseconds was exceeded")
373 .expect("expected a Message");
374 assert!(
375 msg.is_connected(),
376 "message after opening stream was not `Connected`: {:?}",
377 msg
378 );
379
380 let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
381 .await
382 .expect("timeout duration of 500 milliseconds was exceeded")
383 .expect("expected a Message");
384 assert!(
385 msg.is_pattern_subscription(),
386 "message after connection was not `PatternSubscription`: {:?}",
387 msg
388 );
389
390 let msg = tokio::time::timeout(Duration::from_secs(2), stream.next())
391 .await
392 .expect("timeout duration of 2 seconds was exceeded")
393 .expect("expected a Message");
394 assert!(
395 msg.is_pattern_message(),
396 "message after subscription was not `PatternMessage`: {:?}",
397 msg
398 );
399 match msg {
400 Message::PatternMessage {
401 pattern,
402 channel,
403 message,
404 } => {
405 assert_eq!(pattern, "*420*".to_string());
406 assert_eq!(channel, "64209".to_string());
407 assert_eq!(message, "123456".to_string());
408 }
409 _ => unreachable!("already checked this is message"),
410 }
411 }
412
413 redis_sub
414 });
415
416 tokio::time::sleep(Duration::from_millis(1100)).await;
418 connection
419 .publish::<&str, &str, u32>("64209", "123456")
420 .await
421 .expect("failed to send publish command to Redis");
422 let redis_sub = f.await.expect("background future failed");
423
424 let mut stream = redis_sub
425 .listen()
426 .await
427 .expect("failed to connect to redis");
428 let _ = stream.next().await;
429 let _ = stream.next().await;
430 redis_sub
431 .punsubscribe("*420*".to_string())
432 .await
433 .expect("failed to unsubscribe from Redis channel");
434 let msg = stream.next().await.expect("expected a Message");
435 assert!(
436 msg.is_pattern_unsubscription(),
437 "message after unsubscription was not `Unsubscription`: {:?}",
438 msg
439 )
440 }
441}