postgres_queue/
lib.rs

1//! A library for managing and executing tasks in a PostgreSQL-backed queue.
2//!
3//! This library provides a simple way to define, enqueue, and process tasks in a concurrent
4//! and fault-tolerant manner using a PostgreSQL database as the task queue.
5//!
6//! # Example
7//!
8//! ```rust
9//! use my_task_queue::{TaskRegistry, TaskData, TaskError, connect, initialize_database};
10//! use chrono::{Utc, Duration};
11//!
12//! #[tokio::main]
13//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
14//!     let database_url = "postgres://user:password@localhost/dbname";
15//!     let pool = connect(database_url).await?;
16//!     initialize_database(&pool).await?;
17//!
18//!     let mut task_registry = TaskRegistry::new();
19//!     task_registry.register_task("my_task", my_task_handler);
20//!
21//!     let task_data = serde_json::json!({ "message": "Hello, world!" });
22//!     let run_at = Utc::now() + Duration::seconds(10);
23//!     let task_id = my_task_queue::enqueue(&pool, "my_task", task_data.clone(), run_at, None).await?;
24//!
25//!     task_registry.run(&pool, 4).await?;
26//!
27//!     Ok(())
28//! }
29//!
30//! async fn my_task_handler(task_id: i32, task_data: TaskData) -> Result<(), TaskError> {
31//!     println!("Task {}: {:?}", task_id, task_data);
32//!     Ok(())
33//! }
34//! ```
35use chrono::{DateTime, Utc};
36use deadpool_postgres::{Client, Config, Pool, PoolError, Runtime};
37use serde::{Deserialize, Serialize};
38use serde_json::Value as JsonValue;
39use std::collections::HashMap;
40use std::future::Future;
41use std::pin::Pin;
42use std::sync::Arc;
43use std::time::Duration;
44use thiserror::Error;
45use tokio::task::JoinHandle;
46use tokio::time::sleep;
47use url::Url;
48
49/// A type alias for Task ID.
50pub type TaskId = i32;
51
52/// A type alias for Task Data.
53pub type TaskData = JsonValue;
54
55/// A type alias for Task Status.
56pub type TaskStatus = String;
57
58/// A type alias for Task Handler.
59pub type TaskHandler = Box<
60    dyn Fn(
61            TaskId,
62            TaskData,
63        )
64            -> Pin<Box<dyn std::future::Future<Output = Result<(), TaskError>> + Send + Sync>>
65        + Send
66        + Sync,
67>;
68
69/// An enumeration of possible errors that can occur while working with tasks.
70#[derive(Error, Debug)]
71pub enum TaskError {
72    #[error("Serialization error: {0}")]
73    SerializationError(#[from] serde_json::Error),
74
75    #[error("Database error: {0}")]
76    DatabaseError(#[from] tokio_postgres::Error),
77
78    #[error("Database pool error: {0}")]
79    PoolError(#[from] PoolError),
80
81    #[error("IO error: {0}")]
82    IoError(#[from] std::io::Error),
83
84    #[error("URL parsing error: {0}")]
85    UrlError(#[from] url::ParseError),
86}
87
88/// An enumeration of possible errors that can occur while connecting to the database.
89#[derive(Error, Debug)]
90pub enum ConnectionError {
91    #[error("URL parsing error: {0}")]
92    UrlError(#[from] url::ParseError),
93
94    #[error("Error creating pool: {0}")]
95    CreatePoolError(#[from] deadpool_postgres::CreatePoolError),
96}
97
98/// A struct representing a task in the task queue.
99#[derive(Debug, Deserialize, Serialize)]
100pub struct Task {
101    pub id: TaskId,
102    pub name: String,
103    pub data: TaskData,
104    pub status: TaskStatus,
105    pub run_at: DateTime<Utc>,
106    pub interval: Option<Duration>,
107}
108
109/// A struct for managing a registry of task handlers.
110pub struct TaskRegistry {
111    handlers: Arc<HashMap<String, TaskHandler>>,
112}
113
114impl TaskRegistry {
115    /// Creates a new TaskRegistry.
116    pub fn new() -> Self {
117        Self {
118            handlers: Arc::new(HashMap::new()),
119        }
120    }
121
122    /// Registers a task handler with the provided name.
123    pub fn register_task<F, Fut>(&mut self, name: String, handler: F)
124    where
125        F: Fn(i32, TaskData) -> Fut + Send + Sync + 'static,
126        Fut: Future<Output = Result<(), TaskError>> + Send + Sync + 'static,
127    {
128        let wrapped_handler = move |task_id: i32, task_data: TaskData| {
129            Box::pin(handler(task_id, task_data))
130                as Pin<Box<dyn Future<Output = Result<(), TaskError>> + Send + Sync>>
131        };
132
133        Arc::get_mut(&mut self.handlers)
134            .unwrap()
135            .insert(name, Box::new(wrapped_handler));
136    }
137
138    /// Returns a reference to the task handlers.
139    pub fn handlers(&self) -> &Arc<HashMap<String, TaskHandler>> {
140        &self.handlers
141    }
142
143    /// Runs the task handlers with the provided number of workers.
144    pub async fn run(
145        &self,
146        pool: &Pool,
147        num_workers: usize,
148    ) -> Result<Vec<JoinHandle<()>>, TaskError> {
149        let mut tasks = Vec::new();
150
151        for _ in 0..num_workers {
152            let pool = pool.clone(); // Clone the pool for each worker
153            let handlers = self.handlers.clone();
154
155            let task = tokio::spawn(async move {
156                let mut client = pool.get().await.expect("Failed to get client");
157                loop {
158                    let task_opt = dequeue(&mut client).await.expect("Failed to dequeue task");
159
160                    if let Some(task) = task_opt {
161                        if let Some(handler) = handlers.get(&task.name) {
162                            match handler(task.id, task.data).await {
163                                Ok(_) => {
164                                    complete_task(&client, task.id, task.interval)
165                                        .await
166                                        .expect("Failed to complete task");
167                                }
168                                Err(err) => {
169                                    let error_message = format!("{}", err);
170                                    fail_task(&client, task.id, &error_message)
171                                        .await
172                                        .expect("Failed to fail task");
173                                }
174                            }
175                        } else {
176                            eprintln!("No handler found for task: {}", task.name);
177                        }
178                    } else {
179                        sleep(Duration::from_secs(1)).await;
180                    }
181                }
182            });
183
184            tasks.push(task);
185        }
186
187        Ok(tasks)
188    }
189}
190
191impl Default for TaskRegistry {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197/// Creates a Deadpool configuration from a database URL.
198fn create_deadpool_config_from_url(url: &str) -> Result<Config, url::ParseError> {
199    let parsed_url = Url::parse(url)?;
200
201    let config = Config {
202        user: Some(parsed_url.username().to_owned()),
203        password: parsed_url.password().map(ToString::to_string),
204        host: Some(parsed_url.host_str().unwrap().to_owned()),
205        port: Some(parsed_url.port().unwrap_or(5432)),
206        dbname: Some(
207            parsed_url
208                .path_segments()
209                .map(|mut segments| segments.next().unwrap().to_owned())
210                .unwrap(),
211        ),
212        ..Default::default()
213    };
214
215    // TODO
216    // for (key, value) in parsed_url.query_pairs() {
217    //     config.options.push((key.to_owned(), value.to_owned()));
218    // }
219
220    Ok(config)
221}
222
223/// Connects to the PostgreSQL database using the provided URL.
224pub async fn connect(database_url: &str) -> Result<Pool, ConnectionError> {
225    let config = create_deadpool_config_from_url(database_url)?;
226    let pool = config.create_pool(Some(Runtime::Tokio1), tokio_postgres::NoTls)?;
227    Ok(pool)
228}
229
230/// Initializes the task queue database schema.
231pub async fn initialize_database(pool: &Pool) -> Result<(), TaskError> {
232    let client = pool.get().await?;
233    client
234        .batch_execute(
235            r#"
236            CREATE TABLE IF NOT EXISTS task_queue (
237                id SERIAL PRIMARY KEY,
238                name VARCHAR NOT NULL,
239                task_data JSONB NOT NULL,
240                status VARCHAR NOT NULL DEFAULT 'queued',
241                run_at TIMESTAMPTZ NOT NULL,
242                interval BIGINT,
243                created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
244                updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
245            );
246            "#,
247        )
248        .await?;
249    Ok(())
250}
251
252/// Enqueues a task with the specified parameters.
253pub async fn enqueue(
254    client: &Client,
255    name: &str,
256    task_data: TaskData,
257    run_at: DateTime<Utc>,
258    interval: Option<Duration>,
259) -> Result<TaskId, TaskError> {
260    let task_data_json = serde_json::to_value(task_data)?;
261    let interval_ms: Option<i64> = interval.map(|i| i.as_millis() as i64);
262    let row = client
263        .query_one(
264            "INSERT INTO task_queue (task_data, name, run_at, interval) VALUES ($1, $2, $3, $4) RETURNING id",
265            &[&task_data_json, &name, &run_at, &interval_ms],
266        )
267        .await?;
268    Ok(row.get(0))
269}
270
271/// Dequeues a task from the task queue.
272pub async fn dequeue(client: &mut Client) -> Result<Option<Task>, TaskError> {
273    let tx = client.transaction().await?;
274    let row = tx
275        .query_opt(
276            "SELECT id, name, task_data, status, run_at, interval FROM task_queue WHERE status = 'queued' AND run_at <= NOW() ORDER BY run_at LIMIT 1 FOR UPDATE SKIP LOCKED",
277            &[],
278        )
279        .await?;
280
281    if let Some(row) = row {
282        let interval_ms: Option<i64> = row.get(5);
283        let interval = interval_ms.map(|i| Duration::from_millis(i as u64)); // Convert i64 to Duration
284
285        let task = Task {
286            id: row.get(0),
287            name: row.get(1),
288            data: row.get(2),
289            status: row.get(3),
290            run_at: row.get(4),
291            interval,
292        };
293
294        tx.execute(
295            "UPDATE task_queue SET status = 'processing', updated_at = NOW() WHERE id = $1",
296            &[&task.id],
297        )
298        .await?;
299
300        tx.commit().await?;
301
302        Ok(Some(task))
303    } else {
304        Ok(None)
305    }
306}
307
308/// Marks a task as complete and reschedules it if it has an interval.
309pub async fn complete_task(
310    client: &Client,
311    task_id: TaskId,
312    interval: Option<Duration>,
313) -> Result<(), TaskError> {
314    if let Some(interval) = interval {
315        let interval_ms = interval.as_millis() as i64; // Convert Duration to i64
316        let next_run_at = Utc::now() + chrono::Duration::milliseconds(interval_ms);
317        client
318            .execute(
319                "UPDATE task_queue SET status = 'queued', updated_at = NOW(), run_at = $1 WHERE id = $2",
320                &[&next_run_at, &task_id],
321            )
322            .await?;
323    } else {
324        client
325            .execute(
326                "UPDATE task_queue SET status = 'completed', updated_at = NOW() WHERE id = $1",
327                &[&task_id],
328            )
329            .await?;
330    }
331    Ok(())
332}
333
334/// Marks a task as failed and stores the error message in the task data.
335pub async fn fail_task(
336    client: &Client,
337    task_id: TaskId,
338    error_message: &str,
339) -> Result<(), TaskError> {
340    let error_json = serde_json::json!({ "error": error_message });
341    client
342        .execute(
343            "UPDATE task_queue SET status = 'failed', updated_at = NOW(), task_data = task_data || $1::jsonb WHERE id = $2",
344            &[&error_json, &task_id],
345        )
346        .await?;
347    Ok(())
348}