kojin_postgres/
result_backend.rs1use async_trait::async_trait;
2use sqlx::PgPool;
3use std::time::Duration;
4
5use kojin_core::error::{KojinError, TaskResult};
6use kojin_core::result_backend::ResultBackend;
7use kojin_core::task_id::TaskId;
8
9fn backend_err(e: impl std::fmt::Display) -> KojinError {
10 KojinError::ResultBackend(e.to_string())
11}
12
13pub struct PostgresResultBackend {
21 pool: PgPool,
22 ttl: Duration,
23}
24
25impl PostgresResultBackend {
26 pub fn new(pool: PgPool) -> Self {
31 Self {
32 pool,
33 ttl: Duration::from_secs(86400), }
35 }
36
37 pub async fn connect(url: &str) -> TaskResult<Self> {
41 let pool = PgPool::connect(url).await.map_err(backend_err)?;
42 Ok(Self::new(pool))
43 }
44
45 pub fn with_ttl(mut self, ttl: Duration) -> Self {
51 self.ttl = ttl;
52 self
53 }
54
55 pub async fn migrate(&self) -> TaskResult<()> {
58 sqlx::query(
59 r#"
60 CREATE TABLE IF NOT EXISTS kojin_results (
61 task_id TEXT PRIMARY KEY,
62 result JSONB NOT NULL,
63 created_at TIMESTAMPTZ DEFAULT NOW(),
64 expires_at TIMESTAMPTZ
65 )
66 "#,
67 )
68 .execute(&self.pool)
69 .await
70 .map_err(backend_err)?;
71
72 sqlx::query(
73 r#"
74 CREATE TABLE IF NOT EXISTS kojin_groups (
75 group_id TEXT NOT NULL,
76 task_id TEXT NOT NULL,
77 result JSONB,
78 completed BOOLEAN DEFAULT FALSE,
79 PRIMARY KEY (group_id, task_id)
80 )
81 "#,
82 )
83 .execute(&self.pool)
84 .await
85 .map_err(backend_err)?;
86
87 sqlx::query(
89 r#"
90 CREATE INDEX IF NOT EXISTS idx_kojin_results_expires
91 ON kojin_results (expires_at)
92 WHERE expires_at IS NOT NULL
93 "#,
94 )
95 .execute(&self.pool)
96 .await
97 .map_err(backend_err)?;
98
99 Ok(())
100 }
101}
102
103#[async_trait]
104impl ResultBackend for PostgresResultBackend {
105 async fn store(&self, id: &TaskId, result: &serde_json::Value) -> TaskResult<()> {
106 let expires_at = chrono::Utc::now() + chrono::Duration::seconds(self.ttl.as_secs() as i64);
107
108 sqlx::query(
109 r#"
110 INSERT INTO kojin_results (task_id, result, expires_at)
111 VALUES ($1, $2, $3)
112 ON CONFLICT (task_id) DO UPDATE SET result = $2, expires_at = $3
113 "#,
114 )
115 .bind(id.to_string())
116 .bind(result)
117 .bind(expires_at)
118 .execute(&self.pool)
119 .await
120 .map_err(backend_err)?;
121
122 Ok(())
123 }
124
125 async fn get(&self, id: &TaskId) -> TaskResult<Option<serde_json::Value>> {
126 let row: Option<(serde_json::Value,)> = sqlx::query_as(
127 r#"
128 SELECT result FROM kojin_results
129 WHERE task_id = $1 AND (expires_at IS NULL OR expires_at > NOW())
130 "#,
131 )
132 .bind(id.to_string())
133 .fetch_optional(&self.pool)
134 .await
135 .map_err(backend_err)?;
136
137 Ok(row.map(|(r,)| r))
138 }
139
140 async fn wait(&self, id: &TaskId, timeout: Duration) -> TaskResult<serde_json::Value> {
141 let deadline = tokio::time::Instant::now() + timeout;
142 loop {
143 if let Some(result) = self.get(id).await? {
144 return Ok(result);
145 }
146 if tokio::time::Instant::now() >= deadline {
147 return Err(KojinError::Timeout(timeout));
148 }
149 tokio::time::sleep(Duration::from_millis(100)).await;
150 }
151 }
152
153 async fn delete(&self, id: &TaskId) -> TaskResult<()> {
154 sqlx::query("DELETE FROM kojin_results WHERE task_id = $1")
155 .bind(id.to_string())
156 .execute(&self.pool)
157 .await
158 .map_err(backend_err)?;
159 Ok(())
160 }
161
162 async fn init_group(&self, group_id: &str, total: u32) -> TaskResult<()> {
163 for i in 0..total {
166 let placeholder_id = format!("{group_id}:placeholder:{i}");
167 sqlx::query(
168 r#"
169 INSERT INTO kojin_groups (group_id, task_id, completed)
170 VALUES ($1, $2, FALSE)
171 ON CONFLICT (group_id, task_id) DO NOTHING
172 "#,
173 )
174 .bind(group_id)
175 .bind(&placeholder_id)
176 .execute(&self.pool)
177 .await
178 .map_err(backend_err)?;
179 }
180 Ok(())
181 }
182
183 async fn complete_group_member(
184 &self,
185 group_id: &str,
186 task_id: &TaskId,
187 result: &serde_json::Value,
188 ) -> TaskResult<u32> {
189 sqlx::query(
191 r#"
192 INSERT INTO kojin_groups (group_id, task_id, result, completed)
193 VALUES ($1, $2, $3, TRUE)
194 ON CONFLICT (group_id, task_id) DO UPDATE SET result = $3, completed = TRUE
195 "#,
196 )
197 .bind(group_id)
198 .bind(task_id.to_string())
199 .bind(result)
200 .execute(&self.pool)
201 .await
202 .map_err(backend_err)?;
203
204 let (count,): (i64,) = sqlx::query_as(
206 r#"
207 SELECT COUNT(*) FROM kojin_groups
208 WHERE group_id = $1 AND completed = TRUE AND result IS NOT NULL
209 "#,
210 )
211 .bind(group_id)
212 .fetch_one(&self.pool)
213 .await
214 .map_err(backend_err)?;
215
216 Ok(count as u32)
217 }
218
219 async fn get_group_results(&self, group_id: &str) -> TaskResult<Vec<serde_json::Value>> {
220 let rows: Vec<(serde_json::Value,)> = sqlx::query_as(
221 r#"
222 SELECT result FROM kojin_groups
223 WHERE group_id = $1 AND completed = TRUE AND result IS NOT NULL
224 ORDER BY task_id
225 "#,
226 )
227 .bind(group_id)
228 .fetch_all(&self.pool)
229 .await
230 .map_err(backend_err)?;
231
232 Ok(rows.into_iter().map(|(r,)| r).collect())
233 }
234}
235
236#[cfg(all(test, feature = "integration-tests"))]
237mod tests {
238 use super::*;
239 use testcontainers::{ImageExt, runners::AsyncRunner};
240 use testcontainers_modules::postgres::Postgres;
241
242 async fn setup_backend() -> (
243 PostgresResultBackend,
244 testcontainers::ContainerAsync<Postgres>,
245 ) {
246 let container = Postgres::default().with_tag("16").start().await.unwrap();
247 let port = container.get_host_port_ipv4(5432).await.unwrap();
248 let url = format!("postgres://postgres:postgres@127.0.0.1:{port}/postgres");
249
250 let backend = PostgresResultBackend::connect(&url).await.unwrap();
251 backend.migrate().await.unwrap();
252 (backend, container)
253 }
254
255 #[tokio::test]
256 async fn store_and_get() {
257 let (backend, _container) = setup_backend().await;
258 let id = TaskId::new();
259 let value = serde_json::json!({"result": 42});
260
261 backend.store(&id, &value).await.unwrap();
262 let got = backend.get(&id).await.unwrap();
263 assert_eq!(got, Some(value));
264 }
265
266 #[tokio::test]
267 async fn get_missing() {
268 let (backend, _container) = setup_backend().await;
269 let id = TaskId::new();
270 assert_eq!(backend.get(&id).await.unwrap(), None);
271 }
272
273 #[tokio::test]
274 async fn delete_result() {
275 let (backend, _container) = setup_backend().await;
276 let id = TaskId::new();
277 backend.store(&id, &serde_json::json!(1)).await.unwrap();
278 backend.delete(&id).await.unwrap();
279 assert_eq!(backend.get(&id).await.unwrap(), None);
280 }
281
282 #[tokio::test]
283 async fn group_completion() {
284 let (backend, _container) = setup_backend().await;
285 backend.init_group("g1", 3).await.unwrap();
286
287 let id1 = TaskId::new();
288 let id2 = TaskId::new();
289 let id3 = TaskId::new();
290
291 let c1 = backend
292 .complete_group_member("g1", &id1, &serde_json::json!(1))
293 .await
294 .unwrap();
295 assert_eq!(c1, 1);
296 let c2 = backend
297 .complete_group_member("g1", &id2, &serde_json::json!(2))
298 .await
299 .unwrap();
300 assert_eq!(c2, 2);
301 let c3 = backend
302 .complete_group_member("g1", &id3, &serde_json::json!(3))
303 .await
304 .unwrap();
305 assert_eq!(c3, 3);
306
307 let results = backend.get_group_results("g1").await.unwrap();
308 assert_eq!(results.len(), 3);
309 }
310}