apalis_postgres/queries/
wait_for.rs

1use std::{collections::HashSet, str::FromStr, vec};
2
3use apalis_core::{
4    backend::{Backend, TaskResult, WaitForCompletion},
5    task::{status::Status, task_id::TaskId},
6};
7use apalis_sql::context::SqlContext;
8use futures::{StreamExt, stream::BoxStream};
9use serde::de::DeserializeOwned;
10use ulid::Ulid;
11
12use crate::{CompactType, PostgresStorage};
13
14#[derive(Debug)]
15pub struct TaskResultRow {
16    pub id: Option<String>,
17    pub status: Option<String>,
18    pub result: Option<serde_json::Value>,
19}
20
21impl<O: 'static + Send, Args, F, Decode> WaitForCompletion<O>
22    for PostgresStorage<Args, CompactType, Decode, F>
23where
24    PostgresStorage<Args, CompactType, Decode, F>:
25        Backend<Context = SqlContext, Compact = CompactType, IdType = Ulid, Error = sqlx::Error>,
26    Result<O, String>: DeserializeOwned,
27{
28    type ResultStream = BoxStream<'static, Result<TaskResult<O>, Self::Error>>;
29    fn wait_for(
30        &self,
31        task_ids: impl IntoIterator<Item = TaskId<Self::IdType>>,
32    ) -> Self::ResultStream {
33        let pool = self.pool.clone();
34        let ids: HashSet<String> = task_ids.into_iter().map(|id| id.to_string()).collect();
35
36        let stream = futures::stream::unfold(ids, move |mut remaining_ids| {
37            let pool = pool.clone();
38            async move {
39                if remaining_ids.is_empty() {
40                    return None;
41                }
42
43                let ids_vec: Vec<String> = remaining_ids.iter().cloned().collect();
44                let ids_vec = serde_json::to_value(&ids_vec).unwrap();
45                let rows = sqlx::query_file_as!(
46                    TaskResultRow,
47                    "queries/backend/fetch_completed_tasks.sql",
48                    ids_vec
49                )
50                .fetch_all(&pool)
51                .await
52                .ok()?;
53
54                if rows.is_empty() {
55                    apalis_core::timer::sleep(std::time::Duration::from_millis(500)).await;
56                    return Some((futures::stream::iter(vec![]), remaining_ids));
57                }
58
59                let mut results = Vec::new();
60                for row in rows {
61                    let task_id = row.id.clone().unwrap();
62                    remaining_ids.remove(&task_id);
63                    // Here we would normally decode the output O from the row
64                    // For simplicity, we assume O is String and the output is stored in row.output
65                    let result: Result<O, String> =
66                        serde_json::from_value(row.result.unwrap()).unwrap();
67                    results.push(Ok(TaskResult::new(
68                        TaskId::from_str(&task_id).ok()?,
69                        Status::from_str(&row.status.unwrap()).ok()?,
70                        result,
71                    )));
72                }
73
74                Some((futures::stream::iter(results), remaining_ids))
75            }
76        });
77        stream.flatten().boxed()
78    }
79
80    // Implementation of check_status
81    fn check_status(
82        &self,
83        task_ids: impl IntoIterator<Item = TaskId<Self::IdType>> + Send,
84    ) -> impl Future<Output = Result<Vec<TaskResult<O>>, Self::Error>> + Send {
85        let pool = self.pool.clone();
86        let ids: Vec<String> = task_ids.into_iter().map(|id| id.to_string()).collect();
87
88        async move {
89            let ids = serde_json::to_value(&ids).unwrap();
90            let rows = sqlx::query_file_as!(
91                TaskResultRow,
92                "queries/backend/fetch_completed_tasks.sql",
93                ids
94            )
95            .fetch_all(&pool)
96            .await?;
97
98            let mut results = Vec::new();
99            for row in rows {
100                let task_id = TaskId::from_str(&row.id.unwrap())
101                    .map_err(|_| sqlx::Error::Protocol("Invalid task ID".into()))?;
102
103                let result: Result<O, String> = serde_json::from_value(row.result.unwrap())
104                    .map_err(|_| sqlx::Error::Protocol("Failed to decode result".into()))?;
105
106                results.push(TaskResult::new(
107                    task_id,
108                    row.status
109                        .unwrap()
110                        .parse()
111                        .map_err(|_| sqlx::Error::Protocol("Invalid status value".into()))?,
112                    result,
113                ));
114            }
115
116            Ok(results)
117        }
118    }
119}