use hyperinfer_core::{ChatRequest, ChatResponse};
use redis::{aio::ConnectionManager, AsyncCommands};
use sha2::{Digest, Sha256};
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, warn};
pub const DEFAULT_TTL_SECS: u64 = 300;
#[derive(Clone)]
pub struct ExactMatchCache {
conn: Option<Arc<Mutex<ConnectionManager>>>,
ttl_secs: u64,
namespace: String,
}
impl ExactMatchCache {
pub async fn new(redis_url: &str, namespace: &str) -> Self {
match redis::Client::open(redis_url) {
Ok(client) => match ConnectionManager::new(client).await {
Ok(mgr) => {
debug!("ExactMatchCache: connected to Redis");
Self {
conn: Some(Arc::new(Mutex::new(mgr))),
ttl_secs: DEFAULT_TTL_SECS,
namespace: namespace.to_string(),
}
}
Err(e) => {
warn!(
"ExactMatchCache: Redis connection failed: {}; cache disabled",
e
);
Self {
conn: None,
ttl_secs: DEFAULT_TTL_SECS,
namespace: namespace.to_string(),
}
}
},
Err(e) => {
warn!("ExactMatchCache: invalid Redis URL: {}; cache disabled", e);
Self {
conn: None,
ttl_secs: DEFAULT_TTL_SECS,
namespace: namespace.to_string(),
}
}
}
}
pub fn with_ttl(mut self, secs: u64) -> Self {
self.ttl_secs = secs;
self
}
pub fn cache_key(&self, request: &ChatRequest) -> Option<String> {
let mut normalized_request = request.clone();
normalized_request.stream = None;
match serde_json::to_string(&normalized_request) {
Ok(json) => {
let mut hasher = Sha256::new();
hasher.update(json.as_bytes());
let hash = hex::encode(hasher.finalize());
Some(format!("hyperinfer:cache:{}:{}", self.namespace, hash))
}
Err(e) => {
warn!("Cache key serialisation error: {}", e);
None
}
}
}
pub async fn get(&self, request: &ChatRequest) -> Option<ChatResponse> {
let conn = self.conn.as_ref()?;
let key = self.cache_key(request)?;
let mut guard = conn.lock().await;
let raw: Option<String> = guard.get(&key).await.ok()?;
drop(guard);
let raw = raw?;
match serde_json::from_str::<ChatResponse>(&raw) {
Ok(resp) => {
debug!("Cache HIT for key {}", key);
Some(resp)
}
Err(e) => {
warn!("Cache deserialisation error: {}", e);
None
}
}
}
pub async fn set(&self, request: &ChatRequest, response: &ChatResponse) {
let conn = match self.conn.as_ref() {
Some(c) => c,
None => return,
};
let key = match self.cache_key(request) {
Some(k) => k,
None => return,
};
let raw = match serde_json::to_string(response) {
Ok(s) => s,
Err(e) => {
warn!("Cache serialisation error: {}", e);
return;
}
};
let mut guard = conn.lock().await;
let result: redis::RedisResult<()> = guard.set_ex(&key, &raw, self.ttl_secs).await;
drop(guard);
if let Err(e) = result {
warn!("Cache write error: {}", e);
} else {
debug!("Cache SET key {} ttl={}s", key, self.ttl_secs);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use hyperinfer_core::{
types::{ChatMessage, Choice, MessageRole, Usage},
ChatRequest, ChatResponse,
};
fn sample_request(model: &str) -> ChatRequest {
ChatRequest {
model: model.to_string(),
messages: vec![ChatMessage {
role: MessageRole::User,
content: "hello".to_string(),
}],
max_tokens: Some(100),
temperature: None,
stream: None,
stop: None,
}
}
fn sample_response() -> ChatResponse {
ChatResponse {
id: "resp-test".to_string(),
model: model_unused(),
choices: vec![Choice {
message: ChatMessage {
role: MessageRole::Assistant,
content: "Hi there!".to_string(),
},
finish_reason: Some("stop".to_string()),
index: 0,
}],
usage: Usage {
input_tokens: 5,
output_tokens: 10,
},
}
}
fn model_unused() -> String {
"gpt-4".to_string()
}
#[test]
fn test_cache_key_deterministic() {
let req = sample_request("gpt-4");
let cache = ExactMatchCache {
conn: None,
ttl_secs: DEFAULT_TTL_SECS,
namespace: "test-ns".to_string(),
};
let k1 = cache.cache_key(&req);
let k2 = cache.cache_key(&req);
assert_eq!(k1, k2);
assert!(k1.unwrap().starts_with("hyperinfer:cache:test-ns:"));
}
#[test]
fn test_cache_key_different_models() {
let cache = ExactMatchCache {
conn: None,
ttl_secs: DEFAULT_TTL_SECS,
namespace: "test-ns".to_string(),
};
let k1 = cache.cache_key(&sample_request("gpt-4"));
let k2 = cache.cache_key(&sample_request("claude-3"));
assert_ne!(k1, k2);
}
#[test]
fn test_cache_key_different_messages() {
let cache = ExactMatchCache {
conn: None,
ttl_secs: DEFAULT_TTL_SECS,
namespace: "test-ns".to_string(),
};
let mut r1 = sample_request("gpt-4");
let mut r2 = sample_request("gpt-4");
r1.messages[0].content = "hello".to_string();
r2.messages[0].content = "goodbye".to_string();
assert_ne!(cache.cache_key(&r1), cache.cache_key(&r2));
}
#[test]
fn test_cache_key_ignores_stream() {
let cache = ExactMatchCache {
conn: None,
ttl_secs: DEFAULT_TTL_SECS,
namespace: "test-ns".to_string(),
};
let mut r1 = sample_request("gpt-4");
r1.stream = Some(true);
let mut r2 = sample_request("gpt-4");
r2.stream = Some(false);
let mut r3 = sample_request("gpt-4");
r3.stream = None;
let k1 = cache.cache_key(&r1);
let k2 = cache.cache_key(&r2);
let k3 = cache.cache_key(&r3);
assert_eq!(k1, k2);
assert_eq!(k2, k3);
}
#[tokio::test]
async fn test_cache_disabled_get_returns_none() {
let cache = ExactMatchCache::new("redis://invalid-host:1", "test-ns").await;
let req = sample_request("gpt-4");
let result = cache.get(&req).await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_cache_disabled_set_no_panic() {
let cache = ExactMatchCache::new("redis://invalid-host:1", "test-ns").await;
let req = sample_request("gpt-4");
let resp = sample_response();
cache.set(&req, &resp).await;
}
#[test]
fn test_with_ttl() {
let cache = ExactMatchCache {
conn: None,
ttl_secs: DEFAULT_TTL_SECS,
namespace: "test-ns".to_string(),
};
let cache = cache.with_ttl(60);
assert_eq!(cache.ttl_secs, 60);
}
#[tokio::test]
async fn test_cache_deserialisation_error() {
use testcontainers::{core::IntoContainerPort, runners::AsyncRunner, GenericImage};
use testcontainers_modules::redis::REDIS_PORT;
let container_result = GenericImage::new("redis", "7.2.4")
.with_exposed_port(REDIS_PORT.tcp())
.with_wait_for(testcontainers::core::WaitFor::message_on_stdout(
"Ready to accept connections",
))
.start()
.await;
let container = match container_result {
Ok(c) => c,
Err(e) => {
let is_ci = std::env::var("CI").map(|v| v == "true").unwrap_or(false);
if is_ci {
panic!(
"FATAL: testcontainers failed to start Redis in CI environment: {}. \
This indicates a test infrastructure issue that must be resolved.",
e
);
} else {
println!(
"Skipping test: testcontainers failed to start Redis ({})",
e
);
return;
}
}
};
let port = container
.get_host_port_ipv4(REDIS_PORT)
.await
.expect("Failed to get port");
let redis_url = format!("redis://127.0.0.1:{}", port);
let cache = ExactMatchCache::new(&redis_url, "test-ns-malformed").await;
let req = sample_request("gpt-4");
let key = cache.cache_key(&req).unwrap();
let client = redis::Client::open(redis_url.as_str()).unwrap();
let mut conn = client.get_multiplexed_async_connection().await.unwrap();
let _: () = redis::cmd("SET")
.arg(&key)
.arg("not valid json")
.query_async(&mut conn)
.await
.unwrap();
let result = cache.get(&req).await;
assert!(
result.is_none(),
"Deserialization error should result in a cache miss (None)"
);
}
}