Skip to main content

synaptic_postgres/
cache.rs

1use async_trait::async_trait;
2use sqlx::PgPool;
3use synaptic_core::{validate_table_name, ChatResponse, SynapticError};
4
5/// Configuration for [`PgCache`].
6#[derive(Debug, Clone)]
7pub struct PgCacheConfig {
8    /// Name of the PostgreSQL table used to store cached LLM responses.
9    pub table_name: String,
10    /// Optional TTL in seconds. When set, cached entries older than this are
11    /// treated as expired and excluded from lookups.
12    pub ttl: Option<u64>,
13}
14
15impl PgCacheConfig {
16    /// Create a new configuration with the given table name.
17    pub fn new(table_name: impl Into<String>) -> Self {
18        Self {
19            table_name: table_name.into(),
20            ttl: None,
21        }
22    }
23
24    /// Set the TTL (time-to-live) in seconds for cached entries.
25    pub fn with_ttl(mut self, seconds: u64) -> Self {
26        self.ttl = Some(seconds);
27        self
28    }
29}
30
31/// PostgreSQL-backed implementation of the [`LlmCache`](synaptic_core::LlmCache) trait.
32///
33/// Stores serialized [`ChatResponse`] values in a PostgreSQL table with optional
34/// TTL expiration. Call [`initialize`](PgCache::initialize) once after construction
35/// to create the backing table (idempotent).
36///
37/// # Example
38///
39/// ```rust,no_run
40/// use sqlx::postgres::PgPoolOptions;
41/// use synaptic_postgres::{PgCache, PgCacheConfig};
42///
43/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
44/// let pool = PgPoolOptions::new()
45///     .max_connections(5)
46///     .connect("postgres://user:pass@localhost/mydb")
47///     .await?;
48///
49/// let config = PgCacheConfig::new("llm_cache").with_ttl(3600);
50/// let cache = PgCache::new(pool, config);
51/// cache.initialize().await?;
52/// # Ok(())
53/// # }
54/// ```
55pub struct PgCache {
56    pool: PgPool,
57    config: PgCacheConfig,
58}
59
60impl PgCache {
61    /// Create a new `PgCache` from an existing connection pool and config.
62    pub fn new(pool: PgPool, config: PgCacheConfig) -> Self {
63        Self { pool, config }
64    }
65
66    /// Ensure the backing table exists.
67    ///
68    /// This is idempotent and safe to call on every application startup.
69    pub async fn initialize(&self) -> Result<(), SynapticError> {
70        validate_table_name(&self.config.table_name)?;
71
72        let create_table = format!(
73            r#"CREATE TABLE IF NOT EXISTS {table} (
74                key        TEXT PRIMARY KEY,
75                value      TEXT NOT NULL,
76                created_at BIGINT NOT NULL DEFAULT (EXTRACT(EPOCH FROM now())::BIGINT)
77            )"#,
78            table = self.config.table_name,
79        );
80
81        sqlx::query(&create_table)
82            .execute(&self.pool)
83            .await
84            .map_err(|e| SynapticError::Cache(format!("failed to create table: {e}")))?;
85
86        Ok(())
87    }
88
89    /// Return a reference to the underlying connection pool.
90    pub fn pool(&self) -> &PgPool {
91        &self.pool
92    }
93
94    /// Return a reference to the configuration.
95    pub fn config(&self) -> &PgCacheConfig {
96        &self.config
97    }
98}
99
100#[async_trait]
101impl synaptic_core::LlmCache for PgCache {
102    async fn get(&self, key: &str) -> Result<Option<ChatResponse>, SynapticError> {
103        validate_table_name(&self.config.table_name)?;
104
105        let json_str: Option<String> = if let Some(ttl) = self.config.ttl {
106            let sql = format!(
107                "SELECT value FROM {table} WHERE key = $1 AND created_at + $2 > EXTRACT(EPOCH FROM now())::BIGINT",
108                table = self.config.table_name,
109            );
110            sqlx::query_scalar(&sql)
111                .bind(key)
112                .bind(ttl as i64)
113                .fetch_optional(&self.pool)
114                .await
115                .map_err(|e| SynapticError::Cache(format!("query error: {e}")))?
116        } else {
117            let sql = format!(
118                "SELECT value FROM {table} WHERE key = $1",
119                table = self.config.table_name,
120            );
121            sqlx::query_scalar(&sql)
122                .bind(key)
123                .fetch_optional(&self.pool)
124                .await
125                .map_err(|e| SynapticError::Cache(format!("query error: {e}")))?
126        };
127
128        match json_str {
129            Some(s) => {
130                let response: ChatResponse = serde_json::from_str(&s)
131                    .map_err(|e| SynapticError::Cache(format!("JSON deserialize error: {e}")))?;
132                Ok(Some(response))
133            }
134            None => Ok(None),
135        }
136    }
137
138    async fn put(&self, key: &str, response: &ChatResponse) -> Result<(), SynapticError> {
139        validate_table_name(&self.config.table_name)?;
140
141        let value = serde_json::to_string(response)
142            .map_err(|e| SynapticError::Cache(format!("JSON serialize error: {e}")))?;
143
144        let sql = format!(
145            r#"INSERT INTO {table} (key, value, created_at)
146               VALUES ($1, $2, EXTRACT(EPOCH FROM now())::BIGINT)
147               ON CONFLICT (key) DO UPDATE
148               SET value = EXCLUDED.value,
149                   created_at = EXCLUDED.created_at"#,
150            table = self.config.table_name,
151        );
152
153        sqlx::query(&sql)
154            .bind(key)
155            .bind(&value)
156            .execute(&self.pool)
157            .await
158            .map_err(|e| SynapticError::Cache(format!("insert error: {e}")))?;
159
160        Ok(())
161    }
162
163    async fn clear(&self) -> Result<(), SynapticError> {
164        validate_table_name(&self.config.table_name)?;
165
166        let sql = format!("DELETE FROM {table}", table = self.config.table_name);
167
168        sqlx::query(&sql)
169            .execute(&self.pool)
170            .await
171            .map_err(|e| SynapticError::Cache(format!("delete error: {e}")))?;
172
173        Ok(())
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn config_construction() {
183        let config = PgCacheConfig::new("my_cache");
184        assert_eq!(config.table_name, "my_cache");
185        assert!(config.ttl.is_none());
186    }
187
188    #[test]
189    fn config_with_ttl() {
190        let config = PgCacheConfig::new("my_cache").with_ttl(3600);
191        assert_eq!(config.table_name, "my_cache");
192        assert_eq!(config.ttl, Some(3600));
193    }
194
195    #[test]
196    fn validate_table_name_accepts_valid_names() {
197        assert!(validate_table_name("llm_cache").is_ok());
198        assert!(validate_table_name("my_cache").is_ok());
199        assert!(validate_table_name("public.llm_cache").is_ok());
200        assert!(validate_table_name("schema1.cache2").is_ok());
201    }
202
203    #[test]
204    fn validate_table_name_rejects_sql_injection() {
205        assert!(validate_table_name("cache; DROP TABLE users").is_err());
206        assert!(validate_table_name("cache--comment").is_err());
207        assert!(validate_table_name("cache'malicious").is_err());
208        assert!(validate_table_name("").is_err());
209    }
210}