celery/backend/
redis.rs

1use async_trait::async_trait;
2use deadpool_redis::{Config, Pool, Runtime};
3use redis::AsyncCommands;
4use std::time::Duration;
5
6use super::{ResultBackend, TaskMeta};
7use crate::error::BackendError;
8
9/// Redis-backed task result store compatible with Celery's default key layout.
10pub struct RedisBackend {
11    pool: Pool,
12    key_prefix: String,
13    result_ttl: Option<Duration>,
14}
15
16impl RedisBackend {
17    pub fn new(redis_url: &str) -> Result<Self, BackendError> {
18        let cfg = Config::from_url(redis_url);
19        let pool = cfg
20            .create_pool(Some(Runtime::Tokio1))
21            .map_err(|err| BackendError::PoolCreationError(err.to_string()))?;
22        Ok(Self {
23            pool,
24            key_prefix: "celery-task-meta".into(),
25            result_ttl: None,
26        })
27    }
28
29    pub fn with_key_prefix(mut self, prefix: impl Into<String>) -> Self {
30        self.key_prefix = prefix.into();
31        self
32    }
33
34    pub fn with_result_ttl(mut self, ttl: Duration) -> Self {
35        self.result_ttl = Some(ttl);
36        self
37    }
38
39    fn key_for(&self, task_id: &str) -> String {
40        format!("{}-{}", self.key_prefix, task_id)
41    }
42}
43
44#[async_trait]
45impl ResultBackend for RedisBackend {
46    async fn store_task_meta(&self, meta: TaskMeta) -> Result<(), BackendError> {
47        let payload = serde_json::to_string(&meta)?;
48        let mut conn = self.pool.get().await?;
49        let key = self.key_for(&meta.task_id);
50
51        if let Some(ttl) = self.result_ttl {
52            conn.set_ex::<_, _, ()>(key, payload, ttl.as_secs()).await?;
53        } else {
54            conn.set::<_, _, ()>(key, payload).await?;
55        }
56        Ok(())
57    }
58
59    async fn get_task_meta(&self, task_id: &str) -> Result<Option<TaskMeta>, BackendError> {
60        let mut conn = self.pool.get().await?;
61        let key = self.key_for(task_id);
62        let raw: Option<String> = conn.get(key).await?;
63        match raw {
64            Some(json) => Ok(Some(serde_json::from_str(&json)?)),
65            None => Ok(None),
66        }
67    }
68
69    async fn forget(&self, task_id: &str) -> Result<(), BackendError> {
70        let mut conn = self.pool.get().await?;
71        let key = self.key_for(task_id);
72        let _: () = conn.del(key).await?;
73        Ok(())
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use crate::backend::{TaskMeta, TaskState};
81    use async_trait::async_trait;
82
83    struct InMemoryBackend {
84        store: tokio::sync::Mutex<std::collections::HashMap<String, String>>,
85    }
86
87    #[async_trait]
88    impl ResultBackend for InMemoryBackend {
89        async fn store_task_meta(&self, meta: TaskMeta) -> Result<(), BackendError> {
90            let key = format!("celery-task-meta-{}", meta.task_id);
91            let json = serde_json::to_string(&meta)?;
92            self.store.lock().await.insert(key, json);
93            Ok(())
94        }
95
96        async fn get_task_meta(&self, task_id: &str) -> Result<Option<TaskMeta>, BackendError> {
97            let key = format!("celery-task-meta-{}", task_id);
98            Ok(self
99                .store
100                .lock()
101                .await
102                .get(&key)
103                .map(|json| serde_json::from_str(json).unwrap()))
104        }
105
106        async fn forget(&self, task_id: &str) -> Result<(), BackendError> {
107            let key = format!("celery-task-meta-{}", task_id);
108            self.store.lock().await.remove(&key);
109            Ok(())
110        }
111    }
112
113    #[tokio::test]
114    async fn mock_backend_roundtrip() {
115        let backend = InMemoryBackend {
116            store: tokio::sync::Mutex::new(std::collections::HashMap::new()),
117        };
118
119        let meta = TaskMeta {
120            task_id: "abc".into(),
121            status: TaskState::Success,
122            result: None,
123            traceback: None,
124            children: vec![],
125            date_done: None,
126            retries: None,
127            eta: None,
128            meta: None,
129        };
130
131        backend.store_task_meta(meta.clone()).await.unwrap();
132        let stored = backend.get_task_meta("abc").await.unwrap();
133        assert_eq!(stored, Some(meta));
134        backend.forget("abc").await.unwrap();
135        assert_eq!(backend.get_task_meta("abc").await.unwrap(), None);
136    }
137}