1pub use deadpool_redis::redis;
2use deadpool_redis::PoolError;
3use deadpool_redis::Runtime;
4use redis::FromRedisValue;
5use redis::RedisError;
6use redis::ToRedisArgs;
7use std::sync::Arc;
8mod config;
9pub mod lock;
10pub mod work_queue;
11use futures::stream::FuturesUnordered;
12use futures::StreamExt;
13use redis::AsyncCommands;
14use redis::RedisResult;
15use serde::de::DeserializeOwned;
16use serde::Serialize;
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::atomic::AtomicBool;
20use std::sync::atomic::Ordering;
21use std::time::Duration;
22use tokio::runtime::Builder;
23use tokio::sync::RwLock;
24use tokio::task::LocalSet;
25use work_queue::Item;
26use work_queue::KeyPrefix;
27use work_queue::WorkQueue;
28
29pub use crate::config::Config as RedisConfig;
30use crate::lock::Lock;
31
32#[derive(Debug, thiserror::Error)]
33pub enum CacheError {
34 #[error(transparent)]
35 Pool(#[from] PoolError),
36 #[error(transparent)]
37 Redis(#[from] RedisError),
38 #[error("failed to fetch: {0}")]
39 Failure(String),
40}
41
42#[derive(serde::Serialize, serde::Deserialize)]
43pub struct Json<T>(T);
44
45impl<T> FromRedisValue for Json<T>
46where
47 T: DeserializeOwned,
48{
49 fn from_redis_value(v: &redis::Value) -> RedisResult<Self> {
50 if let redis::Value::SimpleString(s) = v {
51 serde_json::from_str(s).map_err(From::from)
52 } else {
53 Err(redis::RedisError::from((
54 redis::ErrorKind::TypeError,
55 "expected simple string value",
56 )))
57 }
58 }
59}
60
61impl<T> ToRedisArgs for Json<T>
62where
63 T: Serialize,
64{
65 fn write_redis_args<W>(&self, out: &mut W)
66 where
67 W: ?Sized + redis::RedisWrite,
68 {
69 let v = serde_json::to_string(&self.0).unwrap_or_default();
70 v.write_redis_args(out);
71 }
72}
73
74pub struct Inner {
75 config: RedisConfig,
76 client: redis::Client,
77 pool: deadpool_redis::Pool,
78}
79
80#[derive(Clone)]
81pub struct Redis {
82 inner: Arc<Inner>,
83}
84
85impl AsRef<deadpool_redis::Pool> for Redis {
86 fn as_ref(&self) -> &deadpool_redis::Pool {
87 &self.inner.pool
88 }
89}
90
91impl Redis {
92 pub fn new() -> anyhow::Result<Self> {
93 let config = RedisConfig::builder().build()?;
94 let client = redis::Client::open(config.address())?;
95 let redis_cfg = deadpool_redis::Config::from_url(config.address());
96 let pool = redis_cfg.create_pool(Some(Runtime::Tokio1))?;
97 Ok(Self {
98 inner: Arc::new(Inner {
99 config,
100 client,
101 pool,
102 }),
103 })
104 }
105
106 pub fn config(&self) -> &RedisConfig {
107 &self.inner.config
108 }
109
110 pub fn client(&self) -> &redis::Client {
111 &self.inner.client
112 }
113
114 pub fn pool(&self) -> Arc<deadpool_redis::Pool> {
115 Arc::new(self.inner.pool.clone())
116 }
117
118 pub async fn connect(&self) -> Result<deadpool_redis::Connection, deadpool_redis::PoolError> {
119 self.inner.pool.get().await
120 }
121
122 pub async fn cleanup(&self) -> anyhow::Result<()> {
123 let mut con = self.connect().await?;
124 let _: redis::Value = redis::cmd("FLUSHALL").query_async(&mut con).await?;
125 Ok(())
126 }
127
128 pub async fn lock(
129 &self,
130 key: &str,
131 ttl: usize,
132 retry_count: u32,
133 retry_delay: u32,
134 ) -> Result<Lock, lock::Error> {
135 let mut con = self.connect().await?;
136 lock::lock(&mut con, key, ttl, retry_count, retry_delay).await
137 }
138
139 pub async fn unlock(&self, key: &str, lock_id: &str) -> Result<i64, lock::Error> {
140 let mut con = self.connect().await?;
141 lock::unlock(&mut con, key, lock_id).await
142 }
143}
144
145pub async fn mutex_run<S, O, E, F>(lock_name: S, redis: &Redis, f: F) -> Result<O, E>
154where
155 S: AsRef<str>,
156 F: std::future::Future<Output = Result<O, E>>,
157 E: From<self::lock::Error>,
158{
159 let lock = redis.lock(lock_name.as_ref(), 5000, 20, 250).await?;
160
161 let result = f.await;
162
163 redis.unlock(lock_name.as_ref(), &lock.id).await?;
164
165 result
166}
167
168#[macro_export]
169macro_rules! redis {
170 ($storage:ty) => {
171 impl AsRef<qm::redis::Redis> for $storage {
172 fn as_ref(&self) -> &qm::redis::Redis {
173 &self.inner.redis
174 }
175 }
176 };
177}
178
179pub type RunningWorkers =
180 FuturesUnordered<Pin<Box<dyn Future<Output = String> + Send + Sync + 'static>>>;
181
182pub type ExecItemFuture = Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>>;
183
184pub struct WorkerContext<Ctx>
185where
186 Ctx: Clone + Send + Sync + 'static,
187{
188 ctx: Ctx,
189 pub worker_id: usize,
190 pub queue: Arc<WorkQueue>,
191 pub client: Arc<redis::Client>,
192 pub item: Item,
193}
194
195impl<Ctx> WorkerContext<Ctx>
196where
197 Ctx: Clone + Send + Sync + 'static,
198{
199 pub fn ctx(&self) -> &Ctx {
200 &self.ctx
201 }
202 pub async fn complete(&self) -> anyhow::Result<()> {
203 let mut con = self.client.get_multiplexed_async_connection().await?;
204 self.queue.complete(&mut con, &self.item).await?;
205 Ok(())
206 }
207}
208
209async fn add(
210 is_running: Arc<AtomicBool>,
211 instances: Arc<RwLock<Option<RunningWorkers>>>,
212 fut: Pin<Box<dyn Future<Output = String> + Send + Sync + 'static>>,
213) {
214 if !is_running.load(Ordering::SeqCst) {
215 return;
216 }
217 instances.write().await.as_mut().unwrap().push(fut);
218}
219
220#[async_trait::async_trait]
221pub trait Work<Ctx, T>: Send + Sync
222where
223 Ctx: Clone + Send + Sync + 'static,
224 T: DeserializeOwned + Send + Sync,
225{
226 async fn run(&self, ctx: WorkerContext<Ctx>, item: T) -> anyhow::Result<()>;
227}
228
229async fn run_recovery_worker<Ctx, T>(
230 client: Arc<redis::Client>,
231 is_running: Arc<AtomicBool>,
232 worker: Arc<AsyncWorker<Ctx, T>>,
233) -> anyhow::Result<()>
234where
235 Ctx: Clone + Send + Sync + 'static,
236 T: DeserializeOwned + Send + Sync,
237{
238 tracing::info!("start {} worker recovery", worker.prefix);
239 let mut con = client.get_multiplexed_async_connection().await?;
240 loop {
241 if !is_running.load(Ordering::SeqCst) {
242 break;
243 }
244 tokio::time::sleep(Duration::from_secs(10)).await;
245 worker.recover(&mut con).await?;
246 }
247 Ok(())
248}
249
250async fn run_worker_queue<Ctx, T>(
251 ctx: Ctx,
252 client: Arc<redis::Client>,
253 is_running: Arc<AtomicBool>,
254 worker: Arc<AsyncWorker<Ctx, T>>,
255 worker_id: usize,
256) -> anyhow::Result<()>
257where
258 Ctx: Clone + Send + Sync + 'static,
259 T: DeserializeOwned + Send + Sync,
260{
261 tracing::info!("start {} worker #{worker_id} queue", worker.prefix);
262 let request_queue = Arc::new(WorkQueue::new(KeyPrefix::new(worker.prefix.clone())));
263 let mut con = client.get_multiplexed_async_connection().await?;
264 loop {
265 if !is_running.load(Ordering::SeqCst) {
266 break;
267 }
268 if let Some(item) = request_queue
269 .lease(
270 &mut con,
271 Some(Duration::from_secs(worker.timeout)),
272 Duration::from_secs(worker.lease_duration),
273 )
274 .await?
275 {
276 if item.data.is_empty() {
277 tracing::info!("item is empty");
278 request_queue.complete(&mut con, &item).await?;
279 continue;
280 }
281 if let Ok(request) = serde_json::from_slice::<T>(&item.data).inspect_err(|_| {
282 tracing::error!(
283 "invalid request item on worker {} #{worker_id} Item: {}",
284 worker.prefix,
285 String::from_utf8_lossy(&item.data)
286 );
287 }) {
288 if let Some(work) = worker.work.as_ref() {
289 work.run(
290 WorkerContext {
291 ctx: ctx.clone(),
292 worker_id,
293 queue: request_queue.clone(),
294 client: client.clone(),
295 item: Item {
296 id: item.id.clone(),
297 data: Box::new([]),
298 },
299 },
300 request,
301 )
302 .await?;
303 }
304 } else {
305 request_queue.complete(&mut con, &item).await?;
306 }
307 }
308 }
309 Ok(())
310}
311
312struct WorkerInner {
313 client: Arc<redis::Client>,
314 instances: Arc<RwLock<Option<RunningWorkers>>>,
315 is_running: Arc<AtomicBool>,
316}
317
318#[derive(Clone)]
319pub struct Workers {
320 inner: Arc<WorkerInner>,
321}
322
323impl Workers {
324 pub fn new(config: &RedisConfig) -> RedisResult<Self> {
325 let client = Arc::new(redis::Client::open(config.address())?);
326 Ok(Self::new_with_client(client))
327 }
328
329 pub fn new_with_client(client: Arc<redis::Client>) -> Self {
330 Self {
331 inner: Arc::new(WorkerInner {
332 client,
333 instances: Arc::new(RwLock::new(Some(RunningWorkers::default()))),
334 is_running: Arc::new(AtomicBool::new(true)),
335 }),
336 }
337 }
338
339 pub async fn start<Ctx, T>(&self, ctx: Ctx, worker: AsyncWorker<Ctx, T>) -> anyhow::Result<()>
340 where
341 Ctx: Clone + Send + Sync + 'static,
342 T: DeserializeOwned + Send + Sync + 'static,
343 {
344 let worker = Arc::new(worker);
345 let mut con = self.inner.client.get_multiplexed_async_connection().await?;
346 worker.recover(&mut con).await?;
347 {
348 let instances = self.inner.instances.clone();
349 let client = self.inner.client.clone();
350 let worker = worker.clone();
351 let _th = std::thread::spawn(move || {
352 let rt = Builder::new_current_thread().enable_all().build().unwrap();
353 let local = LocalSet::new();
354 local.spawn_local(async move {
355 let fut_worker = worker.clone();
356 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
357 let is_running = Arc::new(AtomicBool::new(true));
358 let is_fut_running = is_running.clone();
359 add(
360 is_running.clone(),
361 instances,
362 Box::pin(async move {
363 let worker = fut_worker.clone();
364 tracing::info!("stopping {} recovery", worker.prefix);
365 is_fut_running.store(false, Ordering::SeqCst);
366 rx.await.ok();
367 " recovery".to_string()
368 }),
369 )
370 .await;
371 if let Err(err) = run_recovery_worker(client, is_running, worker).await {
372 tracing::error!("{err:#?}");
373 std::process::exit(1);
374 }
375 tx.send(()).ok();
376 });
377 rt.block_on(local);
378 });
379 }
380 for worker_id in 0..worker.num_workers {
381 let worker = worker.clone();
382 let client = self.inner.client.clone();
383 let ctx = ctx.clone();
384 let instances = self.inner.instances.clone();
385 let _th = std::thread::spawn(move || {
386 let rt = Builder::new_current_thread().enable_all().build().unwrap();
387 let local = LocalSet::new();
388 local.spawn_local(async move {
389 let fut_worker = worker.clone();
390 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
391 let is_running = Arc::new(AtomicBool::new(true));
392 let is_fut_running = is_running.clone();
393 add(
394 is_running.clone(),
395 instances,
396 Box::pin(async move {
397 let worker = fut_worker.clone();
398 tracing::info!("stopping {} #{worker_id}", worker.prefix);
399 is_fut_running.store(false, Ordering::SeqCst);
400 rx.await.ok();
401 format!("{} worker #{worker_id}", fut_worker.prefix)
402 }),
403 )
404 .await;
405 if let Err(err) =
406 run_worker_queue(ctx.clone(), client, is_running, worker, worker_id).await
407 {
408 tracing::error!("{err:#?}");
409 std::process::exit(1);
410 }
411 tx.send(()).ok();
412 });
413 rt.block_on(local);
414 });
415 }
416 Ok(())
417 }
418
419 pub async fn terminate(&self) -> anyhow::Result<()> {
420 if !self.inner.is_running.load(Ordering::SeqCst) {
421 anyhow::bail!("Workers already terminated");
422 }
423 let mut futs = self.inner.instances.write().await.take().unwrap();
424 tracing::info!("try stopping {} workers", futs.len());
425 while let Some(result) = futs.next().await {
426 tracing::info!("stopped {}", result);
427 }
428 Ok(())
429 }
430}
431
432pub struct Producer {
433 client: Arc<deadpool_redis::Pool>,
434 queue: WorkQueue,
435}
436
437impl Producer {
438 pub fn new<S>(config: &RedisConfig, prefix: S) -> anyhow::Result<Self>
439 where
440 S: Into<String>,
441 {
442 let redis_cfg = deadpool_redis::Config::from_url(config.address());
443 let redis = Arc::new(redis_cfg.create_pool(Some(Runtime::Tokio1))?);
444 Ok(Self::new_with_client(redis, prefix))
445 }
446
447 pub fn new_with_client<S>(client: Arc<deadpool_redis::Pool>, prefix: S) -> Self
448 where
449 S: Into<String>,
450 {
451 let queue = WorkQueue::new(KeyPrefix::new(prefix.into()));
452 Self { client, queue }
453 }
454
455 pub async fn add_item_with_connection<C, T>(&self, db: &mut C, data: &T) -> anyhow::Result<()>
456 where
457 C: AsyncCommands,
458 T: Serialize,
459 {
460 let item = Item::from_json_data(data)?;
461 self.queue.add_item(db, &item).await?;
462 Ok(())
463 }
464
465 pub async fn add_item<T>(&self, data: &T) -> anyhow::Result<()>
466 where
467 T: Serialize,
468 {
469 let item = Item::from_json_data(data)?;
470 let mut con = self.client.get().await?;
471 self.queue.add_item(&mut con, &item).await?;
472 Ok(())
473 }
474}
475
476pub struct AsyncWorker<Ctx, T>
477where
478 Ctx: Clone + Send + Sync + 'static,
479 T: DeserializeOwned + Send + Sync,
480{
481 prefix: String,
482 num_workers: usize,
483 timeout: u64,
484 lease_duration: u64,
485 recovery_key: String,
486 recovery_queue: WorkQueue,
487 work: Option<Box<dyn Work<Ctx, T>>>,
488}
489
490impl<Ctx, T> AsyncWorker<Ctx, T>
491where
492 Ctx: Clone + Send + Sync + 'static,
493 T: DeserializeOwned + Send + Sync,
494{
495 pub fn new<S>(prefix: S) -> Self
496 where
497 S: Into<String>,
498 {
499 let prefix = prefix.into();
500 let name = KeyPrefix::new(prefix.clone());
501 Self {
502 recovery_key: name.of(":clean"),
503 recovery_queue: WorkQueue::new(name),
504 timeout: 5,
505 lease_duration: 60,
506 num_workers: 1,
507 prefix,
508 work: None,
509 }
510 }
511
512 pub fn with_timeout(mut self, timeout: u64) -> Self {
513 self.timeout = timeout;
514 self
515 }
516
517 pub fn with_lease_duration(mut self, lease_duration: u64) -> Self {
518 self.lease_duration = lease_duration;
519 self
520 }
521
522 pub fn with_num_workers(mut self, num_workers: usize) -> Self {
523 self.num_workers = num_workers;
524 self
525 }
526
527 pub fn producer(&self, client: Arc<deadpool_redis::Pool>) -> Producer {
528 Producer {
529 client,
530 queue: WorkQueue::new(KeyPrefix::new(self.prefix.clone())),
531 }
532 }
533
534 pub async fn recover<C: AsyncCommands>(&self, db: &mut C) -> anyhow::Result<()> {
535 let l = lock::lock(db, &self.recovery_key, 3600, 36, 100).await?;
536 self.recovery_queue.recover(db).await?;
537 lock::unlock(db, &self.recovery_key, l.id).await?;
538 Ok(())
539 }
540
541 pub fn run(mut self, work: impl Work<Ctx, T> + 'static) -> Self {
542 self.work = Some(Box::new(work));
543 self
544 }
545}