1use crate::errors::{PgQueueError, Result};
2use serde::{de::DeserializeOwned, Serialize};
3use sqlx::PgPool;
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct QueueName {
22 name: String,
23}
24
25impl QueueName {
26 pub fn new(name: impl Into<String>) -> Result<Self> {
27 let name = name.into();
28 if name.is_empty()
29 || !name
30 .chars()
31 .all(|c| c.is_ascii_alphanumeric() || c == '_')
32 {
33 return Err(PgQueueError::InvalidQueueName(name));
34 }
35 Ok(Self { name })
36 }
37
38 pub fn table_name(&self) -> String {
40 format!("queue_{}", self.name)
41 }
42
43 pub fn channel_name(&self) -> String {
45 self.table_name()
46 }
47
48 pub fn name(&self) -> &str {
50 &self.name
51 }
52}
53
54impl std::fmt::Display for QueueName {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.table_name())
57 }
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum JobStatus {
63 Pending,
64 Processing,
65 Completed,
66}
67
68impl JobStatus {
69 pub fn as_str(&self) -> &'static str {
70 match self {
71 Self::Pending => "pending",
72 Self::Processing => "processing",
73 Self::Completed => "completed",
74 }
75 }
76}
77
78impl std::fmt::Display for JobStatus {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.write_str(self.as_str())
81 }
82}
83
84#[derive(Debug)]
86pub struct Job<T> {
87 pub id: i64,
88 pub payload: T,
89}
90
91#[derive(Clone)]
93pub struct QueueRepository {
94 pool: PgPool,
95}
96
97impl QueueRepository {
98 pub fn new(pool: PgPool) -> Self {
99 Self { pool }
100 }
101
102 pub async fn push<T: Serialize>(&self, queue: &QueueName, payload: &T) -> Result<i64> {
104 let json = serde_json::to_value(payload)?;
105
106 let row: (i64,) = sqlx::query_as(&format!(
107 "INSERT INTO {} (payload) VALUES ($1) RETURNING id",
108 queue.table_name()
109 ))
110 .bind(json)
111 .fetch_one(&self.pool)
112 .await?;
113
114 Ok(row.0)
115 }
116
117 pub async fn pop<T: DeserializeOwned>(&self, queue: &QueueName) -> Result<Option<Job<T>>> {
120 let table = queue.table_name();
121
122 let row: Option<(i64, serde_json::Value)> = sqlx::query_as(&format!(
123 r#"
124 UPDATE {table} SET status = '{processing}', processed_at = NOW()
125 WHERE id = (
126 SELECT id FROM {table} WHERE status = '{pending}'
127 ORDER BY created_at FOR UPDATE SKIP LOCKED LIMIT 1
128 )
129 RETURNING id, payload
130 "#,
131 table = table,
132 processing = JobStatus::Processing,
133 pending = JobStatus::Pending,
134 ))
135 .fetch_optional(&self.pool)
136 .await?;
137
138 match row {
139 Some((id, payload)) => {
140 let parsed: T = serde_json::from_value(payload)?;
141 Ok(Some(Job {
142 id,
143 payload: parsed,
144 }))
145 }
146 None => Ok(None),
147 }
148 }
149
150 pub async fn complete(&self, queue: &QueueName, job_id: i64) -> Result<()> {
152 sqlx::query(&format!(
153 "UPDATE {} SET status = '{}' WHERE id = $1",
154 queue.table_name(),
155 JobStatus::Completed,
156 ))
157 .bind(job_id)
158 .execute(&self.pool)
159 .await?;
160
161 Ok(())
162 }
163
164 pub async fn fail(&self, queue: &QueueName, job_id: i64) -> Result<()> {
166 sqlx::query(&format!(
167 "UPDATE {} SET status = '{}', processed_at = NULL WHERE id = $1",
168 queue.table_name(),
169 JobStatus::Pending,
170 ))
171 .bind(job_id)
172 .execute(&self.pool)
173 .await?;
174
175 Ok(())
176 }
177
178 pub async fn pending_count(&self, queue: &QueueName) -> Result<i64> {
180 let row: (i64,) = sqlx::query_as(&format!(
181 "SELECT COUNT(*) FROM {} WHERE status = '{}'",
182 queue.table_name(),
183 JobStatus::Pending,
184 ))
185 .fetch_one(&self.pool)
186 .await?;
187
188 Ok(row.0)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_queue_name_valid() {
198 let q = QueueName::new("orders").unwrap();
199 assert_eq!(q.table_name(), "queue_orders");
200 assert_eq!(q.name(), "orders");
201 }
202
203 #[test]
204 fn test_queue_name_rejects_empty() {
205 assert!(QueueName::new("").is_err());
206 }
207
208 #[test]
209 fn test_queue_name_rejects_sql_injection() {
210 assert!(QueueName::new("x; DROP TABLE users; --").is_err());
211 assert!(QueueName::new("name with spaces").is_err());
212 assert!(QueueName::new("bad'name").is_err());
213 }
214
215 #[test]
216 fn test_queue_name_allows_underscores() {
217 let q = QueueName::new("my_queue_123").unwrap();
218 assert_eq!(q.table_name(), "queue_my_queue_123");
219 }
220
221 #[test]
222 fn test_queue_name_channel() {
223 let q = QueueName::new("emails").unwrap();
224 assert_eq!(q.channel_name(), "queue_emails");
225 assert_eq!(q.channel_name(), q.table_name());
226 }
227
228 #[test]
229 fn test_queue_name_display() {
230 let q = QueueName::new("tasks").unwrap();
231 assert_eq!(format!("{}", q), "queue_tasks");
232 }
233
234 #[test]
235 fn test_queue_name_equality() {
236 let a = QueueName::new("jobs").unwrap();
237 let b = QueueName::new("jobs").unwrap();
238 let c = QueueName::new("other").unwrap();
239 assert_eq!(a, b);
240 assert_ne!(a, c);
241 }
242
243 #[test]
244 fn test_job_status_as_str() {
245 assert_eq!(JobStatus::Pending.as_str(), "pending");
246 assert_eq!(JobStatus::Processing.as_str(), "processing");
247 assert_eq!(JobStatus::Completed.as_str(), "completed");
248 }
249
250 #[test]
251 fn test_job_status_display() {
252 assert_eq!(format!("{}", JobStatus::Pending), "pending");
253 }
254}