Skip to main content

pg_queue/
queue.rs

1use crate::errors::{PgQueueError, Result};
2use serde::{de::DeserializeOwned, Serialize};
3use sqlx::PgPool;
4
5/// A named queue backed by a PostgreSQL table.
6///
7/// Each `QueueName` maps to a table called `queue_{name}`.
8/// Create the table using the `pg_queue_create_queue()` SQL function
9/// from `migrations/setup.sql`.
10///
11/// Names must be non-empty and contain only ASCII alphanumerics or underscores.
12///
13/// # Example
14/// ```
15/// use pg_queue::QueueName;
16///
17/// let emails = QueueName::new("emails").unwrap();
18/// assert_eq!(emails.table_name(), "queue_emails");
19/// ```
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct QueueName {
22    name: String,
23}
24
25impl QueueName {
26    pub fn new(name: impl Into<String>) -> Result<Self> {
27        let name = name.into();
28        if name.is_empty()
29            || !name
30                .chars()
31                .all(|c| c.is_ascii_alphanumeric() || c == '_')
32        {
33            return Err(PgQueueError::InvalidQueueName(name));
34        }
35        Ok(Self { name })
36    }
37
38    /// Returns the backing table name: `queue_{name}`
39    pub fn table_name(&self) -> String {
40        format!("queue_{}", self.name)
41    }
42
43    /// Returns the NOTIFY channel name (same as table name by convention)
44    pub fn channel_name(&self) -> String {
45        self.table_name()
46    }
47
48    /// Returns the raw queue name
49    pub fn name(&self) -> &str {
50        &self.name
51    }
52}
53
54impl std::fmt::Display for QueueName {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        write!(f, "{}", self.table_name())
57    }
58}
59
60/// Type-safe job status for queue state transitions
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum JobStatus {
63    Pending,
64    Processing,
65    Completed,
66}
67
68impl JobStatus {
69    pub fn as_str(&self) -> &'static str {
70        match self {
71            Self::Pending => "pending",
72            Self::Processing => "processing",
73            Self::Completed => "completed",
74        }
75    }
76}
77
78impl std::fmt::Display for JobStatus {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        f.write_str(self.as_str())
81    }
82}
83
84/// A job retrieved from the queue
85#[derive(Debug)]
86pub struct Job<T> {
87    pub id: i64,
88    pub payload: T,
89}
90
91/// Queue repository for push/pop operations using SKIP LOCKED
92#[derive(Clone)]
93pub struct QueueRepository {
94    pool: PgPool,
95}
96
97impl QueueRepository {
98    pub fn new(pool: PgPool) -> Self {
99        Self { pool }
100    }
101
102    /// Push a job to the queue
103    pub async fn push<T: Serialize>(&self, queue: &QueueName, payload: &T) -> Result<i64> {
104        let json = serde_json::to_value(payload)?;
105
106        let row: (i64,) = sqlx::query_as(&format!(
107            "INSERT INTO {} (payload) VALUES ($1) RETURNING id",
108            queue.table_name()
109        ))
110        .bind(json)
111        .fetch_one(&self.pool)
112        .await?;
113
114        Ok(row.0)
115    }
116
117    /// Pop a job from the queue using SKIP LOCKED for concurrent safety.
118    /// Returns None if no pending jobs are available.
119    pub async fn pop<T: DeserializeOwned>(&self, queue: &QueueName) -> Result<Option<Job<T>>> {
120        let table = queue.table_name();
121
122        let row: Option<(i64, serde_json::Value)> = sqlx::query_as(&format!(
123            r#"
124            UPDATE {table} SET status = '{processing}', processed_at = NOW()
125            WHERE id = (
126                SELECT id FROM {table} WHERE status = '{pending}'
127                ORDER BY created_at FOR UPDATE SKIP LOCKED LIMIT 1
128            )
129            RETURNING id, payload
130            "#,
131            table = table,
132            processing = JobStatus::Processing,
133            pending = JobStatus::Pending,
134        ))
135        .fetch_optional(&self.pool)
136        .await?;
137
138        match row {
139            Some((id, payload)) => {
140                let parsed: T = serde_json::from_value(payload)?;
141                Ok(Some(Job {
142                    id,
143                    payload: parsed,
144                }))
145            }
146            None => Ok(None),
147        }
148    }
149
150    /// Mark a job as completed
151    pub async fn complete(&self, queue: &QueueName, job_id: i64) -> Result<()> {
152        sqlx::query(&format!(
153            "UPDATE {} SET status = '{}' WHERE id = $1",
154            queue.table_name(),
155            JobStatus::Completed,
156        ))
157        .bind(job_id)
158        .execute(&self.pool)
159        .await?;
160
161        Ok(())
162    }
163
164    /// Mark a job as failed, resetting it to pending for retry
165    pub async fn fail(&self, queue: &QueueName, job_id: i64) -> Result<()> {
166        sqlx::query(&format!(
167            "UPDATE {} SET status = '{}', processed_at = NULL WHERE id = $1",
168            queue.table_name(),
169            JobStatus::Pending,
170        ))
171        .bind(job_id)
172        .execute(&self.pool)
173        .await?;
174
175        Ok(())
176    }
177
178    /// Get the count of pending jobs in a queue
179    pub async fn pending_count(&self, queue: &QueueName) -> Result<i64> {
180        let row: (i64,) = sqlx::query_as(&format!(
181            "SELECT COUNT(*) FROM {} WHERE status = '{}'",
182            queue.table_name(),
183            JobStatus::Pending,
184        ))
185        .fetch_one(&self.pool)
186        .await?;
187
188        Ok(row.0)
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn test_queue_name_valid() {
198        let q = QueueName::new("orders").unwrap();
199        assert_eq!(q.table_name(), "queue_orders");
200        assert_eq!(q.name(), "orders");
201    }
202
203    #[test]
204    fn test_queue_name_rejects_empty() {
205        assert!(QueueName::new("").is_err());
206    }
207
208    #[test]
209    fn test_queue_name_rejects_sql_injection() {
210        assert!(QueueName::new("x; DROP TABLE users; --").is_err());
211        assert!(QueueName::new("name with spaces").is_err());
212        assert!(QueueName::new("bad'name").is_err());
213    }
214
215    #[test]
216    fn test_queue_name_allows_underscores() {
217        let q = QueueName::new("my_queue_123").unwrap();
218        assert_eq!(q.table_name(), "queue_my_queue_123");
219    }
220
221    #[test]
222    fn test_queue_name_channel() {
223        let q = QueueName::new("emails").unwrap();
224        assert_eq!(q.channel_name(), "queue_emails");
225        assert_eq!(q.channel_name(), q.table_name());
226    }
227
228    #[test]
229    fn test_queue_name_display() {
230        let q = QueueName::new("tasks").unwrap();
231        assert_eq!(format!("{}", q), "queue_tasks");
232    }
233
234    #[test]
235    fn test_queue_name_equality() {
236        let a = QueueName::new("jobs").unwrap();
237        let b = QueueName::new("jobs").unwrap();
238        let c = QueueName::new("other").unwrap();
239        assert_eq!(a, b);
240        assert_ne!(a, c);
241    }
242
243    #[test]
244    fn test_job_status_as_str() {
245        assert_eq!(JobStatus::Pending.as_str(), "pending");
246        assert_eq!(JobStatus::Processing.as_str(), "processing");
247        assert_eq!(JobStatus::Completed.as_str(), "completed");
248    }
249
250    #[test]
251    fn test_job_status_display() {
252        assert_eq!(format!("{}", JobStatus::Pending), "pending");
253    }
254}