Skip to main content

kojin_redis/
result_backend.rs

1use async_trait::async_trait;
2use deadpool_redis::Pool;
3use redis::AsyncCommands;
4use std::time::Duration;
5
6use kojin_core::error::{KojinError, TaskResult};
7use kojin_core::result_backend::ResultBackend;
8use kojin_core::task_id::TaskId;
9
10use crate::config::RedisConfig;
11use crate::keys::KeyBuilder;
12
13fn backend_err(e: impl std::fmt::Display) -> KojinError {
14    KojinError::ResultBackend(e.to_string())
15}
16
17/// Lua script: atomically increment completed count and push result.
18/// Returns the new completed count.
19const GROUP_COMPLETE_SCRIPT: &str = r#"
20local completed_key = KEYS[1]
21local results_key = KEYS[2]
22local result = ARGV[1]
23local count = redis.call('INCR', completed_key)
24redis.call('RPUSH', results_key, result)
25return count
26"#;
27
28/// Redis-backed result storage.
29///
30/// Results are stored as JSON strings with a configurable TTL (default 24 hours).
31/// Group operations (`complete_group_member`) use a Lua script to atomically
32/// increment the completed count and push the result, preventing races when
33/// multiple workers finish group members concurrently.
34pub struct RedisResultBackend {
35    pool: Pool,
36    keys: KeyBuilder,
37    ttl: Duration,
38}
39
40impl RedisResultBackend {
41    /// Create a new Redis result backend.
42    ///
43    /// Builds a `deadpool_redis` connection pool from the given config and
44    /// verifies connectivity by acquiring one connection. The default result
45    /// TTL is 24 hours; override with [`with_ttl`](Self::with_ttl).
46    pub async fn new(config: RedisConfig) -> TaskResult<Self> {
47        let cfg = deadpool_redis::Config::from_url(&config.url);
48        let pool = cfg
49            .builder()
50            .map_err(backend_err)?
51            .max_size(config.pool_size)
52            .runtime(deadpool_redis::Runtime::Tokio1)
53            .build()
54            .map_err(backend_err)?;
55
56        // Verify connection
57        let _conn = pool.get().await.map_err(backend_err)?;
58
59        Ok(Self {
60            pool,
61            keys: KeyBuilder::new(config.key_prefix),
62            ttl: Duration::from_secs(86400), // 24h default
63        })
64    }
65
66    /// Override the result TTL (time-to-live).
67    ///
68    /// Results older than this duration are automatically expired by Redis.
69    /// Defaults to 24 hours if not called.
70    pub fn with_ttl(mut self, ttl: Duration) -> Self {
71        self.ttl = ttl;
72        self
73    }
74
75    async fn conn(&self) -> TaskResult<deadpool_redis::Connection> {
76        self.pool.get().await.map_err(backend_err)
77    }
78}
79
80#[async_trait]
81impl ResultBackend for RedisResultBackend {
82    async fn store(&self, id: &TaskId, result: &serde_json::Value) -> TaskResult<()> {
83        let mut conn = self.conn().await?;
84        let key = self.keys.result(&id.to_string());
85        let serialized = serde_json::to_string(result)?;
86        let ttl_secs = self.ttl.as_secs() as i64;
87
88        redis::cmd("SET")
89            .arg(&key)
90            .arg(&serialized)
91            .arg("EX")
92            .arg(ttl_secs)
93            .query_async::<()>(&mut *conn)
94            .await
95            .map_err(backend_err)?;
96
97        Ok(())
98    }
99
100    async fn get(&self, id: &TaskId) -> TaskResult<Option<serde_json::Value>> {
101        let mut conn = self.conn().await?;
102        let key = self.keys.result(&id.to_string());
103        let result: Option<String> = conn.get(&key).await.map_err(backend_err)?;
104
105        match result {
106            Some(s) => Ok(Some(serde_json::from_str(&s)?)),
107            None => Ok(None),
108        }
109    }
110
111    async fn wait(&self, id: &TaskId, timeout: Duration) -> TaskResult<serde_json::Value> {
112        let deadline = tokio::time::Instant::now() + timeout;
113        loop {
114            if let Some(result) = self.get(id).await? {
115                return Ok(result);
116            }
117            if tokio::time::Instant::now() >= deadline {
118                return Err(KojinError::Timeout(timeout));
119            }
120            tokio::time::sleep(Duration::from_millis(100)).await;
121        }
122    }
123
124    async fn delete(&self, id: &TaskId) -> TaskResult<()> {
125        let mut conn = self.conn().await?;
126        let key = self.keys.result(&id.to_string());
127        conn.del::<_, ()>(&key).await.map_err(backend_err)?;
128        Ok(())
129    }
130
131    async fn init_group(&self, group_id: &str, total: u32) -> TaskResult<()> {
132        let mut conn = self.conn().await?;
133        let key = self.keys.group_total(group_id);
134        let ttl_secs = self.ttl.as_secs() as i64;
135
136        redis::cmd("SET")
137            .arg(&key)
138            .arg(total)
139            .arg("EX")
140            .arg(ttl_secs)
141            .query_async::<()>(&mut *conn)
142            .await
143            .map_err(backend_err)?;
144
145        Ok(())
146    }
147
148    async fn complete_group_member(
149        &self,
150        group_id: &str,
151        _task_id: &TaskId,
152        result: &serde_json::Value,
153    ) -> TaskResult<u32> {
154        let mut conn = self.conn().await?;
155        let completed_key = self.keys.group_completed(group_id);
156        let results_key = self.keys.group_results(group_id);
157        let serialized = serde_json::to_string(result)?;
158
159        let script = redis::Script::new(GROUP_COMPLETE_SCRIPT);
160        let count: u32 = script
161            .key(&completed_key)
162            .key(&results_key)
163            .arg(&serialized)
164            .invoke_async(&mut *conn)
165            .await
166            .map_err(backend_err)?;
167
168        Ok(count)
169    }
170
171    async fn get_group_results(&self, group_id: &str) -> TaskResult<Vec<serde_json::Value>> {
172        let mut conn = self.conn().await?;
173        let results_key = self.keys.group_results(group_id);
174        let items: Vec<String> = conn
175            .lrange(&results_key, 0, -1)
176            .await
177            .map_err(backend_err)?;
178
179        items
180            .into_iter()
181            .map(|s| serde_json::from_str(&s).map_err(Into::into))
182            .collect()
183    }
184}
185
186#[cfg(all(test, feature = "integration-tests"))]
187mod tests {
188    use super::*;
189    use testcontainers::{ImageExt, runners::AsyncRunner};
190    use testcontainers_modules::redis::Redis;
191
192    async fn setup_backend() -> (RedisResultBackend, testcontainers::ContainerAsync<Redis>) {
193        let container = Redis::default().with_tag("7").start().await.unwrap();
194        let port = container.get_host_port_ipv4(6379).await.unwrap();
195        let config = RedisConfig::new(format!("redis://127.0.0.1:{port}")).with_prefix("test");
196        let backend = RedisResultBackend::new(config).await.unwrap();
197        (backend, container)
198    }
199
200    #[tokio::test]
201    async fn store_and_get() {
202        let (backend, _container) = setup_backend().await;
203        let id = TaskId::new();
204        let value = serde_json::json!({"result": 42});
205
206        backend.store(&id, &value).await.unwrap();
207        let got = backend.get(&id).await.unwrap();
208        assert_eq!(got, Some(value));
209    }
210
211    #[tokio::test]
212    async fn get_missing() {
213        let (backend, _container) = setup_backend().await;
214        let id = TaskId::new();
215        assert_eq!(backend.get(&id).await.unwrap(), None);
216    }
217
218    #[tokio::test]
219    async fn delete_result() {
220        let (backend, _container) = setup_backend().await;
221        let id = TaskId::new();
222        backend.store(&id, &serde_json::json!(1)).await.unwrap();
223        backend.delete(&id).await.unwrap();
224        assert_eq!(backend.get(&id).await.unwrap(), None);
225    }
226
227    #[tokio::test]
228    async fn group_completion() {
229        let (backend, _container) = setup_backend().await;
230        backend.init_group("g1", 3).await.unwrap();
231
232        let id1 = TaskId::new();
233        let id2 = TaskId::new();
234        let id3 = TaskId::new();
235
236        let c1 = backend
237            .complete_group_member("g1", &id1, &serde_json::json!(1))
238            .await
239            .unwrap();
240        assert_eq!(c1, 1);
241        let c2 = backend
242            .complete_group_member("g1", &id2, &serde_json::json!(2))
243            .await
244            .unwrap();
245        assert_eq!(c2, 2);
246        let c3 = backend
247            .complete_group_member("g1", &id3, &serde_json::json!(3))
248            .await
249            .unwrap();
250        assert_eq!(c3, 3);
251
252        let results = backend.get_group_results("g1").await.unwrap();
253        assert_eq!(results.len(), 3);
254    }
255}