use async_trait::async_trait;
use sqlx::PgPool;
use synaptic_core::{validate_table_name, ChatResponse, SynapticError};
#[derive(Debug, Clone)]
pub struct PgCacheConfig {
pub table_name: String,
pub ttl: Option<u64>,
}
impl PgCacheConfig {
pub fn new(table_name: impl Into<String>) -> Self {
Self {
table_name: table_name.into(),
ttl: None,
}
}
pub fn with_ttl(mut self, seconds: u64) -> Self {
self.ttl = Some(seconds);
self
}
}
pub struct PgCache {
pool: PgPool,
config: PgCacheConfig,
}
impl PgCache {
pub fn new(pool: PgPool, config: PgCacheConfig) -> Self {
Self { pool, config }
}
pub async fn initialize(&self) -> Result<(), SynapticError> {
validate_table_name(&self.config.table_name)?;
let create_table = format!(
r#"CREATE TABLE IF NOT EXISTS {table} (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
created_at BIGINT NOT NULL DEFAULT (EXTRACT(EPOCH FROM now())::BIGINT)
)"#,
table = self.config.table_name,
);
sqlx::query(&create_table)
.execute(&self.pool)
.await
.map_err(|e| SynapticError::Cache(format!("failed to create table: {e}")))?;
Ok(())
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub fn config(&self) -> &PgCacheConfig {
&self.config
}
}
#[async_trait]
impl synaptic_core::LlmCache for PgCache {
async fn get(&self, key: &str) -> Result<Option<ChatResponse>, SynapticError> {
validate_table_name(&self.config.table_name)?;
let json_str: Option<String> = if let Some(ttl) = self.config.ttl {
let sql = format!(
"SELECT value FROM {table} WHERE key = $1 AND created_at + $2 > EXTRACT(EPOCH FROM now())::BIGINT",
table = self.config.table_name,
);
sqlx::query_scalar(&sql)
.bind(key)
.bind(ttl as i64)
.fetch_optional(&self.pool)
.await
.map_err(|e| SynapticError::Cache(format!("query error: {e}")))?
} else {
let sql = format!(
"SELECT value FROM {table} WHERE key = $1",
table = self.config.table_name,
);
sqlx::query_scalar(&sql)
.bind(key)
.fetch_optional(&self.pool)
.await
.map_err(|e| SynapticError::Cache(format!("query error: {e}")))?
};
match json_str {
Some(s) => {
let response: ChatResponse = serde_json::from_str(&s)
.map_err(|e| SynapticError::Cache(format!("JSON deserialize error: {e}")))?;
Ok(Some(response))
}
None => Ok(None),
}
}
async fn put(&self, key: &str, response: &ChatResponse) -> Result<(), SynapticError> {
validate_table_name(&self.config.table_name)?;
let value = serde_json::to_string(response)
.map_err(|e| SynapticError::Cache(format!("JSON serialize error: {e}")))?;
let sql = format!(
r#"INSERT INTO {table} (key, value, created_at)
VALUES ($1, $2, EXTRACT(EPOCH FROM now())::BIGINT)
ON CONFLICT (key) DO UPDATE
SET value = EXCLUDED.value,
created_at = EXCLUDED.created_at"#,
table = self.config.table_name,
);
sqlx::query(&sql)
.bind(key)
.bind(&value)
.execute(&self.pool)
.await
.map_err(|e| SynapticError::Cache(format!("insert error: {e}")))?;
Ok(())
}
async fn clear(&self) -> Result<(), SynapticError> {
validate_table_name(&self.config.table_name)?;
let sql = format!("DELETE FROM {table}", table = self.config.table_name);
sqlx::query(&sql)
.execute(&self.pool)
.await
.map_err(|e| SynapticError::Cache(format!("delete error: {e}")))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_construction() {
let config = PgCacheConfig::new("my_cache");
assert_eq!(config.table_name, "my_cache");
assert!(config.ttl.is_none());
}
#[test]
fn config_with_ttl() {
let config = PgCacheConfig::new("my_cache").with_ttl(3600);
assert_eq!(config.table_name, "my_cache");
assert_eq!(config.ttl, Some(3600));
}
#[test]
fn validate_table_name_accepts_valid_names() {
assert!(validate_table_name("llm_cache").is_ok());
assert!(validate_table_name("my_cache").is_ok());
assert!(validate_table_name("public.llm_cache").is_ok());
assert!(validate_table_name("schema1.cache2").is_ok());
}
#[test]
fn validate_table_name_rejects_sql_injection() {
assert!(validate_table_name("cache; DROP TABLE users").is_err());
assert!(validate_table_name("cache--comment").is_err());
assert!(validate_table_name("cache'malicious").is_err());
assert!(validate_table_name("").is_err());
}
}