apalis_redis/queries/
wait_for.rs

1use apalis_core::backend::codec::Codec;
2use apalis_core::backend::{TaskResult, WaitForCompletion};
3use apalis_core::error::BoxDynError;
4use apalis_core::task::status::Status;
5use apalis_core::task::task_id::TaskId;
6use apalis_core::timer::sleep;
7use futures::stream::{self, BoxStream, StreamExt};
8use redis::aio::ConnectionLike;
9use std::collections::HashSet;
10use std::str::FromStr;
11use std::time::Duration;
12
13use crate::{RedisStorage, build_error};
14
15impl<Res, Args, Conn, Decode, Err> WaitForCompletion<Res> for RedisStorage<Args, Conn, Decode>
16where
17    Args: Unpin + Send + Sync + 'static,
18    Conn: Clone + ConnectionLike + Send + Sync + 'static,
19    Decode: Codec<Args, Compact = Vec<u8>, Error = Err>
20        + Codec<Result<Res, String>, Compact = Vec<u8>, Error = Err>
21        + Send
22        + Sync
23        + Unpin
24        + 'static
25        + Clone,
26    Err: Into<BoxDynError> + Send + 'static,
27    Res: Send + 'static,
28{
29    type ResultStream = BoxStream<'static, Result<TaskResult<Res>, Self::Error>>;
30
31    fn wait_for(
32        &self,
33        task_ids: impl IntoIterator<Item = TaskId<Self::IdType>>,
34    ) -> Self::ResultStream {
35        let storage = self.clone();
36        let pending_ids: HashSet<_> = task_ids.into_iter().map(|id| id.to_string()).collect();
37
38        stream::unfold(
39            (storage, pending_ids),
40            |(storage, mut pending_ids)| async move {
41                if pending_ids.is_empty() {
42                    return None;
43                }
44
45                // Poll for completed tasks
46                let ids_to_check: Vec<_> = pending_ids
47                    .iter()
48                    .cloned()
49                    .map(|t| TaskId::from_str(&t).unwrap())
50                    .collect();
51
52                match storage.check_status(ids_to_check).await {
53                    Ok(results) => {
54                        if results.is_empty() {
55                            // No tasks completed yet, wait before next poll
56                            sleep(Duration::from_millis(100)).await;
57                            Some((vec![], (storage, pending_ids)))
58                        } else {
59                            // Remove completed task IDs from pending set
60                            for result in &results {
61                                pending_ids.remove(&result.task_id().to_string());
62                            }
63
64                            Some((
65                                results.into_iter().map(Ok).collect(),
66                                (storage, pending_ids),
67                            ))
68                        }
69                    }
70                    Err(e) => {
71                        // Emit error and terminate stream
72                        Some((vec![Err(e)], (storage, pending_ids)))
73                    }
74                }
75            },
76        )
77        .flat_map(stream::iter)
78        .boxed()
79    }
80
81    async fn check_status(
82        &self,
83        task_ids: impl IntoIterator<Item = TaskId<Self::IdType>> + Send,
84    ) -> Result<Vec<TaskResult<Res>>, Self::Error> {
85        use redis::AsyncCommands;
86        let task_ids: Vec<_> = task_ids.into_iter().collect();
87        if task_ids.is_empty() {
88            return Ok(vec![]);
89        }
90
91        let mut conn = self.conn.clone();
92        let mut results = Vec::new();
93
94        for task_id in task_ids {
95            let task_id_str = task_id.to_string();
96            let task_meta_key = format!("{}:{}", self.config.job_meta_hash(), task_id_str);
97
98            // Check if task has a status (Done or Failed)
99            let status: Option<String> = conn.hget(&task_meta_key, "status").await?;
100
101            if let Some(status_str) = status {
102                let status = Status::from_str(&status_str)
103                    .map_err(|e| build_error(e.to_string().as_str()))?;
104
105                // Fetch the serialized result
106                let result_ns = format!("{}:result", self.config.job_meta_hash());
107                let serialized_result: Option<Vec<u8>> =
108                    conn.hget(&result_ns, &task_id_str).await?;
109
110                if let Some(data) = serialized_result {
111                    // Deserialize the Result<Res, String>
112                    let result: Result<Res, String> = Decode::decode(&data)
113                        .map_err(|e: Err| build_error(e.into().to_string().as_str()))?;
114
115                    results.push(TaskResult::new(
116                        TaskId::from_str(&task_id.to_string()).unwrap(),
117                        status,
118                        result,
119                    ));
120                }
121            }
122        }
123
124        Ok(results)
125    }
126}