qm_redis/
lib.rs

1pub use deadpool_redis::redis;
2use deadpool_redis::Runtime;
3use std::sync::Arc;
4mod config;
5pub mod lock;
6pub mod work_queue;
7use futures::stream::FuturesUnordered;
8use futures::StreamExt;
9use redis::AsyncCommands;
10use redis::RedisResult;
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::atomic::AtomicBool;
16use std::sync::atomic::Ordering;
17use std::time::Duration;
18use tokio::runtime::Builder;
19use tokio::sync::RwLock;
20use tokio::task::LocalSet;
21use work_queue::Item;
22use work_queue::KeyPrefix;
23use work_queue::WorkQueue;
24
25pub use crate::config::Config as RedisConfig;
26use crate::lock::Lock;
27
28pub struct Inner {
29    config: RedisConfig,
30    client: redis::Client,
31    pool: deadpool_redis::Pool,
32}
33
34#[derive(Clone)]
35pub struct Redis {
36    inner: Arc<Inner>,
37}
38
39impl AsRef<deadpool_redis::Pool> for Redis {
40    fn as_ref(&self) -> &deadpool_redis::Pool {
41        &self.inner.pool
42    }
43}
44
45impl Redis {
46    pub fn new() -> anyhow::Result<Self> {
47        let config = RedisConfig::builder().build()?;
48        let client = redis::Client::open(config.address())?;
49        let redis_cfg = deadpool_redis::Config::from_url(config.address());
50        let pool = redis_cfg.create_pool(Some(Runtime::Tokio1))?;
51        Ok(Self {
52            inner: Arc::new(Inner {
53                config,
54                client,
55                pool,
56            }),
57        })
58    }
59
60    pub fn config(&self) -> &RedisConfig {
61        &self.inner.config
62    }
63
64    pub fn client(&self) -> &redis::Client {
65        &self.inner.client
66    }
67
68    pub fn pool(&self) -> Arc<deadpool_redis::Pool> {
69        Arc::new(self.inner.pool.clone())
70    }
71
72    pub async fn connect(&self) -> Result<deadpool_redis::Connection, deadpool_redis::PoolError> {
73        self.inner.pool.get().await
74    }
75
76    pub async fn cleanup(&self) -> anyhow::Result<()> {
77        let mut con = self.connect().await?;
78        let _: redis::Value = redis::cmd("FLUSHALL").query_async(&mut con).await?;
79        Ok(())
80    }
81
82    pub async fn lock(
83        &self,
84        key: &str,
85        ttl: usize,
86        retry_count: u32,
87        retry_delay: u32,
88    ) -> Result<Lock, lock::Error> {
89        let mut con = self.connect().await?;
90        lock::lock(&mut con, key, ttl, retry_count, retry_delay).await
91    }
92
93    pub async fn unlock(&self, key: &str, lock_id: &str) -> Result<i64, lock::Error> {
94        let mut con = self.connect().await?;
95        lock::unlock(&mut con, key, lock_id).await
96    }
97}
98
99/// Runs async function exclusively using Redis lock.
100///
101/// Lock will be released even if async block fails.
102///
103/// # Errors
104///
105/// This function will return an error if either `f` call triggers exception, or lock failure.
106/// Panic in async call will not release lock, but it will be released after timeout.
107pub async fn mutex_run<S, O, E, F>(lock_name: S, redis: &Redis, f: F) -> Result<O, E>
108where
109    S: AsRef<str>,
110    F: std::future::Future<Output = Result<O, E>>,
111    E: From<self::lock::Error>,
112{
113    let lock = redis.lock(lock_name.as_ref(), 5000, 20, 250).await?;
114
115    let result = f.await;
116
117    redis.unlock(lock_name.as_ref(), &lock.id).await?;
118
119    result
120}
121
122#[macro_export]
123macro_rules! redis {
124    ($storage:ty) => {
125        impl AsRef<qm::redis::Redis> for $storage {
126            fn as_ref(&self) -> &qm::redis::Redis {
127                &self.inner.redis
128            }
129        }
130    };
131}
132
133pub type RunningWorkers =
134    FuturesUnordered<Pin<Box<dyn Future<Output = String> + Send + Sync + 'static>>>;
135
136pub type ExecItemFuture = Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>>;
137
138pub struct WorkerContext<Ctx>
139where
140    Ctx: Clone + Send + Sync + 'static,
141{
142    ctx: Ctx,
143    pub worker_id: usize,
144    pub queue: Arc<WorkQueue>,
145    pub client: Arc<redis::Client>,
146    pub item: Item,
147}
148
149impl<Ctx> WorkerContext<Ctx>
150where
151    Ctx: Clone + Send + Sync + 'static,
152{
153    pub fn ctx(&self) -> &Ctx {
154        &self.ctx
155    }
156    pub async fn complete(&self) -> anyhow::Result<()> {
157        let mut con = self.client.get_multiplexed_async_connection().await?;
158        self.queue.complete(&mut con, &self.item).await?;
159        Ok(())
160    }
161}
162
163async fn add(
164    is_running: Arc<AtomicBool>,
165    instances: Arc<RwLock<Option<RunningWorkers>>>,
166    fut: Pin<Box<dyn Future<Output = String> + Send + Sync + 'static>>,
167) {
168    if !is_running.load(Ordering::SeqCst) {
169        return;
170    }
171    instances.write().await.as_mut().unwrap().push(fut);
172}
173
174#[async_trait::async_trait]
175pub trait Work<Ctx, T>: Send + Sync
176where
177    Ctx: Clone + Send + Sync + 'static,
178    T: DeserializeOwned + Send + Sync,
179{
180    async fn run(&self, ctx: WorkerContext<Ctx>, item: T) -> anyhow::Result<()>;
181}
182
183async fn run_recovery_worker<Ctx, T>(
184    client: Arc<redis::Client>,
185    is_running: Arc<AtomicBool>,
186    worker: Arc<AsyncWorker<Ctx, T>>,
187) -> anyhow::Result<()>
188where
189    Ctx: Clone + Send + Sync + 'static,
190    T: DeserializeOwned + Send + Sync,
191{
192    tracing::info!("start {} worker recovery", worker.prefix);
193    let mut con = client.get_multiplexed_async_connection().await?;
194    loop {
195        if !is_running.load(Ordering::SeqCst) {
196            break;
197        }
198        tokio::time::sleep(Duration::from_secs(10)).await;
199        worker.recover(&mut con).await?;
200    }
201    Ok(())
202}
203
204async fn run_worker_queue<Ctx, T>(
205    ctx: Ctx,
206    client: Arc<redis::Client>,
207    is_running: Arc<AtomicBool>,
208    worker: Arc<AsyncWorker<Ctx, T>>,
209    worker_id: usize,
210) -> anyhow::Result<()>
211where
212    Ctx: Clone + Send + Sync + 'static,
213    T: DeserializeOwned + Send + Sync,
214{
215    tracing::info!("start {} worker #{worker_id} queue", worker.prefix);
216    let request_queue = Arc::new(WorkQueue::new(KeyPrefix::new(worker.prefix.clone())));
217    let mut con = client.get_multiplexed_async_connection().await?;
218    loop {
219        if !is_running.load(Ordering::SeqCst) {
220            break;
221        }
222        if let Some(item) = request_queue
223            .lease(
224                &mut con,
225                Some(Duration::from_secs(worker.timeout)),
226                Duration::from_secs(worker.lease_duration),
227            )
228            .await?
229        {
230            if item.data.is_empty() {
231                tracing::info!("item is empty");
232                request_queue.complete(&mut con, &item).await?;
233                continue;
234            }
235            if let Ok(request) = serde_json::from_slice::<T>(&item.data).inspect_err(|_| {
236                tracing::error!(
237                    "invalid request item on worker {} #{worker_id} Item: {}",
238                    worker.prefix,
239                    String::from_utf8_lossy(&item.data)
240                );
241            }) {
242                if let Some(work) = worker.work.as_ref() {
243                    work.run(
244                        WorkerContext {
245                            ctx: ctx.clone(),
246                            worker_id,
247                            queue: request_queue.clone(),
248                            client: client.clone(),
249                            item: Item {
250                                id: item.id.clone(),
251                                data: Box::new([]),
252                            },
253                        },
254                        request,
255                    )
256                    .await?;
257                }
258            } else {
259                request_queue.complete(&mut con, &item).await?;
260            }
261        }
262    }
263    Ok(())
264}
265
266struct WorkerInner {
267    client: Arc<redis::Client>,
268    instances: Arc<RwLock<Option<RunningWorkers>>>,
269    is_running: Arc<AtomicBool>,
270}
271
272#[derive(Clone)]
273pub struct Workers {
274    inner: Arc<WorkerInner>,
275}
276
277impl Workers {
278    pub fn new(config: &RedisConfig) -> RedisResult<Self> {
279        let client = Arc::new(redis::Client::open(config.address())?);
280        Ok(Self::new_with_client(client))
281    }
282
283    pub fn new_with_client(client: Arc<redis::Client>) -> Self {
284        Self {
285            inner: Arc::new(WorkerInner {
286                client,
287                instances: Arc::new(RwLock::new(Some(RunningWorkers::default()))),
288                is_running: Arc::new(AtomicBool::new(true)),
289            }),
290        }
291    }
292
293    pub async fn start<Ctx, T>(&self, ctx: Ctx, worker: AsyncWorker<Ctx, T>) -> anyhow::Result<()>
294    where
295        Ctx: Clone + Send + Sync + 'static,
296        T: DeserializeOwned + Send + Sync + 'static,
297    {
298        let worker = Arc::new(worker);
299        let mut con = self.inner.client.get_multiplexed_async_connection().await?;
300        worker.recover(&mut con).await?;
301        {
302            let instances = self.inner.instances.clone();
303            let client = self.inner.client.clone();
304            let worker = worker.clone();
305            let _th = std::thread::spawn(move || {
306                let rt = Builder::new_current_thread().enable_all().build().unwrap();
307                let local = LocalSet::new();
308                local.spawn_local(async move {
309                    let fut_worker = worker.clone();
310                    let (tx, rx) = tokio::sync::oneshot::channel::<()>();
311                    let is_running = Arc::new(AtomicBool::new(true));
312                    let is_fut_running = is_running.clone();
313                    add(
314                        is_running.clone(),
315                        instances,
316                        Box::pin(async move {
317                            let worker = fut_worker.clone();
318                            tracing::info!("stopping {} recovery", worker.prefix);
319                            is_fut_running.store(false, Ordering::SeqCst);
320                            rx.await.ok();
321                            " recovery".to_string()
322                        }),
323                    )
324                    .await;
325                    if let Err(err) = run_recovery_worker(client, is_running, worker).await {
326                        tracing::error!("{err:#?}");
327                        std::process::exit(1);
328                    }
329                    tx.send(()).ok();
330                });
331                rt.block_on(local);
332            });
333        }
334        for worker_id in 0..worker.num_workers {
335            let worker = worker.clone();
336            let client = self.inner.client.clone();
337            let ctx = ctx.clone();
338            let instances = self.inner.instances.clone();
339            let _th = std::thread::spawn(move || {
340                let rt = Builder::new_current_thread().enable_all().build().unwrap();
341                let local = LocalSet::new();
342                local.spawn_local(async move {
343                    let fut_worker = worker.clone();
344                    let (tx, rx) = tokio::sync::oneshot::channel::<()>();
345                    let is_running = Arc::new(AtomicBool::new(true));
346                    let is_fut_running = is_running.clone();
347                    add(
348                        is_running.clone(),
349                        instances,
350                        Box::pin(async move {
351                            let worker = fut_worker.clone();
352                            tracing::info!("stopping {} #{worker_id}", worker.prefix);
353                            is_fut_running.store(false, Ordering::SeqCst);
354                            rx.await.ok();
355                            format!("{} worker #{worker_id}", fut_worker.prefix)
356                        }),
357                    )
358                    .await;
359                    if let Err(err) =
360                        run_worker_queue(ctx.clone(), client, is_running, worker, worker_id).await
361                    {
362                        tracing::error!("{err:#?}");
363                        std::process::exit(1);
364                    }
365                    tx.send(()).ok();
366                });
367                rt.block_on(local);
368            });
369        }
370        Ok(())
371    }
372
373    pub async fn terminate(&self) -> anyhow::Result<()> {
374        if !self.inner.is_running.load(Ordering::SeqCst) {
375            anyhow::bail!("Workers already terminated");
376        }
377        let mut futs = self.inner.instances.write().await.take().unwrap();
378        tracing::info!("try stopping {} workers", futs.len());
379        while let Some(result) = futs.next().await {
380            tracing::info!("stopped {}", result);
381        }
382        Ok(())
383    }
384}
385
386pub struct Producer {
387    client: Arc<deadpool_redis::Pool>,
388    queue: WorkQueue,
389}
390
391impl Producer {
392    pub fn new<S>(config: &RedisConfig, prefix: S) -> anyhow::Result<Self>
393    where
394        S: Into<String>,
395    {
396        let redis_cfg = deadpool_redis::Config::from_url(config.address());
397        let redis = Arc::new(redis_cfg.create_pool(Some(Runtime::Tokio1))?);
398        Ok(Self::new_with_client(redis, prefix))
399    }
400
401    pub fn new_with_client<S>(client: Arc<deadpool_redis::Pool>, prefix: S) -> Self
402    where
403        S: Into<String>,
404    {
405        let queue = WorkQueue::new(KeyPrefix::new(prefix.into()));
406        Self { client, queue }
407    }
408
409    pub async fn add_item_with_connection<C, T>(&self, db: &mut C, data: &T) -> anyhow::Result<()>
410    where
411        C: AsyncCommands,
412        T: Serialize,
413    {
414        let item = Item::from_json_data(data)?;
415        self.queue.add_item(db, &item).await?;
416        Ok(())
417    }
418
419    pub async fn add_item<T>(&self, data: &T) -> anyhow::Result<()>
420    where
421        T: Serialize,
422    {
423        let item = Item::from_json_data(data)?;
424        let mut con = self.client.get().await?;
425        self.queue.add_item(&mut con, &item).await?;
426        Ok(())
427    }
428}
429
430pub struct AsyncWorker<Ctx, T>
431where
432    Ctx: Clone + Send + Sync + 'static,
433    T: DeserializeOwned + Send + Sync,
434{
435    prefix: String,
436    num_workers: usize,
437    timeout: u64,
438    lease_duration: u64,
439    recovery_key: String,
440    recovery_queue: WorkQueue,
441    work: Option<Box<dyn Work<Ctx, T>>>,
442}
443
444impl<Ctx, T> AsyncWorker<Ctx, T>
445where
446    Ctx: Clone + Send + Sync + 'static,
447    T: DeserializeOwned + Send + Sync,
448{
449    pub fn new<S>(prefix: S) -> Self
450    where
451        S: Into<String>,
452    {
453        let prefix = prefix.into();
454        let name = KeyPrefix::new(prefix.clone());
455        Self {
456            recovery_key: name.of(":clean"),
457            recovery_queue: WorkQueue::new(name),
458            timeout: 5,
459            lease_duration: 60,
460            num_workers: 1,
461            prefix,
462            work: None,
463        }
464    }
465
466    pub fn with_timeout(mut self, timeout: u64) -> Self {
467        self.timeout = timeout;
468        self
469    }
470
471    pub fn with_lease_duration(mut self, lease_duration: u64) -> Self {
472        self.lease_duration = lease_duration;
473        self
474    }
475
476    pub fn with_num_workers(mut self, num_workers: usize) -> Self {
477        self.num_workers = num_workers;
478        self
479    }
480
481    pub fn producer(&self, client: Arc<deadpool_redis::Pool>) -> Producer {
482        Producer {
483            client,
484            queue: WorkQueue::new(KeyPrefix::new(self.prefix.clone())),
485        }
486    }
487
488    pub async fn recover<C: AsyncCommands>(&self, db: &mut C) -> anyhow::Result<()> {
489        let l = lock::lock(db, &self.recovery_key, 3600, 36, 100).await?;
490        self.recovery_queue.recover(db).await?;
491        lock::unlock(db, &self.recovery_key, l.id).await?;
492        Ok(())
493    }
494
495    pub fn run(mut self, work: impl Work<Ctx, T> + 'static) -> Self {
496        self.work = Some(Box::new(work));
497        self
498    }
499}