use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use bytes::Bytes;
use serde_json::Value;
use tracing::error;
use super::stats::CacheStats;
fn remove_cache_control(value: &mut Value) -> (bool, bool) {
let mut rewrote = false;
let mut had_marker = false;
match value {
Value::Object(map) => {
if let Some(removed) = map.remove("cache_control") {
rewrote = true;
had_marker |= !removed.is_null();
}
for v in map.values_mut() {
let (r, h) = remove_cache_control(v);
rewrote |= r;
had_marker |= h;
}
}
Value::Array(items) => {
for v in items.iter_mut() {
let (r, h) = remove_cache_control(v);
rewrote |= r;
had_marker |= h;
}
}
_ => {}
}
(rewrote, had_marker)
}
pub fn strip_cache_control(body: &[u8]) -> (Option<Bytes>, bool) {
let Ok(mut json) = serde_json::from_slice::<Value>(body) else {
return (None, false);
};
let (rewrote, had_markers) = remove_cache_control(&mut json);
let mut usage_set = false;
if let Some(obj) = json.as_object_mut() {
let is_streaming = obj.get("stream").and_then(Value::as_bool) == Some(true);
if is_streaming {
let opts = obj.entry("stream_options").or_insert_with(|| serde_json::json!({}));
if let Some(opts_obj) = opts.as_object_mut() {
let already = opts_obj.get("include_usage").and_then(Value::as_bool) == Some(true);
if !already {
opts_obj.insert("include_usage".to_string(), serde_json::json!(true));
usage_set = true;
}
}
}
}
let body = if rewrote || usage_set {
serde_json::to_vec(&json).ok().map(Bytes::from)
} else {
None
};
(body, had_markers)
}
fn splice_cache_fields(usage: &mut serde_json::Map<String, Value>, stats: &CacheStats) {
let details = usage.entry("prompt_tokens_details").or_insert_with(|| serde_json::json!({}));
if let Some(details_obj) = details.as_object_mut() {
details_obj.insert("cached_tokens".to_string(), serde_json::json!(stats.read));
}
usage.insert("cache_read_input_tokens".to_string(), serde_json::json!(stats.read));
usage.insert("cache_creation_input_tokens".to_string(), serde_json::json!(stats.creation_total()));
usage.insert(
"cache_creation".to_string(),
serde_json::json!({
"ephemeral_5m_input_tokens": stats.creation_5m,
"ephemeral_1h_input_tokens": stats.creation_1h,
"ephemeral_24h_input_tokens": stats.creation_24h,
}),
);
}
pub fn inject_into_usage_json(body: &[u8], stats: &CacheStats) -> Option<Bytes> {
let mut json: Value = serde_json::from_slice(body).ok()?;
let obj = json.as_object_mut()?;
let usage = obj.get_mut("usage")?.as_object_mut()?;
splice_cache_fields(usage, stats);
serde_json::to_vec(&json).ok().map(Bytes::from)
}
pub(crate) struct SseScan {
pub rewritten: Option<Bytes>,
pub saw_error: bool,
pub saw_usage: bool,
}
pub(crate) fn scan_inject_sse(body: &[u8], stats: &CacheStats, already_edited: bool) -> SseScan {
let Ok(body_str) = std::str::from_utf8(body) else {
return SseScan {
rewritten: None,
saw_error: false,
saw_usage: false,
};
};
if already_edited {
let mut saw_error = false;
let mut saw_usage = false;
for line in body_str.split('\n') {
if let Some(chunk) = sse_data_json(line) {
saw_error |= chunk.get("error").is_some();
saw_usage |= chunk.get("usage").is_some_and(Value::is_object);
}
}
return SseScan {
rewritten: None,
saw_error,
saw_usage,
};
}
let mut out = String::with_capacity(body_str.len() + 256);
let mut edited = false;
let mut saw_error = false;
let mut saw_usage = false;
let mut first = true;
for line in body_str.split('\n') {
if !first {
out.push('\n');
}
first = false;
if let Some(mut chunk) = sse_data_json(line) {
let chunk_obj = chunk.as_object_mut().expect("sse_data_json returns only objects");
if chunk_obj.contains_key("error") {
saw_error = true;
}
if let Some(usage) = chunk_obj.get_mut("usage")
&& let Some(usage_obj) = usage.as_object_mut()
{
saw_usage = true;
if !edited {
let has_cr = line.ends_with('\r');
splice_cache_fields(usage_obj, stats);
if let Ok(reserialized) = serde_json::to_string(&chunk) {
out.push_str("data: ");
out.push_str(&reserialized);
if has_cr {
out.push('\r');
}
edited = true;
continue;
}
}
}
}
out.push_str(line);
}
SseScan {
rewritten: if edited { Some(Bytes::from(out)) } else { None },
saw_error,
saw_usage,
}
}
fn sse_data_json(line: &str) -> Option<Value> {
let data = line.strip_prefix("data:")?;
let trimmed = data.strip_prefix(' ').unwrap_or(data).trim();
if trimmed == "[DONE]" {
return None;
}
serde_json::from_str::<Value>(trimmed).ok().filter(Value::is_object)
}
pub fn inject_into_sse_body(body: &[u8], stats: &CacheStats) -> Option<Bytes> {
scan_inject_sse(body, stats, false).rewritten
}
pub async fn inject_into_response_nonstreaming(response: Response, stats: &CacheStats) -> (Response, bool) {
let status_ok = response.status().is_success();
let is_json = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|v| {
v.split(';')
.next()
.map(str::trim)
.is_some_and(|ct| ct.eq_ignore_ascii_case("application/json"))
})
.unwrap_or(true);
if !is_json {
return (response, false);
}
let (mut parts, body) = response.into_parts();
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(b) => b,
Err(e) => {
error!("Failed to buffer response body for cache injection: {}", e);
let err_body = serde_json::json!({
"error": {
"message": format!("failed to read upstream response body: {e}"),
"type": "internal_error",
"code": "response_body_read_failed",
}
});
return ((StatusCode::INTERNAL_SERVER_ERROR, axum::Json(err_body)).into_response(), false);
}
};
match inject_into_usage_json(&body_bytes, stats) {
Some(rewritten) => {
let len = rewritten.len();
parts.headers.remove(axum::http::header::TRANSFER_ENCODING);
parts.headers.remove(axum::http::header::CONTENT_ENCODING);
parts
.headers
.insert(axum::http::header::CONTENT_LENGTH, axum::http::HeaderValue::from(len as u64));
(Response::from_parts(parts, axum::body::Body::from(rewritten)), status_ok)
}
None => {
let len = body_bytes.len();
parts.headers.remove(axum::http::header::TRANSFER_ENCODING);
parts
.headers
.insert(axum::http::header::CONTENT_LENGTH, axum::http::HeaderValue::from(len as u64));
(Response::from_parts(parts, axum::body::Body::from(body_bytes)), false)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn stats() -> CacheStats {
CacheStats {
read: 1024,
creation_5m: 10,
creation_1h: 20,
creation_24h: 30,
}
}
#[test]
fn strip_removes_nested_cache_control_and_sets_include_usage() {
let body = serde_json::json!({
"stream": true,
"messages": [{"role":"system","content":[{"type":"text","text":"x","cache_control":{"type":"ephemeral"}}]}]
})
.to_string();
let (out, had_markers) = strip_cache_control(body.as_bytes());
assert!(had_markers, "body had cache_control");
let out = out.expect("changed");
let v: Value = serde_json::from_slice(&out).unwrap();
assert!(!out.windows(13).any(|w| w == b"cache_control"));
assert_eq!(v["stream_options"]["include_usage"], true);
}
#[test]
fn strip_none_when_nothing_to_do() {
let body = serde_json::json!({"messages":[{"role":"user","content":"hi"}]}).to_string();
let (out, had_markers) = strip_cache_control(body.as_bytes());
assert!(out.is_none());
assert!(!had_markers);
}
#[test]
fn strip_stream_without_markers_changes_body_but_not_marked() {
let body = serde_json::json!({"stream": true, "messages":[{"role":"user","content":"hi"}]}).to_string();
let (out, had_markers) = strip_cache_control(body.as_bytes());
assert!(out.is_some(), "include_usage injected");
assert!(!had_markers, "no markers present");
}
#[test]
fn strip_null_cache_control_is_removed_but_not_marked() {
let body = serde_json::json!({
"messages": [{"role":"system","content":[{"type":"text","text":"x","cache_control":null}]}]
})
.to_string();
let (out, had_markers) = strip_cache_control(body.as_bytes());
assert!(!had_markers, "null cache_control is not a marker");
let out = out.expect("body rewritten to drop the null cache_control key");
assert!(!out.windows(13).any(|w| w == b"cache_control"), "cache_control key removed");
}
#[test]
fn inject_non_streaming_adds_cache_fields() {
let body = serde_json::json!({"usage":{"prompt_tokens":2000,"completion_tokens":5}}).to_string();
let out = inject_into_usage_json(body.as_bytes(), &stats()).unwrap();
let v: Value = serde_json::from_slice(&out).unwrap();
assert_eq!(v["usage"]["prompt_tokens"], 2000, "total preserved");
assert_eq!(v["usage"]["prompt_tokens_details"]["cached_tokens"], 1024);
assert_eq!(v["usage"]["cache_read_input_tokens"], 1024);
assert_eq!(v["usage"]["cache_creation_input_tokens"], 60);
assert_eq!(v["usage"]["cache_creation"]["ephemeral_1h_input_tokens"], 20);
}
#[test]
fn inject_non_streaming_none_when_no_usage() {
let body = serde_json::json!({"choices":[]}).to_string();
assert!(inject_into_usage_json(body.as_bytes(), &stats()).is_none());
}
#[test]
fn inject_sse_preserves_crlf_on_edited_frame() {
let sse = "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":2000}}\r\n\r\ndata: [DONE]\r\n\r\n";
let out = inject_into_sse_body(sse.as_bytes(), &stats()).unwrap();
let s = std::str::from_utf8(&out).unwrap();
assert!(s.contains("\"cache_read_input_tokens\":1024"), "got: {s}");
assert!(s.contains("}\r\n\r\n"), "edited frame must keep CRLF framing, got: {s}");
assert!(!s.contains("}\n\r"), "must not produce a malformed \\n\\r, got: {s}");
}
#[test]
fn inject_sse_edits_only_terminal_usage_frame() {
let sse = "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: {\"choices\":[],\"usage\":{\"prompt_tokens\":2000}}\n\ndata: [DONE]\n\n";
let out = inject_into_sse_body(sse.as_bytes(), &stats()).unwrap();
let s = std::str::from_utf8(&out).unwrap();
assert!(s.contains("\"cached_tokens\":1024"));
assert!(s.contains("\"cache_read_input_tokens\":1024"));
assert_eq!(s.matches("cached_tokens").count(), 1);
assert!(s.contains("data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}"));
assert!(s.contains("data: [DONE]"));
}
#[test]
fn inject_sse_none_when_no_usage_frame() {
let sse = "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n\n";
assert!(inject_into_sse_body(sse.as_bytes(), &stats()).is_none());
}
#[test]
fn inject_sse_handles_data_prefix_without_space() {
let sse = "data:{\"choices\":[],\"usage\":{\"prompt_tokens\":2000}}\n\ndata:[DONE]\n\n";
let out = inject_into_sse_body(sse.as_bytes(), &stats()).expect("no-space data: frame is injected");
let s = std::str::from_utf8(&out).unwrap();
assert!(s.contains("\"cache_read_input_tokens\":1024"), "got: {s}");
}
#[test]
fn inject_into_sse_body_edits_the_usage_frame() {
let body = b"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: {\"choices\":[],\"usage\":{\"prompt_tokens\":2000}}\n\ndata: [DONE]\n\n";
let out = inject_into_sse_body(body, &stats()).expect("usage frame present → edited");
let s = std::str::from_utf8(&out).unwrap();
assert!(s.contains("\"cached_tokens\":1024"), "got: {s}");
assert!(s.contains("data: [DONE]"), "DONE preserved");
assert!(s.contains("\"content\":\"hi\""), "delta preserved");
}
#[test]
fn inject_into_sse_body_none_without_usage() {
let body = b"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n\n";
assert!(inject_into_sse_body(body, &stats()).is_none(), "no usage frame → nothing to edit");
}
#[tokio::test]
async fn inject_nonstreaming_error_body_vetoes_commit() {
use axum::body::Body;
let resp = Response::builder()
.status(400)
.header("content-type", "application/json")
.body(Body::from(serde_json::json!({"error":{"message":"bad request"}}).to_string()))
.unwrap();
let (_out, billing_ok) = inject_into_response_nonstreaming(resp, &stats()).await;
assert!(!billing_ok, "error body → no commit");
}
#[tokio::test]
async fn inject_nonstreaming_success_injects_and_allows_commit() {
use axum::body::Body;
let resp = Response::builder()
.status(200)
.header("content-type", "application/json")
.body(Body::from(serde_json::json!({"usage":{"prompt_tokens":2000}}).to_string()))
.unwrap();
let (out, billing_ok) = inject_into_response_nonstreaming(resp, &stats()).await;
assert!(billing_ok, "2xx with usage → commit allowed");
let collected = axum::body::to_bytes(out.into_body(), usize::MAX).await.unwrap();
let s = std::str::from_utf8(&collected).unwrap();
assert!(s.contains("\"cached_tokens\":1024"), "got: {s}");
}
}