1use async_trait::async_trait;
2use deadpool_redis::{Config, Pool, Runtime};
3use redis::AsyncCommands;
4use std::time::Duration;
5
6use super::{ResultBackend, TaskMeta};
7use crate::error::BackendError;
8
9pub struct RedisBackend {
11 pool: Pool,
12 key_prefix: String,
13 result_ttl: Option<Duration>,
14}
15
16impl RedisBackend {
17 pub fn new(redis_url: &str) -> Result<Self, BackendError> {
18 let cfg = Config::from_url(redis_url);
19 let pool = cfg
20 .create_pool(Some(Runtime::Tokio1))
21 .map_err(|err| BackendError::PoolCreationError(err.to_string()))?;
22 Ok(Self {
23 pool,
24 key_prefix: "celery-task-meta".into(),
25 result_ttl: None,
26 })
27 }
28
29 pub fn with_key_prefix(mut self, prefix: impl Into<String>) -> Self {
30 self.key_prefix = prefix.into();
31 self
32 }
33
34 pub fn with_result_ttl(mut self, ttl: Duration) -> Self {
35 self.result_ttl = Some(ttl);
36 self
37 }
38
39 fn key_for(&self, task_id: &str) -> String {
40 format!("{}-{}", self.key_prefix, task_id)
41 }
42}
43
44#[async_trait]
45impl ResultBackend for RedisBackend {
46 async fn store_task_meta(&self, meta: TaskMeta) -> Result<(), BackendError> {
47 let payload = serde_json::to_string(&meta)?;
48 let mut conn = self.pool.get().await?;
49 let key = self.key_for(&meta.task_id);
50
51 if let Some(ttl) = self.result_ttl {
52 conn.set_ex::<_, _, ()>(key, payload, ttl.as_secs()).await?;
53 } else {
54 conn.set::<_, _, ()>(key, payload).await?;
55 }
56 Ok(())
57 }
58
59 async fn get_task_meta(&self, task_id: &str) -> Result<Option<TaskMeta>, BackendError> {
60 let mut conn = self.pool.get().await?;
61 let key = self.key_for(task_id);
62 let raw: Option<String> = conn.get(key).await?;
63 match raw {
64 Some(json) => Ok(Some(serde_json::from_str(&json)?)),
65 None => Ok(None),
66 }
67 }
68
69 async fn forget(&self, task_id: &str) -> Result<(), BackendError> {
70 let mut conn = self.pool.get().await?;
71 let key = self.key_for(task_id);
72 let _: () = conn.del(key).await?;
73 Ok(())
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use crate::backend::{TaskMeta, TaskState};
81 use async_trait::async_trait;
82
83 struct InMemoryBackend {
84 store: tokio::sync::Mutex<std::collections::HashMap<String, String>>,
85 }
86
87 #[async_trait]
88 impl ResultBackend for InMemoryBackend {
89 async fn store_task_meta(&self, meta: TaskMeta) -> Result<(), BackendError> {
90 let key = format!("celery-task-meta-{}", meta.task_id);
91 let json = serde_json::to_string(&meta)?;
92 self.store.lock().await.insert(key, json);
93 Ok(())
94 }
95
96 async fn get_task_meta(&self, task_id: &str) -> Result<Option<TaskMeta>, BackendError> {
97 let key = format!("celery-task-meta-{}", task_id);
98 Ok(self
99 .store
100 .lock()
101 .await
102 .get(&key)
103 .map(|json| serde_json::from_str(json).unwrap()))
104 }
105
106 async fn forget(&self, task_id: &str) -> Result<(), BackendError> {
107 let key = format!("celery-task-meta-{}", task_id);
108 self.store.lock().await.remove(&key);
109 Ok(())
110 }
111 }
112
113 #[tokio::test]
114 async fn mock_backend_roundtrip() {
115 let backend = InMemoryBackend {
116 store: tokio::sync::Mutex::new(std::collections::HashMap::new()),
117 };
118
119 let meta = TaskMeta {
120 task_id: "abc".into(),
121 status: TaskState::Success,
122 result: None,
123 traceback: None,
124 children: vec![],
125 date_done: None,
126 retries: None,
127 eta: None,
128 meta: None,
129 };
130
131 backend.store_task_meta(meta.clone()).await.unwrap();
132 let stored = backend.get_task_meta("abc").await.unwrap();
133 assert_eq!(stored, Some(meta));
134 backend.forget("abc").await.unwrap();
135 assert_eq!(backend.get_task_meta("abc").await.unwrap(), None);
136 }
137}