1use std::time::Duration;
2
3use chrono::Utc;
4use sqlx::PgPool;
5
6use forge_core::error::{ForgeError, Result};
7
8pub struct KvStore {
18 pool: PgPool,
19 namespace: &'static str,
20}
21
22impl KvStore {
23 pub fn new(pool: PgPool, namespace: &'static str) -> Self {
24 Self { pool, namespace }
25 }
26
27 fn prefixed_key(&self, key: &str) -> String {
28 format!("{}:{}", self.namespace, key)
29 }
30
31 pub async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
33 let full_key = self.prefixed_key(key);
34 let row = sqlx::query_scalar!(
35 r#"
36 SELECT value
37 FROM forge_kv
38 WHERE key = $1
39 AND (expires_at IS NULL OR expires_at > NOW())
40 "#,
41 full_key,
42 )
43 .fetch_optional(&self.pool)
44 .await
45 .map_err(ForgeError::Database)?;
46
47 Ok(row)
48 }
49
50 pub async fn set(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
52 let full_key = self.prefixed_key(key);
53 let expires_at = ttl.map(|d| Utc::now() + d);
54 sqlx::query!(
55 r#"
56 INSERT INTO forge_kv (key, value, expires_at, updated_at)
57 VALUES ($1, $2, $3, NOW())
58 ON CONFLICT (key)
59 DO UPDATE SET value = $2, expires_at = $3, updated_at = NOW()
60 "#,
61 full_key,
62 value,
63 expires_at,
64 )
65 .execute(&self.pool)
66 .await
67 .map_err(ForgeError::Database)?;
68
69 Ok(())
70 }
71
72 pub async fn set_if_absent(
78 &self,
79 key: &str,
80 value: &[u8],
81 ttl: Option<Duration>,
82 ) -> Result<bool> {
83 let full_key = self.prefixed_key(key);
84 let expires_at = ttl.map(|d| Utc::now() + d);
85 #[allow(clippy::disallowed_methods)]
88 let rows = sqlx::query(
89 r#"
90 INSERT INTO forge_kv (key, value, expires_at, updated_at)
91 VALUES ($1, $2, $3, NOW())
92 ON CONFLICT (key) DO UPDATE
93 SET value = $2, expires_at = $3, updated_at = NOW()
94 WHERE forge_kv.expires_at IS NOT NULL AND forge_kv.expires_at <= NOW()
95 "#,
96 )
97 .bind(&full_key)
98 .bind(value)
99 .bind(expires_at)
100 .execute(&self.pool)
101 .await
102 .map_err(ForgeError::Database)?
103 .rows_affected();
104
105 Ok(rows > 0)
106 }
107
108 pub async fn delete(&self, key: &str) -> Result<bool> {
110 let full_key = self.prefixed_key(key);
111 let result = sqlx::query!("DELETE FROM forge_kv WHERE key = $1", full_key)
112 .execute(&self.pool)
113 .await
114 .map_err(ForgeError::Database)?;
115
116 Ok(result.rows_affected() > 0)
117 }
118
119 pub async fn increment(&self, key: &str, delta: i64, ttl: Option<Duration>) -> Result<i64> {
127 let full_key = self.prefixed_key(key);
128 let expires_at = ttl.map(|d| Utc::now() + d);
129 #[allow(clippy::disallowed_methods)]
132 let row: (i64,) = sqlx::query_as(
133 r#"
134 INSERT INTO forge_kv_counters (key, value, expires_at, updated_at)
135 VALUES ($1, $2, $3, NOW())
136 ON CONFLICT (key)
137 DO UPDATE SET
138 value = CASE
139 WHEN forge_kv_counters.expires_at IS NOT NULL AND forge_kv_counters.expires_at <= NOW()
140 THEN $2
141 ELSE forge_kv_counters.value + $2
142 END,
143 expires_at = COALESCE($3, forge_kv_counters.expires_at),
144 updated_at = NOW()
145 RETURNING value
146 "#,
147 )
148 .bind(&full_key)
149 .bind(delta)
150 .bind(expires_at)
151 .fetch_one(&self.pool)
152 .await
153 .map_err(ForgeError::Database)?;
154
155 Ok(row.0)
156 }
157
158 pub async fn cleanup_expired(&self) -> Result<u64> {
160 let kv_deleted = sqlx::query!(
161 "DELETE FROM forge_kv WHERE expires_at IS NOT NULL AND expires_at <= NOW()"
162 )
163 .execute(&self.pool)
164 .await
165 .map_err(ForgeError::Database)?
166 .rows_affected();
167
168 let counter_deleted = sqlx::query!(
169 "DELETE FROM forge_kv_counters WHERE expires_at IS NOT NULL AND expires_at <= NOW()"
170 )
171 .execute(&self.pool)
172 .await
173 .map_err(ForgeError::Database)?
174 .rows_affected();
175
176 Ok(kv_deleted + counter_deleted)
177 }
178}
179
180#[cfg(test)]
181#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
182mod tests {
183 use super::*;
184
185 #[tokio::test]
186 async fn prefixed_key_combines_namespace_and_key() {
187 let pool = sqlx::postgres::PgPoolOptions::new()
188 .max_connections(1)
189 .connect_lazy("postgres://localhost/nonexistent")
190 .expect("connect_lazy never fails for a syntactically valid URL");
191
192 let store = KvStore::new(pool, "ratelimit");
193 assert_eq!(store.prefixed_key("user:42"), "ratelimit:user:42");
194 assert_eq!(store.prefixed_key(""), "ratelimit:");
195 }
196
197 #[tokio::test]
198 async fn prefixed_key_isolates_namespaces() {
199 let pool = sqlx::postgres::PgPoolOptions::new()
200 .max_connections(1)
201 .connect_lazy("postgres://localhost/nonexistent")
202 .expect("connect_lazy never fails for a syntactically valid URL");
203
204 let a = KvStore::new(pool.clone(), "subsystem_a");
207 let b = KvStore::new(pool, "subsystem_b");
208 assert_ne!(a.prefixed_key("shared"), b.prefixed_key("shared"));
209 }
210}
211
212#[cfg(all(test, feature = "testcontainers"))]
213#[allow(
214 clippy::unwrap_used,
215 clippy::indexing_slicing,
216 clippy::panic,
217 clippy::disallowed_methods
218)]
219mod integration_tests {
220 use super::*;
221 use forge_core::testing::{IsolatedTestDb, TestDatabase};
222
223 async fn setup_db(test_name: &str) -> IsolatedTestDb {
224 let base = TestDatabase::from_env()
225 .await
226 .expect("Failed to create test database");
227 let db = base
228 .isolated(test_name)
229 .await
230 .expect("Failed to create isolated db");
231 let system_sql = crate::pg::migration::get_all_system_sql();
232 db.run_sql(&system_sql)
233 .await
234 .expect("Failed to apply system schema");
235 db
236 }
237
238 #[tokio::test]
239 async fn get_returns_none_for_missing_key() {
240 let db = setup_db("kv_missing").await;
241 let kv = KvStore::new(db.pool().clone(), "test");
242 assert!(kv.get("nope").await.unwrap().is_none());
243 }
244
245 #[tokio::test]
246 async fn set_then_get_roundtrips_bytes() {
247 let db = setup_db("kv_roundtrip").await;
248 let kv = KvStore::new(db.pool().clone(), "test");
249 kv.set("greeting", b"hello, world", None).await.unwrap();
250 let got = kv.get("greeting").await.unwrap();
251 assert_eq!(got.as_deref(), Some(&b"hello, world"[..]));
252 }
253
254 #[tokio::test]
255 async fn set_overwrites_existing_value() {
256 let db = setup_db("kv_overwrite").await;
257 let kv = KvStore::new(db.pool().clone(), "test");
258 kv.set("k", b"v1", None).await.unwrap();
259 kv.set("k", b"v2", None).await.unwrap();
260 assert_eq!(kv.get("k").await.unwrap().as_deref(), Some(&b"v2"[..]));
261 }
262
263 #[tokio::test]
264 async fn expired_key_returns_none_before_cleanup() {
265 let db = setup_db("kv_expired_read").await;
266 let kv = KvStore::new(db.pool().clone(), "test");
267 kv.set("k", b"v", Some(Duration::from_millis(50)))
269 .await
270 .unwrap();
271 tokio::time::sleep(Duration::from_millis(150)).await;
272 assert!(
273 kv.get("k").await.unwrap().is_none(),
274 "expired key must not be returned"
275 );
276 }
277
278 #[tokio::test]
279 async fn delete_returns_true_when_key_existed() {
280 let db = setup_db("kv_delete_existing").await;
281 let kv = KvStore::new(db.pool().clone(), "test");
282 kv.set("k", b"v", None).await.unwrap();
283 assert!(kv.delete("k").await.unwrap());
284 assert!(kv.get("k").await.unwrap().is_none());
285 }
286
287 #[tokio::test]
288 async fn delete_returns_false_when_key_missing() {
289 let db = setup_db("kv_delete_missing").await;
290 let kv = KvStore::new(db.pool().clone(), "test");
291 assert!(!kv.delete("never_existed").await.unwrap());
292 }
293
294 #[tokio::test]
295 async fn set_if_absent_inserts_when_missing() {
296 let db = setup_db("kv_sia_insert").await;
297 let kv = KvStore::new(db.pool().clone(), "test");
298 let claimed = kv.set_if_absent("lock", b"owner", None).await.unwrap();
299 assert!(claimed);
300 assert_eq!(
301 kv.get("lock").await.unwrap().as_deref(),
302 Some(&b"owner"[..])
303 );
304 }
305
306 #[tokio::test]
307 async fn set_if_absent_refuses_when_present_and_fresh() {
308 let db = setup_db("kv_sia_present").await;
309 let kv = KvStore::new(db.pool().clone(), "test");
310 assert!(kv.set_if_absent("lock", b"alice", None).await.unwrap());
311 let second = kv.set_if_absent("lock", b"bob", None).await.unwrap();
312 assert!(!second, "second writer must lose the race");
313 assert_eq!(
314 kv.get("lock").await.unwrap().as_deref(),
315 Some(&b"alice"[..]),
316 "value must still belong to the first writer"
317 );
318 }
319
320 #[tokio::test]
321 async fn set_if_absent_succeeds_when_existing_value_expired() {
322 let db = setup_db("kv_sia_expired").await;
323 let kv = KvStore::new(db.pool().clone(), "test");
324 kv.set("lock", b"old", Some(Duration::from_millis(50)))
325 .await
326 .unwrap();
327 tokio::time::sleep(Duration::from_millis(150)).await;
328 assert!(
329 kv.set_if_absent("lock", b"new", None).await.unwrap(),
330 "expired key must be reclaimable"
331 );
332 assert_eq!(kv.get("lock").await.unwrap().as_deref(), Some(&b"new"[..]));
333 }
334
335 #[tokio::test]
336 async fn increment_creates_counter_at_delta() {
337 let db = setup_db("kv_inc_create").await;
338 let kv = KvStore::new(db.pool().clone(), "test");
339 let v = kv.increment("hits", 5, None).await.unwrap();
340 assert_eq!(v, 5);
341 }
342
343 #[tokio::test]
344 async fn increment_accumulates_across_calls() {
345 let db = setup_db("kv_inc_accum").await;
346 let kv = KvStore::new(db.pool().clone(), "test");
347 assert_eq!(kv.increment("hits", 3, None).await.unwrap(), 3);
348 assert_eq!(kv.increment("hits", 7, None).await.unwrap(), 10);
349 assert_eq!(kv.increment("hits", -4, None).await.unwrap(), 6);
350 }
351
352 #[tokio::test]
353 async fn increment_preserves_existing_ttl_when_none_passed() {
354 let db = setup_db("kv_inc_ttl_preserve").await;
355 let kv = KvStore::new(db.pool().clone(), "test");
356 kv.increment("hits", 1, Some(Duration::from_secs(3600)))
358 .await
359 .unwrap();
360 kv.increment("hits", 1, None).await.unwrap();
362 let row: (Option<chrono::DateTime<Utc>>,) =
363 sqlx::query_as("SELECT expires_at FROM forge_kv_counters WHERE key = $1")
364 .bind("test:hits")
365 .fetch_one(db.pool())
366 .await
367 .unwrap();
368 assert!(
369 row.0.is_some(),
370 "TTL must survive a None increment, got {:?}",
371 row.0
372 );
373 }
374
375 #[tokio::test]
376 async fn increment_resets_when_existing_counter_expired() {
377 let db = setup_db("kv_inc_reset_expired").await;
378 let kv = KvStore::new(db.pool().clone(), "test");
379 kv.increment("hits", 100, Some(Duration::from_millis(50)))
380 .await
381 .unwrap();
382 tokio::time::sleep(Duration::from_millis(150)).await;
383 let v = kv.increment("hits", 5, None).await.unwrap();
385 assert_eq!(v, 5, "expired counter must reset, not accumulate");
386 }
387
388 #[tokio::test]
389 async fn cleanup_expired_removes_expired_keys_and_counters() {
390 let db = setup_db("kv_cleanup").await;
391 let kv = KvStore::new(db.pool().clone(), "test");
392 kv.set("fresh", b"keep", Some(Duration::from_secs(3600)))
393 .await
394 .unwrap();
395 kv.set("stale", b"drop", Some(Duration::from_millis(50)))
396 .await
397 .unwrap();
398 kv.increment("counter_fresh", 1, Some(Duration::from_secs(3600)))
399 .await
400 .unwrap();
401 kv.increment("counter_stale", 1, Some(Duration::from_millis(50)))
402 .await
403 .unwrap();
404
405 tokio::time::sleep(Duration::from_millis(150)).await;
406 let removed = kv.cleanup_expired().await.unwrap();
407 assert_eq!(removed, 2, "cleanup must touch both stale rows");
408
409 assert!(kv.get("fresh").await.unwrap().is_some());
410 assert!(kv.get("stale").await.unwrap().is_none());
411
412 let fresh_counter: i64 =
413 sqlx::query_scalar("SELECT value FROM forge_kv_counters WHERE key = $1")
414 .bind("test:counter_fresh")
415 .fetch_one(db.pool())
416 .await
417 .unwrap();
418 assert_eq!(fresh_counter, 1);
419
420 let stale_exists: Option<i64> =
421 sqlx::query_scalar("SELECT value FROM forge_kv_counters WHERE key = $1")
422 .bind("test:counter_stale")
423 .fetch_optional(db.pool())
424 .await
425 .unwrap();
426 assert!(stale_exists.is_none(), "stale counter row must be deleted");
427 }
428
429 #[tokio::test]
430 async fn namespaced_stores_do_not_see_each_others_keys() {
431 let db = setup_db("kv_namespace_isolation").await;
432 let a = KvStore::new(db.pool().clone(), "subsys_a");
433 let b = KvStore::new(db.pool().clone(), "subsys_b");
434
435 a.set("shared", b"only-a", None).await.unwrap();
436 assert_eq!(
437 a.get("shared").await.unwrap().as_deref(),
438 Some(&b"only-a"[..])
439 );
440 assert!(
441 b.get("shared").await.unwrap().is_none(),
442 "namespace b must not see namespace a's key"
443 );
444 }
445}