Skip to main content

kojin_postgres/
result_backend.rs

1use async_trait::async_trait;
2use sqlx::PgPool;
3use std::time::Duration;
4
5use kojin_core::error::{KojinError, TaskResult};
6use kojin_core::result_backend::ResultBackend;
7use kojin_core::task_id::TaskId;
8
9fn backend_err(e: impl std::fmt::Display) -> KojinError {
10    KojinError::ResultBackend(e.to_string())
11}
12
13/// PostgreSQL-backed result storage.
14///
15/// Results are stored in the `kojin_results` table as JSONB with a configurable
16/// TTL (default 24 hours). Group state is tracked in `kojin_groups`.
17///
18/// **Important:** call [`migrate()`](Self::migrate) before first use to create
19/// the required tables and indexes.
20pub struct PostgresResultBackend {
21    pool: PgPool,
22    ttl: Duration,
23}
24
25impl PostgresResultBackend {
26    /// Create a new PostgreSQL result backend from an existing connection pool.
27    ///
28    /// The default result TTL is 24 hours; override with [`with_ttl`](Self::with_ttl).
29    /// You must call [`migrate()`](Self::migrate) before storing results.
30    pub fn new(pool: PgPool) -> Self {
31        Self {
32            pool,
33            ttl: Duration::from_secs(86400), // 24h default
34        }
35    }
36
37    /// Connect to PostgreSQL by URL and create the backend.
38    ///
39    /// This is a convenience wrapper around `PgPool::connect` + [`new`](Self::new).
40    pub async fn connect(url: &str) -> TaskResult<Self> {
41        let pool = PgPool::connect(url).await.map_err(backend_err)?;
42        Ok(Self::new(pool))
43    }
44
45    /// Override the result TTL (time-to-live).
46    ///
47    /// Expired results are not automatically deleted; they are filtered out at
48    /// read time. Run periodic cleanup queries against `kojin_results.expires_at`
49    /// if storage reclamation is needed. Defaults to 24 hours if not called.
50    pub fn with_ttl(mut self, ttl: Duration) -> Self {
51        self.ttl = ttl;
52        self
53    }
54
55    /// Run migrations to create the required tables (`kojin_results`, `kojin_groups`)
56    /// and indexes. Safe to call multiple times — all statements use `IF NOT EXISTS`.
57    pub async fn migrate(&self) -> TaskResult<()> {
58        sqlx::query(
59            r#"
60            CREATE TABLE IF NOT EXISTS kojin_results (
61                task_id TEXT PRIMARY KEY,
62                result JSONB NOT NULL,
63                created_at TIMESTAMPTZ DEFAULT NOW(),
64                expires_at TIMESTAMPTZ
65            )
66            "#,
67        )
68        .execute(&self.pool)
69        .await
70        .map_err(backend_err)?;
71
72        sqlx::query(
73            r#"
74            CREATE TABLE IF NOT EXISTS kojin_groups (
75                group_id TEXT NOT NULL,
76                task_id TEXT NOT NULL,
77                result JSONB,
78                completed BOOLEAN DEFAULT FALSE,
79                PRIMARY KEY (group_id, task_id)
80            )
81            "#,
82        )
83        .execute(&self.pool)
84        .await
85        .map_err(backend_err)?;
86
87        // Index for cleanup of expired results
88        sqlx::query(
89            r#"
90            CREATE INDEX IF NOT EXISTS idx_kojin_results_expires
91            ON kojin_results (expires_at)
92            WHERE expires_at IS NOT NULL
93            "#,
94        )
95        .execute(&self.pool)
96        .await
97        .map_err(backend_err)?;
98
99        Ok(())
100    }
101}
102
103#[async_trait]
104impl ResultBackend for PostgresResultBackend {
105    async fn store(&self, id: &TaskId, result: &serde_json::Value) -> TaskResult<()> {
106        let expires_at = chrono::Utc::now() + chrono::Duration::seconds(self.ttl.as_secs() as i64);
107
108        sqlx::query(
109            r#"
110            INSERT INTO kojin_results (task_id, result, expires_at)
111            VALUES ($1, $2, $3)
112            ON CONFLICT (task_id) DO UPDATE SET result = $2, expires_at = $3
113            "#,
114        )
115        .bind(id.to_string())
116        .bind(result)
117        .bind(expires_at)
118        .execute(&self.pool)
119        .await
120        .map_err(backend_err)?;
121
122        Ok(())
123    }
124
125    async fn get(&self, id: &TaskId) -> TaskResult<Option<serde_json::Value>> {
126        let row: Option<(serde_json::Value,)> = sqlx::query_as(
127            r#"
128            SELECT result FROM kojin_results
129            WHERE task_id = $1 AND (expires_at IS NULL OR expires_at > NOW())
130            "#,
131        )
132        .bind(id.to_string())
133        .fetch_optional(&self.pool)
134        .await
135        .map_err(backend_err)?;
136
137        Ok(row.map(|(r,)| r))
138    }
139
140    async fn wait(&self, id: &TaskId, timeout: Duration) -> TaskResult<serde_json::Value> {
141        let deadline = tokio::time::Instant::now() + timeout;
142        loop {
143            if let Some(result) = self.get(id).await? {
144                return Ok(result);
145            }
146            if tokio::time::Instant::now() >= deadline {
147                return Err(KojinError::Timeout(timeout));
148            }
149            tokio::time::sleep(Duration::from_millis(100)).await;
150        }
151    }
152
153    async fn delete(&self, id: &TaskId) -> TaskResult<()> {
154        sqlx::query("DELETE FROM kojin_results WHERE task_id = $1")
155            .bind(id.to_string())
156            .execute(&self.pool)
157            .await
158            .map_err(backend_err)?;
159        Ok(())
160    }
161
162    async fn init_group(&self, group_id: &str, total: u32) -> TaskResult<()> {
163        // Pre-create placeholder rows for the group
164        // We use a single INSERT with generate_series for efficiency
165        for i in 0..total {
166            let placeholder_id = format!("{group_id}:placeholder:{i}");
167            sqlx::query(
168                r#"
169                INSERT INTO kojin_groups (group_id, task_id, completed)
170                VALUES ($1, $2, FALSE)
171                ON CONFLICT (group_id, task_id) DO NOTHING
172                "#,
173            )
174            .bind(group_id)
175            .bind(&placeholder_id)
176            .execute(&self.pool)
177            .await
178            .map_err(backend_err)?;
179        }
180        Ok(())
181    }
182
183    async fn complete_group_member(
184        &self,
185        group_id: &str,
186        task_id: &TaskId,
187        result: &serde_json::Value,
188    ) -> TaskResult<u32> {
189        // Upsert the actual task result
190        sqlx::query(
191            r#"
192            INSERT INTO kojin_groups (group_id, task_id, result, completed)
193            VALUES ($1, $2, $3, TRUE)
194            ON CONFLICT (group_id, task_id) DO UPDATE SET result = $3, completed = TRUE
195            "#,
196        )
197        .bind(group_id)
198        .bind(task_id.to_string())
199        .bind(result)
200        .execute(&self.pool)
201        .await
202        .map_err(backend_err)?;
203
204        // Count completed members (only those with completed = TRUE and actual results)
205        let (count,): (i64,) = sqlx::query_as(
206            r#"
207            SELECT COUNT(*) FROM kojin_groups
208            WHERE group_id = $1 AND completed = TRUE AND result IS NOT NULL
209            "#,
210        )
211        .bind(group_id)
212        .fetch_one(&self.pool)
213        .await
214        .map_err(backend_err)?;
215
216        Ok(count as u32)
217    }
218
219    async fn get_group_results(&self, group_id: &str) -> TaskResult<Vec<serde_json::Value>> {
220        let rows: Vec<(serde_json::Value,)> = sqlx::query_as(
221            r#"
222            SELECT result FROM kojin_groups
223            WHERE group_id = $1 AND completed = TRUE AND result IS NOT NULL
224            ORDER BY task_id
225            "#,
226        )
227        .bind(group_id)
228        .fetch_all(&self.pool)
229        .await
230        .map_err(backend_err)?;
231
232        Ok(rows.into_iter().map(|(r,)| r).collect())
233    }
234}
235
236#[cfg(all(test, feature = "integration-tests"))]
237mod tests {
238    use super::*;
239    use testcontainers::{ImageExt, runners::AsyncRunner};
240    use testcontainers_modules::postgres::Postgres;
241
242    async fn setup_backend() -> (
243        PostgresResultBackend,
244        testcontainers::ContainerAsync<Postgres>,
245    ) {
246        let container = Postgres::default().with_tag("16").start().await.unwrap();
247        let port = container.get_host_port_ipv4(5432).await.unwrap();
248        let url = format!("postgres://postgres:postgres@127.0.0.1:{port}/postgres");
249
250        let backend = PostgresResultBackend::connect(&url).await.unwrap();
251        backend.migrate().await.unwrap();
252        (backend, container)
253    }
254
255    #[tokio::test]
256    async fn store_and_get() {
257        let (backend, _container) = setup_backend().await;
258        let id = TaskId::new();
259        let value = serde_json::json!({"result": 42});
260
261        backend.store(&id, &value).await.unwrap();
262        let got = backend.get(&id).await.unwrap();
263        assert_eq!(got, Some(value));
264    }
265
266    #[tokio::test]
267    async fn get_missing() {
268        let (backend, _container) = setup_backend().await;
269        let id = TaskId::new();
270        assert_eq!(backend.get(&id).await.unwrap(), None);
271    }
272
273    #[tokio::test]
274    async fn delete_result() {
275        let (backend, _container) = setup_backend().await;
276        let id = TaskId::new();
277        backend.store(&id, &serde_json::json!(1)).await.unwrap();
278        backend.delete(&id).await.unwrap();
279        assert_eq!(backend.get(&id).await.unwrap(), None);
280    }
281
282    #[tokio::test]
283    async fn group_completion() {
284        let (backend, _container) = setup_backend().await;
285        backend.init_group("g1", 3).await.unwrap();
286
287        let id1 = TaskId::new();
288        let id2 = TaskId::new();
289        let id3 = TaskId::new();
290
291        let c1 = backend
292            .complete_group_member("g1", &id1, &serde_json::json!(1))
293            .await
294            .unwrap();
295        assert_eq!(c1, 1);
296        let c2 = backend
297            .complete_group_member("g1", &id2, &serde_json::json!(2))
298            .await
299            .unwrap();
300        assert_eq!(c2, 2);
301        let c3 = backend
302            .complete_group_member("g1", &id3, &serde_json::json!(3))
303            .await
304            .unwrap();
305        assert_eq!(c3, 3);
306
307        let results = backend.get_group_results("g1").await.unwrap();
308        assert_eq!(results.len(), 3);
309    }
310}