apalis_libsql/
fetcher.rs

1//! Fetcher implementation for polling tasks from libSQL database
2
3use std::{
4    collections::VecDeque,
5    future::Future,
6    marker::PhantomData,
7    pin::Pin,
8    task::{Context, Poll},
9};
10
11use apalis_core::{task::Task, worker::context::WorkerContext};
12use apalis_sql::{context::SqlContext, from_row::TaskRow};
13use futures::{FutureExt, future::BoxFuture, stream::Stream};
14use libsql::Database;
15use pin_project::pin_project;
16use ulid::Ulid;
17
18use crate::{CompactType, LibsqlError, LibsqlTask, config::Config, row::LibsqlTaskRow};
19
20/// SQL query to fetch the next batch of tasks (atomic lock via UPDATE ... RETURNING)
21const FETCH_NEXT_SQL: &str = r#"
22UPDATE Jobs
23SET status = 'Queued', lock_by = ?1, lock_at = strftime('%s', 'now')
24WHERE ROWID IN (
25    SELECT ROWID FROM Jobs
26    WHERE job_type = ?2
27        AND ((status = 'Pending' AND lock_by IS NULL) 
28             OR (status = 'Failed' AND attempts < max_attempts))
29        AND (run_at IS NULL OR run_at <= strftime('%s', 'now'))
30    ORDER BY priority DESC, run_at ASC, id ASC
31    LIMIT ?3
32)
33RETURNING job, id, job_type, status, attempts, max_attempts, run_at, last_error, lock_at, lock_by, done_at, priority, metadata
34"#;
35
36/// Fetch the next batch of tasks from the database
37pub async fn fetch_next(
38    db: &'static Database,
39    config: &Config,
40    worker: &WorkerContext,
41) -> Result<Vec<Task<CompactType, SqlContext, Ulid>>, LibsqlError> {
42    let conn = db.connect()?;
43    let job_type = config.queue().to_string();
44    let buffer_size = config.buffer_size() as i64;
45    let worker_id = worker.name().to_string();
46
47    let mut rows = conn
48        .query(
49            FETCH_NEXT_SQL,
50            libsql::params![worker_id, job_type, buffer_size],
51        )
52        .await
53        .map_err(LibsqlError::Database)?;
54
55    let mut tasks = Vec::new();
56    while let Some(row) = rows.next().await.map_err(LibsqlError::Database)? {
57        let libsql_row = LibsqlTaskRow::from_row(&row)?;
58        let task_row: TaskRow = libsql_row.try_into()?;
59        let task = task_row
60            .try_into_task_compact::<Ulid>()
61            .map_err(|e| LibsqlError::Other(e.to_string()))?;
62        tasks.push(task);
63    }
64
65    Ok(tasks)
66}
67
68/// State machine for the polling stream
69enum StreamState {
70    /// Ready to fetch
71    Ready,
72    /// Waiting for delay before next fetch
73    Delay(Pin<Box<tokio::time::Sleep>>),
74    /// Currently fetching from database
75    Fetch(BoxFuture<'static, Result<Vec<LibsqlTask<CompactType>>, LibsqlError>>),
76    /// Buffered tasks ready to yield
77    Buffered(VecDeque<LibsqlTask<CompactType>>),
78}
79
80/// Polling-based fetcher for retrieving tasks from a libSQL backend
81#[pin_project]
82pub struct LibsqlPollFetcher<Decode> {
83    db: &'static Database,
84    config: Config,
85    worker: WorkerContext,
86    _marker: PhantomData<Decode>,
87    state: StreamState,
88}
89
90impl<Decode> std::fmt::Debug for LibsqlPollFetcher<Decode> {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        f.debug_struct("LibsqlPollFetcher")
93            .field("config", &self.config)
94            .field("worker", &self.worker)
95            .finish()
96    }
97}
98
99impl<Decode> Clone for LibsqlPollFetcher<Decode> {
100    fn clone(&self) -> Self {
101        Self {
102            db: self.db,
103            config: self.config.clone(),
104            worker: self.worker.clone(),
105            _marker: PhantomData,
106            state: StreamState::Ready,
107        }
108    }
109}
110
111impl<Decode> LibsqlPollFetcher<Decode> {
112    /// Create a new LibsqlPollFetcher
113    #[must_use]
114    pub fn new(db: &'static Database, config: &Config, worker: &WorkerContext) -> Self {
115        Self {
116            db,
117            config: config.clone(),
118            worker: worker.clone(),
119            _marker: PhantomData,
120            state: StreamState::Ready,
121        }
122    }
123}
124
125impl<Decode: Send + 'static> Stream for LibsqlPollFetcher<Decode> {
126    type Item = Result<Option<LibsqlTask<CompactType>>, LibsqlError>;
127
128    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129        let this = self.get_mut();
130
131        loop {
132            match &mut this.state {
133                StreamState::Ready => {
134                    // Start a fetch operation
135                    let db = this.db;
136                    let config = this.config.clone();
137                    let worker = this.worker.clone();
138                    let fut = async move { fetch_next(db, &config, &worker).await };
139                    this.state = StreamState::Fetch(fut.boxed());
140                }
141
142                StreamState::Delay(sleep) => match Pin::new(sleep).poll(cx) {
143                    Poll::Pending => return Poll::Pending,
144                    Poll::Ready(()) => {
145                        this.state = StreamState::Ready;
146                    }
147                },
148
149                StreamState::Fetch(fut) => {
150                    match fut.poll_unpin(cx) {
151                        Poll::Pending => return Poll::Pending,
152                        Poll::Ready(result) => match result {
153                            Ok(tasks) => {
154                                if tasks.is_empty() {
155                                    // No tasks available, wait before polling again
156                                    let delay = tokio::time::sleep(this.config.poll_interval());
157                                    this.state = StreamState::Delay(Box::pin(delay));
158                                } else {
159                                    let buffer: VecDeque<_> = tasks.into_iter().collect();
160                                    this.state = StreamState::Buffered(buffer);
161                                }
162                            }
163                            Err(e) => {
164                                // Log the error and transition to delay state for retry
165                                // Stream continues running even after errors (they transition to Delay state)
166                                log::error!("Error fetching tasks: {}", e);
167                                let delay = tokio::time::sleep(this.config.poll_interval());
168                                this.state = StreamState::Delay(Box::pin(delay));
169                                return Poll::Ready(Some(Err(e)));
170                            }
171                        },
172                    }
173                }
174
175                StreamState::Buffered(buffer) => {
176                    if let Some(task) = buffer.pop_front() {
177                        if buffer.is_empty() {
178                            this.state = StreamState::Ready;
179                        }
180                        return Poll::Ready(Some(Ok(Some(task))));
181                    } else {
182                        // Buffer is empty, transition back to ready state
183                        this.state = StreamState::Ready;
184                    }
185                }
186            }
187        }
188    }
189}
190
191impl<Decode> LibsqlPollFetcher<Decode> {
192    /// Take pending tasks from the fetcher's buffer
193    pub fn take_pending(&mut self) -> VecDeque<LibsqlTask<CompactType>> {
194        match &mut self.state {
195            StreamState::Buffered(tasks) => std::mem::take(tasks),
196            _ => VecDeque::new(),
197        }
198    }
199}