apalis_sqlite/queries/
wait_for.rs1use std::{collections::HashSet, str::FromStr, vec};
2
3use apalis_core::{
4 backend::{BackendExt, TaskResult, WaitForCompletion},
5 task::{status::Status, task_id::TaskId},
6};
7use futures::{StreamExt, stream::BoxStream};
8use serde::de::DeserializeOwned;
9use ulid::Ulid;
10
11use crate::{CompactType, SqliteStorage};
12
13#[derive(Debug)]
14struct ResultRow {
15 pub id: Option<String>,
16 pub status: Option<String>,
17 pub result: Option<String>,
18}
19
20impl<O: 'static + Send, Args, F, Decode> WaitForCompletion<O> for SqliteStorage<Args, Decode, F>
21where
22 Self: BackendExt<IdType = Ulid, Codec = Decode, Error = sqlx::Error, Compact = CompactType>,
23 Result<O, String>: DeserializeOwned,
24{
25 type ResultStream = BoxStream<'static, Result<TaskResult<O>, Self::Error>>;
26 fn wait_for(
27 &self,
28 task_ids: impl IntoIterator<Item = TaskId<Self::IdType>>,
29 ) -> Self::ResultStream {
30 let pool = self.pool.clone();
31 let ids: HashSet<String> = task_ids.into_iter().map(|id| id.to_string()).collect();
32
33 let stream = futures::stream::unfold(ids, move |mut remaining_ids| {
34 let pool = pool.clone();
35 async move {
36 if remaining_ids.is_empty() {
37 return None;
38 }
39
40 let ids_vec: Vec<String> = remaining_ids.iter().cloned().collect();
41 let ids_vec = serde_json::to_string(&ids_vec).unwrap();
42 let rows = sqlx::query_file_as!(
43 ResultRow,
44 "queries/backend/fetch_completed_tasks.sql",
45 ids_vec
46 )
47 .fetch_all(&pool)
48 .await
49 .ok()?;
50
51 if rows.is_empty() {
52 apalis_core::timer::sleep(std::time::Duration::from_millis(500)).await;
53 return Some((futures::stream::iter(vec![]), remaining_ids));
54 }
55
56 let mut results = Vec::new();
57 for row in rows {
58 let task_id = row.id.clone().unwrap();
59 remaining_ids.remove(&task_id);
60 let result: Result<O, String> =
63 serde_json::from_str(&row.result.unwrap()).unwrap();
64 results.push(Ok(TaskResult::new(
65 TaskId::from_str(&task_id).ok()?,
66 Status::from_str(&row.status.unwrap()).ok()?,
67 result,
68 )));
69 }
70
71 Some((futures::stream::iter(results), remaining_ids))
72 }
73 });
74 stream.flatten().boxed()
75 }
76
77 fn check_status(
79 &self,
80 task_ids: impl IntoIterator<Item = TaskId<Self::IdType>> + Send,
81 ) -> impl Future<Output = Result<Vec<TaskResult<O>>, Self::Error>> + Send {
82 let pool = self.pool.clone();
83 let ids: Vec<String> = task_ids.into_iter().map(|id| id.to_string()).collect();
84
85 async move {
86 let ids = serde_json::to_string(&ids).unwrap();
87 let rows =
88 sqlx::query_file_as!(ResultRow, "queries/backend/fetch_completed_tasks.sql", ids)
89 .fetch_all(&pool)
90 .await?;
91
92 let mut results = Vec::new();
93 for row in rows {
94 let task_id = TaskId::from_str(&row.id.unwrap())
95 .map_err(|_| sqlx::Error::Protocol("Invalid task ID".into()))?;
96
97 let result: Result<O, String> = serde_json::from_str(&row.result.unwrap())
98 .map_err(|_| sqlx::Error::Protocol("Failed to decode result".into()))?;
99
100 results.push(TaskResult::new(
101 task_id,
102 row.status
103 .unwrap()
104 .parse()
105 .map_err(|_| sqlx::Error::Protocol("Invalid status value".into()))?,
106 result,
107 ));
108 }
109
110 Ok(results)
111 }
112 }
113}