qm_redis/
lib.rs

1pub use deadpool_redis::redis;
2use deadpool_redis::PoolError;
3use deadpool_redis::Runtime;
4use redis::FromRedisValue;
5use redis::RedisError;
6use redis::ToRedisArgs;
7use std::sync::Arc;
8mod config;
9pub mod lock;
10pub mod work_queue;
11use futures::stream::FuturesUnordered;
12use futures::StreamExt;
13use redis::AsyncCommands;
14use redis::RedisResult;
15use serde::de::DeserializeOwned;
16use serde::Serialize;
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::atomic::AtomicBool;
20use std::sync::atomic::Ordering;
21use std::time::Duration;
22use tokio::runtime::Builder;
23use tokio::sync::RwLock;
24use tokio::task::LocalSet;
25use work_queue::Item;
26use work_queue::KeyPrefix;
27use work_queue::WorkQueue;
28
29pub use crate::config::Config as RedisConfig;
30use crate::lock::Lock;
31
32#[derive(Debug, thiserror::Error)]
33pub enum CacheError {
34    #[error(transparent)]
35    Pool(#[from] PoolError),
36    #[error(transparent)]
37    Redis(#[from] RedisError),
38    #[error("failed to fetch: {0}")]
39    Failure(String),
40}
41
42#[derive(serde::Serialize, serde::Deserialize)]
43pub struct Json<T>(T);
44
45impl<T> FromRedisValue for Json<T>
46where
47    T: DeserializeOwned,
48{
49    fn from_redis_value(v: &redis::Value) -> RedisResult<Self> {
50        if let redis::Value::SimpleString(s) = v {
51            serde_json::from_str(s).map_err(From::from)
52        } else {
53            Err(redis::RedisError::from((
54                redis::ErrorKind::TypeError,
55                "expected simple string value",
56            )))
57        }
58    }
59}
60
61impl<T> ToRedisArgs for Json<T>
62where
63    T: Serialize,
64{
65    fn write_redis_args<W>(&self, out: &mut W)
66    where
67        W: ?Sized + redis::RedisWrite,
68    {
69        let v = serde_json::to_string(&self.0).unwrap_or_default();
70        v.write_redis_args(out);
71    }
72}
73
74pub struct Inner {
75    config: RedisConfig,
76    client: redis::Client,
77    pool: deadpool_redis::Pool,
78}
79
80#[derive(Clone)]
81pub struct Redis {
82    inner: Arc<Inner>,
83}
84
85impl AsRef<deadpool_redis::Pool> for Redis {
86    fn as_ref(&self) -> &deadpool_redis::Pool {
87        &self.inner.pool
88    }
89}
90
91impl Redis {
92    pub fn new() -> anyhow::Result<Self> {
93        let config = RedisConfig::builder().build()?;
94        let client = redis::Client::open(config.address())?;
95        let redis_cfg = deadpool_redis::Config::from_url(config.address());
96        let pool = redis_cfg.create_pool(Some(Runtime::Tokio1))?;
97        Ok(Self {
98            inner: Arc::new(Inner {
99                config,
100                client,
101                pool,
102            }),
103        })
104    }
105
106    pub fn config(&self) -> &RedisConfig {
107        &self.inner.config
108    }
109
110    pub fn client(&self) -> &redis::Client {
111        &self.inner.client
112    }
113
114    pub fn pool(&self) -> Arc<deadpool_redis::Pool> {
115        Arc::new(self.inner.pool.clone())
116    }
117
118    pub async fn connect(&self) -> Result<deadpool_redis::Connection, deadpool_redis::PoolError> {
119        self.inner.pool.get().await
120    }
121
122    pub async fn cleanup(&self) -> anyhow::Result<()> {
123        let mut con = self.connect().await?;
124        let _: redis::Value = redis::cmd("FLUSHALL").query_async(&mut con).await?;
125        Ok(())
126    }
127
128    pub async fn lock(
129        &self,
130        key: &str,
131        ttl: usize,
132        retry_count: u32,
133        retry_delay: u32,
134    ) -> Result<Lock, lock::Error> {
135        let mut con = self.connect().await?;
136        lock::lock(&mut con, key, ttl, retry_count, retry_delay).await
137    }
138
139    pub async fn unlock(&self, key: &str, lock_id: &str) -> Result<i64, lock::Error> {
140        let mut con = self.connect().await?;
141        lock::unlock(&mut con, key, lock_id).await
142    }
143}
144
145/// Runs async function exclusively using Redis lock.
146///
147/// Lock will be released even if async block fails.
148///
149/// # Errors
150///
151/// This function will return an error if either `f` call triggers exception, or lock failure.
152/// Panic in async call will not release lock, but it will be released after timeout.
153pub async fn mutex_run<S, O, E, F>(lock_name: S, redis: &Redis, f: F) -> Result<O, E>
154where
155    S: AsRef<str>,
156    F: std::future::Future<Output = Result<O, E>>,
157    E: From<self::lock::Error>,
158{
159    let lock = redis.lock(lock_name.as_ref(), 5000, 20, 250).await?;
160
161    let result = f.await;
162
163    redis.unlock(lock_name.as_ref(), &lock.id).await?;
164
165    result
166}
167
168#[macro_export]
169macro_rules! redis {
170    ($storage:ty) => {
171        impl AsRef<qm::redis::Redis> for $storage {
172            fn as_ref(&self) -> &qm::redis::Redis {
173                &self.inner.redis
174            }
175        }
176    };
177}
178
179pub type RunningWorkers =
180    FuturesUnordered<Pin<Box<dyn Future<Output = String> + Send + Sync + 'static>>>;
181
182pub type ExecItemFuture = Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>>;
183
184pub struct WorkerContext<Ctx>
185where
186    Ctx: Clone + Send + Sync + 'static,
187{
188    ctx: Ctx,
189    pub worker_id: usize,
190    pub queue: Arc<WorkQueue>,
191    pub client: Arc<redis::Client>,
192    pub item: Item,
193}
194
195impl<Ctx> WorkerContext<Ctx>
196where
197    Ctx: Clone + Send + Sync + 'static,
198{
199    pub fn ctx(&self) -> &Ctx {
200        &self.ctx
201    }
202    pub async fn complete(&self) -> anyhow::Result<()> {
203        let mut con = self.client.get_multiplexed_async_connection().await?;
204        self.queue.complete(&mut con, &self.item).await?;
205        Ok(())
206    }
207}
208
209async fn add(
210    is_running: Arc<AtomicBool>,
211    instances: Arc<RwLock<Option<RunningWorkers>>>,
212    fut: Pin<Box<dyn Future<Output = String> + Send + Sync + 'static>>,
213) {
214    if !is_running.load(Ordering::SeqCst) {
215        return;
216    }
217    instances.write().await.as_mut().unwrap().push(fut);
218}
219
220#[async_trait::async_trait]
221pub trait Work<Ctx, T>: Send + Sync
222where
223    Ctx: Clone + Send + Sync + 'static,
224    T: DeserializeOwned + Send + Sync,
225{
226    async fn run(&self, ctx: WorkerContext<Ctx>, item: T) -> anyhow::Result<()>;
227}
228
229async fn run_recovery_worker<Ctx, T>(
230    client: Arc<redis::Client>,
231    is_running: Arc<AtomicBool>,
232    worker: Arc<AsyncWorker<Ctx, T>>,
233) -> anyhow::Result<()>
234where
235    Ctx: Clone + Send + Sync + 'static,
236    T: DeserializeOwned + Send + Sync,
237{
238    tracing::info!("start {} worker recovery", worker.prefix);
239    let mut con = client.get_multiplexed_async_connection().await?;
240    loop {
241        if !is_running.load(Ordering::SeqCst) {
242            break;
243        }
244        tokio::time::sleep(Duration::from_secs(10)).await;
245        worker.recover(&mut con).await?;
246    }
247    Ok(())
248}
249
250async fn run_worker_queue<Ctx, T>(
251    ctx: Ctx,
252    client: Arc<redis::Client>,
253    is_running: Arc<AtomicBool>,
254    worker: Arc<AsyncWorker<Ctx, T>>,
255    worker_id: usize,
256) -> anyhow::Result<()>
257where
258    Ctx: Clone + Send + Sync + 'static,
259    T: DeserializeOwned + Send + Sync,
260{
261    tracing::info!("start {} worker #{worker_id} queue", worker.prefix);
262    let request_queue = Arc::new(WorkQueue::new(KeyPrefix::new(worker.prefix.clone())));
263    let mut con = client.get_multiplexed_async_connection().await?;
264    loop {
265        if !is_running.load(Ordering::SeqCst) {
266            break;
267        }
268        if let Some(item) = request_queue
269            .lease(
270                &mut con,
271                Some(Duration::from_secs(worker.timeout)),
272                Duration::from_secs(worker.lease_duration),
273            )
274            .await?
275        {
276            if item.data.is_empty() {
277                tracing::info!("item is empty");
278                request_queue.complete(&mut con, &item).await?;
279                continue;
280            }
281            if let Ok(request) = serde_json::from_slice::<T>(&item.data).inspect_err(|_| {
282                tracing::error!(
283                    "invalid request item on worker {} #{worker_id} Item: {}",
284                    worker.prefix,
285                    String::from_utf8_lossy(&item.data)
286                );
287            }) {
288                if let Some(work) = worker.work.as_ref() {
289                    work.run(
290                        WorkerContext {
291                            ctx: ctx.clone(),
292                            worker_id,
293                            queue: request_queue.clone(),
294                            client: client.clone(),
295                            item: Item {
296                                id: item.id.clone(),
297                                data: Box::new([]),
298                            },
299                        },
300                        request,
301                    )
302                    .await?;
303                }
304            } else {
305                request_queue.complete(&mut con, &item).await?;
306            }
307        }
308    }
309    Ok(())
310}
311
312struct WorkerInner {
313    client: Arc<redis::Client>,
314    instances: Arc<RwLock<Option<RunningWorkers>>>,
315    is_running: Arc<AtomicBool>,
316}
317
318#[derive(Clone)]
319pub struct Workers {
320    inner: Arc<WorkerInner>,
321}
322
323impl Workers {
324    pub fn new(config: &RedisConfig) -> RedisResult<Self> {
325        let client = Arc::new(redis::Client::open(config.address())?);
326        Ok(Self::new_with_client(client))
327    }
328
329    pub fn new_with_client(client: Arc<redis::Client>) -> Self {
330        Self {
331            inner: Arc::new(WorkerInner {
332                client,
333                instances: Arc::new(RwLock::new(Some(RunningWorkers::default()))),
334                is_running: Arc::new(AtomicBool::new(true)),
335            }),
336        }
337    }
338
339    pub async fn start<Ctx, T>(&self, ctx: Ctx, worker: AsyncWorker<Ctx, T>) -> anyhow::Result<()>
340    where
341        Ctx: Clone + Send + Sync + 'static,
342        T: DeserializeOwned + Send + Sync + 'static,
343    {
344        let worker = Arc::new(worker);
345        let mut con = self.inner.client.get_multiplexed_async_connection().await?;
346        worker.recover(&mut con).await?;
347        {
348            let instances = self.inner.instances.clone();
349            let client = self.inner.client.clone();
350            let worker = worker.clone();
351            let _th = std::thread::spawn(move || {
352                let rt = Builder::new_current_thread().enable_all().build().unwrap();
353                let local = LocalSet::new();
354                local.spawn_local(async move {
355                    let fut_worker = worker.clone();
356                    let (tx, rx) = tokio::sync::oneshot::channel::<()>();
357                    let is_running = Arc::new(AtomicBool::new(true));
358                    let is_fut_running = is_running.clone();
359                    add(
360                        is_running.clone(),
361                        instances,
362                        Box::pin(async move {
363                            let worker = fut_worker.clone();
364                            tracing::info!("stopping {} recovery", worker.prefix);
365                            is_fut_running.store(false, Ordering::SeqCst);
366                            rx.await.ok();
367                            " recovery".to_string()
368                        }),
369                    )
370                    .await;
371                    if let Err(err) = run_recovery_worker(client, is_running, worker).await {
372                        tracing::error!("{err:#?}");
373                        std::process::exit(1);
374                    }
375                    tx.send(()).ok();
376                });
377                rt.block_on(local);
378            });
379        }
380        for worker_id in 0..worker.num_workers {
381            let worker = worker.clone();
382            let client = self.inner.client.clone();
383            let ctx = ctx.clone();
384            let instances = self.inner.instances.clone();
385            let _th = std::thread::spawn(move || {
386                let rt = Builder::new_current_thread().enable_all().build().unwrap();
387                let local = LocalSet::new();
388                local.spawn_local(async move {
389                    let fut_worker = worker.clone();
390                    let (tx, rx) = tokio::sync::oneshot::channel::<()>();
391                    let is_running = Arc::new(AtomicBool::new(true));
392                    let is_fut_running = is_running.clone();
393                    add(
394                        is_running.clone(),
395                        instances,
396                        Box::pin(async move {
397                            let worker = fut_worker.clone();
398                            tracing::info!("stopping {} #{worker_id}", worker.prefix);
399                            is_fut_running.store(false, Ordering::SeqCst);
400                            rx.await.ok();
401                            format!("{} worker #{worker_id}", fut_worker.prefix)
402                        }),
403                    )
404                    .await;
405                    if let Err(err) =
406                        run_worker_queue(ctx.clone(), client, is_running, worker, worker_id).await
407                    {
408                        tracing::error!("{err:#?}");
409                        std::process::exit(1);
410                    }
411                    tx.send(()).ok();
412                });
413                rt.block_on(local);
414            });
415        }
416        Ok(())
417    }
418
419    pub async fn terminate(&self) -> anyhow::Result<()> {
420        if !self.inner.is_running.load(Ordering::SeqCst) {
421            anyhow::bail!("Workers already terminated");
422        }
423        let mut futs = self.inner.instances.write().await.take().unwrap();
424        tracing::info!("try stopping {} workers", futs.len());
425        while let Some(result) = futs.next().await {
426            tracing::info!("stopped {}", result);
427        }
428        Ok(())
429    }
430}
431
432pub struct Producer {
433    client: Arc<deadpool_redis::Pool>,
434    queue: WorkQueue,
435}
436
437impl Producer {
438    pub fn new<S>(config: &RedisConfig, prefix: S) -> anyhow::Result<Self>
439    where
440        S: Into<String>,
441    {
442        let redis_cfg = deadpool_redis::Config::from_url(config.address());
443        let redis = Arc::new(redis_cfg.create_pool(Some(Runtime::Tokio1))?);
444        Ok(Self::new_with_client(redis, prefix))
445    }
446
447    pub fn new_with_client<S>(client: Arc<deadpool_redis::Pool>, prefix: S) -> Self
448    where
449        S: Into<String>,
450    {
451        let queue = WorkQueue::new(KeyPrefix::new(prefix.into()));
452        Self { client, queue }
453    }
454
455    pub async fn add_item_with_connection<C, T>(&self, db: &mut C, data: &T) -> anyhow::Result<()>
456    where
457        C: AsyncCommands,
458        T: Serialize,
459    {
460        let item = Item::from_json_data(data)?;
461        self.queue.add_item(db, &item).await?;
462        Ok(())
463    }
464
465    pub async fn add_item<T>(&self, data: &T) -> anyhow::Result<()>
466    where
467        T: Serialize,
468    {
469        let item = Item::from_json_data(data)?;
470        let mut con = self.client.get().await?;
471        self.queue.add_item(&mut con, &item).await?;
472        Ok(())
473    }
474}
475
476pub struct AsyncWorker<Ctx, T>
477where
478    Ctx: Clone + Send + Sync + 'static,
479    T: DeserializeOwned + Send + Sync,
480{
481    prefix: String,
482    num_workers: usize,
483    timeout: u64,
484    lease_duration: u64,
485    recovery_key: String,
486    recovery_queue: WorkQueue,
487    work: Option<Box<dyn Work<Ctx, T>>>,
488}
489
490impl<Ctx, T> AsyncWorker<Ctx, T>
491where
492    Ctx: Clone + Send + Sync + 'static,
493    T: DeserializeOwned + Send + Sync,
494{
495    pub fn new<S>(prefix: S) -> Self
496    where
497        S: Into<String>,
498    {
499        let prefix = prefix.into();
500        let name = KeyPrefix::new(prefix.clone());
501        Self {
502            recovery_key: name.of(":clean"),
503            recovery_queue: WorkQueue::new(name),
504            timeout: 5,
505            lease_duration: 60,
506            num_workers: 1,
507            prefix,
508            work: None,
509        }
510    }
511
512    pub fn with_timeout(mut self, timeout: u64) -> Self {
513        self.timeout = timeout;
514        self
515    }
516
517    pub fn with_lease_duration(mut self, lease_duration: u64) -> Self {
518        self.lease_duration = lease_duration;
519        self
520    }
521
522    pub fn with_num_workers(mut self, num_workers: usize) -> Self {
523        self.num_workers = num_workers;
524        self
525    }
526
527    pub fn producer(&self, client: Arc<deadpool_redis::Pool>) -> Producer {
528        Producer {
529            client,
530            queue: WorkQueue::new(KeyPrefix::new(self.prefix.clone())),
531        }
532    }
533
534    pub async fn recover<C: AsyncCommands>(&self, db: &mut C) -> anyhow::Result<()> {
535        let l = lock::lock(db, &self.recovery_key, 3600, 36, 100).await?;
536        self.recovery_queue.recover(db).await?;
537        lock::unlock(db, &self.recovery_key, l.id).await?;
538        Ok(())
539    }
540
541    pub fn run(mut self, work: impl Work<Ctx, T> + 'static) -> Self {
542        self.work = Some(Box::new(work));
543        self
544    }
545}