apalis_redis/
shared.rs

1use std::{
2    collections::HashMap,
3    marker::PhantomData,
4    sync::{Arc, Mutex},
5};
6
7use apalis_core::backend::shared::MakeShared;
8use event_listener::Event;
9use redis::{
10    AsyncConnectionConfig, Client, PushInfo, RedisError, Value, aio::MultiplexedConnection,
11};
12
13use crate::{RedisStorage, config::RedisConfig, sink::RedisSink};
14
15/// A shared Redis storage that can create multiple RedisStorage instances.
16#[derive(Debug, Clone)]
17pub struct SharedRedisStorage {
18    conn: MultiplexedConnection,
19    registry: Arc<Mutex<HashMap<String, Arc<Event>>>>,
20}
21
22fn parse_channel_info(push: &PushInfo) -> Option<(String, String, String)> {
23    if let Some(Value::BulkString(channel_bytes)) = push.data.get(1) {
24        if let Ok(channel_str) = std::str::from_utf8(channel_bytes) {
25            let parts: Vec<&str> = channel_str.split(':').collect();
26            if parts.len() >= 4 {
27                let namespace = parts[1].to_owned();
28                let action = parts[2].to_owned();
29                let signal = parts[3].to_string();
30                return Some((namespace, action, signal));
31            }
32        }
33    }
34    None
35}
36
37impl SharedRedisStorage {
38    /// Creates a new SharedRedisStorage with the given Redis client.
39    pub async fn new(client: Client) -> Result<Self, RedisError> {
40        let registry: Arc<Mutex<HashMap<String, Arc<Event>>>> =
41            Arc::new(Mutex::new(HashMap::new()));
42        let r2 = registry.clone();
43        let config = AsyncConnectionConfig::new().set_push_sender(move |msg| {
44            let Ok(registry) = r2.lock() else {
45                return Err(redis::aio::SendError);
46            };
47            if let Some((namespace, _, signal_kind)) = parse_channel_info(&msg) {
48                if signal_kind == "available" {
49                    registry.get(&namespace).map(|f| f.notify(usize::MAX));
50                }
51            }
52            Ok(())
53        });
54        let mut conn = client
55            .get_multiplexed_async_connection_with_config(&config)
56            .await?;
57        conn.psubscribe("tasks:*:available").await?;
58        Ok(SharedRedisStorage { conn, registry })
59    }
60}
61
62impl<Args> MakeShared<Args> for SharedRedisStorage {
63    type Backend = RedisStorage<Args, MultiplexedConnection>;
64    type Config = RedisConfig;
65
66    type MakeError = RedisError;
67
68    fn make_shared(&mut self) -> Result<RedisStorage<Args, MultiplexedConnection>, Self::MakeError>
69    where
70        Self::Config: Default,
71    {
72        let config = RedisConfig::default().set_namespace(std::any::type_name::<Args>());
73        Self::make_shared_with_config(self, config)
74    }
75
76    fn make_shared_with_config(
77        &mut self,
78        config: Self::Config,
79    ) -> Result<RedisStorage<Args, MultiplexedConnection>, Self::MakeError> {
80        let poller = Arc::new(Event::new());
81        self.registry
82            .lock()
83            .unwrap()
84            .insert(config.get_namespace().to_string(), poller.clone());
85        let conn = self.conn.clone();
86        let sink = RedisSink::new(&conn, &config);
87        Ok(RedisStorage {
88            conn,
89            job_type: PhantomData,
90            config,
91            codec: PhantomData,
92            poller,
93            sink,
94        })
95    }
96}