apalis_redis/queries/
wait_for.rs1use apalis_core::backend::codec::Codec;
2use apalis_core::backend::{TaskResult, WaitForCompletion};
3use apalis_core::error::BoxDynError;
4use apalis_core::task::status::Status;
5use apalis_core::task::task_id::TaskId;
6use apalis_core::timer::sleep;
7use futures::stream::{self, BoxStream, StreamExt};
8use redis::aio::ConnectionLike;
9use std::collections::HashSet;
10use std::str::FromStr;
11use std::time::Duration;
12
13use crate::{RedisStorage, build_error};
14
15impl<Res, Args, Conn, Decode, Err> WaitForCompletion<Res> for RedisStorage<Args, Conn, Decode>
16where
17 Args: Unpin + Send + Sync + 'static,
18 Conn: Clone + ConnectionLike + Send + Sync + 'static,
19 Decode: Codec<Args, Compact = Vec<u8>, Error = Err>
20 + Codec<Result<Res, String>, Compact = Vec<u8>, Error = Err>
21 + Send
22 + Sync
23 + Unpin
24 + 'static
25 + Clone,
26 Err: Into<BoxDynError> + Send + 'static,
27 Res: Send + 'static,
28{
29 type ResultStream = BoxStream<'static, Result<TaskResult<Res>, Self::Error>>;
30
31 fn wait_for(
32 &self,
33 task_ids: impl IntoIterator<Item = TaskId<Self::IdType>>,
34 ) -> Self::ResultStream {
35 let storage = self.clone();
36 let pending_ids: HashSet<_> = task_ids.into_iter().map(|id| id.to_string()).collect();
37
38 stream::unfold(
39 (storage, pending_ids),
40 |(storage, mut pending_ids)| async move {
41 if pending_ids.is_empty() {
42 return None;
43 }
44
45 let ids_to_check: Vec<_> = pending_ids
47 .iter()
48 .cloned()
49 .map(|t| TaskId::from_str(&t).unwrap())
50 .collect();
51
52 match storage.check_status(ids_to_check).await {
53 Ok(results) => {
54 if results.is_empty() {
55 sleep(Duration::from_millis(100)).await;
57 Some((vec![], (storage, pending_ids)))
58 } else {
59 for result in &results {
61 pending_ids.remove(&result.task_id().to_string());
62 }
63
64 Some((
65 results.into_iter().map(Ok).collect(),
66 (storage, pending_ids),
67 ))
68 }
69 }
70 Err(e) => {
71 Some((vec![Err(e)], (storage, pending_ids)))
73 }
74 }
75 },
76 )
77 .flat_map(stream::iter)
78 .boxed()
79 }
80
81 async fn check_status(
82 &self,
83 task_ids: impl IntoIterator<Item = TaskId<Self::IdType>> + Send,
84 ) -> Result<Vec<TaskResult<Res>>, Self::Error> {
85 use redis::AsyncCommands;
86 let task_ids: Vec<_> = task_ids.into_iter().collect();
87 if task_ids.is_empty() {
88 return Ok(vec![]);
89 }
90
91 let mut conn = self.conn.clone();
92 let mut results = Vec::new();
93
94 for task_id in task_ids {
95 let task_id_str = task_id.to_string();
96 let task_meta_key = format!("{}:{}", self.config.job_meta_hash(), task_id_str);
97
98 let status: Option<String> = conn.hget(&task_meta_key, "status").await?;
100
101 if let Some(status_str) = status {
102 let status = Status::from_str(&status_str)
103 .map_err(|e| build_error(e.to_string().as_str()))?;
104
105 let result_ns = format!("{}:result", self.config.job_meta_hash());
107 let serialized_result: Option<Vec<u8>> =
108 conn.hget(&result_ns, &task_id_str).await?;
109
110 if let Some(data) = serialized_result {
111 let result: Result<Res, String> = Decode::decode(&data)
113 .map_err(|e: Err| build_error(e.into().to_string().as_str()))?;
114
115 results.push(TaskResult::new(
116 TaskId::from_str(&task_id.to_string()).unwrap(),
117 status,
118 result,
119 ));
120 }
121 }
122 }
123
124 Ok(results)
125 }
126}