1use chrono::{DateTime, Utc};
36use deadpool_postgres::{Client, Config, Pool, PoolError, Runtime};
37use serde::{Deserialize, Serialize};
38use serde_json::Value as JsonValue;
39use std::collections::HashMap;
40use std::future::Future;
41use std::pin::Pin;
42use std::sync::Arc;
43use std::time::Duration;
44use thiserror::Error;
45use tokio::task::JoinHandle;
46use tokio::time::sleep;
47use url::Url;
48
49pub type TaskId = i32;
51
52pub type TaskData = JsonValue;
54
55pub type TaskStatus = String;
57
58pub type TaskHandler = Box<
60 dyn Fn(
61 TaskId,
62 TaskData,
63 )
64 -> Pin<Box<dyn std::future::Future<Output = Result<(), TaskError>> + Send + Sync>>
65 + Send
66 + Sync,
67>;
68
69#[derive(Error, Debug)]
71pub enum TaskError {
72 #[error("Serialization error: {0}")]
73 SerializationError(#[from] serde_json::Error),
74
75 #[error("Database error: {0}")]
76 DatabaseError(#[from] tokio_postgres::Error),
77
78 #[error("Database pool error: {0}")]
79 PoolError(#[from] PoolError),
80
81 #[error("IO error: {0}")]
82 IoError(#[from] std::io::Error),
83
84 #[error("URL parsing error: {0}")]
85 UrlError(#[from] url::ParseError),
86}
87
88#[derive(Error, Debug)]
90pub enum ConnectionError {
91 #[error("URL parsing error: {0}")]
92 UrlError(#[from] url::ParseError),
93
94 #[error("Error creating pool: {0}")]
95 CreatePoolError(#[from] deadpool_postgres::CreatePoolError),
96}
97
98#[derive(Debug, Deserialize, Serialize)]
100pub struct Task {
101 pub id: TaskId,
102 pub name: String,
103 pub data: TaskData,
104 pub status: TaskStatus,
105 pub run_at: DateTime<Utc>,
106 pub interval: Option<Duration>,
107}
108
109pub struct TaskRegistry {
111 handlers: Arc<HashMap<String, TaskHandler>>,
112}
113
114impl TaskRegistry {
115 pub fn new() -> Self {
117 Self {
118 handlers: Arc::new(HashMap::new()),
119 }
120 }
121
122 pub fn register_task<F, Fut>(&mut self, name: String, handler: F)
124 where
125 F: Fn(i32, TaskData) -> Fut + Send + Sync + 'static,
126 Fut: Future<Output = Result<(), TaskError>> + Send + Sync + 'static,
127 {
128 let wrapped_handler = move |task_id: i32, task_data: TaskData| {
129 Box::pin(handler(task_id, task_data))
130 as Pin<Box<dyn Future<Output = Result<(), TaskError>> + Send + Sync>>
131 };
132
133 Arc::get_mut(&mut self.handlers)
134 .unwrap()
135 .insert(name, Box::new(wrapped_handler));
136 }
137
138 pub fn handlers(&self) -> &Arc<HashMap<String, TaskHandler>> {
140 &self.handlers
141 }
142
143 pub async fn run(
145 &self,
146 pool: &Pool,
147 num_workers: usize,
148 ) -> Result<Vec<JoinHandle<()>>, TaskError> {
149 let mut tasks = Vec::new();
150
151 for _ in 0..num_workers {
152 let pool = pool.clone(); let handlers = self.handlers.clone();
154
155 let task = tokio::spawn(async move {
156 let mut client = pool.get().await.expect("Failed to get client");
157 loop {
158 let task_opt = dequeue(&mut client).await.expect("Failed to dequeue task");
159
160 if let Some(task) = task_opt {
161 if let Some(handler) = handlers.get(&task.name) {
162 match handler(task.id, task.data).await {
163 Ok(_) => {
164 complete_task(&client, task.id, task.interval)
165 .await
166 .expect("Failed to complete task");
167 }
168 Err(err) => {
169 let error_message = format!("{}", err);
170 fail_task(&client, task.id, &error_message)
171 .await
172 .expect("Failed to fail task");
173 }
174 }
175 } else {
176 eprintln!("No handler found for task: {}", task.name);
177 }
178 } else {
179 sleep(Duration::from_secs(1)).await;
180 }
181 }
182 });
183
184 tasks.push(task);
185 }
186
187 Ok(tasks)
188 }
189}
190
191impl Default for TaskRegistry {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197fn create_deadpool_config_from_url(url: &str) -> Result<Config, url::ParseError> {
199 let parsed_url = Url::parse(url)?;
200
201 let config = Config {
202 user: Some(parsed_url.username().to_owned()),
203 password: parsed_url.password().map(ToString::to_string),
204 host: Some(parsed_url.host_str().unwrap().to_owned()),
205 port: Some(parsed_url.port().unwrap_or(5432)),
206 dbname: Some(
207 parsed_url
208 .path_segments()
209 .map(|mut segments| segments.next().unwrap().to_owned())
210 .unwrap(),
211 ),
212 ..Default::default()
213 };
214
215 Ok(config)
221}
222
223pub async fn connect(database_url: &str) -> Result<Pool, ConnectionError> {
225 let config = create_deadpool_config_from_url(database_url)?;
226 let pool = config.create_pool(Some(Runtime::Tokio1), tokio_postgres::NoTls)?;
227 Ok(pool)
228}
229
230pub async fn initialize_database(pool: &Pool) -> Result<(), TaskError> {
232 let client = pool.get().await?;
233 client
234 .batch_execute(
235 r#"
236 CREATE TABLE IF NOT EXISTS task_queue (
237 id SERIAL PRIMARY KEY,
238 name VARCHAR NOT NULL,
239 task_data JSONB NOT NULL,
240 status VARCHAR NOT NULL DEFAULT 'queued',
241 run_at TIMESTAMPTZ NOT NULL,
242 interval BIGINT,
243 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
244 updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
245 );
246 "#,
247 )
248 .await?;
249 Ok(())
250}
251
252pub async fn enqueue(
254 client: &Client,
255 name: &str,
256 task_data: TaskData,
257 run_at: DateTime<Utc>,
258 interval: Option<Duration>,
259) -> Result<TaskId, TaskError> {
260 let task_data_json = serde_json::to_value(task_data)?;
261 let interval_ms: Option<i64> = interval.map(|i| i.as_millis() as i64);
262 let row = client
263 .query_one(
264 "INSERT INTO task_queue (task_data, name, run_at, interval) VALUES ($1, $2, $3, $4) RETURNING id",
265 &[&task_data_json, &name, &run_at, &interval_ms],
266 )
267 .await?;
268 Ok(row.get(0))
269}
270
271pub async fn dequeue(client: &mut Client) -> Result<Option<Task>, TaskError> {
273 let tx = client.transaction().await?;
274 let row = tx
275 .query_opt(
276 "SELECT id, name, task_data, status, run_at, interval FROM task_queue WHERE status = 'queued' AND run_at <= NOW() ORDER BY run_at LIMIT 1 FOR UPDATE SKIP LOCKED",
277 &[],
278 )
279 .await?;
280
281 if let Some(row) = row {
282 let interval_ms: Option<i64> = row.get(5);
283 let interval = interval_ms.map(|i| Duration::from_millis(i as u64)); let task = Task {
286 id: row.get(0),
287 name: row.get(1),
288 data: row.get(2),
289 status: row.get(3),
290 run_at: row.get(4),
291 interval,
292 };
293
294 tx.execute(
295 "UPDATE task_queue SET status = 'processing', updated_at = NOW() WHERE id = $1",
296 &[&task.id],
297 )
298 .await?;
299
300 tx.commit().await?;
301
302 Ok(Some(task))
303 } else {
304 Ok(None)
305 }
306}
307
308pub async fn complete_task(
310 client: &Client,
311 task_id: TaskId,
312 interval: Option<Duration>,
313) -> Result<(), TaskError> {
314 if let Some(interval) = interval {
315 let interval_ms = interval.as_millis() as i64; let next_run_at = Utc::now() + chrono::Duration::milliseconds(interval_ms);
317 client
318 .execute(
319 "UPDATE task_queue SET status = 'queued', updated_at = NOW(), run_at = $1 WHERE id = $2",
320 &[&next_run_at, &task_id],
321 )
322 .await?;
323 } else {
324 client
325 .execute(
326 "UPDATE task_queue SET status = 'completed', updated_at = NOW() WHERE id = $1",
327 &[&task_id],
328 )
329 .await?;
330 }
331 Ok(())
332}
333
334pub async fn fail_task(
336 client: &Client,
337 task_id: TaskId,
338 error_message: &str,
339) -> Result<(), TaskError> {
340 let error_json = serde_json::json!({ "error": error_message });
341 client
342 .execute(
343 "UPDATE task_queue SET status = 'failed', updated_at = NOW(), task_data = task_data || $1::jsonb WHERE id = $2",
344 &[&error_json, &task_id],
345 )
346 .await?;
347 Ok(())
348}