1use std::cmp;
2use std::collections::HashSet;
3use std::time::Duration;
4
5use async_stream::stream;
6use rand::{thread_rng, Rng};
7use tokio::{
8 io::{AsyncReadExt, AsyncWriteExt},
9 net::{
10 tcp::{OwnedReadHalf, OwnedWriteHalf},
11 TcpStream,
12 },
13 sync::Mutex,
14 time::sleep,
15};
16use tokio_stream::Stream;
17
18use crate::{parser, Command, Message};
19
20#[derive(Debug)]
23pub struct RedisSub {
24 addr: String,
26 channels: Mutex<HashSet<String>>,
28 pattern_channels: Mutex<HashSet<String>>,
30 writer: Mutex<Option<OwnedWriteHalf>>,
32}
33
34impl RedisSub {
35 #[must_use]
38 pub fn new(addr: &str) -> Self {
39 Self {
40 addr: addr.to_string(),
41 channels: Mutex::new(HashSet::new()),
42 pattern_channels: Mutex::new(HashSet::new()),
43 writer: Mutex::new(None),
44 }
45 }
46
47 pub async fn subscribe(&self, channel: String) -> crate::Result<()> {
52 self.channels.lock().await.insert(channel.clone());
53
54 self.send_cmd(Command::Subscribe(channel)).await
55 }
56
57 pub async fn unsubscribe(&self, channel: String) -> crate::Result<()> {
62 if !self.channels.lock().await.remove(&channel) {
63 return Err(crate::Error::NotSubscribed);
64 }
65
66 self.send_cmd(Command::Unsubscribe(channel)).await
67 }
68
69 pub async fn psubscribe(&self, channel: String) -> crate::Result<()> {
74 self.pattern_channels.lock().await.insert(channel.clone());
75
76 self.send_cmd(Command::PatternSubscribe(channel)).await
77 }
78
79 pub async fn punsubscribe(&self, channel: String) -> crate::Result<()> {
84 if !self.pattern_channels.lock().await.remove(&channel) {
85 return Err(crate::Error::NotSubscribed);
86 }
87
88 self.send_cmd(Command::PatternUnsubscribe(channel)).await
89 }
90
91 pub(crate) async fn connect(
100 &self,
101 fail_fast: bool,
102 ) -> crate::Result<(OwnedReadHalf, OwnedWriteHalf)> {
103 let mut retry_count = 0;
104
105 loop {
106 let jitter = thread_rng().gen_range(0..1000);
108 match TcpStream::connect(self.addr.as_str()).await {
110 Ok(stream) => return Ok(stream.into_split()),
111 Err(e) if fail_fast => return Err(crate::Error::IoError(e)),
112 Err(e) if retry_count <= 7 => {
113 warn!(
115 "failed to connect to redis (attempt {}/8) {:?}",
116 retry_count, e
117 );
118 retry_count += 1;
119 let timeout = cmp::min(retry_count ^ 2, 64) * 1000 + jitter;
120 sleep(Duration::from_millis(timeout)).await;
121 continue;
122 }
123 Err(e) => {
124 return Err(crate::Error::IoError(e));
127 }
128 };
129 }
130 }
131
132 async fn subscribe_stored(&self) -> crate::Result<()> {
133 for channel in self.channels.lock().await.iter() {
134 self.send_cmd(Command::Subscribe(channel.to_string()))
135 .await?;
136 }
137
138 for channel in self.pattern_channels.lock().await.iter() {
139 self.send_cmd(Command::PatternSubscribe(channel.to_string()))
140 .await?;
141 }
142
143 Ok(())
144 }
145
146 pub async fn listen(&self) -> crate::Result<impl Stream<Item = Message> + '_> {
153 self.connect(true).await?;
154
155 Ok(Box::pin(stream! {
156 loop {
157 let (mut read, write) = match self.connect(false).await {
158 Ok(t) => t,
159 Err(e) => {
160 warn!("failed to connect to server: {:?}", e);
161 continue;
162 }
163 };
164
165 {
167 debug!("updating stored Redis TCP writer");
168 let mut stored_writer = self.writer.lock().await;
169 *stored_writer = Some(write);
170 }
171
172 debug!("subscribing to stored channels after connect");
174 if let Err(e) = self.subscribe_stored().await {
175 warn!("failed to subscribe to stored channels on connection, trying connection again... (err {:?})", e);
176 continue;
177 }
178
179 yield Message::Connected;
181
182 let mut buf = [0; 64 * 1024];
184 let mut unread_buf = String::new();
185
186 'inner: loop {
187 debug!("reading incoming data");
188 let res = match read.read(&mut buf).await {
190 Ok(0) => Err(crate::Error::ZeroBytesRead),
191 Ok(n) => Ok(n),
192 Err(e) => Err(crate::Error::from(e)),
193 };
194
195 let n = match res {
197 Ok(n) => n,
198 Err(e) => {
199 *self.writer.lock().await = None;
200 yield Message::Disconnected(e);
201 break 'inner;
202 }
203 };
204
205 let buf_data = match std::str::from_utf8(&buf[..n]) {
206 Ok(d) => d,
207 Err(e) => {
208 yield Message::Error(e.into());
209 continue;
210 }
211 };
212
213 unread_buf.push_str(buf_data);
215 let parsed = parser::parse(&mut unread_buf);
217
218 for res in parsed {
220 debug!("new message");
221 match Message::from_response(res) {
223 Ok(msg) => yield msg,
224 Err(e) => {
225 warn!("failed to parse message: {:?}", e);
226 continue;
227 },
228 };
229 }
230 }
231 }
232 }))
233 }
234
235 async fn send_cmd(&self, command: Command) -> crate::Result<()> {
237 if let Some(writer) = &mut *self.writer.lock().await {
238 writer.writable().await?;
239
240 debug!("sending command {:?} to redis", &command);
241 writer.write_all(command.to_string().as_bytes()).await?;
242 }
243
244 Ok(())
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use redis::AsyncCommands;
252 use tokio_stream::StreamExt;
253
254 async fn get_redis_connections() -> (redis::Client, redis::aio::Connection, RedisSub) {
255 let client =
256 redis::Client::open("redis://127.0.0.1/").expect("failed to create Redis client");
257 let connection = client
258 .get_tokio_connection()
259 .await
260 .expect("failed to open Redis connection");
261 let redis_sub = RedisSub::new("127.0.0.1:6379");
262 (client, connection, redis_sub)
263 }
264
265 #[tokio::test]
266 async fn test_redis_sub() {
267 let (_client, mut connection, redis_sub) = get_redis_connections().await;
268
269 redis_sub
270 .subscribe("1234".to_string())
271 .await
272 .expect("failed to subscribe to new Redis channel");
273 let f = tokio::spawn(async move {
274 {
275 let mut stream = redis_sub
276 .listen()
277 .await
278 .expect("failed to connect to redis");
279
280 let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
281 .await
282 .expect("timeout duration of 500 milliseconds was exceeded")
283 .expect("expected a Message");
284 assert!(
285 msg.is_connected(),
286 "message after opening stream was not `Connected`: {:?}",
287 msg
288 );
289
290 let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
291 .await
292 .expect("timeout duration of 500 milliseconds was exceeded")
293 .expect("expected a Message");
294 assert!(
295 msg.is_subscription(),
296 "message after connection was not `Subscription`: {:?}",
297 msg
298 );
299
300 let msg = tokio::time::timeout(Duration::from_secs(2), stream.next())
301 .await
302 .expect("timeout duration of 2 seconds was exceeded")
303 .expect("expected a Message");
304 assert!(
305 msg.is_message(),
306 "message after subscription was not `Message`: {:?}",
307 msg
308 );
309 match msg {
310 Message::Message { channel, message } => {
311 assert_eq!(channel, "1234".to_string());
312 assert_eq!(message, "1234".to_string());
313 }
314 _ => unreachable!("already checked this is message"),
315 }
316 }
317
318 redis_sub
319 });
320
321 tokio::time::sleep(Duration::from_millis(1100)).await;
323 connection
324 .publish::<&str, &str, u32>("1234", "1234")
325 .await
326 .expect("failed to send publish command to Redis");
327 let redis_sub = f.await.expect("background future failed");
328
329 let mut stream = redis_sub
330 .listen()
331 .await
332 .expect("failed to connect to redis");
333 let _ = stream.next().await;
334 let _ = stream.next().await;
335 redis_sub
336 .unsubscribe("1234".to_string())
337 .await
338 .expect("failed to unsubscribe from Redis channel");
339 let msg = stream.next().await.expect("expected a Message");
340 assert!(
341 msg.is_unsubscription(),
342 "message after unsubscription was not `Unsubscription`: {:?}",
343 msg
344 )
345 }
346
347 #[tokio::test]
348 pub async fn test_redis_pattern_sub() {
349 let (_client, mut connection, redis_sub) = get_redis_connections().await;
350
351 redis_sub
352 .psubscribe("*420*".to_string())
353 .await
354 .expect("failed to subscribe to new Redis channel");
355 let f = tokio::spawn(async move {
356 {
357 let mut stream = redis_sub
358 .listen()
359 .await
360 .expect("failed to connect to redis");
361
362 let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
363 .await
364 .expect("timeout duration of 500 milliseconds was exceeded")
365 .expect("expected a Message");
366 assert!(
367 msg.is_connected(),
368 "message after opening stream was not `Connected`: {:?}",
369 msg
370 );
371
372 let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
373 .await
374 .expect("timeout duration of 500 milliseconds was exceeded")
375 .expect("expected a Message");
376 assert!(
377 msg.is_pattern_subscription(),
378 "message after connection was not `PatternSubscription`: {:?}",
379 msg
380 );
381
382 let msg = tokio::time::timeout(Duration::from_secs(2), stream.next())
383 .await
384 .expect("timeout duration of 2 seconds was exceeded")
385 .expect("expected a Message");
386 assert!(
387 msg.is_pattern_message(),
388 "message after subscription was not `PatternMessage`: {:?}",
389 msg
390 );
391 match msg {
392 Message::PatternMessage {
393 pattern,
394 channel,
395 message,
396 } => {
397 assert_eq!(pattern, "*420*".to_string());
398 assert_eq!(channel, "64209".to_string());
399 assert_eq!(message, "123456".to_string());
400 }
401 _ => unreachable!("already checked this is message"),
402 }
403 }
404
405 redis_sub
406 });
407
408 tokio::time::sleep(Duration::from_millis(1100)).await;
410 connection
411 .publish::<&str, &str, u32>("64209", "123456")
412 .await
413 .expect("failed to send publish command to Redis");
414 let redis_sub = f.await.expect("background future failed");
415
416 let mut stream = redis_sub
417 .listen()
418 .await
419 .expect("failed to connect to redis");
420 let _ = stream.next().await;
421 let _ = stream.next().await;
422 redis_sub
423 .punsubscribe("*420*".to_string())
424 .await
425 .expect("failed to unsubscribe from Redis channel");
426 let msg = stream.next().await.expect("expected a Message");
427 assert!(
428 msg.is_pattern_unsubscription(),
429 "message after unsubscription was not `Unsubscription`: {:?}",
430 msg
431 )
432 }
433}