1use std::{
4 collections::VecDeque,
5 future::Future,
6 marker::PhantomData,
7 pin::Pin,
8 task::{Context, Poll},
9};
10
11use apalis_core::{task::Task, worker::context::WorkerContext};
12use apalis_sql::{context::SqlContext, from_row::TaskRow};
13use futures::{FutureExt, future::BoxFuture, stream::Stream};
14use libsql::Database;
15use pin_project::pin_project;
16use ulid::Ulid;
17
18use crate::{CompactType, LibsqlError, LibsqlTask, config::Config, row::LibsqlTaskRow};
19
20const FETCH_NEXT_SQL: &str = r#"
22UPDATE Jobs
23SET status = 'Queued', lock_by = ?1, lock_at = strftime('%s', 'now')
24WHERE ROWID IN (
25 SELECT ROWID FROM Jobs
26 WHERE job_type = ?2
27 AND ((status = 'Pending' AND lock_by IS NULL)
28 OR (status = 'Failed' AND attempts < max_attempts))
29 AND (run_at IS NULL OR run_at <= strftime('%s', 'now'))
30 ORDER BY priority DESC, run_at ASC, id ASC
31 LIMIT ?3
32)
33RETURNING job, id, job_type, status, attempts, max_attempts, run_at, last_error, lock_at, lock_by, done_at, priority, metadata
34"#;
35
36pub async fn fetch_next(
38 db: &'static Database,
39 config: &Config,
40 worker: &WorkerContext,
41) -> Result<Vec<Task<CompactType, SqlContext, Ulid>>, LibsqlError> {
42 let conn = db.connect()?;
43 let job_type = config.queue().to_string();
44 let buffer_size = config.buffer_size() as i64;
45 let worker_id = worker.name().to_string();
46
47 let mut rows = conn
48 .query(
49 FETCH_NEXT_SQL,
50 libsql::params![worker_id, job_type, buffer_size],
51 )
52 .await
53 .map_err(LibsqlError::Database)?;
54
55 let mut tasks = Vec::new();
56 while let Some(row) = rows.next().await.map_err(LibsqlError::Database)? {
57 let libsql_row = LibsqlTaskRow::from_row(&row)?;
58 let task_row: TaskRow = libsql_row.try_into()?;
59 let task = task_row
60 .try_into_task_compact::<Ulid>()
61 .map_err(|e| LibsqlError::Other(e.to_string()))?;
62 tasks.push(task);
63 }
64
65 Ok(tasks)
66}
67
68enum StreamState {
70 Ready,
72 Delay(Pin<Box<tokio::time::Sleep>>),
74 Fetch(BoxFuture<'static, Result<Vec<LibsqlTask<CompactType>>, LibsqlError>>),
76 Buffered(VecDeque<LibsqlTask<CompactType>>),
78}
79
80#[pin_project]
82pub struct LibsqlPollFetcher<Decode> {
83 db: &'static Database,
84 config: Config,
85 worker: WorkerContext,
86 _marker: PhantomData<Decode>,
87 state: StreamState,
88}
89
90impl<Decode> std::fmt::Debug for LibsqlPollFetcher<Decode> {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 f.debug_struct("LibsqlPollFetcher")
93 .field("config", &self.config)
94 .field("worker", &self.worker)
95 .finish()
96 }
97}
98
99impl<Decode> Clone for LibsqlPollFetcher<Decode> {
100 fn clone(&self) -> Self {
101 Self {
102 db: self.db,
103 config: self.config.clone(),
104 worker: self.worker.clone(),
105 _marker: PhantomData,
106 state: StreamState::Ready,
107 }
108 }
109}
110
111impl<Decode> LibsqlPollFetcher<Decode> {
112 #[must_use]
114 pub fn new(db: &'static Database, config: &Config, worker: &WorkerContext) -> Self {
115 Self {
116 db,
117 config: config.clone(),
118 worker: worker.clone(),
119 _marker: PhantomData,
120 state: StreamState::Ready,
121 }
122 }
123}
124
125impl<Decode: Send + 'static> Stream for LibsqlPollFetcher<Decode> {
126 type Item = Result<Option<LibsqlTask<CompactType>>, LibsqlError>;
127
128 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129 let this = self.get_mut();
130
131 loop {
132 match &mut this.state {
133 StreamState::Ready => {
134 let db = this.db;
136 let config = this.config.clone();
137 let worker = this.worker.clone();
138 let fut = async move { fetch_next(db, &config, &worker).await };
139 this.state = StreamState::Fetch(fut.boxed());
140 }
141
142 StreamState::Delay(sleep) => match Pin::new(sleep).poll(cx) {
143 Poll::Pending => return Poll::Pending,
144 Poll::Ready(()) => {
145 this.state = StreamState::Ready;
146 }
147 },
148
149 StreamState::Fetch(fut) => {
150 match fut.poll_unpin(cx) {
151 Poll::Pending => return Poll::Pending,
152 Poll::Ready(result) => match result {
153 Ok(tasks) => {
154 if tasks.is_empty() {
155 let delay = tokio::time::sleep(this.config.poll_interval());
157 this.state = StreamState::Delay(Box::pin(delay));
158 } else {
159 let buffer: VecDeque<_> = tasks.into_iter().collect();
160 this.state = StreamState::Buffered(buffer);
161 }
162 }
163 Err(e) => {
164 log::error!("Error fetching tasks: {}", e);
167 let delay = tokio::time::sleep(this.config.poll_interval());
168 this.state = StreamState::Delay(Box::pin(delay));
169 return Poll::Ready(Some(Err(e)));
170 }
171 },
172 }
173 }
174
175 StreamState::Buffered(buffer) => {
176 if let Some(task) = buffer.pop_front() {
177 if buffer.is_empty() {
178 this.state = StreamState::Ready;
179 }
180 return Poll::Ready(Some(Ok(Some(task))));
181 } else {
182 this.state = StreamState::Ready;
184 }
185 }
186 }
187 }
188 }
189}
190
191impl<Decode> LibsqlPollFetcher<Decode> {
192 pub fn take_pending(&mut self) -> VecDeque<LibsqlTask<CompactType>> {
194 match &mut self.state {
195 StreamState::Buffered(tasks) => std::mem::take(tasks),
196 _ => VecDeque::new(),
197 }
198 }
199}