apalis_postgres/queries/
wait_for.rs1use std::{collections::HashSet, str::FromStr, vec};
2
3use apalis_core::{
4 backend::{Backend, TaskResult, WaitForCompletion},
5 task::{status::Status, task_id::TaskId},
6};
7use apalis_sql::context::SqlContext;
8use futures::{StreamExt, stream::BoxStream};
9use serde::de::DeserializeOwned;
10use ulid::Ulid;
11
12use crate::{CompactType, PostgresStorage};
13
14#[derive(Debug)]
15pub struct TaskResultRow {
16 pub id: Option<String>,
17 pub status: Option<String>,
18 pub result: Option<serde_json::Value>,
19}
20
21impl<O: 'static + Send, Args, F, Decode> WaitForCompletion<O>
22 for PostgresStorage<Args, CompactType, Decode, F>
23where
24 PostgresStorage<Args, CompactType, Decode, F>:
25 Backend<Context = SqlContext, Compact = CompactType, IdType = Ulid, Error = sqlx::Error>,
26 Result<O, String>: DeserializeOwned,
27{
28 type ResultStream = BoxStream<'static, Result<TaskResult<O>, Self::Error>>;
29 fn wait_for(
30 &self,
31 task_ids: impl IntoIterator<Item = TaskId<Self::IdType>>,
32 ) -> Self::ResultStream {
33 let pool = self.pool.clone();
34 let ids: HashSet<String> = task_ids.into_iter().map(|id| id.to_string()).collect();
35
36 let stream = futures::stream::unfold(ids, move |mut remaining_ids| {
37 let pool = pool.clone();
38 async move {
39 if remaining_ids.is_empty() {
40 return None;
41 }
42
43 let ids_vec: Vec<String> = remaining_ids.iter().cloned().collect();
44 let ids_vec = serde_json::to_value(&ids_vec).unwrap();
45 let rows = sqlx::query_file_as!(
46 TaskResultRow,
47 "queries/backend/fetch_completed_tasks.sql",
48 ids_vec
49 )
50 .fetch_all(&pool)
51 .await
52 .ok()?;
53
54 if rows.is_empty() {
55 apalis_core::timer::sleep(std::time::Duration::from_millis(500)).await;
56 return Some((futures::stream::iter(vec![]), remaining_ids));
57 }
58
59 let mut results = Vec::new();
60 for row in rows {
61 let task_id = row.id.clone().unwrap();
62 remaining_ids.remove(&task_id);
63 let result: Result<O, String> =
66 serde_json::from_value(row.result.unwrap()).unwrap();
67 results.push(Ok(TaskResult::new(
68 TaskId::from_str(&task_id).ok()?,
69 Status::from_str(&row.status.unwrap()).ok()?,
70 result,
71 )));
72 }
73
74 Some((futures::stream::iter(results), remaining_ids))
75 }
76 });
77 stream.flatten().boxed()
78 }
79
80 fn check_status(
82 &self,
83 task_ids: impl IntoIterator<Item = TaskId<Self::IdType>> + Send,
84 ) -> impl Future<Output = Result<Vec<TaskResult<O>>, Self::Error>> + Send {
85 let pool = self.pool.clone();
86 let ids: Vec<String> = task_ids.into_iter().map(|id| id.to_string()).collect();
87
88 async move {
89 let ids = serde_json::to_value(&ids).unwrap();
90 let rows = sqlx::query_file_as!(
91 TaskResultRow,
92 "queries/backend/fetch_completed_tasks.sql",
93 ids
94 )
95 .fetch_all(&pool)
96 .await?;
97
98 let mut results = Vec::new();
99 for row in rows {
100 let task_id = TaskId::from_str(&row.id.unwrap())
101 .map_err(|_| sqlx::Error::Protocol("Invalid task ID".into()))?;
102
103 let result: Result<O, String> = serde_json::from_value(row.result.unwrap())
104 .map_err(|_| sqlx::Error::Protocol("Failed to decode result".into()))?;
105
106 results.push(TaskResult::new(
107 task_id,
108 row.status
109 .unwrap()
110 .parse()
111 .map_err(|_| sqlx::Error::Protocol("Invalid status value".into()))?,
112 result,
113 ));
114 }
115
116 Ok(results)
117 }
118 }
119}