use std::time::Duration;
use axum::body::Body;
use axum::extract::{Request, State};
use axum::http::{Method, StatusCode, header};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use tracing::warn;
use super::classifier::{Classifier, ClassifyOutcome, ClassifyRequest};
use super::inject::{CommitGate, inject_cache_stats_into_response, strip_cache_control};
const COMMIT_DEADLINE: Duration = Duration::from_secs(30);
#[derive(Clone)]
pub struct CacheLayerState {
pub classifier: Classifier,
pub deadline: Duration,
pub body_limit: usize,
}
impl CacheLayerState {
pub fn new(classifier: Classifier, body_limit: usize) -> Self {
Self {
classifier,
deadline: Duration::from_secs(5),
body_limit,
}
}
}
fn is_cacheable(req: &Request) -> bool {
req.method() == Method::POST && req.uri().path().ends_with("/chat/completions")
}
pub async fn cache_middleware(State(state): State<CacheLayerState>, request: Request, next: Next) -> Response {
if !is_cacheable(&request) {
return next.run(request).await;
}
let (mut parts, body) = request.into_parts();
let body_bytes = match axum::body::to_bytes(body, state.body_limit).await {
Ok(b) => b,
Err(e) => {
warn!(error = %e, "Failed to read request body in cache middleware");
let body = serde_json::json!({
"error": {
"message": format!("failed to read request body: {e}"),
"type": "invalid_request_error",
"code": "body_read_failed",
}
});
return (StatusCode::BAD_REQUEST, axum::Json(body)).into_response();
}
};
let virtual_model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
.ok()
.and_then(|v| v.get("model").and_then(|m| m.as_str()).map(String::from));
let api_key = parts
.headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer ").or_else(|| v.strip_prefix("bearer ")))
.map(|t| t.trim().to_string());
let classify_handle = virtual_model.map(|model| {
let classifier = state.classifier.clone();
let body = body_bytes.to_vec();
tokio::spawn(async move {
classifier
.classify(ClassifyRequest {
virtual_model: &model,
body: &body,
api_key: api_key.as_deref(),
})
.await
})
});
let forward = strip_cache_control(&body_bytes).unwrap_or(body_bytes);
parts.headers.remove(header::TRANSFER_ENCODING);
parts
.headers
.insert(header::CONTENT_LENGTH, axum::http::HeaderValue::from(forward.len() as u64));
let response = next.run(Request::from_parts(parts, Body::from(forward))).await;
let outcome = match classify_handle {
Some(mut handle) => match tokio::time::timeout(state.deadline, &mut handle).await {
Ok(Ok(Ok(result))) => result,
Ok(Ok(Err(e))) => {
warn!(error = %e, "cache classify failed — billing un-cached");
ClassifyOutcome::inactive()
}
Ok(Err(e)) => {
warn!(error = %e, "cache classify task panicked");
ClassifyOutcome::inactive()
}
Err(_) => {
handle.abort(); ClassifyOutcome::inactive()
}
},
None => ClassifyOutcome::inactive(),
};
if !outcome.active {
return response;
}
let (response, gate) = inject_cache_stats_into_response(response, &outcome.stats).await;
if !outcome.pending.is_empty() {
let classifier = state.classifier.clone();
let pending = outcome.pending;
match gate {
CommitGate::Ready(true) => spawn_commit(classifier, pending),
CommitGate::Ready(false) => {}
CommitGate::Deferred(rx) => {
tokio::spawn(async move {
if rx.await.unwrap_or(false) {
commit_with_deadline(&classifier, &pending).await;
}
});
}
}
}
response
}
async fn commit_with_deadline(classifier: &Classifier, pending: &super::stats::PendingWrite) {
match tokio::time::timeout(COMMIT_DEADLINE, classifier.commit(pending)).await {
Ok(Ok(())) => {}
Ok(Err(e)) => warn!(error = %e, "cache index commit failed"),
Err(_) => warn!("cache index commit timed out"),
}
}
fn spawn_commit(classifier: Classifier, pending: super::stats::PendingWrite) {
tokio::spawn(async move {
commit_with_deadline(&classifier, &pending).await;
});
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::models::users::Role;
use crate::prompt_cache::{
CacheIndex, IndexScope, ModelConfigResolver, PostgresIndex, PrincipalResolver, TokenizerClient, parse_chat_completions,
};
use crate::test::utils::{create_test_api_key_for_user, create_test_endpoint, create_test_model, create_test_user};
use axum::middleware::from_fn_with_state;
use axum::routing::post;
use axum::{Json, Router};
use sqlx::PgPool;
use std::sync::Arc;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
const ALIAS: &str = "layer-model";
const TOK_VER: &str = "sha256:lv1";
async fn mock_upstream() -> Json<serde_json::Value> {
Json(serde_json::json!({
"id": "chatcmpl-1", "object": "chat.completion",
"choices": [{"index":0,"message":{"role":"assistant","content":"hi"},"finish_reason":"stop"}],
"usage": {"prompt_tokens": 2000, "completion_tokens": 2, "total_tokens": 2002}
}))
}
fn body() -> serde_json::Value {
serde_json::json!({
"model": ALIAS,
"messages": [
{"role":"system","content":[{"type":"text","text":"static system","cache_control":{"type":"ephemeral","ttl":"1h"}}]},
{"role":"user","content":"hi"}
]
})
}
#[sqlx::test]
async fn end_to_end_injects_then_reads(pool: PgPool) {
let user = create_test_user(&pool, Role::StandardUser).await;
let key = create_test_api_key_for_user(&pool, user.id).await;
let endpoint = create_test_endpoint(&pool, "ep", user.id).await;
let id = create_test_model(&pool, "m", ALIAS, endpoint, user.id).await;
sqlx::query!(
r#"INSERT INTO model_cache_tariffs
(deployed_model_id, write_multiplier_5m, write_multiplier_1h, write_multiplier_24h, min_prefix_tokens)
VALUES ($1, 1.25, 2.0, 2.5, 1024)"#,
id
)
.execute(&pool)
.await
.unwrap();
let tok = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/models"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"models": [{"alias": ALIAS, "hf_repo": "o/m", "tokenizer_version": TOK_VER}]
})))
.mount(&tok)
.await;
Mock::given(method("POST"))
.and(path("/v1/tokenize"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"virtual_model": ALIAS, "tokenizer_version": TOK_VER,
"segment_counts": [1500], "cumulative": [1500], "total": 1500
})))
.mount(&tok)
.await;
let classifier = Classifier::new(
PrincipalResolver::new(pool.clone()),
ModelConfigResolver::new(pool.clone()),
TokenizerClient::new(tok.uri()),
Arc::new(PostgresIndex::new(pool.clone())),
);
let app = Router::new()
.route("/v1/chat/completions", post(mock_upstream))
.layer(from_fn_with_state(CacheLayerState::new(classifier, usize::MAX), cache_middleware));
let server = axum_test::TestServer::new(app).unwrap();
let r1 = server
.post("/v1/chat/completions")
.add_header("authorization", format!("Bearer {}", key.secret))
.json(&body())
.await;
r1.assert_status_ok();
let v1: serde_json::Value = r1.json();
assert_eq!(v1["usage"]["prompt_tokens"], 2000, "upstream total preserved");
assert_eq!(v1["usage"]["cache_read_input_tokens"], 0);
assert_eq!(v1["usage"]["cache_creation_input_tokens"], 1500);
assert_eq!(v1["usage"]["prompt_tokens_details"]["cached_tokens"], 0);
let scope = IndexScope {
principal_id: user.id,
virtual_model: ALIAS.into(),
tokenizer_version: TOK_VER.into(),
};
let hash = parse_chat_completions(&serde_json::to_vec(&body()).unwrap())
.unwrap()
.cumulative_hashes[0]
.clone();
let idx = PostgresIndex::new(pool.clone());
let mut committed = false;
for _ in 0..100 {
if !idx.lookup(&scope, std::slice::from_ref(&hash)).await.unwrap().is_empty() {
committed = true;
break;
}
tokio::task::yield_now().await;
}
assert!(committed, "the write should have committed after a 2xx");
let r2 = server
.post("/v1/chat/completions")
.add_header("authorization", format!("Bearer {}", key.secret))
.json(&body())
.await;
let v2: serde_json::Value = r2.json();
assert_eq!(
v2["usage"]["cache_read_input_tokens"], 1500,
"second request reads the cached prefix"
);
assert_eq!(v2["usage"]["cache_creation_input_tokens"], 0);
}
#[sqlx::test]
async fn non_cacheable_path_passes_through(pool: PgPool) {
let classifier = Classifier::new(
PrincipalResolver::new(pool.clone()),
ModelConfigResolver::new(pool.clone()),
TokenizerClient::new("http://127.0.0.1:1"),
Arc::new(PostgresIndex::new(pool.clone())),
);
let app = Router::new()
.route("/v1/embeddings", post(mock_upstream))
.layer(from_fn_with_state(CacheLayerState::new(classifier, usize::MAX), cache_middleware));
let server = axum_test::TestServer::new(app).unwrap();
let r = server
.post("/v1/embeddings")
.json(&serde_json::json!({"model": "x", "input": "hi"}))
.await;
let v: serde_json::Value = r.json();
assert!(v["usage"].get("cache_read_input_tokens").is_none());
}
}