kojin_redis/
result_backend.rs1use async_trait::async_trait;
2use deadpool_redis::Pool;
3use redis::AsyncCommands;
4use std::time::Duration;
5
6use kojin_core::error::{KojinError, TaskResult};
7use kojin_core::result_backend::ResultBackend;
8use kojin_core::task_id::TaskId;
9
10use crate::config::RedisConfig;
11use crate::keys::KeyBuilder;
12
13fn backend_err(e: impl std::fmt::Display) -> KojinError {
14 KojinError::ResultBackend(e.to_string())
15}
16
17const GROUP_COMPLETE_SCRIPT: &str = r#"
20local completed_key = KEYS[1]
21local results_key = KEYS[2]
22local result = ARGV[1]
23local count = redis.call('INCR', completed_key)
24redis.call('RPUSH', results_key, result)
25return count
26"#;
27
28pub struct RedisResultBackend {
35 pool: Pool,
36 keys: KeyBuilder,
37 ttl: Duration,
38}
39
40impl RedisResultBackend {
41 pub async fn new(config: RedisConfig) -> TaskResult<Self> {
47 let cfg = deadpool_redis::Config::from_url(&config.url);
48 let pool = cfg
49 .builder()
50 .map_err(backend_err)?
51 .max_size(config.pool_size)
52 .runtime(deadpool_redis::Runtime::Tokio1)
53 .build()
54 .map_err(backend_err)?;
55
56 let _conn = pool.get().await.map_err(backend_err)?;
58
59 Ok(Self {
60 pool,
61 keys: KeyBuilder::new(config.key_prefix),
62 ttl: Duration::from_secs(86400), })
64 }
65
66 pub fn with_ttl(mut self, ttl: Duration) -> Self {
71 self.ttl = ttl;
72 self
73 }
74
75 async fn conn(&self) -> TaskResult<deadpool_redis::Connection> {
76 self.pool.get().await.map_err(backend_err)
77 }
78}
79
80#[async_trait]
81impl ResultBackend for RedisResultBackend {
82 async fn store(&self, id: &TaskId, result: &serde_json::Value) -> TaskResult<()> {
83 let mut conn = self.conn().await?;
84 let key = self.keys.result(&id.to_string());
85 let serialized = serde_json::to_string(result)?;
86 let ttl_secs = self.ttl.as_secs() as i64;
87
88 redis::cmd("SET")
89 .arg(&key)
90 .arg(&serialized)
91 .arg("EX")
92 .arg(ttl_secs)
93 .query_async::<()>(&mut *conn)
94 .await
95 .map_err(backend_err)?;
96
97 Ok(())
98 }
99
100 async fn get(&self, id: &TaskId) -> TaskResult<Option<serde_json::Value>> {
101 let mut conn = self.conn().await?;
102 let key = self.keys.result(&id.to_string());
103 let result: Option<String> = conn.get(&key).await.map_err(backend_err)?;
104
105 match result {
106 Some(s) => Ok(Some(serde_json::from_str(&s)?)),
107 None => Ok(None),
108 }
109 }
110
111 async fn wait(&self, id: &TaskId, timeout: Duration) -> TaskResult<serde_json::Value> {
112 let deadline = tokio::time::Instant::now() + timeout;
113 loop {
114 if let Some(result) = self.get(id).await? {
115 return Ok(result);
116 }
117 if tokio::time::Instant::now() >= deadline {
118 return Err(KojinError::Timeout(timeout));
119 }
120 tokio::time::sleep(Duration::from_millis(100)).await;
121 }
122 }
123
124 async fn delete(&self, id: &TaskId) -> TaskResult<()> {
125 let mut conn = self.conn().await?;
126 let key = self.keys.result(&id.to_string());
127 conn.del::<_, ()>(&key).await.map_err(backend_err)?;
128 Ok(())
129 }
130
131 async fn init_group(&self, group_id: &str, total: u32) -> TaskResult<()> {
132 let mut conn = self.conn().await?;
133 let key = self.keys.group_total(group_id);
134 let ttl_secs = self.ttl.as_secs() as i64;
135
136 redis::cmd("SET")
137 .arg(&key)
138 .arg(total)
139 .arg("EX")
140 .arg(ttl_secs)
141 .query_async::<()>(&mut *conn)
142 .await
143 .map_err(backend_err)?;
144
145 Ok(())
146 }
147
148 async fn complete_group_member(
149 &self,
150 group_id: &str,
151 _task_id: &TaskId,
152 result: &serde_json::Value,
153 ) -> TaskResult<u32> {
154 let mut conn = self.conn().await?;
155 let completed_key = self.keys.group_completed(group_id);
156 let results_key = self.keys.group_results(group_id);
157 let serialized = serde_json::to_string(result)?;
158
159 let script = redis::Script::new(GROUP_COMPLETE_SCRIPT);
160 let count: u32 = script
161 .key(&completed_key)
162 .key(&results_key)
163 .arg(&serialized)
164 .invoke_async(&mut *conn)
165 .await
166 .map_err(backend_err)?;
167
168 Ok(count)
169 }
170
171 async fn get_group_results(&self, group_id: &str) -> TaskResult<Vec<serde_json::Value>> {
172 let mut conn = self.conn().await?;
173 let results_key = self.keys.group_results(group_id);
174 let items: Vec<String> = conn
175 .lrange(&results_key, 0, -1)
176 .await
177 .map_err(backend_err)?;
178
179 items
180 .into_iter()
181 .map(|s| serde_json::from_str(&s).map_err(Into::into))
182 .collect()
183 }
184}
185
186#[cfg(all(test, feature = "integration-tests"))]
187mod tests {
188 use super::*;
189 use testcontainers::{ImageExt, runners::AsyncRunner};
190 use testcontainers_modules::redis::Redis;
191
192 async fn setup_backend() -> (RedisResultBackend, testcontainers::ContainerAsync<Redis>) {
193 let container = Redis::default().with_tag("7").start().await.unwrap();
194 let port = container.get_host_port_ipv4(6379).await.unwrap();
195 let config = RedisConfig::new(format!("redis://127.0.0.1:{port}")).with_prefix("test");
196 let backend = RedisResultBackend::new(config).await.unwrap();
197 (backend, container)
198 }
199
200 #[tokio::test]
201 async fn store_and_get() {
202 let (backend, _container) = setup_backend().await;
203 let id = TaskId::new();
204 let value = serde_json::json!({"result": 42});
205
206 backend.store(&id, &value).await.unwrap();
207 let got = backend.get(&id).await.unwrap();
208 assert_eq!(got, Some(value));
209 }
210
211 #[tokio::test]
212 async fn get_missing() {
213 let (backend, _container) = setup_backend().await;
214 let id = TaskId::new();
215 assert_eq!(backend.get(&id).await.unwrap(), None);
216 }
217
218 #[tokio::test]
219 async fn delete_result() {
220 let (backend, _container) = setup_backend().await;
221 let id = TaskId::new();
222 backend.store(&id, &serde_json::json!(1)).await.unwrap();
223 backend.delete(&id).await.unwrap();
224 assert_eq!(backend.get(&id).await.unwrap(), None);
225 }
226
227 #[tokio::test]
228 async fn group_completion() {
229 let (backend, _container) = setup_backend().await;
230 backend.init_group("g1", 3).await.unwrap();
231
232 let id1 = TaskId::new();
233 let id2 = TaskId::new();
234 let id3 = TaskId::new();
235
236 let c1 = backend
237 .complete_group_member("g1", &id1, &serde_json::json!(1))
238 .await
239 .unwrap();
240 assert_eq!(c1, 1);
241 let c2 = backend
242 .complete_group_member("g1", &id2, &serde_json::json!(2))
243 .await
244 .unwrap();
245 assert_eq!(c2, 2);
246 let c3 = backend
247 .complete_group_member("g1", &id3, &serde_json::json!(3))
248 .await
249 .unwrap();
250 assert_eq!(c3, 3);
251
252 let results = backend.get_group_results("g1").await.unwrap();
253 assert_eq!(results.len(), 3);
254 }
255}