Skip to main content

pg_queue/
cache.rs

1use crate::errors::Result;
2use chrono::{DateTime, Duration, Utc};
3use serde::{de::DeserializeOwned, Serialize};
4use sqlx::PgPool;
5
6/// Cache repository using PostgreSQL UNLOGGED table with per-query TTL check
7#[derive(Clone)]
8pub struct CacheRepository {
9    pool: PgPool,
10}
11
12impl CacheRepository {
13    pub fn new(pool: PgPool) -> Self {
14        Self { pool }
15    }
16
17    /// Get a cached value, returning None if not found or expired
18    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
19        let row: Option<(serde_json::Value,)> =
20            sqlx::query_as("SELECT value FROM cache_entries WHERE key = $1 AND expires_at > NOW()")
21                .bind(key)
22                .fetch_optional(&self.pool)
23                .await?;
24
25        match row {
26            Some((value,)) => {
27                let parsed: T = serde_json::from_value(value)?;
28                Ok(Some(parsed))
29            }
30            None => Ok(None),
31        }
32    }
33
34    /// Set a cached value with TTL in seconds
35    pub async fn set<T: Serialize>(&self, key: &str, value: &T, ttl_secs: u64) -> Result<()> {
36        let expires_at = Utc::now() + Duration::seconds(ttl_secs as i64);
37        self.set_with_expiry(key, value, expires_at).await
38    }
39
40    /// Set a cached value with explicit expiration time
41    pub async fn set_with_expiry<T: Serialize>(
42        &self,
43        key: &str,
44        value: &T,
45        expires_at: DateTime<Utc>,
46    ) -> Result<()> {
47        let json = serde_json::to_value(value)?;
48
49        sqlx::query(
50            r#"
51            INSERT INTO cache_entries (key, value, expires_at)
52            VALUES ($1, $2, $3)
53            ON CONFLICT (key) DO UPDATE SET value = $2, expires_at = $3
54            "#,
55        )
56        .bind(key)
57        .bind(json)
58        .bind(expires_at)
59        .execute(&self.pool)
60        .await?;
61
62        Ok(())
63    }
64
65    /// Delete a cached entry
66    pub async fn delete(&self, key: &str) -> Result<bool> {
67        let result = sqlx::query("DELETE FROM cache_entries WHERE key = $1")
68            .bind(key)
69            .execute(&self.pool)
70            .await?;
71
72        Ok(result.rows_affected() > 0)
73    }
74
75    /// Delete all expired entries (cleanup)
76    pub async fn cleanup_expired(&self) -> Result<u64> {
77        let result = sqlx::query("DELETE FROM cache_entries WHERE expires_at <= NOW()")
78            .execute(&self.pool)
79            .await?;
80
81        Ok(result.rows_affected())
82    }
83
84    /// Get or set a cached value using a fallback function
85    pub async fn get_or_set<T, F, Fut>(&self, key: &str, ttl_secs: u64, fetch: F) -> Result<T>
86    where
87        T: Serialize + DeserializeOwned,
88        F: FnOnce() -> Fut,
89        Fut: std::future::Future<Output = Result<T>>,
90    {
91        if let Some(cached) = self.get::<T>(key).await? {
92            return Ok(cached);
93        }
94
95        let value = fetch().await?;
96        self.set(key, &value, ttl_secs).await?;
97        Ok(value)
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    #[test]
104    fn test_cache_key_format() {
105        let key = format!("user:{}:profile", 42);
106        assert_eq!(key, "user:42:profile");
107    }
108}