Skip to main content

forge_runtime/kv/
store.rs

1use std::time::Duration;
2
3use chrono::Utc;
4use sqlx::PgPool;
5
6use forge_core::error::{ForgeError, Result};
7
8/// PostgreSQL-backed key-value store.
9///
10/// Provides a simple get/set/delete/set_if_absent/increment API over
11/// `forge_kv` and `forge_kv_counters` tables. All operations are atomic.
12/// TTLs are enforced both at read time (expired keys return `None`) and
13/// via periodic cleanup.
14///
15/// Keys are automatically namespaced with the configured prefix to prevent
16/// collisions between different subsystems sharing the same database.
17pub struct KvStore {
18    pool: PgPool,
19    namespace: &'static str,
20}
21
22impl KvStore {
23    pub fn new(pool: PgPool, namespace: &'static str) -> Self {
24        Self { pool, namespace }
25    }
26
27    fn prefixed_key(&self, key: &str) -> String {
28        format!("{}:{}", self.namespace, key)
29    }
30
31    /// Get a value by key. Returns `None` if the key doesn't exist or is expired.
32    pub async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
33        let full_key = self.prefixed_key(key);
34        let row = sqlx::query_scalar!(
35            r#"
36            SELECT value
37            FROM forge_kv
38            WHERE key = $1
39              AND (expires_at IS NULL OR expires_at > NOW())
40            "#,
41            full_key,
42        )
43        .fetch_optional(&self.pool)
44        .await
45        .map_err(ForgeError::Database)?;
46
47        Ok(row)
48    }
49
50    /// Set a key to a value. Overwrites any existing value.
51    pub async fn set(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
52        let full_key = self.prefixed_key(key);
53        let expires_at = ttl.map(|d| Utc::now() + d);
54        sqlx::query!(
55            r#"
56            INSERT INTO forge_kv (key, value, expires_at, updated_at)
57            VALUES ($1, $2, $3, NOW())
58            ON CONFLICT (key)
59            DO UPDATE SET value = $2, expires_at = $3, updated_at = NOW()
60            "#,
61            full_key,
62            value,
63            expires_at,
64        )
65        .execute(&self.pool)
66        .await
67        .map_err(ForgeError::Database)?;
68
69        Ok(())
70    }
71
72    /// Set a key only if it doesn't already exist (or is expired).
73    /// Returns `true` if the key was set, `false` if it already existed.
74    ///
75    /// Uses `ON CONFLICT DO UPDATE ... WHERE` to atomically treat expired rows
76    /// as absent within a single statement (no CTE snapshot isolation issues).
77    pub async fn set_if_absent(
78        &self,
79        key: &str,
80        value: &[u8],
81        ttl: Option<Duration>,
82    ) -> Result<bool> {
83        let full_key = self.prefixed_key(key);
84        let expires_at = ttl.map(|d| Utc::now() + d);
85        // ON CONFLICT WHERE treats expired rows as absent atomically.
86        // Convert to query!() after next `cargo sqlx prepare`.
87        #[allow(clippy::disallowed_methods)]
88        let rows = sqlx::query(
89            r#"
90            INSERT INTO forge_kv (key, value, expires_at, updated_at)
91            VALUES ($1, $2, $3, NOW())
92            ON CONFLICT (key) DO UPDATE
93                SET value = $2, expires_at = $3, updated_at = NOW()
94                WHERE forge_kv.expires_at IS NOT NULL AND forge_kv.expires_at <= NOW()
95            "#,
96        )
97        .bind(&full_key)
98        .bind(value)
99        .bind(expires_at)
100        .execute(&self.pool)
101        .await
102        .map_err(ForgeError::Database)?
103        .rows_affected();
104
105        Ok(rows > 0)
106    }
107
108    /// Delete a key. Returns `true` if the key existed.
109    pub async fn delete(&self, key: &str) -> Result<bool> {
110        let full_key = self.prefixed_key(key);
111        let result = sqlx::query!("DELETE FROM forge_kv WHERE key = $1", full_key)
112            .execute(&self.pool)
113            .await
114            .map_err(ForgeError::Database)?;
115
116        Ok(result.rows_affected() > 0)
117    }
118
119    /// Atomically increment a counter by `delta`. Creates the counter at 0 if
120    /// it doesn't exist. Returns the new value. When `ttl` is `None`, an
121    /// existing counter's TTL is preserved (pass `Some` to override it).
122    /// Expired counters are treated as non-existent (value resets to delta).
123    ///
124    /// Uses `ON CONFLICT DO UPDATE ... WHERE` to handle expired rows atomically
125    /// without CTE snapshot isolation issues.
126    pub async fn increment(&self, key: &str, delta: i64, ttl: Option<Duration>) -> Result<i64> {
127        let full_key = self.prefixed_key(key);
128        let expires_at = ttl.map(|d| Utc::now() + d);
129        // Expired counters reset to delta rather than accumulating.
130        // Convert to query_scalar!() after next `cargo sqlx prepare`.
131        #[allow(clippy::disallowed_methods)]
132        let row: (i64,) = sqlx::query_as(
133            r#"
134            INSERT INTO forge_kv_counters (key, value, expires_at, updated_at)
135            VALUES ($1, $2, $3, NOW())
136            ON CONFLICT (key)
137            DO UPDATE SET
138                value = CASE
139                    WHEN forge_kv_counters.expires_at IS NOT NULL AND forge_kv_counters.expires_at <= NOW()
140                    THEN $2
141                    ELSE forge_kv_counters.value + $2
142                END,
143                expires_at = COALESCE($3, forge_kv_counters.expires_at),
144                updated_at = NOW()
145            RETURNING value
146            "#,
147        )
148        .bind(&full_key)
149        .bind(delta)
150        .bind(expires_at)
151        .fetch_one(&self.pool)
152        .await
153        .map_err(ForgeError::Database)?;
154
155        Ok(row.0)
156    }
157
158    /// Remove expired keys from both tables. Returns total rows cleaned up.
159    pub async fn cleanup_expired(&self) -> Result<u64> {
160        let kv_deleted = sqlx::query!(
161            "DELETE FROM forge_kv WHERE expires_at IS NOT NULL AND expires_at <= NOW()"
162        )
163        .execute(&self.pool)
164        .await
165        .map_err(ForgeError::Database)?
166        .rows_affected();
167
168        let counter_deleted = sqlx::query!(
169            "DELETE FROM forge_kv_counters WHERE expires_at IS NOT NULL AND expires_at <= NOW()"
170        )
171        .execute(&self.pool)
172        .await
173        .map_err(ForgeError::Database)?
174        .rows_affected();
175
176        Ok(kv_deleted + counter_deleted)
177    }
178}
179
180#[cfg(test)]
181#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
182mod tests {
183    use super::*;
184
185    #[tokio::test]
186    async fn prefixed_key_combines_namespace_and_key() {
187        let pool = sqlx::postgres::PgPoolOptions::new()
188            .max_connections(1)
189            .connect_lazy("postgres://localhost/nonexistent")
190            .expect("connect_lazy never fails for a syntactically valid URL");
191
192        let store = KvStore::new(pool, "ratelimit");
193        assert_eq!(store.prefixed_key("user:42"), "ratelimit:user:42");
194        assert_eq!(store.prefixed_key(""), "ratelimit:");
195    }
196
197    #[tokio::test]
198    async fn prefixed_key_isolates_namespaces() {
199        let pool = sqlx::postgres::PgPoolOptions::new()
200            .max_connections(1)
201            .connect_lazy("postgres://localhost/nonexistent")
202            .expect("connect_lazy never fails for a syntactically valid URL");
203
204        // Same logical key under different namespaces produces distinct
205        // physical keys — the property the namespace exists to guarantee.
206        let a = KvStore::new(pool.clone(), "subsystem_a");
207        let b = KvStore::new(pool, "subsystem_b");
208        assert_ne!(a.prefixed_key("shared"), b.prefixed_key("shared"));
209    }
210}
211
212#[cfg(all(test, feature = "testcontainers"))]
213#[allow(
214    clippy::unwrap_used,
215    clippy::indexing_slicing,
216    clippy::panic,
217    clippy::disallowed_methods
218)]
219mod integration_tests {
220    use super::*;
221    use forge_core::testing::{IsolatedTestDb, TestDatabase};
222
223    async fn setup_db(test_name: &str) -> IsolatedTestDb {
224        let base = TestDatabase::from_env()
225            .await
226            .expect("Failed to create test database");
227        let db = base
228            .isolated(test_name)
229            .await
230            .expect("Failed to create isolated db");
231        let system_sql = crate::pg::migration::get_all_system_sql();
232        db.run_sql(&system_sql)
233            .await
234            .expect("Failed to apply system schema");
235        db
236    }
237
238    #[tokio::test]
239    async fn get_returns_none_for_missing_key() {
240        let db = setup_db("kv_missing").await;
241        let kv = KvStore::new(db.pool().clone(), "test");
242        assert!(kv.get("nope").await.unwrap().is_none());
243    }
244
245    #[tokio::test]
246    async fn set_then_get_roundtrips_bytes() {
247        let db = setup_db("kv_roundtrip").await;
248        let kv = KvStore::new(db.pool().clone(), "test");
249        kv.set("greeting", b"hello, world", None).await.unwrap();
250        let got = kv.get("greeting").await.unwrap();
251        assert_eq!(got.as_deref(), Some(&b"hello, world"[..]));
252    }
253
254    #[tokio::test]
255    async fn set_overwrites_existing_value() {
256        let db = setup_db("kv_overwrite").await;
257        let kv = KvStore::new(db.pool().clone(), "test");
258        kv.set("k", b"v1", None).await.unwrap();
259        kv.set("k", b"v2", None).await.unwrap();
260        assert_eq!(kv.get("k").await.unwrap().as_deref(), Some(&b"v2"[..]));
261    }
262
263    #[tokio::test]
264    async fn expired_key_returns_none_before_cleanup() {
265        let db = setup_db("kv_expired_read").await;
266        let kv = KvStore::new(db.pool().clone(), "test");
267        // Set with a tiny TTL, sleep past it.
268        kv.set("k", b"v", Some(Duration::from_millis(50)))
269            .await
270            .unwrap();
271        tokio::time::sleep(Duration::from_millis(150)).await;
272        assert!(
273            kv.get("k").await.unwrap().is_none(),
274            "expired key must not be returned"
275        );
276    }
277
278    #[tokio::test]
279    async fn delete_returns_true_when_key_existed() {
280        let db = setup_db("kv_delete_existing").await;
281        let kv = KvStore::new(db.pool().clone(), "test");
282        kv.set("k", b"v", None).await.unwrap();
283        assert!(kv.delete("k").await.unwrap());
284        assert!(kv.get("k").await.unwrap().is_none());
285    }
286
287    #[tokio::test]
288    async fn delete_returns_false_when_key_missing() {
289        let db = setup_db("kv_delete_missing").await;
290        let kv = KvStore::new(db.pool().clone(), "test");
291        assert!(!kv.delete("never_existed").await.unwrap());
292    }
293
294    #[tokio::test]
295    async fn set_if_absent_inserts_when_missing() {
296        let db = setup_db("kv_sia_insert").await;
297        let kv = KvStore::new(db.pool().clone(), "test");
298        let claimed = kv.set_if_absent("lock", b"owner", None).await.unwrap();
299        assert!(claimed);
300        assert_eq!(
301            kv.get("lock").await.unwrap().as_deref(),
302            Some(&b"owner"[..])
303        );
304    }
305
306    #[tokio::test]
307    async fn set_if_absent_refuses_when_present_and_fresh() {
308        let db = setup_db("kv_sia_present").await;
309        let kv = KvStore::new(db.pool().clone(), "test");
310        assert!(kv.set_if_absent("lock", b"alice", None).await.unwrap());
311        let second = kv.set_if_absent("lock", b"bob", None).await.unwrap();
312        assert!(!second, "second writer must lose the race");
313        assert_eq!(
314            kv.get("lock").await.unwrap().as_deref(),
315            Some(&b"alice"[..]),
316            "value must still belong to the first writer"
317        );
318    }
319
320    #[tokio::test]
321    async fn set_if_absent_succeeds_when_existing_value_expired() {
322        let db = setup_db("kv_sia_expired").await;
323        let kv = KvStore::new(db.pool().clone(), "test");
324        kv.set("lock", b"old", Some(Duration::from_millis(50)))
325            .await
326            .unwrap();
327        tokio::time::sleep(Duration::from_millis(150)).await;
328        assert!(
329            kv.set_if_absent("lock", b"new", None).await.unwrap(),
330            "expired key must be reclaimable"
331        );
332        assert_eq!(kv.get("lock").await.unwrap().as_deref(), Some(&b"new"[..]));
333    }
334
335    #[tokio::test]
336    async fn increment_creates_counter_at_delta() {
337        let db = setup_db("kv_inc_create").await;
338        let kv = KvStore::new(db.pool().clone(), "test");
339        let v = kv.increment("hits", 5, None).await.unwrap();
340        assert_eq!(v, 5);
341    }
342
343    #[tokio::test]
344    async fn increment_accumulates_across_calls() {
345        let db = setup_db("kv_inc_accum").await;
346        let kv = KvStore::new(db.pool().clone(), "test");
347        assert_eq!(kv.increment("hits", 3, None).await.unwrap(), 3);
348        assert_eq!(kv.increment("hits", 7, None).await.unwrap(), 10);
349        assert_eq!(kv.increment("hits", -4, None).await.unwrap(), 6);
350    }
351
352    #[tokio::test]
353    async fn increment_preserves_existing_ttl_when_none_passed() {
354        let db = setup_db("kv_inc_ttl_preserve").await;
355        let kv = KvStore::new(db.pool().clone(), "test");
356        // Seed with a 1h TTL.
357        kv.increment("hits", 1, Some(Duration::from_secs(3600)))
358            .await
359            .unwrap();
360        // Increment without specifying TTL — must keep the existing one.
361        kv.increment("hits", 1, None).await.unwrap();
362        let row: (Option<chrono::DateTime<Utc>>,) =
363            sqlx::query_as("SELECT expires_at FROM forge_kv_counters WHERE key = $1")
364                .bind("test:hits")
365                .fetch_one(db.pool())
366                .await
367                .unwrap();
368        assert!(
369            row.0.is_some(),
370            "TTL must survive a None increment, got {:?}",
371            row.0
372        );
373    }
374
375    #[tokio::test]
376    async fn increment_resets_when_existing_counter_expired() {
377        let db = setup_db("kv_inc_reset_expired").await;
378        let kv = KvStore::new(db.pool().clone(), "test");
379        kv.increment("hits", 100, Some(Duration::from_millis(50)))
380            .await
381            .unwrap();
382        tokio::time::sleep(Duration::from_millis(150)).await;
383        // Expired counter should reset to delta, not accumulate.
384        let v = kv.increment("hits", 5, None).await.unwrap();
385        assert_eq!(v, 5, "expired counter must reset, not accumulate");
386    }
387
388    #[tokio::test]
389    async fn cleanup_expired_removes_expired_keys_and_counters() {
390        let db = setup_db("kv_cleanup").await;
391        let kv = KvStore::new(db.pool().clone(), "test");
392        kv.set("fresh", b"keep", Some(Duration::from_secs(3600)))
393            .await
394            .unwrap();
395        kv.set("stale", b"drop", Some(Duration::from_millis(50)))
396            .await
397            .unwrap();
398        kv.increment("counter_fresh", 1, Some(Duration::from_secs(3600)))
399            .await
400            .unwrap();
401        kv.increment("counter_stale", 1, Some(Duration::from_millis(50)))
402            .await
403            .unwrap();
404
405        tokio::time::sleep(Duration::from_millis(150)).await;
406        let removed = kv.cleanup_expired().await.unwrap();
407        assert_eq!(removed, 2, "cleanup must touch both stale rows");
408
409        assert!(kv.get("fresh").await.unwrap().is_some());
410        assert!(kv.get("stale").await.unwrap().is_none());
411
412        let fresh_counter: i64 =
413            sqlx::query_scalar("SELECT value FROM forge_kv_counters WHERE key = $1")
414                .bind("test:counter_fresh")
415                .fetch_one(db.pool())
416                .await
417                .unwrap();
418        assert_eq!(fresh_counter, 1);
419
420        let stale_exists: Option<i64> =
421            sqlx::query_scalar("SELECT value FROM forge_kv_counters WHERE key = $1")
422                .bind("test:counter_stale")
423                .fetch_optional(db.pool())
424                .await
425                .unwrap();
426        assert!(stale_exists.is_none(), "stale counter row must be deleted");
427    }
428
429    #[tokio::test]
430    async fn namespaced_stores_do_not_see_each_others_keys() {
431        let db = setup_db("kv_namespace_isolation").await;
432        let a = KvStore::new(db.pool().clone(), "subsys_a");
433        let b = KvStore::new(db.pool().clone(), "subsys_b");
434
435        a.set("shared", b"only-a", None).await.unwrap();
436        assert_eq!(
437            a.get("shared").await.unwrap().as_deref(),
438            Some(&b"only-a"[..])
439        );
440        assert!(
441            b.get("shared").await.unwrap().is_none(),
442            "namespace b must not see namespace a's key"
443        );
444    }
445}