1use std::marker::PhantomData;
2
3use apalis_core::{
4 backend::codec::{Codec, json::JsonCodec},
5 error::BoxDynError,
6 task::Parts,
7 worker::ext::ack::Acknowledge,
8};
9use chrono::Utc;
10use futures::{FutureExt, future::BoxFuture};
11use redis::{
12 RedisError, Script,
13 aio::{ConnectionLike, ConnectionManager},
14};
15use ulid::Ulid;
16
17use crate::{build_error, config::RedisConfig, context::RedisContext};
18
19#[derive(Debug)]
21pub struct RedisAck<Conn = ConnectionManager, Encode = JsonCodec<Vec<u8>>> {
22 conn: Conn,
23 config: RedisConfig,
24 _codec: PhantomData<Encode>,
25}
26impl<Conn: Clone, Encode> RedisAck<Conn, Encode> {
27 pub fn new(conn: &Conn, config: &RedisConfig) -> Self {
29 Self {
30 conn: conn.clone(),
31 config: config.clone(),
32 _codec: PhantomData,
33 }
34 }
35}
36
37impl<Conn, Encode> Clone for RedisAck<Conn, Encode>
38where
39 Conn: Clone,
40 RedisConfig: Clone,
41{
42 fn clone(&self) -> Self {
43 Self {
44 conn: self.conn.clone(),
45 config: self.config.clone(),
46 _codec: PhantomData,
47 }
48 }
49}
50
51impl<Conn: ConnectionLike + Send + Clone + 'static, Res, Encode>
52 Acknowledge<Res, RedisContext, Ulid> for RedisAck<Conn, Encode>
53where
54 Encode: Codec<Res, Compact = Vec<u8>>,
55{
56 type Future = BoxFuture<'static, Result<(), RedisError>>;
57
58 type Error = RedisError;
59
60 fn ack(
61 &mut self,
62 res: &Result<Res, BoxDynError>,
63 parts: &Parts<RedisContext, Ulid>,
64 ) -> Self::Future {
65 let task_id = parts.task_id.unwrap().to_string();
66 let attempt = parts.attempt.current();
67 let worker_id = &parts.ctx.lock_by.as_ref().unwrap();
68 let inflight_set = format!("{}:{}", self.config.inflight_jobs_set(), worker_id);
69 let done_jobs_set = self.config.done_jobs_set();
70 let dead_jobs_set = self.config.dead_jobs_set();
71 let job_meta_hash = self.config.job_meta_hash();
72 let status = if res.is_ok() { "ok" } else { "err" };
73 let res = res.as_ref().map_err(|e| e.to_string().bytes().collect());
74
75 let result_data = match res {
76 Ok(res) => Encode::encode(res)
77 .map_err(|_| build_error("could not encode result"))
78 .unwrap(),
79 Err(e) => e,
80 };
81 let timestamp = Utc::now().timestamp();
82 let script = Script::new(include_str!("../lua/ack_job.lua"));
83 let mut conn = self.conn.clone();
84
85 async move {
86 let mut script = script.key(inflight_set);
87 let _ = script
88 .key(done_jobs_set)
89 .key(dead_jobs_set)
90 .key(job_meta_hash)
91 .arg(task_id)
92 .arg(timestamp)
93 .arg(result_data)
94 .arg(status)
95 .arg(attempt)
96 .invoke_async::<u32>(&mut conn)
97 .boxed()
98 .await?;
99 Ok(())
100 }
101 .boxed()
102 }
103}