1use crate::errors::Result;
2use chrono::{DateTime, Duration, Utc};
3use serde::{de::DeserializeOwned, Serialize};
4use sqlx::PgPool;
5
6#[derive(Clone)]
8pub struct CacheRepository {
9 pool: PgPool,
10}
11
12impl CacheRepository {
13 pub fn new(pool: PgPool) -> Self {
14 Self { pool }
15 }
16
17 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
19 let row: Option<(serde_json::Value,)> =
20 sqlx::query_as("SELECT value FROM cache_entries WHERE key = $1 AND expires_at > NOW()")
21 .bind(key)
22 .fetch_optional(&self.pool)
23 .await?;
24
25 match row {
26 Some((value,)) => {
27 let parsed: T = serde_json::from_value(value)?;
28 Ok(Some(parsed))
29 }
30 None => Ok(None),
31 }
32 }
33
34 pub async fn set<T: Serialize>(&self, key: &str, value: &T, ttl_secs: u64) -> Result<()> {
36 let expires_at = Utc::now() + Duration::seconds(ttl_secs as i64);
37 self.set_with_expiry(key, value, expires_at).await
38 }
39
40 pub async fn set_with_expiry<T: Serialize>(
42 &self,
43 key: &str,
44 value: &T,
45 expires_at: DateTime<Utc>,
46 ) -> Result<()> {
47 let json = serde_json::to_value(value)?;
48
49 sqlx::query(
50 r#"
51 INSERT INTO cache_entries (key, value, expires_at)
52 VALUES ($1, $2, $3)
53 ON CONFLICT (key) DO UPDATE SET value = $2, expires_at = $3
54 "#,
55 )
56 .bind(key)
57 .bind(json)
58 .bind(expires_at)
59 .execute(&self.pool)
60 .await?;
61
62 Ok(())
63 }
64
65 pub async fn delete(&self, key: &str) -> Result<bool> {
67 let result = sqlx::query("DELETE FROM cache_entries WHERE key = $1")
68 .bind(key)
69 .execute(&self.pool)
70 .await?;
71
72 Ok(result.rows_affected() > 0)
73 }
74
75 pub async fn cleanup_expired(&self) -> Result<u64> {
77 let result = sqlx::query("DELETE FROM cache_entries WHERE expires_at <= NOW()")
78 .execute(&self.pool)
79 .await?;
80
81 Ok(result.rows_affected())
82 }
83
84 pub async fn get_or_set<T, F, Fut>(&self, key: &str, ttl_secs: u64, fetch: F) -> Result<T>
86 where
87 T: Serialize + DeserializeOwned,
88 F: FnOnce() -> Fut,
89 Fut: std::future::Future<Output = Result<T>>,
90 {
91 if let Some(cached) = self.get::<T>(key).await? {
92 return Ok(cached);
93 }
94
95 let value = fetch().await?;
96 self.set(key, &value, ttl_secs).await?;
97 Ok(value)
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 #[test]
104 fn test_cache_key_format() {
105 let key = format!("user:{}:profile", 42);
106 assert_eq!(key, "user:42:profile");
107 }
108}