Skip to main content

synaptic_redis/
cache.rs

1use async_trait::async_trait;
2use redis::AsyncCommands;
3use synaptic_core::{ChatResponse, SynapticError};
4
5/// Configuration for [`RedisCache`].
6#[derive(Debug, Clone)]
7pub struct RedisCacheConfig {
8    /// Key prefix for all cache entries. Defaults to `"synaptic:cache:"`.
9    pub prefix: String,
10    /// Optional TTL in seconds. When set, cached entries expire automatically.
11    pub ttl: Option<u64>,
12}
13
14impl Default for RedisCacheConfig {
15    fn default() -> Self {
16        Self {
17            prefix: "synaptic:cache:".to_string(),
18            ttl: None,
19        }
20    }
21}
22
23/// Redis-backed implementation of the [`LlmCache`](synaptic_core::LlmCache) trait.
24///
25/// Stores serialized [`ChatResponse`] values under `{prefix}{key}` with
26/// optional TTL expiration managed by Redis itself.
27pub struct RedisCache {
28    client: redis::Client,
29    config: RedisCacheConfig,
30}
31
32impl RedisCache {
33    /// Create a new `RedisCache` with an existing Redis client and configuration.
34    pub fn new(client: redis::Client, config: RedisCacheConfig) -> Self {
35        Self { client, config }
36    }
37
38    /// Create a new `RedisCache` from a Redis URL with default configuration.
39    pub fn from_url(url: &str) -> Result<Self, SynapticError> {
40        let client = redis::Client::open(url)
41            .map_err(|e| SynapticError::Cache(format!("failed to connect to Redis: {e}")))?;
42        Ok(Self {
43            client,
44            config: RedisCacheConfig::default(),
45        })
46    }
47
48    /// Create a new `RedisCache` from a Redis URL with custom configuration.
49    pub fn from_url_with_config(
50        url: &str,
51        config: RedisCacheConfig,
52    ) -> Result<Self, SynapticError> {
53        let client = redis::Client::open(url)
54            .map_err(|e| SynapticError::Cache(format!("failed to connect to Redis: {e}")))?;
55        Ok(Self { client, config })
56    }
57
58    /// Build the full Redis key for a cache entry.
59    fn redis_key(&self, key: &str) -> String {
60        format!("{}{key}", self.config.prefix)
61    }
62
63    async fn get_connection(
64        &self,
65    ) -> Result<redis::aio::MultiplexedConnection, SynapticError> {
66        self.client
67            .get_multiplexed_async_connection()
68            .await
69            .map_err(|e| SynapticError::Cache(format!("Redis connection error: {e}")))
70    }
71}
72
73/// Helper to GET a key from Redis as an `Option<String>`.
74async fn redis_get_string(
75    con: &mut redis::aio::MultiplexedConnection,
76    key: &str,
77) -> Result<Option<String>, SynapticError> {
78    let raw: Option<String> = con
79        .get(key)
80        .await
81        .map_err(|e| SynapticError::Cache(format!("Redis GET error: {e}")))?;
82    Ok(raw)
83}
84
85#[async_trait]
86impl synaptic_core::LlmCache for RedisCache {
87    async fn get(&self, key: &str) -> Result<Option<ChatResponse>, SynapticError> {
88        let mut con = self.get_connection().await?;
89        let redis_key = self.redis_key(key);
90
91        let raw = redis_get_string(&mut con, &redis_key).await?;
92
93        match raw {
94            Some(json_str) => {
95                let response: ChatResponse = serde_json::from_str(&json_str)
96                    .map_err(|e| SynapticError::Cache(format!("JSON deserialize error: {e}")))?;
97                Ok(Some(response))
98            }
99            None => Ok(None),
100        }
101    }
102
103    async fn put(&self, key: &str, response: &ChatResponse) -> Result<(), SynapticError> {
104        let mut con = self.get_connection().await?;
105        let redis_key = self.redis_key(key);
106
107        let json_str = serde_json::to_string(response)
108            .map_err(|e| SynapticError::Cache(format!("JSON serialize error: {e}")))?;
109
110        con.set::<_, _, ()>(&redis_key, &json_str)
111            .await
112            .map_err(|e| SynapticError::Cache(format!("Redis SET error: {e}")))?;
113
114        // Apply TTL if configured
115        if let Some(ttl_secs) = self.config.ttl {
116            con.expire::<_, ()>(&redis_key, ttl_secs as i64)
117                .await
118                .map_err(|e| SynapticError::Cache(format!("Redis EXPIRE error: {e}")))?;
119        }
120
121        Ok(())
122    }
123
124    async fn clear(&self) -> Result<(), SynapticError> {
125        let mut con = self.get_connection().await?;
126        let pattern = format!("{}*", self.config.prefix);
127
128        // Collect all matching keys via SCAN, then delete them
129        let mut cursor: u64 = 0;
130        loop {
131            let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
132                .arg(cursor)
133                .arg("MATCH")
134                .arg(&pattern)
135                .arg("COUNT")
136                .arg(100)
137                .query_async(&mut con)
138                .await
139                .map_err(|e| SynapticError::Cache(format!("Redis SCAN error: {e}")))?;
140
141            if !keys.is_empty() {
142                con.del::<_, ()>(&keys)
143                    .await
144                    .map_err(|e| SynapticError::Cache(format!("Redis DEL error: {e}")))?;
145            }
146
147            cursor = next_cursor;
148            if cursor == 0 {
149                break;
150            }
151        }
152
153        Ok(())
154    }
155}