Skip to main content

synaptic_redis/
cache.rs

1use async_trait::async_trait;
2use redis::AsyncCommands;
3use synaptic_core::{ChatResponse, SynapticError};
4
5/// Configuration for [`RedisCache`].
6#[derive(Debug, Clone)]
7pub struct RedisCacheConfig {
8    /// Key prefix for all cache entries. Defaults to `"synaptic:cache:"`.
9    pub prefix: String,
10    /// Optional TTL in seconds. When set, cached entries expire automatically.
11    pub ttl: Option<u64>,
12}
13
14impl Default for RedisCacheConfig {
15    fn default() -> Self {
16        Self {
17            prefix: "synaptic:cache:".to_string(),
18            ttl: None,
19        }
20    }
21}
22
23/// Redis-backed implementation of the [`LlmCache`](synaptic_core::LlmCache) trait.
24///
25/// Stores serialized [`ChatResponse`] values under `{prefix}{key}` with
26/// optional TTL expiration managed by Redis itself.
27pub struct RedisCache {
28    client: redis::Client,
29    config: RedisCacheConfig,
30}
31
32impl RedisCache {
33    /// Create a new `RedisCache` with an existing Redis client and configuration.
34    pub fn new(client: redis::Client, config: RedisCacheConfig) -> Self {
35        Self { client, config }
36    }
37
38    /// Create a new `RedisCache` from a Redis URL with default configuration.
39    pub fn from_url(url: &str) -> Result<Self, SynapticError> {
40        let client = redis::Client::open(url)
41            .map_err(|e| SynapticError::Cache(format!("failed to connect to Redis: {e}")))?;
42        Ok(Self {
43            client,
44            config: RedisCacheConfig::default(),
45        })
46    }
47
48    /// Create a new `RedisCache` from a Redis URL with custom configuration.
49    pub fn from_url_with_config(
50        url: &str,
51        config: RedisCacheConfig,
52    ) -> Result<Self, SynapticError> {
53        let client = redis::Client::open(url)
54            .map_err(|e| SynapticError::Cache(format!("failed to connect to Redis: {e}")))?;
55        Ok(Self { client, config })
56    }
57
58    /// Build the full Redis key for a cache entry.
59    fn redis_key(&self, key: &str) -> String {
60        format!("{}{key}", self.config.prefix)
61    }
62
63    async fn get_connection(&self) -> Result<redis::aio::MultiplexedConnection, SynapticError> {
64        self.client
65            .get_multiplexed_async_connection()
66            .await
67            .map_err(|e| SynapticError::Cache(format!("Redis connection error: {e}")))
68    }
69}
70
71/// Helper to GET a key from Redis as an `Option<String>`.
72async fn redis_get_string(
73    con: &mut redis::aio::MultiplexedConnection,
74    key: &str,
75) -> Result<Option<String>, SynapticError> {
76    let raw: Option<String> = con
77        .get(key)
78        .await
79        .map_err(|e| SynapticError::Cache(format!("Redis GET error: {e}")))?;
80    Ok(raw)
81}
82
83#[async_trait]
84impl synaptic_core::LlmCache for RedisCache {
85    async fn get(&self, key: &str) -> Result<Option<ChatResponse>, SynapticError> {
86        let mut con = self.get_connection().await?;
87        let redis_key = self.redis_key(key);
88
89        let raw = redis_get_string(&mut con, &redis_key).await?;
90
91        match raw {
92            Some(json_str) => {
93                let response: ChatResponse = serde_json::from_str(&json_str)
94                    .map_err(|e| SynapticError::Cache(format!("JSON deserialize error: {e}")))?;
95                Ok(Some(response))
96            }
97            None => Ok(None),
98        }
99    }
100
101    async fn put(&self, key: &str, response: &ChatResponse) -> Result<(), SynapticError> {
102        let mut con = self.get_connection().await?;
103        let redis_key = self.redis_key(key);
104
105        let json_str = serde_json::to_string(response)
106            .map_err(|e| SynapticError::Cache(format!("JSON serialize error: {e}")))?;
107
108        con.set::<_, _, ()>(&redis_key, &json_str)
109            .await
110            .map_err(|e| SynapticError::Cache(format!("Redis SET error: {e}")))?;
111
112        // Apply TTL if configured
113        if let Some(ttl_secs) = self.config.ttl {
114            con.expire::<_, ()>(&redis_key, ttl_secs as i64)
115                .await
116                .map_err(|e| SynapticError::Cache(format!("Redis EXPIRE error: {e}")))?;
117        }
118
119        Ok(())
120    }
121
122    async fn clear(&self) -> Result<(), SynapticError> {
123        let mut con = self.get_connection().await?;
124        let pattern = format!("{}*", self.config.prefix);
125
126        // Collect all matching keys via SCAN, then delete them
127        let mut cursor: u64 = 0;
128        loop {
129            let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
130                .arg(cursor)
131                .arg("MATCH")
132                .arg(&pattern)
133                .arg("COUNT")
134                .arg(100)
135                .query_async(&mut con)
136                .await
137                .map_err(|e| SynapticError::Cache(format!("Redis SCAN error: {e}")))?;
138
139            if !keys.is_empty() {
140                con.del::<_, ()>(&keys)
141                    .await
142                    .map_err(|e| SynapticError::Cache(format!("Redis DEL error: {e}")))?;
143            }
144
145            cursor = next_cursor;
146            if cursor == 0 {
147                break;
148            }
149        }
150
151        Ok(())
152    }
153}