use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
use tower::{Layer, Service};
use super::types::{LlmRequest, LlmResponse};
use crate::client::BoxFuture;
use crate::error::{LiterLlmError, Result};
use crate::types::{ChatCompletionResponse, EmbeddingResponse};
#[derive(Debug, Clone, Default)]
pub enum CacheBackend {
#[default]
Memory,
#[cfg(feature = "opendal-cache")]
OpenDal {
scheme: String,
config: std::collections::HashMap<String, String>,
},
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub ttl: Duration,
pub backend: CacheBackend,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 256,
ttl: Duration::from_secs(300),
backend: CacheBackend::Memory,
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub enum CachedResponse {
Chat(ChatCompletionResponse),
Embed(EmbeddingResponse),
}
impl CachedResponse {
pub fn into_llm_response(self) -> LlmResponse {
match self {
Self::Chat(r) => LlmResponse::Chat(r),
Self::Embed(r) => LlmResponse::Embed(r),
}
}
}
pub trait CacheStore: Send + Sync + 'static {
fn get(&self, key: u64, request_body: &str) -> Pin<Box<dyn Future<Output = Option<CachedResponse>> + Send + '_>>;
fn put(
&self,
key: u64,
request_body: String,
response: CachedResponse,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>>;
fn remove(&self, key: u64) -> Pin<Box<dyn Future<Output = ()> + Send + '_>>;
}
#[derive(Clone)]
struct CacheEntry {
request_body: String,
response: CachedResponse,
inserted_at: Instant,
}
struct InnerCache {
map: HashMap<u64, CacheEntry>,
order: VecDeque<u64>,
max_entries: usize,
ttl: Duration,
}
impl InnerCache {
fn new(config: &CacheConfig) -> Self {
Self {
map: HashMap::new(),
order: VecDeque::new(),
max_entries: config.max_entries,
ttl: config.ttl,
}
}
fn get_if_valid(&self, key: u64, request_body: &str) -> Option<CachedResponse> {
let entry = self.map.get(&key)?;
if entry.request_body != request_body {
return None;
}
if entry.inserted_at.elapsed() > self.ttl {
return None;
}
Some(entry.response.clone())
}
fn is_expired(&self, key: u64) -> bool {
self.map.get(&key).is_some_and(|e| e.inserted_at.elapsed() > self.ttl)
}
fn remove_expired(&mut self, key: u64) {
if self.map.get(&key).is_some_and(|e| e.inserted_at.elapsed() > self.ttl) {
self.map.remove(&key);
}
}
fn insert(&mut self, key: u64, request_body: String, response: CachedResponse) {
if self.map.contains_key(&key) {
self.order.retain(|k| *k != key);
}
while self.map.len() >= self.max_entries {
if let Some(oldest_key) = self.order.pop_front() {
self.map.remove(&oldest_key);
} else {
break;
}
}
self.map.insert(
key,
CacheEntry {
request_body,
response,
inserted_at: Instant::now(),
},
);
self.order.push_back(key);
}
}
pub struct InMemoryStore {
inner: RwLock<InnerCache>,
}
impl InMemoryStore {
#[must_use]
pub fn new(config: &CacheConfig) -> Self {
Self {
inner: RwLock::new(InnerCache::new(config)),
}
}
}
impl CacheStore for InMemoryStore {
fn get(&self, key: u64, request_body: &str) -> Pin<Box<dyn Future<Output = Option<CachedResponse>> + Send + '_>> {
let result = self.inner.read().ok().and_then(|cache| {
let hit = cache.get_if_valid(key, request_body);
let expired = hit.is_none() && cache.is_expired(key);
drop(cache);
if expired && let Ok(mut w) = self.inner.write() {
w.remove_expired(key);
}
hit
});
Box::pin(std::future::ready(result))
}
fn put(
&self,
key: u64,
request_body: String,
response: CachedResponse,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
if let Ok(mut cache) = self.inner.write() {
cache.insert(key, request_body, response);
}
Box::pin(std::future::ready(()))
}
fn remove(&self, key: u64) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
if let Ok(mut cache) = self.inner.write() {
cache.map.remove(&key);
}
Box::pin(std::future::ready(()))
}
}
pub struct CacheLayer {
store: Arc<dyn CacheStore>,
}
impl CacheLayer {
#[must_use]
pub fn new(config: CacheConfig) -> Self {
Self {
store: Arc::new(InMemoryStore::new(&config)),
}
}
#[must_use]
pub fn with_store(store: Arc<dyn CacheStore>) -> Self {
Self { store }
}
}
impl<S> Layer<S> for CacheLayer {
type Service = CacheService<S>;
fn layer(&self, inner: S) -> Self::Service {
CacheService {
inner,
store: Arc::clone(&self.store),
}
}
}
pub struct CacheService<S> {
inner: S,
store: Arc<dyn CacheStore>,
}
impl<S: Clone> Clone for CacheService<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
store: Arc::clone(&self.store),
}
}
}
fn cache_key(req: &LlmRequest) -> Option<(u64, String)> {
let json = match req {
LlmRequest::Chat(r) => serde_json::to_string(r).ok()?,
LlmRequest::Embed(r) => serde_json::to_string(r).ok()?,
_ => return None,
};
let mut hasher = DefaultHasher::new();
json.hash(&mut hasher);
Some((hasher.finish(), json))
}
impl<S> Service<LlmRequest> for CacheService<S>
where
S: Service<LlmRequest, Response = LlmResponse, Error = LiterLlmError> + Send + 'static,
S::Future: Send + 'static,
{
type Response = LlmResponse;
type Error = LiterLlmError;
type Future = BoxFuture<'static, LlmResponse>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: LlmRequest) -> Self::Future {
let key_and_body = cache_key(&req);
let store = Arc::clone(&self.store);
let fut = self.inner.call(req);
Box::pin(async move {
if let Some((k, ref body)) = key_and_body
&& let Some(cached) = store.get(k, body).await
{
return Ok(cached.into_llm_response());
}
let resp = fut.await?;
if let Some((k, body)) = key_and_body {
let cached = match &resp {
LlmResponse::Chat(r) => Some(CachedResponse::Chat(r.clone())),
LlmResponse::Embed(r) => Some(CachedResponse::Embed(r.clone())),
_ => None,
};
if let Some(cached) = cached {
store.put(k, body, cached).await;
}
}
Ok(resp)
})
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::Ordering;
use tower::{Layer as _, Service as _};
use super::*;
use crate::tower::service::LlmService;
use crate::tower::tests_common::{MockClient, chat_req};
use crate::tower::types::LlmRequest;
#[tokio::test]
async fn cache_returns_cached_response_on_second_call() {
let config = CacheConfig {
backend: CacheBackend::default(),
max_entries: 10,
ttl: Duration::from_secs(60),
};
let layer = CacheLayer::new(config);
let client = MockClient::ok();
let call_count = Arc::clone(&client.call_count);
let inner = LlmService::new(client);
let mut svc = layer.layer(inner);
svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await.unwrap();
assert_eq!(call_count.load(Ordering::SeqCst), 1);
svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await.unwrap();
assert_eq!(call_count.load(Ordering::SeqCst), 1, "second call should hit cache");
}
#[tokio::test]
async fn cache_does_not_cache_streaming_requests() {
let config = CacheConfig {
backend: CacheBackend::default(),
max_entries: 10,
ttl: Duration::from_secs(60),
};
let layer = CacheLayer::new(config);
let client = MockClient::ok();
let call_count = Arc::clone(&client.call_count);
let inner = LlmService::new(client);
let mut svc = layer.layer(inner);
svc.call(LlmRequest::ChatStream(chat_req("gpt-4"))).await.unwrap();
svc.call(LlmRequest::ChatStream(chat_req("gpt-4"))).await.unwrap();
assert_eq!(call_count.load(Ordering::SeqCst), 2, "streaming should not be cached");
}
#[tokio::test]
async fn cache_evicts_oldest_when_full() {
let config = CacheConfig {
backend: CacheBackend::default(),
max_entries: 1,
ttl: Duration::from_secs(60),
};
let layer = CacheLayer::new(config);
let client = MockClient::ok();
let call_count = Arc::clone(&client.call_count);
let inner = LlmService::new(client);
let mut svc = layer.layer(inner);
svc.call(LlmRequest::Chat(chat_req("model-a"))).await.unwrap();
assert_eq!(call_count.load(Ordering::SeqCst), 1);
svc.call(LlmRequest::Chat(chat_req("model-b"))).await.unwrap();
assert_eq!(call_count.load(Ordering::SeqCst), 2);
svc.call(LlmRequest::Chat(chat_req("model-a"))).await.unwrap();
assert_eq!(
call_count.load(Ordering::SeqCst),
3,
"evicted entry should be a cache miss"
);
}
#[tokio::test]
async fn cache_different_requests_have_different_keys() {
let config = CacheConfig {
backend: CacheBackend::default(),
max_entries: 10,
ttl: Duration::from_secs(60),
};
let layer = CacheLayer::new(config);
let client = MockClient::ok();
let call_count = Arc::clone(&client.call_count);
let inner = LlmService::new(client);
let mut svc = layer.layer(inner);
svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await.unwrap();
svc.call(LlmRequest::Chat(chat_req("gpt-3.5-turbo"))).await.unwrap();
assert_eq!(
call_count.load(Ordering::SeqCst),
2,
"different models should be cache misses"
);
}
}