apalis_sqlite/queries/
wait_for.rs

1use std::{collections::HashSet, str::FromStr, vec};
2
3use apalis_core::{
4    backend::{Backend, TaskResult, WaitForCompletion},
5    task::{status::Status, task_id::TaskId},
6};
7use futures::{StreamExt, stream::BoxStream};
8use serde::de::DeserializeOwned;
9use ulid::Ulid;
10
11use crate::{CompactType, SqliteStorage};
12
13#[derive(Debug)]
14struct ResultRow {
15    pub id: Option<String>,
16    pub status: Option<String>,
17    pub result: Option<String>,
18}
19
20impl<O: 'static + Send, Args, F, Decode> WaitForCompletion<O> for SqliteStorage<Args, Decode, F>
21where
22    SqliteStorage<Args, Decode, F>:
23        Backend<IdType = Ulid, Codec = Decode, Error = sqlx::Error, Compact = CompactType>,
24    Result<O, String>: DeserializeOwned,
25{
26    type ResultStream = BoxStream<'static, Result<TaskResult<O>, Self::Error>>;
27    fn wait_for(
28        &self,
29        task_ids: impl IntoIterator<Item = TaskId<Self::IdType>>,
30    ) -> Self::ResultStream {
31        let pool = self.pool.clone();
32        let ids: HashSet<String> = task_ids.into_iter().map(|id| id.to_string()).collect();
33
34        let stream = futures::stream::unfold(ids, move |mut remaining_ids| {
35            let pool = pool.clone();
36            async move {
37                if remaining_ids.is_empty() {
38                    return None;
39                }
40
41                let ids_vec: Vec<String> = remaining_ids.iter().cloned().collect();
42                let ids_vec = serde_json::to_string(&ids_vec).unwrap();
43                let rows = sqlx::query_file_as!(
44                    ResultRow,
45                    "queries/backend/fetch_completed_tasks.sql",
46                    ids_vec
47                )
48                .fetch_all(&pool)
49                .await
50                .ok()?;
51
52                if rows.is_empty() {
53                    apalis_core::timer::sleep(std::time::Duration::from_millis(500)).await;
54                    return Some((futures::stream::iter(vec![]), remaining_ids));
55                }
56
57                let mut results = Vec::new();
58                for row in rows {
59                    let task_id = row.id.clone().unwrap();
60                    remaining_ids.remove(&task_id);
61                    // Here we would normally decode the output O from the row
62                    // For simplicity, we assume O is String and the output is stored in row.output
63                    let result: Result<O, String> =
64                        serde_json::from_str(&row.result.unwrap()).unwrap();
65                    results.push(Ok(TaskResult::new(
66                        TaskId::from_str(&task_id).ok()?,
67                        Status::from_str(&row.status.unwrap()).ok()?,
68                        result,
69                    )));
70                }
71
72                Some((futures::stream::iter(results), remaining_ids))
73            }
74        });
75        stream.flatten().boxed()
76    }
77
78    // Implementation of check_status
79    fn check_status(
80        &self,
81        task_ids: impl IntoIterator<Item = TaskId<Self::IdType>> + Send,
82    ) -> impl Future<Output = Result<Vec<TaskResult<O>>, Self::Error>> + Send {
83        let pool = self.pool.clone();
84        let ids: Vec<String> = task_ids.into_iter().map(|id| id.to_string()).collect();
85
86        async move {
87            let ids = serde_json::to_string(&ids).unwrap();
88            let rows =
89                sqlx::query_file_as!(ResultRow, "queries/backend/fetch_completed_tasks.sql", ids)
90                    .fetch_all(&pool)
91                    .await?;
92
93            let mut results = Vec::new();
94            for row in rows {
95                let task_id = TaskId::from_str(&row.id.unwrap())
96                    .map_err(|_| sqlx::Error::Protocol("Invalid task ID".into()))?;
97
98                let result: Result<O, String> = serde_json::from_str(&row.result.unwrap())
99                    .map_err(|_| sqlx::Error::Protocol("Failed to decode result".into()))?;
100
101                results.push(TaskResult::new(
102                    task_id,
103                    row.status
104                        .unwrap()
105                        .parse()
106                        .map_err(|_| sqlx::Error::Protocol("Invalid status value".into()))?,
107                    result,
108                ));
109            }
110
111            Ok(results)
112        }
113    }
114}