use std::time::Duration;
use axum::body::Body;
use axum::extract::{Request, State};
use axum::http::{Method, StatusCode, header};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use tracing::warn;
use futures::StreamExt;
use http_body_util::BodyExt;
use super::classifier::{Classifier, ClassifyOutcome, ClassifyRequest};
use super::index::{CacheResult, TierPolicy};
use super::inject::{inject_into_response_nonstreaming, scan_inject_sse, strip_cache_control};
use super::metrics as cache_metrics;
use super::parse::{ParseError, validate_markers};
use super::sse::SseBufferedStream;
use super::stats::CacheStats;
const COMMIT_DEADLINE: Duration = Duration::from_secs(30);
#[derive(Clone)]
pub struct CacheLayerState {
pub classifier: Classifier,
pub deadline: Duration,
pub body_limit: usize,
}
impl CacheLayerState {
pub fn new(classifier: Classifier, body_limit: usize) -> Self {
Self {
classifier,
deadline: Duration::from_secs(5),
body_limit,
}
}
}
fn is_cacheable(req: &Request) -> bool {
req.method() == Method::POST && req.uri().path().ends_with("/chat/completions")
}
fn marker_rejection_response(e: &ParseError, policy: &TierPolicy) -> Response {
let reason = match e {
ParseError::DisabledTier(_) => Some("tier_disabled"),
ParseError::InvalidTtl(_) => Some("invalid_ttl"),
ParseError::UnsupportedType(_) => Some("unsupported_type"),
ParseError::TooManyBreakpoints => Some("too_many_breakpoints"),
ParseError::MalformedCacheControl => Some("malformed_cache_control"),
ParseError::Json(_) => None,
};
if let Some(r) = reason {
cache_metrics::record_markers_rejected(r);
}
let message = match e {
ParseError::DisabledTier(_) => format!("{e}; available tiers: {}", policy.enabled_strs().join(", ")),
_ => e.to_string(),
};
let body = serde_json::json!({
"error": {
"message": message,
"type": "invalid_request_error",
"code": "invalid_cache_control",
"param": "messages[].content[].cache_control",
}
});
(StatusCode::BAD_REQUEST, axum::Json(body)).into_response()
}
pub async fn cache_middleware(State(state): State<CacheLayerState>, request: Request, next: Next) -> Response {
if !is_cacheable(&request) {
return next.run(request).await;
}
let (mut parts, body) = request.into_parts();
let body_bytes = match axum::body::to_bytes(body, state.body_limit).await {
Ok(b) => b,
Err(e) => {
warn!(error = %e, "Failed to read request body in cache middleware");
cache_metrics::record_body_read_failed();
let body = serde_json::json!({
"error": {
"message": format!("failed to read request body: {e}"),
"type": "invalid_request_error",
"code": "body_read_failed",
}
});
return (StatusCode::BAD_REQUEST, axum::Json(body)).into_response();
}
};
let parsed_body = serde_json::from_slice::<serde_json::Value>(&body_bytes).ok();
if let Some(body) = &parsed_body
&& let Err(e) = validate_markers(body, state.classifier.tier_policy())
{
return marker_rejection_response(&e, state.classifier.tier_policy());
}
let virtual_model = parsed_body
.as_ref()
.and_then(|v| v.get("model").and_then(|m| m.as_str()).map(String::from));
let api_key = parts
.headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer ").or_else(|| v.strip_prefix("bearer ")))
.map(|t| t.trim().to_string());
let model_label = virtual_model.clone().unwrap_or_default();
let classify_handle = virtual_model.map(|model| {
let classifier = state.classifier.clone();
let body = body_bytes.to_vec();
tokio::spawn(async move {
classifier
.classify(ClassifyRequest {
virtual_model: &model,
body: &body,
api_key: api_key.as_deref(),
})
.await
})
});
let (stripped, had_markers) = strip_cache_control(&body_bytes);
cache_metrics::record_marker_request(had_markers);
let forward = stripped.unwrap_or(body_bytes);
parts.headers.remove(header::TRANSFER_ENCODING);
parts
.headers
.insert(header::CONTENT_LENGTH, axum::http::HeaderValue::from(forward.len() as u64));
let response = next.run(Request::from_parts(parts, Body::from(forward))).await;
let Some(mut handle) = classify_handle else {
cache_metrics::record_request_outcome("inactive");
return response;
};
if is_streaming(&response) {
return defer_classify_into_stream(response, handle, state.deadline, model_label, state.classifier.clone());
}
let outcome = join_classify(&mut handle, state.deadline, &model_label).await;
if !outcome.active {
return response;
}
let (response, billing_ok) = inject_into_response_nonstreaming(response, &outcome.stats).await;
if !outcome.pending.is_empty() {
if billing_ok {
spawn_commit(state.classifier.clone(), outcome.pending);
} else {
let reason = if response.status().is_success() { "no_usage" } else { "non_2xx" };
cache_metrics::record_commit_vetoed(reason);
}
}
response
}
fn is_streaming(response: &Response) -> bool {
response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(';').next())
.is_some_and(|ct| ct.trim().eq_ignore_ascii_case("text/event-stream"))
}
async fn join_classify(
handle: &mut tokio::task::JoinHandle<CacheResult<ClassifyOutcome>>,
deadline: Duration,
model_label: &str,
) -> ClassifyOutcome {
let outcome = match tokio::time::timeout(deadline, &mut *handle).await {
Ok(Ok(Ok(result))) => {
cache_metrics::record_classify("ok");
result
}
Ok(Ok(Err(e))) => {
cache_metrics::record_classify("error");
warn!(error = %e, "cache classify failed — billing un-cached");
ClassifyOutcome::inactive()
}
Ok(Err(e)) => {
cache_metrics::record_classify(if e.is_panic() { "panicked" } else { "error" });
warn!(error = %e, "cache classify task failed");
ClassifyOutcome::inactive()
}
Err(_) => {
cache_metrics::record_classify("deadline_exceeded");
handle.abort(); ClassifyOutcome::inactive()
}
};
cache_metrics::record_request_outcome(outcome_label(&outcome));
if outcome.active && !model_label.is_empty() {
cache_metrics::record_token_volumes(
model_label,
outcome.stats.read,
outcome.stats.creation_5m,
outcome.stats.creation_1h,
outcome.stats.creation_24h,
);
}
outcome
}
fn outcome_label(outcome: &ClassifyOutcome) -> &'static str {
if !outcome.active {
"inactive"
} else if outcome.stats.read > 0 && outcome.stats.creation_total() > 0 {
"read_and_create"
} else if outcome.stats.read > 0 {
"read"
} else if outcome.stats.creation_total() > 0 {
"create_only"
} else {
"zero_active"
}
}
struct AbortOnDrop<T>(Option<tokio::task::JoinHandle<T>>);
impl<T> AbortOnDrop<T> {
fn take(&mut self) -> Option<tokio::task::JoinHandle<T>> {
self.0.take()
}
fn as_mut(&mut self) -> Option<&mut tokio::task::JoinHandle<T>> {
self.0.as_mut()
}
}
impl<T> Drop for AbortOnDrop<T> {
fn drop(&mut self) {
if let Some(h) = self.0.take() {
h.abort();
cache_metrics::record_classify("abandoned");
cache_metrics::record_request_outcome("aborted");
}
}
}
fn defer_classify_into_stream(
response: Response,
handle: tokio::task::JoinHandle<CacheResult<ClassifyOutcome>>,
deadline: Duration,
model_label: String,
classifier: Classifier,
) -> Response {
let (parts, body) = response.into_parts();
let status_ok = parts.status.is_success();
let body_stream = BodyExt::into_data_stream(body).map(|r| r.map_err(std::io::Error::other));
let buffered = SseBufferedStream::new(body_stream);
let stream = async_stream::stream! {
futures::pin_mut!(buffered);
let mut handle = AbortOnDrop(Some(handle));
let mut outcome: Option<ClassifyOutcome> = None;
let mut edited = false;
let mut saw_error = false;
let mut saw_usage = false;
while let Some(item) = buffered.next().await {
let chunk = match item {
Ok(c) => c,
Err(e) => {
saw_error = true;
yield Err(e);
continue;
}
};
let probe = scan_inject_sse(&chunk, &CacheStats::default(), true);
saw_error |= probe.saw_error;
if probe.saw_usage && outcome.is_none() {
if let Some(h) = handle.as_mut() {
outcome = Some(join_classify(h, deadline, &model_label).await);
}
handle.take();
}
saw_usage |= probe.saw_usage;
let out = if !edited && probe.saw_usage && outcome.as_ref().is_some_and(|o| o.active) {
let stats = outcome.as_ref().map(|o| o.stats).unwrap_or_default();
let scan = scan_inject_sse(&chunk, &stats, false);
edited |= scan.rewritten.is_some();
scan.rewritten.unwrap_or(chunk)
} else {
chunk
};
yield Ok(out);
}
let outcome = match outcome {
Some(o) => o,
None => {
if let Some(h) = handle.as_mut() {
let o = join_classify(h, deadline, &model_label).await;
handle.take();
o
} else {
ClassifyOutcome::inactive()
}
}
};
if outcome.active && !outcome.pending.is_empty() {
if status_ok && !saw_error && saw_usage {
spawn_commit(classifier, outcome.pending);
} else {
let reason = if !status_ok {
"non_2xx"
} else if saw_error {
"error_frame"
} else {
"no_usage"
};
cache_metrics::record_commit_vetoed(reason);
}
}
};
let mut response = Response::from_parts(parts, Body::from_stream(stream));
response.headers_mut().remove(header::CONTENT_LENGTH);
response
}
async fn commit_with_deadline(classifier: &Classifier, pending: &super::stats::PendingWrite) {
let start = std::time::Instant::now();
let result = tokio::time::timeout(COMMIT_DEADLINE, classifier.commit(pending)).await;
cache_metrics::record_commit_duration(start.elapsed().as_secs_f64());
match result {
Ok(Ok(())) => cache_metrics::record_commit("ok"),
Ok(Err(e)) => {
cache_metrics::record_commit("error");
warn!(error = %e, "cache index commit failed");
}
Err(_) => {
cache_metrics::record_commit("timeout");
warn!("cache index commit timed out");
}
}
}
fn spawn_commit(classifier: Classifier, pending: super::stats::PendingWrite) {
tokio::spawn(async move {
commit_with_deadline(&classifier, &pending).await;
});
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::models::users::Role;
use crate::prompt_cache::{
CacheIndex, IndexScope, ModelConfigResolver, PostgresIndex, PrincipalResolver, TokenizerClient, parse_chat_completions,
};
use crate::test::utils::{create_test_api_key_for_user, create_test_endpoint, create_test_model, create_test_user};
use axum::middleware::from_fn_with_state;
use axum::routing::post;
use axum::{Json, Router};
use sqlx::PgPool;
use std::sync::Arc;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
const ALIAS: &str = "layer-model";
const TOK_VER: &str = "sha256:lv1";
fn all_tiers() -> TierPolicy {
TierPolicy::from_config(&["5m".to_string(), "1h".to_string(), "24h".to_string()], "5m")
}
async fn mock_upstream() -> Json<serde_json::Value> {
Json(serde_json::json!({
"id": "chatcmpl-1", "object": "chat.completion",
"choices": [{"index":0,"message":{"role":"assistant","content":"hi"},"finish_reason":"stop"}],
"usage": {"prompt_tokens": 2000, "completion_tokens": 2, "total_tokens": 2002}
}))
}
fn body() -> serde_json::Value {
serde_json::json!({
"model": ALIAS,
"messages": [
{"role":"system","content":[{"type":"text","text":"static system","cache_control":{"type":"ephemeral","ttl":"1h"}}]},
{"role":"user","content":"hi"}
]
})
}
#[sqlx::test]
async fn end_to_end_injects_then_reads(pool: PgPool) {
let user = create_test_user(&pool, Role::StandardUser).await;
let key = create_test_api_key_for_user(&pool, user.id).await;
let endpoint = create_test_endpoint(&pool, "ep", user.id).await;
let id = create_test_model(&pool, "m", ALIAS, endpoint, user.id).await;
sqlx::query!(
r#"INSERT INTO model_cache_tariffs
(deployed_model_id, write_multiplier_5m, write_multiplier_1h, write_multiplier_24h, min_prefix_tokens)
VALUES ($1, 1.25, 2.0, 2.5, 1024)"#,
id
)
.execute(&pool)
.await
.unwrap();
let tok = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/models"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"models": [{"alias": ALIAS, "hf_repo": "o/m", "tokenizer_version": TOK_VER}]
})))
.mount(&tok)
.await;
Mock::given(method("POST"))
.and(path("/v1/tokenize"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"virtual_model": ALIAS, "tokenizer_version": TOK_VER,
"segment_counts": [1500], "cumulative": [1500], "total": 1500
})))
.mount(&tok)
.await;
let classifier = Classifier::new(
PrincipalResolver::new(pool.clone()),
ModelConfigResolver::new(pool.clone()),
TokenizerClient::new(tok.uri()),
Arc::new(PostgresIndex::new(pool.clone())),
all_tiers(),
);
let app = Router::new()
.route("/v1/chat/completions", post(mock_upstream))
.layer(from_fn_with_state(CacheLayerState::new(classifier, usize::MAX), cache_middleware));
let server = axum_test::TestServer::new(app).unwrap();
let r1 = server
.post("/v1/chat/completions")
.add_header("authorization", format!("Bearer {}", key.secret))
.json(&body())
.await;
r1.assert_status_ok();
let v1: serde_json::Value = r1.json();
assert_eq!(v1["usage"]["prompt_tokens"], 2000, "upstream total preserved");
assert_eq!(v1["usage"]["cache_read_input_tokens"], 0);
assert_eq!(v1["usage"]["cache_creation_input_tokens"], 1500);
assert_eq!(v1["usage"]["prompt_tokens_details"]["cached_tokens"], 0);
let scope = IndexScope {
principal_id: user.id,
virtual_model: ALIAS.into(),
tokenizer_version: TOK_VER.into(),
};
let hash = parse_chat_completions(&serde_json::to_vec(&body()).unwrap(), &all_tiers())
.unwrap()
.cumulative_hashes[0]
.clone();
let idx = PostgresIndex::new(pool.clone());
let mut committed = false;
for _ in 0..100 {
if !idx.lookup(&scope, std::slice::from_ref(&hash)).await.unwrap().is_empty() {
committed = true;
break;
}
tokio::task::yield_now().await;
}
assert!(committed, "the write should have committed after a 2xx");
let r2 = server
.post("/v1/chat/completions")
.add_header("authorization", format!("Bearer {}", key.secret))
.json(&body())
.await;
let v2: serde_json::Value = r2.json();
assert_eq!(
v2["usage"]["cache_read_input_tokens"], 1500,
"second request reads the cached prefix"
);
assert_eq!(v2["usage"]["cache_creation_input_tokens"], 0);
}
async fn mock_upstream_streaming() -> Response {
let sse = "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n\
data: {\"choices\":[],\"usage\":{\"prompt_tokens\":2000,\"completion_tokens\":2,\"total_tokens\":2002}}\n\n\
data: [DONE]\n\n";
Response::builder()
.header("content-type", "text/event-stream")
.body(Body::from(sse))
.unwrap()
}
fn body_streaming() -> serde_json::Value {
serde_json::json!({
"model": ALIAS,
"stream": true,
"messages": [
{"role":"system","content":[{"type":"text","text":"static system","cache_control":{"type":"ephemeral","ttl":"1h"}}]},
{"role":"user","content":"hi"}
]
})
}
#[sqlx::test]
async fn streaming_defers_classify_then_injects_and_commits(pool: PgPool) {
let user = create_test_user(&pool, Role::StandardUser).await;
let key = create_test_api_key_for_user(&pool, user.id).await;
let endpoint = create_test_endpoint(&pool, "ep", user.id).await;
let id = create_test_model(&pool, "m", ALIAS, endpoint, user.id).await;
sqlx::query!(
r#"INSERT INTO model_cache_tariffs
(deployed_model_id, write_multiplier_5m, write_multiplier_1h, write_multiplier_24h, min_prefix_tokens)
VALUES ($1, 1.25, 2.0, 2.5, 1024)"#,
id
)
.execute(&pool)
.await
.unwrap();
let tok = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/models"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"models": [{"alias": ALIAS, "hf_repo": "o/m", "tokenizer_version": TOK_VER}]
})))
.mount(&tok)
.await;
Mock::given(method("POST"))
.and(path("/v1/tokenize"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"virtual_model": ALIAS, "tokenizer_version": TOK_VER,
"segment_counts": [1500], "cumulative": [1500], "total": 1500
})))
.mount(&tok)
.await;
let classifier = Classifier::new(
PrincipalResolver::new(pool.clone()),
ModelConfigResolver::new(pool.clone()),
TokenizerClient::new(tok.uri()),
Arc::new(PostgresIndex::new(pool.clone())),
all_tiers(),
);
let app = Router::new()
.route("/v1/chat/completions", post(mock_upstream_streaming))
.layer(from_fn_with_state(CacheLayerState::new(classifier, usize::MAX), cache_middleware));
let server = axum_test::TestServer::new(app).unwrap();
let r1 = server
.post("/v1/chat/completions")
.add_header("authorization", format!("Bearer {}", key.secret))
.json(&body_streaming())
.await;
r1.assert_status_ok();
let t1 = r1.text();
assert!(t1.contains("\"cache_creation_input_tokens\":1500"), "creation injected: {t1}");
assert!(t1.contains("\"cache_read_input_tokens\":0"), "no read on first sight: {t1}");
assert!(t1.contains("data: [DONE]"), "DONE preserved: {t1}");
assert!(t1.contains("\"content\":\"hi\""), "delta preserved: {t1}");
let scope = IndexScope {
principal_id: user.id,
virtual_model: ALIAS.into(),
tokenizer_version: TOK_VER.into(),
};
let hash = parse_chat_completions(&serde_json::to_vec(&body_streaming()).unwrap(), &all_tiers())
.unwrap()
.cumulative_hashes[0]
.clone();
let idx = PostgresIndex::new(pool.clone());
let mut committed = false;
for _ in 0..100 {
if !idx.lookup(&scope, std::slice::from_ref(&hash)).await.unwrap().is_empty() {
committed = true;
break;
}
tokio::task::yield_now().await;
}
assert!(committed, "streaming write commits after a clean usage frame");
let r2 = server
.post("/v1/chat/completions")
.add_header("authorization", format!("Bearer {}", key.secret))
.json(&body_streaming())
.await;
let t2 = r2.text();
assert!(
t2.contains("\"cache_read_input_tokens\":1500"),
"second stream reads the prefix: {t2}"
);
assert!(t2.contains("\"cache_creation_input_tokens\":0"), "no creation on a read: {t2}");
}
async fn mock_upstream_streaming_error() -> Response {
let sse = "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n\
data: {\"error\":{\"message\":\"upstream exploded\"}}\n\n";
Response::builder()
.header("content-type", "text/event-stream")
.body(Body::from(sse))
.unwrap()
}
#[sqlx::test]
async fn streaming_error_frame_vetoes_the_write(pool: PgPool) {
let user = create_test_user(&pool, Role::StandardUser).await;
let key = create_test_api_key_for_user(&pool, user.id).await;
let endpoint = create_test_endpoint(&pool, "ep", user.id).await;
let id = create_test_model(&pool, "m", ALIAS, endpoint, user.id).await;
sqlx::query!(
r#"INSERT INTO model_cache_tariffs
(deployed_model_id, write_multiplier_5m, write_multiplier_1h, write_multiplier_24h, min_prefix_tokens)
VALUES ($1, 1.25, 2.0, 2.5, 1024)"#,
id
)
.execute(&pool)
.await
.unwrap();
let tok = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/models"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"models": [{"alias": ALIAS, "hf_repo": "o/m", "tokenizer_version": TOK_VER}]
})))
.mount(&tok)
.await;
Mock::given(method("POST"))
.and(path("/v1/tokenize"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"virtual_model": ALIAS, "tokenizer_version": TOK_VER,
"segment_counts": [1500], "cumulative": [1500], "total": 1500
})))
.mount(&tok)
.await;
let classifier = Classifier::new(
PrincipalResolver::new(pool.clone()),
ModelConfigResolver::new(pool.clone()),
TokenizerClient::new(tok.uri()),
Arc::new(PostgresIndex::new(pool.clone())),
all_tiers(),
);
let app = Router::new()
.route("/v1/chat/completions", post(mock_upstream_streaming_error))
.layer(from_fn_with_state(CacheLayerState::new(classifier, usize::MAX), cache_middleware));
let server = axum_test::TestServer::new(app).unwrap();
let r = server
.post("/v1/chat/completions")
.add_header("authorization", format!("Bearer {}", key.secret))
.json(&body_streaming())
.await;
let _ = r.text();
let scope = IndexScope {
principal_id: user.id,
virtual_model: ALIAS.into(),
tokenizer_version: TOK_VER.into(),
};
let hash = parse_chat_completions(&serde_json::to_vec(&body_streaming()).unwrap(), &all_tiers())
.unwrap()
.cumulative_hashes[0]
.clone();
let idx = PostgresIndex::new(pool.clone());
for _ in 0..50 {
tokio::task::yield_now().await;
}
assert!(
idx.lookup(&scope, std::slice::from_ref(&hash)).await.unwrap().is_empty(),
"an unbilled stream (error frame, no usage) must not commit a write"
);
}
#[sqlx::test]
async fn non_cacheable_path_passes_through(pool: PgPool) {
let classifier = Classifier::new(
PrincipalResolver::new(pool.clone()),
ModelConfigResolver::new(pool.clone()),
TokenizerClient::new("http://127.0.0.1:1"),
Arc::new(PostgresIndex::new(pool.clone())),
all_tiers(),
);
let app = Router::new()
.route("/v1/embeddings", post(mock_upstream))
.layer(from_fn_with_state(CacheLayerState::new(classifier, usize::MAX), cache_middleware));
let server = axum_test::TestServer::new(app).unwrap();
let r = server
.post("/v1/embeddings")
.json(&serde_json::json!({"model": "x", "input": "hi"}))
.await;
let v: serde_json::Value = r.json();
assert!(v["usage"].get("cache_read_input_tokens").is_none());
}
#[sqlx::test]
async fn disabled_tier_marker_rejected_with_400(pool: PgPool) {
let classifier = Classifier::new(
PrincipalResolver::new(pool.clone()),
ModelConfigResolver::new(pool.clone()),
TokenizerClient::new("http://127.0.0.1:1"),
Arc::new(PostgresIndex::new(pool.clone())),
TierPolicy::from_config(&["5m".to_string()], "5m"),
);
let app = Router::new()
.route("/v1/chat/completions", post(mock_upstream)) .layer(from_fn_with_state(CacheLayerState::new(classifier, usize::MAX), cache_middleware));
let server = axum_test::TestServer::new(app).unwrap();
let r = server
.post("/v1/chat/completions")
.add_header("authorization", "Bearer anything")
.json(&serde_json::json!({
"model": ALIAS,
"messages": [{"role": "system", "content": [
{"type": "text", "text": "x", "cache_control": {"type": "ephemeral", "ttl": "24h"}}
]}]
}))
.await;
r.assert_status(StatusCode::BAD_REQUEST);
let v: serde_json::Value = r.json();
assert_eq!(v["error"]["code"], "invalid_cache_control");
let msg = v["error"]["message"].as_str().unwrap();
assert!(msg.contains("24h"), "message names the rejected tier: {msg}");
assert!(msg.contains("available tiers: 5m"), "message names the available tiers: {msg}");
}
}