1use hyperinfer_core::{ChatRequest, ChatResponse};
15use redis::{aio::ConnectionManager, AsyncCommands};
16use sha2::{Digest, Sha256};
17use std::sync::Arc;
18use tokio::sync::Mutex;
19use tracing::{debug, warn};
20
21pub const DEFAULT_TTL_SECS: u64 = 300;
23
24#[derive(Clone)]
26pub struct ExactMatchCache {
27 conn: Option<Arc<Mutex<ConnectionManager>>>,
28 ttl_secs: u64,
29 namespace: String,
31}
32
33impl ExactMatchCache {
34 pub async fn new(redis_url: &str, namespace: &str) -> Self {
37 match redis::Client::open(redis_url) {
38 Ok(client) => match ConnectionManager::new(client).await {
39 Ok(mgr) => {
40 debug!("ExactMatchCache: connected to Redis");
41 Self {
42 conn: Some(Arc::new(Mutex::new(mgr))),
43 ttl_secs: DEFAULT_TTL_SECS,
44 namespace: namespace.to_string(),
45 }
46 }
47 Err(e) => {
48 warn!(
49 "ExactMatchCache: Redis connection failed: {}; cache disabled",
50 e
51 );
52 Self {
53 conn: None,
54 ttl_secs: DEFAULT_TTL_SECS,
55 namespace: namespace.to_string(),
56 }
57 }
58 },
59 Err(e) => {
60 warn!("ExactMatchCache: invalid Redis URL: {}; cache disabled", e);
61 Self {
62 conn: None,
63 ttl_secs: DEFAULT_TTL_SECS,
64 namespace: namespace.to_string(),
65 }
66 }
67 }
68 }
69
70 pub fn with_ttl(mut self, secs: u64) -> Self {
72 self.ttl_secs = secs;
73 self
74 }
75
76 pub fn cache_key(&self, request: &ChatRequest) -> Option<String> {
78 let mut normalized_request = request.clone();
80 normalized_request.stream = None;
81
82 match serde_json::to_string(&normalized_request) {
83 Ok(json) => {
84 let mut hasher = Sha256::new();
85 hasher.update(json.as_bytes());
86 let hash = hex::encode(hasher.finalize());
87 Some(format!("hyperinfer:cache:{}:{}", self.namespace, hash))
88 }
89 Err(e) => {
90 warn!("Cache key serialisation error: {}", e);
91 None
92 }
93 }
94 }
95
96 pub async fn get(&self, request: &ChatRequest) -> Option<ChatResponse> {
100 let conn = self.conn.as_ref()?;
101 let key = self.cache_key(request)?;
102
103 let mut guard = conn.lock().await;
104 let raw: Option<String> = guard.get(&key).await.ok()?;
105 drop(guard);
106
107 let raw = raw?;
108 match serde_json::from_str::<ChatResponse>(&raw) {
109 Ok(resp) => {
110 debug!("Cache HIT for key {}", key);
111 Some(resp)
112 }
113 Err(e) => {
114 warn!("Cache deserialisation error: {}", e);
115 None
116 }
117 }
118 }
119
120 pub async fn set(&self, request: &ChatRequest, response: &ChatResponse) {
124 let conn = match self.conn.as_ref() {
125 Some(c) => c,
126 None => return,
127 };
128
129 let key = match self.cache_key(request) {
130 Some(k) => k,
131 None => return,
132 };
133 let raw = match serde_json::to_string(response) {
134 Ok(s) => s,
135 Err(e) => {
136 warn!("Cache serialisation error: {}", e);
137 return;
138 }
139 };
140
141 let mut guard = conn.lock().await;
142 let result: redis::RedisResult<()> = guard.set_ex(&key, &raw, self.ttl_secs).await;
143 drop(guard);
144
145 if let Err(e) = result {
146 warn!("Cache write error: {}", e);
147 } else {
148 debug!("Cache SET key {} ttl={}s", key, self.ttl_secs);
149 }
150 }
151}
152
153#[cfg(test)]
156mod tests {
157 use super::*;
158 use hyperinfer_core::{
159 types::{ChatMessage, Choice, MessageRole, Usage},
160 ChatRequest, ChatResponse,
161 };
162
163 fn sample_request(model: &str) -> ChatRequest {
164 ChatRequest {
165 model: model.to_string(),
166 messages: vec![ChatMessage {
167 role: MessageRole::User,
168 content: "hello".to_string(),
169 }],
170 max_tokens: Some(100),
171 temperature: None,
172 stream: None,
173 stop: None,
174 }
175 }
176
177 fn sample_response() -> ChatResponse {
178 ChatResponse {
179 id: "resp-test".to_string(),
180 model: model_unused(),
181 choices: vec![Choice {
182 message: ChatMessage {
183 role: MessageRole::Assistant,
184 content: "Hi there!".to_string(),
185 },
186 finish_reason: Some("stop".to_string()),
187 index: 0,
188 }],
189 usage: Usage {
190 input_tokens: 5,
191 output_tokens: 10,
192 },
193 }
194 }
195
196 fn model_unused() -> String {
197 "gpt-4".to_string()
198 }
199
200 #[test]
201 fn test_cache_key_deterministic() {
202 let req = sample_request("gpt-4");
203 let cache = ExactMatchCache {
204 conn: None,
205 ttl_secs: DEFAULT_TTL_SECS,
206 namespace: "test-ns".to_string(),
207 };
208 let k1 = cache.cache_key(&req);
209 let k2 = cache.cache_key(&req);
210 assert_eq!(k1, k2);
211 assert!(k1.unwrap().starts_with("hyperinfer:cache:test-ns:"));
212 }
213
214 #[test]
215 fn test_cache_key_different_models() {
216 let cache = ExactMatchCache {
217 conn: None,
218 ttl_secs: DEFAULT_TTL_SECS,
219 namespace: "test-ns".to_string(),
220 };
221 let k1 = cache.cache_key(&sample_request("gpt-4"));
222 let k2 = cache.cache_key(&sample_request("claude-3"));
223 assert_ne!(k1, k2);
224 }
225
226 #[test]
227 fn test_cache_key_different_messages() {
228 let cache = ExactMatchCache {
229 conn: None,
230 ttl_secs: DEFAULT_TTL_SECS,
231 namespace: "test-ns".to_string(),
232 };
233 let mut r1 = sample_request("gpt-4");
234 let mut r2 = sample_request("gpt-4");
235 r1.messages[0].content = "hello".to_string();
236 r2.messages[0].content = "goodbye".to_string();
237 assert_ne!(cache.cache_key(&r1), cache.cache_key(&r2));
238 }
239
240 #[test]
241 fn test_cache_key_ignores_stream() {
242 let cache = ExactMatchCache {
243 conn: None,
244 ttl_secs: DEFAULT_TTL_SECS,
245 namespace: "test-ns".to_string(),
246 };
247 let mut r1 = sample_request("gpt-4");
248 r1.stream = Some(true);
249
250 let mut r2 = sample_request("gpt-4");
251 r2.stream = Some(false);
252
253 let mut r3 = sample_request("gpt-4");
254 r3.stream = None;
255
256 let k1 = cache.cache_key(&r1);
257 let k2 = cache.cache_key(&r2);
258 let k3 = cache.cache_key(&r3);
259
260 assert_eq!(k1, k2);
261 assert_eq!(k2, k3);
262 }
263
264 #[tokio::test]
265 async fn test_cache_disabled_get_returns_none() {
266 let cache = ExactMatchCache::new("redis://invalid-host:1", "test-ns").await;
268 let req = sample_request("gpt-4");
269 let result = cache.get(&req).await;
270 assert!(result.is_none());
271 }
272
273 #[tokio::test]
274 async fn test_cache_disabled_set_no_panic() {
275 let cache = ExactMatchCache::new("redis://invalid-host:1", "test-ns").await;
276 let req = sample_request("gpt-4");
277 let resp = sample_response();
278 cache.set(&req, &resp).await;
280 }
281
282 #[test]
283 fn test_with_ttl() {
284 let cache = ExactMatchCache {
288 conn: None,
289 ttl_secs: DEFAULT_TTL_SECS,
290 namespace: "test-ns".to_string(),
291 };
292 let cache = cache.with_ttl(60);
293 assert_eq!(cache.ttl_secs, 60);
294 }
295
296 #[tokio::test]
297 async fn test_cache_deserialisation_error() {
298 use testcontainers::{core::IntoContainerPort, runners::AsyncRunner, GenericImage};
302 use testcontainers_modules::redis::REDIS_PORT;
303
304 let container_result = GenericImage::new("redis", "7.2.4")
305 .with_exposed_port(REDIS_PORT.tcp())
306 .with_wait_for(testcontainers::core::WaitFor::message_on_stdout(
307 "Ready to accept connections",
308 ))
309 .start()
310 .await;
311
312 let container = match container_result {
314 Ok(c) => c,
315 Err(e) => {
316 let is_ci = std::env::var("CI").map(|v| v == "true").unwrap_or(false);
317 if is_ci {
318 panic!(
319 "FATAL: testcontainers failed to start Redis in CI environment: {}. \
320 This indicates a test infrastructure issue that must be resolved.",
321 e
322 );
323 } else {
324 println!(
325 "Skipping test: testcontainers failed to start Redis ({})",
326 e
327 );
328 return;
329 }
330 }
331 };
332
333 let port = container
334 .get_host_port_ipv4(REDIS_PORT)
335 .await
336 .expect("Failed to get port");
337 let redis_url = format!("redis://127.0.0.1:{}", port);
338
339 let cache = ExactMatchCache::new(&redis_url, "test-ns-malformed").await;
340
341 let req = sample_request("gpt-4");
342 let key = cache.cache_key(&req).unwrap();
343
344 let client = redis::Client::open(redis_url.as_str()).unwrap();
346 let mut conn = client.get_multiplexed_async_connection().await.unwrap();
347 let _: () = redis::cmd("SET")
348 .arg(&key)
349 .arg("not valid json")
350 .query_async(&mut conn)
351 .await
352 .unwrap();
353
354 let result = cache.get(&req).await;
356 assert!(
357 result.is_none(),
358 "Deserialization error should result in a cache miss (None)"
359 );
360 }
361}