apalis_postgres/queries/
wait_for.rs

1use std::{collections::HashSet, str::FromStr, vec};
2
3use apalis_core::{
4    backend::{BackendExt, TaskResult, WaitForCompletion},
5    task::{status::Status, task_id::TaskId},
6};
7use futures::{StreamExt, stream::BoxStream};
8use serde::de::DeserializeOwned;
9use ulid::Ulid;
10
11use crate::{CompactType, PgContext, PostgresStorage};
12
13#[derive(Debug)]
14pub struct TaskResultRow {
15    pub id: Option<String>,
16    pub status: Option<String>,
17    pub result: Option<serde_json::Value>,
18}
19
20impl<O: 'static + Send, Args, F, Decode> WaitForCompletion<O>
21    for PostgresStorage<Args, CompactType, Decode, F>
22where
23    PostgresStorage<Args, CompactType, Decode, F>:
24        BackendExt<Context = PgContext, Compact = CompactType, IdType = Ulid, Error = sqlx::Error>,
25    Result<O, String>: DeserializeOwned,
26{
27    type ResultStream = BoxStream<'static, Result<TaskResult<O, Ulid>, Self::Error>>;
28    fn wait_for(
29        &self,
30        task_ids: impl IntoIterator<Item = TaskId<Self::IdType>>,
31    ) -> Self::ResultStream {
32        let pool = self.pool.clone();
33        let ids: HashSet<String> = task_ids.into_iter().map(|id| id.to_string()).collect();
34
35        let stream = futures::stream::unfold(ids, move |mut remaining_ids| {
36            let pool = pool.clone();
37            async move {
38                if remaining_ids.is_empty() {
39                    return None;
40                }
41
42                let ids_vec: Vec<String> = remaining_ids.iter().cloned().collect();
43                let ids_vec = serde_json::to_value(&ids_vec).unwrap();
44                let rows = sqlx::query_file_as!(
45                    TaskResultRow,
46                    "queries/backend/fetch_completed_tasks.sql",
47                    ids_vec
48                )
49                .fetch_all(&pool)
50                .await
51                .ok()?;
52
53                if rows.is_empty() {
54                    apalis_core::timer::sleep(std::time::Duration::from_millis(500)).await;
55                    return Some((futures::stream::iter(vec![]), remaining_ids));
56                }
57
58                let mut results = Vec::new();
59                for row in rows {
60                    let task_id = row.id.clone().unwrap();
61                    remaining_ids.remove(&task_id);
62                    // Here we would normally decode the output O from the row
63                    // For simplicity, we assume O is String and the output is stored in row.output
64                    let result: Result<O, String> =
65                        serde_json::from_value(row.result.unwrap()).unwrap();
66                    results.push(Ok(TaskResult::new(
67                        TaskId::from_str(&task_id).ok()?,
68                        Status::from_str(&row.status.unwrap()).ok()?,
69                        result,
70                    )));
71                }
72
73                Some((futures::stream::iter(results), remaining_ids))
74            }
75        });
76        stream.flatten().boxed()
77    }
78
79    // Implementation of check_status
80    fn check_status(
81        &self,
82        task_ids: impl IntoIterator<Item = TaskId<Self::IdType>> + Send,
83    ) -> impl Future<Output = Result<Vec<TaskResult<O, Ulid>>, Self::Error>> + Send {
84        let pool = self.pool.clone();
85        let ids: Vec<String> = task_ids.into_iter().map(|id| id.to_string()).collect();
86
87        async move {
88            let ids = serde_json::to_value(&ids).unwrap();
89            let rows = sqlx::query_file_as!(
90                TaskResultRow,
91                "queries/backend/fetch_completed_tasks.sql",
92                ids
93            )
94            .fetch_all(&pool)
95            .await?;
96
97            let mut results = Vec::new();
98            for row in rows {
99                let task_id = TaskId::from_str(&row.id.unwrap())
100                    .map_err(|_| sqlx::Error::Protocol("Invalid task ID".into()))?;
101
102                let result: Result<O, String> = serde_json::from_value(row.result.unwrap())
103                    .map_err(|_| sqlx::Error::Protocol("Failed to decode result".into()))?;
104
105                results.push(TaskResult::new(
106                    task_id,
107                    row.status
108                        .unwrap()
109                        .parse()
110                        .map_err(|_| sqlx::Error::Protocol("Invalid status value".into()))?,
111                    result,
112                ));
113            }
114
115            Ok(results)
116        }
117    }
118}