celery/task/
async_result.rs

1use std::sync::Arc;
2use tokio::time::{sleep, Duration, Instant};
3
4use crate::backend::{ResultBackend, TaskMeta, TaskState};
5use crate::error::BackendError;
6
7/// An [`AsyncResult`] is a handle for the result of a task.
8#[derive(Clone)]
9pub struct AsyncResult {
10    pub task_id: String,
11    backend: Option<Arc<dyn ResultBackend>>,
12    poll_interval: Duration,
13}
14
15impl std::fmt::Debug for AsyncResult {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        f.debug_struct("AsyncResult")
18            .field("task_id", &self.task_id)
19            .finish()
20    }
21}
22
23impl AsyncResult {
24    pub fn new(task_id: &str) -> Self {
25        Self {
26            task_id: task_id.into(),
27            backend: None,
28            poll_interval: Duration::from_millis(200),
29        }
30    }
31
32    pub(crate) fn with_backend(task_id: &str, backend: Option<Arc<dyn ResultBackend>>) -> Self {
33        Self {
34            task_id: task_id.into(),
35            backend,
36            poll_interval: Duration::from_millis(200),
37        }
38    }
39
40    /// Returns the task identifier.
41    pub fn task_id(&self) -> &str {
42        &self.task_id
43    }
44
45    /// Returns the current backend state for this task.
46    pub async fn state(&self) -> Result<TaskState, BackendError> {
47        Ok(self.fetch_meta().await?.status)
48    }
49
50    /// Returns whether the task finished successfully or failed.
51    pub async fn ready(&self) -> Result<bool, BackendError> {
52        Ok(self.state().await?.is_ready())
53    }
54
55    /// Blocks until the task finishes and returns the result serialized as `T`.
56    ///
57    /// If `timeout` is provided the method returns [`BackendError::Timeout`] when the
58    /// interval elapses.
59    pub async fn get<T>(&self, timeout: Option<Duration>) -> Result<T, BackendError>
60    where
61        T: serde::de::DeserializeOwned,
62    {
63        let start = Instant::now();
64        loop {
65            let meta = self.fetch_meta().await?;
66            match meta.status {
67                TaskState::Success => {
68                    let value = meta.result.unwrap_or_default();
69                    return Ok(serde_json::from_value(value)?);
70                }
71                TaskState::Failure => {
72                    let message = meta
73                        .meta
74                        .as_ref()
75                        .and_then(|meta| meta.get("exc_message"))
76                        .and_then(|v| v.as_str())
77                        .map(|s| s.to_string())
78                        .or_else(|| meta.result.and_then(|r| r.as_str().map(|s| s.to_string())))
79                        .unwrap_or_else(|| "task failed".into());
80                    return Err(BackendError::TaskFailed(message));
81                }
82                TaskState::Retry | TaskState::Pending | TaskState::Started => {
83                    if let Some(duration) = timeout {
84                        if start.elapsed() >= duration {
85                            return Err(BackendError::Timeout);
86                        }
87                    }
88                    sleep(self.poll_interval).await;
89                }
90            }
91        }
92    }
93
94    async fn fetch_meta(&self) -> Result<TaskMeta, BackendError> {
95        let backend = self.backend.as_ref().ok_or(BackendError::NotConfigured)?;
96        let task_id = self.task_id.clone();
97        let meta = backend.get_task_meta(&task_id).await?;
98        Ok(meta.unwrap_or_else(|| TaskMeta::pending(&task_id)))
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use async_trait::async_trait;
106    use std::collections::HashMap;
107    use tokio::sync::Mutex;
108
109    use crate::backend::{TaskMeta, TaskState};
110
111    struct MockBackend {
112        store: Mutex<HashMap<String, TaskMeta>>,
113    }
114
115    impl MockBackend {
116        fn new() -> Arc<Self> {
117            Arc::new(Self {
118                store: Mutex::new(HashMap::new()),
119            })
120        }
121    }
122
123    #[async_trait]
124    impl ResultBackend for MockBackend {
125        async fn store_task_meta(&self, meta: TaskMeta) -> Result<(), BackendError> {
126            self.store.lock().await.insert(meta.task_id.clone(), meta);
127            Ok(())
128        }
129
130        async fn get_task_meta(&self, task_id: &str) -> Result<Option<TaskMeta>, BackendError> {
131            Ok(self.store.lock().await.get(task_id).cloned())
132        }
133
134        async fn forget(&self, task_id: &str) -> Result<(), BackendError> {
135            self.store.lock().await.remove(task_id);
136            Ok(())
137        }
138    }
139
140    #[tokio::test]
141    async fn state_defaults_to_pending() {
142        let backend = MockBackend::new();
143        let result = AsyncResult::with_backend("abc", Some(backend as Arc<_>));
144        assert_eq!(result.state().await.unwrap(), TaskState::Pending);
145    }
146
147    #[tokio::test]
148    async fn ready_reflects_success() {
149        let backend = MockBackend::new();
150        backend
151            .store_task_meta(TaskMeta {
152                task_id: "abc".into(),
153                status: TaskState::Success,
154                result: Some(serde_json::json!(123)),
155                traceback: None,
156                children: vec![],
157                date_done: None,
158                retries: None,
159                eta: None,
160                meta: None,
161            })
162            .await
163            .unwrap();
164        let result = AsyncResult::with_backend("abc", Some(backend.clone()));
165        assert!(result.ready().await.unwrap());
166    }
167
168    #[tokio::test]
169    async fn get_waits_until_success() {
170        let backend = MockBackend::new();
171        let result = AsyncResult::with_backend("abc", Some(backend.clone()));
172        let handle = tokio::spawn(async move {
173            tokio::time::sleep(Duration::from_millis(250)).await;
174            backend
175                .store_task_meta(TaskMeta {
176                    task_id: "abc".into(),
177                    status: TaskState::Success,
178                    result: Some(serde_json::json!({"value": 10})),
179                    traceback: None,
180                    children: vec![],
181                    date_done: None,
182                    retries: None,
183                    eta: None,
184                    meta: None,
185                })
186                .await
187                .unwrap();
188        });
189
190        let value: serde_json::Value = result.get(Some(Duration::from_secs(1))).await.unwrap();
191        assert_eq!(value, serde_json::json!({"value": 10}));
192        handle.await.unwrap();
193    }
194}