1use std::{
2 collections::HashMap,
3 marker::PhantomData,
4 sync::{Arc, Mutex},
5};
6
7use apalis_core::backend::shared::MakeShared;
8use event_listener::Event;
9use redis::{
10 AsyncConnectionConfig, Client, PushInfo, RedisError, Value, aio::MultiplexedConnection,
11};
12
13use crate::{RedisStorage, config::RedisConfig, sink::RedisSink};
14
15#[derive(Debug, Clone)]
17pub struct SharedRedisStorage {
18 conn: MultiplexedConnection,
19 registry: Arc<Mutex<HashMap<String, Arc<Event>>>>,
20}
21
22fn parse_channel_info(push: &PushInfo) -> Option<(String, String, String)> {
23 if let Some(Value::BulkString(channel_bytes)) = push.data.get(1) {
24 if let Ok(channel_str) = std::str::from_utf8(channel_bytes) {
25 let parts: Vec<&str> = channel_str.split(':').collect();
26 if parts.len() >= 4 {
27 let namespace = parts[1].to_owned();
28 let action = parts[2].to_owned();
29 let signal = parts[3].to_string();
30 return Some((namespace, action, signal));
31 }
32 }
33 }
34 None
35}
36
37impl SharedRedisStorage {
38 pub async fn new(client: Client) -> Result<Self, RedisError> {
40 let registry: Arc<Mutex<HashMap<String, Arc<Event>>>> =
41 Arc::new(Mutex::new(HashMap::new()));
42 let r2 = registry.clone();
43 let config = AsyncConnectionConfig::new().set_push_sender(move |msg| {
44 let Ok(registry) = r2.lock() else {
45 return Err(redis::aio::SendError);
46 };
47 if let Some((namespace, _, signal_kind)) = parse_channel_info(&msg) {
48 if signal_kind == "available" {
49 registry.get(&namespace).map(|f| f.notify(usize::MAX));
50 }
51 }
52 Ok(())
53 });
54 let mut conn = client
55 .get_multiplexed_async_connection_with_config(&config)
56 .await?;
57 conn.psubscribe("tasks:*:available").await?;
58 Ok(SharedRedisStorage { conn, registry })
59 }
60}
61
62impl<Args> MakeShared<Args> for SharedRedisStorage {
63 type Backend = RedisStorage<Args, MultiplexedConnection>;
64 type Config = RedisConfig;
65
66 type MakeError = RedisError;
67
68 fn make_shared(&mut self) -> Result<RedisStorage<Args, MultiplexedConnection>, Self::MakeError>
69 where
70 Self::Config: Default,
71 {
72 let config = RedisConfig::default().set_namespace(std::any::type_name::<Args>());
73 Self::make_shared_with_config(self, config)
74 }
75
76 fn make_shared_with_config(
77 &mut self,
78 config: Self::Config,
79 ) -> Result<RedisStorage<Args, MultiplexedConnection>, Self::MakeError> {
80 let poller = Arc::new(Event::new());
81 self.registry
82 .lock()
83 .unwrap()
84 .insert(config.get_namespace().to_string(), poller.clone());
85 let conn = self.conn.clone();
86 let sink = RedisSink::new(&conn, &config);
87 Ok(RedisStorage {
88 conn,
89 job_type: PhantomData,
90 config,
91 codec: PhantomData,
92 poller,
93 sink,
94 })
95 }
96}