1use std::{
2 marker::PhantomData,
3 pin::Pin,
4 sync::{Arc, LazyLock},
5 task::{Context, Poll},
6};
7
8use apalis_core::task::Task;
9use chrono::Utc;
10use futures::{
11 FutureExt, Sink,
12 future::{BoxFuture, Shared},
13};
14use redis::{
15 RedisError, Script,
16 aio::{ConnectionLike, ConnectionManager},
17};
18use ulid::Ulid;
19
20use crate::{RedisStorage, build_error, config::RedisConfig, context::RedisContext};
21
22type SinkFuture = Shared<BoxFuture<'static, Result<(u32, u32), Arc<RedisError>>>>;
23
24#[derive(Debug)]
26pub struct RedisSink<Args, Encode, Conn = ConnectionManager> {
27 _args: PhantomData<(Args, Encode)>,
28 config: RedisConfig,
29 pending: Vec<Task<Vec<u8>, RedisContext, Ulid>>,
30 conn: Conn,
31 invoke_future: Option<SinkFuture>,
32}
33impl<Args, Conn: Clone, Encode> RedisSink<Args, Encode, Conn> {
34 pub fn new(conn: &Conn, config: &RedisConfig) -> Self {
36 Self {
37 conn: conn.clone(),
38 config: config.clone(),
39 _args: PhantomData,
40 invoke_future: None,
41 pending: Vec::new(),
42 }
43 }
44}
45
46impl<Args, Conn: Clone, Cdc: Clone> Clone for RedisSink<Args, Cdc, Conn> {
47 fn clone(&self) -> Self {
48 Self {
49 conn: self.conn.clone(),
50 config: self.config.clone(),
51 _args: PhantomData,
52 invoke_future: None,
53 pending: Vec::new(),
54 }
55 }
56}
57
58static BATCH_PUSH_SCRIPT: LazyLock<Script> =
59 LazyLock::new(|| Script::new(include_str!("../lua/batch_push.lua")));
60
61pub async fn push_tasks<Conn: ConnectionLike>(
63 tasks: Vec<Task<Vec<u8>, RedisContext, Ulid>>,
64 config: RedisConfig,
65 mut conn: Conn,
66) -> Result<(u32, u32), Arc<RedisError>> {
67 let mut batch = BATCH_PUSH_SCRIPT.key(config.job_data_hash());
68 let mut script = batch
69 .key(config.active_jobs_list())
70 .key(config.signal_list())
71 .key(config.job_meta_hash())
72 .key(config.scheduled_jobs_set());
73 for request in tasks {
74 let task_id = request
75 .parts
76 .task_id
77 .map(|s| s.to_string())
78 .unwrap_or(Ulid::new().to_string());
79 let attempts = request.parts.attempt.current() as u32;
80 let max_attempts = request.parts.ctx.max_attempts;
81 let job = request.args;
82 let meta = serde_json::to_string(&request.parts.ctx.meta)
83 .map_err(|e| Arc::new(build_error(&e.to_string())))?;
84 let run_at = request.parts.run_at;
85
86 let run_at = if run_at - Utc::now().timestamp() as u64 > 0 {
87 run_at
88 } else {
89 0
90 };
91
92 script = script
93 .arg(task_id)
94 .arg(job)
95 .arg(attempts)
96 .arg(max_attempts)
97 .arg(meta)
98 .arg(run_at);
99 }
100
101 script
102 .invoke_async::<(u32, u32)>(&mut conn)
103 .await
104 .map_err(Arc::new)
105}
106
107impl<Args, Cdc, Conn> Sink<Task<Vec<u8>, RedisContext, Ulid>> for RedisStorage<Args, Conn, Cdc>
108where
109 Args: Unpin,
110 Conn: ConnectionLike + Unpin + Send + Clone + 'static,
111 Cdc: Unpin,
112{
113 type Error = RedisError;
114
115 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116 Poll::Ready(Ok(()))
117 }
118
119 fn start_send(
120 self: Pin<&mut Self>,
121 item: Task<Vec<u8>, RedisContext, Ulid>,
122 ) -> Result<(), Self::Error> {
123 let this = Pin::get_mut(self);
124 this.sink.pending.push(item);
125 Ok(())
126 }
127
128 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
129 let this = Pin::get_mut(self);
130
131 if this.sink.invoke_future.is_none() && !this.sink.pending.is_empty() {
133 let tasks: Vec<_> = this.sink.pending.drain(..).collect();
134 let fut = push_tasks(tasks, this.config.clone(), this.conn.clone());
135
136 this.sink.invoke_future = Some(fut.boxed().shared());
137 }
138
139 if let Some(fut) = &mut this.sink.invoke_future {
141 match fut.poll_unpin(cx) {
142 Poll::Pending => Poll::Pending,
143 Poll::Ready(result) => {
144 this.sink.invoke_future = None;
146
147 Poll::Ready(result.map(|_| ()).map_err(|e| Arc::into_inner(e).unwrap()))
149 }
150 }
151 } else {
152 Poll::Ready(Ok(()))
154 }
155 }
156
157 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
158 Sink::<Task<Vec<u8>, RedisContext, Ulid>>::poll_flush(self, cx)
159 }
160}