Skip to main content

hyperinfer_client/
cache.rs

1//! Redis-backed exact-match response cache.
2//!
3//! The cache key is a SHA-256 hash of the canonical JSON serialization of the
4//! [`ChatRequest`].  Because the request struct derives [`Serialize`], the
5//! hash is deterministic for identical requests regardless of field insertion
6//! order (serde always serialises struct fields in declaration order).
7//!
8//! Cache entries expire after [`DEFAULT_TTL_SECS`] seconds; callers can
9//! override this via [`ExactMatchCache::with_ttl`].
10//!
11//! The cache gracefully degrades: if Redis is unavailable all `get`/`set`
12//! calls return `None`/`Ok(())` without surfacing errors to the caller.
13
14use hyperinfer_core::{ChatRequest, ChatResponse};
15use redis::{aio::ConnectionManager, AsyncCommands};
16use sha2::{Digest, Sha256};
17use std::sync::Arc;
18use tokio::sync::Mutex;
19use tracing::{debug, warn};
20
21/// Default TTL for cached responses (5 minutes).
22pub const DEFAULT_TTL_SECS: u64 = 300;
23
24/// Exact-match Redis cache for [`ChatResponse`] values.
25#[derive(Clone)]
26pub struct ExactMatchCache {
27    conn: Option<Arc<Mutex<ConnectionManager>>>,
28    ttl_secs: u64,
29    /// Namespace for cache keys to avoid cross-client collisions.
30    namespace: String,
31}
32
33impl ExactMatchCache {
34    /// Connect to Redis at `redis_url`.  On failure the cache is disabled and
35    /// all operations become no-ops.
36    pub async fn new(redis_url: &str, namespace: &str) -> Self {
37        match redis::Client::open(redis_url) {
38            Ok(client) => match ConnectionManager::new(client).await {
39                Ok(mgr) => {
40                    debug!("ExactMatchCache: connected to Redis");
41                    Self {
42                        conn: Some(Arc::new(Mutex::new(mgr))),
43                        ttl_secs: DEFAULT_TTL_SECS,
44                        namespace: namespace.to_string(),
45                    }
46                }
47                Err(e) => {
48                    warn!(
49                        "ExactMatchCache: Redis connection failed: {}; cache disabled",
50                        e
51                    );
52                    Self {
53                        conn: None,
54                        ttl_secs: DEFAULT_TTL_SECS,
55                        namespace: namespace.to_string(),
56                    }
57                }
58            },
59            Err(e) => {
60                warn!("ExactMatchCache: invalid Redis URL: {}; cache disabled", e);
61                Self {
62                    conn: None,
63                    ttl_secs: DEFAULT_TTL_SECS,
64                    namespace: namespace.to_string(),
65                }
66            }
67        }
68    }
69
70    /// Override the cache TTL.  Returns `self` for chaining.
71    pub fn with_ttl(mut self, secs: u64) -> Self {
72        self.ttl_secs = secs;
73        self
74    }
75
76    /// Compute the cache key for `request`.
77    pub fn cache_key(&self, request: &ChatRequest) -> Option<String> {
78        // Clone and normalize to ignore streaming preference
79        let mut normalized_request = request.clone();
80        normalized_request.stream = None;
81
82        match serde_json::to_string(&normalized_request) {
83            Ok(json) => {
84                let mut hasher = Sha256::new();
85                hasher.update(json.as_bytes());
86                let hash = hex::encode(hasher.finalize());
87                Some(format!("hyperinfer:cache:{}:{}", self.namespace, hash))
88            }
89            Err(e) => {
90                warn!("Cache key serialisation error: {}", e);
91                None
92            }
93        }
94    }
95
96    /// Attempt to retrieve a cached [`ChatResponse`] for `request`.
97    ///
98    /// Returns `None` on cache miss, Redis error, or deserialisation failure.
99    pub async fn get(&self, request: &ChatRequest) -> Option<ChatResponse> {
100        let conn = self.conn.as_ref()?;
101        let key = self.cache_key(request)?;
102
103        let mut guard = conn.lock().await;
104        let raw: Option<String> = guard.get(&key).await.ok()?;
105        drop(guard);
106
107        let raw = raw?;
108        match serde_json::from_str::<ChatResponse>(&raw) {
109            Ok(resp) => {
110                debug!("Cache HIT for key {}", key);
111                Some(resp)
112            }
113            Err(e) => {
114                warn!("Cache deserialisation error: {}", e);
115                None
116            }
117        }
118    }
119
120    /// Store `response` in the cache under the key derived from `request`.
121    ///
122    /// Silently ignores serialisation and Redis errors.
123    pub async fn set(&self, request: &ChatRequest, response: &ChatResponse) {
124        let conn = match self.conn.as_ref() {
125            Some(c) => c,
126            None => return,
127        };
128
129        let key = match self.cache_key(request) {
130            Some(k) => k,
131            None => return,
132        };
133        let raw = match serde_json::to_string(response) {
134            Ok(s) => s,
135            Err(e) => {
136                warn!("Cache serialisation error: {}", e);
137                return;
138            }
139        };
140
141        let mut guard = conn.lock().await;
142        let result: redis::RedisResult<()> = guard.set_ex(&key, &raw, self.ttl_secs).await;
143        drop(guard);
144
145        if let Err(e) = result {
146            warn!("Cache write error: {}", e);
147        } else {
148            debug!("Cache SET key {} ttl={}s", key, self.ttl_secs);
149        }
150    }
151}
152
153// ── Tests ─────────────────────────────────────────────────────────────────────
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use hyperinfer_core::{
159        types::{ChatMessage, Choice, MessageRole, Usage},
160        ChatRequest, ChatResponse,
161    };
162
163    fn sample_request(model: &str) -> ChatRequest {
164        ChatRequest {
165            model: model.to_string(),
166            messages: vec![ChatMessage {
167                role: MessageRole::User,
168                content: "hello".to_string(),
169            }],
170            max_tokens: Some(100),
171            temperature: None,
172            stream: None,
173            stop: None,
174        }
175    }
176
177    fn sample_response() -> ChatResponse {
178        ChatResponse {
179            id: "resp-test".to_string(),
180            model: model_unused(),
181            choices: vec![Choice {
182                message: ChatMessage {
183                    role: MessageRole::Assistant,
184                    content: "Hi there!".to_string(),
185                },
186                finish_reason: Some("stop".to_string()),
187                index: 0,
188            }],
189            usage: Usage {
190                input_tokens: 5,
191                output_tokens: 10,
192            },
193        }
194    }
195
196    fn model_unused() -> String {
197        "gpt-4".to_string()
198    }
199
200    #[test]
201    fn test_cache_key_deterministic() {
202        let req = sample_request("gpt-4");
203        let cache = ExactMatchCache {
204            conn: None,
205            ttl_secs: DEFAULT_TTL_SECS,
206            namespace: "test-ns".to_string(),
207        };
208        let k1 = cache.cache_key(&req);
209        let k2 = cache.cache_key(&req);
210        assert_eq!(k1, k2);
211        assert!(k1.unwrap().starts_with("hyperinfer:cache:test-ns:"));
212    }
213
214    #[test]
215    fn test_cache_key_different_models() {
216        let cache = ExactMatchCache {
217            conn: None,
218            ttl_secs: DEFAULT_TTL_SECS,
219            namespace: "test-ns".to_string(),
220        };
221        let k1 = cache.cache_key(&sample_request("gpt-4"));
222        let k2 = cache.cache_key(&sample_request("claude-3"));
223        assert_ne!(k1, k2);
224    }
225
226    #[test]
227    fn test_cache_key_different_messages() {
228        let cache = ExactMatchCache {
229            conn: None,
230            ttl_secs: DEFAULT_TTL_SECS,
231            namespace: "test-ns".to_string(),
232        };
233        let mut r1 = sample_request("gpt-4");
234        let mut r2 = sample_request("gpt-4");
235        r1.messages[0].content = "hello".to_string();
236        r2.messages[0].content = "goodbye".to_string();
237        assert_ne!(cache.cache_key(&r1), cache.cache_key(&r2));
238    }
239
240    #[test]
241    fn test_cache_key_ignores_stream() {
242        let cache = ExactMatchCache {
243            conn: None,
244            ttl_secs: DEFAULT_TTL_SECS,
245            namespace: "test-ns".to_string(),
246        };
247        let mut r1 = sample_request("gpt-4");
248        r1.stream = Some(true);
249
250        let mut r2 = sample_request("gpt-4");
251        r2.stream = Some(false);
252
253        let mut r3 = sample_request("gpt-4");
254        r3.stream = None;
255
256        let k1 = cache.cache_key(&r1);
257        let k2 = cache.cache_key(&r2);
258        let k3 = cache.cache_key(&r3);
259
260        assert_eq!(k1, k2);
261        assert_eq!(k2, k3);
262    }
263
264    #[tokio::test]
265    async fn test_cache_disabled_get_returns_none() {
266        // Build a cache with an invalid URL → disabled.
267        let cache = ExactMatchCache::new("redis://invalid-host:1", "test-ns").await;
268        let req = sample_request("gpt-4");
269        let result = cache.get(&req).await;
270        assert!(result.is_none());
271    }
272
273    #[tokio::test]
274    async fn test_cache_disabled_set_no_panic() {
275        let cache = ExactMatchCache::new("redis://invalid-host:1", "test-ns").await;
276        let req = sample_request("gpt-4");
277        let resp = sample_response();
278        // Should not panic.
279        cache.set(&req, &resp).await;
280    }
281
282    #[test]
283    fn test_with_ttl() {
284        // Verify the builder stores the custom TTL.
285        // We can't easily call async new in a sync test, so test the field
286        // directly by constructing a disabled cache inline.
287        let cache = ExactMatchCache {
288            conn: None,
289            ttl_secs: DEFAULT_TTL_SECS,
290            namespace: "test-ns".to_string(),
291        };
292        let cache = cache.with_ttl(60);
293        assert_eq!(cache.ttl_secs, 60);
294    }
295
296    #[tokio::test]
297    async fn test_cache_deserialisation_error() {
298        // We use testcontainers to reliably spin up a redis instance to test this properly,
299        // covering both ExactMatchCache struct's internal `get()` interaction with ConnectionManager
300        // and deserialization match arms seamlessly.
301        use testcontainers::{core::IntoContainerPort, runners::AsyncRunner, GenericImage};
302        use testcontainers_modules::redis::REDIS_PORT;
303
304        let container_result = GenericImage::new("redis", "7.2.4")
305            .with_exposed_port(REDIS_PORT.tcp())
306            .with_wait_for(testcontainers::core::WaitFor::message_on_stdout(
307                "Ready to accept connections",
308            ))
309            .start()
310            .await;
311
312        // In CI some docker socket configurations may fail - fail the test explicitly in CI, skip gracefully locally
313        let container = match container_result {
314            Ok(c) => c,
315            Err(e) => {
316                let is_ci = std::env::var("CI").map(|v| v == "true").unwrap_or(false);
317                if is_ci {
318                    panic!(
319                        "FATAL: testcontainers failed to start Redis in CI environment: {}. \
320                         This indicates a test infrastructure issue that must be resolved.",
321                        e
322                    );
323                } else {
324                    println!(
325                        "Skipping test: testcontainers failed to start Redis ({})",
326                        e
327                    );
328                    return;
329                }
330            }
331        };
332
333        let port = container
334            .get_host_port_ipv4(REDIS_PORT)
335            .await
336            .expect("Failed to get port");
337        let redis_url = format!("redis://127.0.0.1:{}", port);
338
339        let cache = ExactMatchCache::new(&redis_url, "test-ns-malformed").await;
340
341        let req = sample_request("gpt-4");
342        let key = cache.cache_key(&req).unwrap();
343
344        // Directly insert malformed JSON into Redis
345        let client = redis::Client::open(redis_url.as_str()).unwrap();
346        let mut conn = client.get_multiplexed_async_connection().await.unwrap();
347        let _: () = redis::cmd("SET")
348            .arg(&key)
349            .arg("not valid json")
350            .query_async(&mut conn)
351            .await
352            .unwrap();
353
354        // The get should return None and not panic
355        let result = cache.get(&req).await;
356        assert!(
357            result.is_none(),
358            "Deserialization error should result in a cache miss (None)"
359        );
360    }
361}