pg_taskq/
worker.rs

1use futures::Future;
2use sqlx::{Pool, Postgres};
3use std::{
4    pin::Pin,
5    sync::atomic::AtomicUsize,
6    time::{Duration, SystemTime, UNIX_EPOCH},
7};
8use tokio::sync::broadcast::Receiver;
9use uuid::Uuid;
10
11use crate::{Error, Result, Task, TaskTableProvider, TaskType};
12
13static COUNTER: AtomicUsize = AtomicUsize::new(1);
14
15enum LoopAction {
16    Restart,
17    DoNothing,
18    Break,
19    Error(Error),
20}
21
22type TaskFunctionResult = Pin<Box<dyn Future<Output = std::result::Result<(), String>> + Send>>;
23
24pub struct Worker {
25    pool: Pool<Postgres>,
26    stop: Receiver<()>,
27    name: String,
28    tables: Box<dyn TaskTableProvider>,
29}
30
31impl Worker {
32    pub async fn start<F>(
33        pool: Pool<Postgres>,
34        tables: Box<dyn TaskTableProvider>,
35        stop: Receiver<()>,
36        supported_tasks: Vec<impl TaskType>,
37        process: F,
38    ) -> Result<()>
39    where
40        F: FnMut(Task) -> TaskFunctionResult,
41    {
42        let n = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
43        let name = format!("Worker.{n}");
44        let mut worker = Self {
45            pool,
46            tables,
47            name,
48            stop,
49        };
50        worker.run(supported_tasks, process).await
51    }
52
53    pub async fn run<F>(
54        &mut self,
55        supported_tasks: Vec<impl TaskType>,
56        mut process: F,
57    ) -> Result<()>
58    where
59        F: FnMut(Task) -> TaskFunctionResult,
60    {
61        let name = self.name.clone();
62
63        let mut listener = sqlx::postgres::PgListener::connect_with(&self.pool).await?;
64        listener.listen(self.tables.tasks_queue_name()).await?;
65
66        let mut last_status = UNIX_EPOCH;
67
68        info!("[{name}] starting");
69
70        loop {
71            // Get tasks that are ready that we haven't received a notification for
72            if last_status.elapsed().unwrap_or_default() > Duration::from_secs(60) {
73                last_status = SystemTime::now();
74                info!("[{name}] looking for tasks of type {supported_tasks:?}");
75            }
76
77            tokio::select! {
78                task = Task::load_any_waiting(&self.pool, &*self.tables, &supported_tasks) =>
79                        match self.deal_with_task_result(task, &mut process).await {
80                            LoopAction::Restart => continue,
81                            LoopAction::DoNothing => {}
82                            LoopAction::Break => break,
83                            LoopAction::Error(err) => return Err(err),
84                        },
85                _ = self.stop.recv() => {
86                    debug!("[{name}] Received STOP signal");
87                    break;
88                },
89            };
90
91            // wait for tasks becoming ready
92            trace!("[{name}] waiting for notifications...");
93
94            // let sleep_time =
95            //     (self.duration_until_rate_limit_refresh().await?).min(Duration::from_secs(30));
96            let sleep_time = Duration::from_secs(1);
97            let notification = tokio::select! {
98                notification = listener.recv() => notification,
99                _ = self.stop.recv() => {
100                    debug!("[{name}] Received STOP signal");
101                    break;
102                },
103                _ = tokio::time::sleep(sleep_time) => {
104                    continue;
105                },
106            };
107
108            let notification = match notification {
109                Err(sqlx::Error::PoolClosed) => {
110                    warn!("[{name}] pool closed");
111                    break;
112                }
113                Err(err) => {
114                    error!("[{name}] Error receiving notification {err}");
115                    return Err(err.into());
116                }
117                Ok(notification) => notification,
118            };
119
120            let id = match Uuid::parse_str(notification.payload()) {
121                Err(err) => {
122                    error!("[{name}] tasks_queue notification {notification:?} but were no able to parse task id: {err}");
123                    return Ok(());
124                }
125                Ok(id) => id,
126            };
127
128            let task = Task::load_waiting(id, &self.pool, &*self.tables, &supported_tasks).await;
129            match self.deal_with_task_result(task, &mut process).await {
130                LoopAction::Restart => continue,
131                LoopAction::DoNothing => {}
132                LoopAction::Break => break,
133                LoopAction::Error(err) => return Err(err),
134            }
135        }
136
137        info!("[{name}] stopping Worker");
138        // self.env.close().now_or_never();
139
140        Ok(())
141    }
142
143    async fn deal_with_task_result<F>(
144        &mut self,
145        task: Result<Option<Task>>,
146        process: &mut F,
147    ) -> LoopAction
148    where
149        F: FnMut(Task) -> TaskFunctionResult,
150    {
151        let name = &self.name;
152        match task {
153            Ok(Some(task)) => {
154                let id = task.id;
155                trace!("[{name}] task with id {id:?} can be processed");
156                if let Err(err) = process(task).await {
157                    error!("[{name}] Error processing task {id}: {err}");
158                    let error = serde_json::json!({"error": err.to_string()});
159                    if let Err(err) = Task::set_error(id, &self.pool, &*self.tables, error).await {
160                        error!("[{name}] Unable to set_error for {id}: {err}");
161                    }
162                }
163                LoopAction::Restart
164            }
165            Ok(None) => LoopAction::DoNothing,
166            Err(Error::Db(sqlx::error::Error::PoolClosed)) => {
167                warn!("[{name}] pool closed");
168                LoopAction::Break
169            }
170            Err(err) => {
171                error!("[{name}] unexpected error dealing with task: {err}");
172                LoopAction::Error(err)
173            }
174        }
175    }
176}