use std::pin::Pin;
use std::task::{Context, Poll};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use bytes::Bytes;
use futures::Stream;
use http_body_util::BodyExt;
use serde_json::Value;
use tokio::sync::oneshot;
use tracing::error;
use super::sse::SseBufferedStream;
use super::stats::CacheStats;
pub enum CommitGate {
Ready(bool),
Deferred(oneshot::Receiver<bool>),
}
fn remove_cache_control(value: &mut Value) -> bool {
let mut removed = false;
match value {
Value::Object(map) => {
if map.remove("cache_control").is_some() {
removed = true;
}
for v in map.values_mut() {
removed |= remove_cache_control(v);
}
}
Value::Array(items) => {
for v in items.iter_mut() {
removed |= remove_cache_control(v);
}
}
_ => {}
}
removed
}
pub fn strip_cache_control(body: &[u8]) -> Option<Bytes> {
let mut json: Value = serde_json::from_slice(body).ok()?;
let stripped = remove_cache_control(&mut json);
let mut usage_set = false;
if let Some(obj) = json.as_object_mut() {
let is_streaming = obj.get("stream").and_then(Value::as_bool) == Some(true);
if is_streaming {
let opts = obj.entry("stream_options").or_insert_with(|| serde_json::json!({}));
if let Some(opts_obj) = opts.as_object_mut() {
let already = opts_obj.get("include_usage").and_then(Value::as_bool) == Some(true);
if !already {
opts_obj.insert("include_usage".to_string(), serde_json::json!(true));
usage_set = true;
}
}
}
}
if stripped || usage_set {
serde_json::to_vec(&json).ok().map(Bytes::from)
} else {
None
}
}
fn splice_cache_fields(usage: &mut serde_json::Map<String, Value>, stats: &CacheStats) {
let details = usage.entry("prompt_tokens_details").or_insert_with(|| serde_json::json!({}));
if let Some(details_obj) = details.as_object_mut() {
details_obj.insert("cached_tokens".to_string(), serde_json::json!(stats.read));
}
usage.insert("cache_read_input_tokens".to_string(), serde_json::json!(stats.read));
usage.insert("cache_creation_input_tokens".to_string(), serde_json::json!(stats.creation_total()));
usage.insert(
"cache_creation".to_string(),
serde_json::json!({
"ephemeral_5m_input_tokens": stats.creation_5m,
"ephemeral_1h_input_tokens": stats.creation_1h,
"ephemeral_24h_input_tokens": stats.creation_24h,
}),
);
}
pub fn inject_into_usage_json(body: &[u8], stats: &CacheStats) -> Option<Bytes> {
let mut json: Value = serde_json::from_slice(body).ok()?;
let obj = json.as_object_mut()?;
let usage = obj.get_mut("usage")?.as_object_mut()?;
splice_cache_fields(usage, stats);
serde_json::to_vec(&json).ok().map(Bytes::from)
}
struct SseScan {
rewritten: Option<Bytes>,
saw_error: bool,
saw_usage: bool,
}
fn scan_inject_sse(body: &[u8], stats: &CacheStats, already_edited: bool) -> SseScan {
let Ok(body_str) = std::str::from_utf8(body) else {
return SseScan {
rewritten: None,
saw_error: false,
saw_usage: false,
};
};
let mut out = String::with_capacity(body_str.len() + 256);
let mut edited = false;
let mut saw_error = false;
let mut saw_usage = false;
let mut first = true;
for line in body_str.split('\n') {
if !first {
out.push('\n');
}
first = false;
if let Some(data) = line.strip_prefix("data:") {
let data = data.strip_prefix(' ').unwrap_or(data);
let trimmed = data.trim();
if trimmed != "[DONE]"
&& let Ok(mut chunk) = serde_json::from_str::<Value>(trimmed)
&& let Some(chunk_obj) = chunk.as_object_mut()
{
if chunk_obj.contains_key("error") {
saw_error = true;
}
if let Some(usage) = chunk_obj.get_mut("usage")
&& let Some(usage_obj) = usage.as_object_mut()
{
saw_usage = true;
if !already_edited && !edited {
let has_cr = line.ends_with('\r');
splice_cache_fields(usage_obj, stats);
if let Ok(reserialized) = serde_json::to_string(&chunk) {
out.push_str("data: ");
out.push_str(&reserialized);
if has_cr {
out.push('\r');
}
edited = true;
continue;
}
}
}
}
}
out.push_str(line);
}
SseScan {
rewritten: if edited { Some(Bytes::from(out)) } else { None },
saw_error,
saw_usage,
}
}
pub fn inject_into_sse_body(body: &[u8], stats: &CacheStats) -> Option<Bytes> {
scan_inject_sse(body, stats, false).rewritten
}
struct VerdictStream {
inner: Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>,
stats: CacheStats,
edited: bool,
status_ok: bool,
saw_error: bool,
saw_usage: bool,
tx: Option<oneshot::Sender<bool>>,
}
impl Stream for VerdictStream {
type Item = Result<Bytes, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut(); match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
let scan = scan_inject_sse(&chunk, &this.stats, this.edited);
this.saw_error |= scan.saw_error;
this.saw_usage |= scan.saw_usage;
match scan.rewritten {
Some(rewritten) => {
this.edited = true;
Poll::Ready(Some(Ok(rewritten)))
}
None => Poll::Ready(Some(Ok(chunk))),
}
}
Poll::Ready(Some(Err(e))) => {
this.saw_error = true;
Poll::Ready(Some(Err(e)))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl Drop for VerdictStream {
fn drop(&mut self) {
if let Some(tx) = self.tx.take() {
let _ = tx.send(self.status_ok && !self.saw_error && self.saw_usage);
}
}
}
pub async fn inject_cache_stats_into_response(mut response: Response, stats: &CacheStats) -> (Response, CommitGate) {
let status_ok = response.status().is_success();
let is_sse = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|v| v.contains("text/event-stream"))
.unwrap_or(false);
if is_sse {
use futures::StreamExt;
let (tx, rx) = oneshot::channel();
let body_stream = BodyExt::into_data_stream(std::mem::take(response.body_mut()));
let buffered = SseBufferedStream::new(body_stream.map(|r| r.map_err(std::io::Error::other)));
let transformed = VerdictStream {
inner: Box::pin(buffered),
stats: *stats,
edited: false,
status_ok,
saw_error: false,
saw_usage: false,
tx: Some(tx),
};
*response.body_mut() = axum::body::Body::from_stream(transformed);
response.headers_mut().remove(axum::http::header::CONTENT_LENGTH);
(response, CommitGate::Deferred(rx))
} else {
let is_json = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|v| v.contains("application/json"))
.unwrap_or(true);
if !is_json {
return (response, CommitGate::Ready(false));
}
let (mut parts, body) = response.into_parts();
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(b) => b,
Err(e) => {
error!("Failed to buffer response body for cache injection: {}", e);
let err_body = serde_json::json!({
"error": {
"message": format!("failed to read upstream response body: {e}"),
"type": "internal_error",
"code": "response_body_read_failed",
}
});
return (
(StatusCode::INTERNAL_SERVER_ERROR, axum::Json(err_body)).into_response(),
CommitGate::Ready(false),
);
}
};
match inject_into_usage_json(&body_bytes, stats) {
Some(rewritten) => {
let len = rewritten.len();
parts.headers.remove(axum::http::header::TRANSFER_ENCODING);
parts.headers.remove(axum::http::header::CONTENT_ENCODING);
parts
.headers
.insert(axum::http::header::CONTENT_LENGTH, axum::http::HeaderValue::from(len as u64));
(
Response::from_parts(parts, axum::body::Body::from(rewritten)),
CommitGate::Ready(status_ok),
)
}
None => {
let len = body_bytes.len();
parts.headers.remove(axum::http::header::TRANSFER_ENCODING);
parts
.headers
.insert(axum::http::header::CONTENT_LENGTH, axum::http::HeaderValue::from(len as u64));
(
Response::from_parts(parts, axum::body::Body::from(body_bytes)),
CommitGate::Ready(false),
)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn stats() -> CacheStats {
CacheStats {
read: 1024,
creation_5m: 10,
creation_1h: 20,
creation_24h: 30,
}
}
#[test]
fn strip_removes_nested_cache_control_and_sets_include_usage() {
let body = serde_json::json!({
"stream": true,
"messages": [{"role":"system","content":[{"type":"text","text":"x","cache_control":{"type":"ephemeral"}}]}]
})
.to_string();
let out = strip_cache_control(body.as_bytes()).expect("changed");
let v: Value = serde_json::from_slice(&out).unwrap();
assert!(!out.windows(13).any(|w| w == b"cache_control"));
assert_eq!(v["stream_options"]["include_usage"], true);
}
#[test]
fn strip_none_when_nothing_to_do() {
let body = serde_json::json!({"messages":[{"role":"user","content":"hi"}]}).to_string();
assert!(strip_cache_control(body.as_bytes()).is_none());
}
#[test]
fn inject_non_streaming_adds_cache_fields() {
let body = serde_json::json!({"usage":{"prompt_tokens":2000,"completion_tokens":5}}).to_string();
let out = inject_into_usage_json(body.as_bytes(), &stats()).unwrap();
let v: Value = serde_json::from_slice(&out).unwrap();
assert_eq!(v["usage"]["prompt_tokens"], 2000, "total preserved");
assert_eq!(v["usage"]["prompt_tokens_details"]["cached_tokens"], 1024);
assert_eq!(v["usage"]["cache_read_input_tokens"], 1024);
assert_eq!(v["usage"]["cache_creation_input_tokens"], 60);
assert_eq!(v["usage"]["cache_creation"]["ephemeral_1h_input_tokens"], 20);
}
#[test]
fn inject_non_streaming_none_when_no_usage() {
let body = serde_json::json!({"choices":[]}).to_string();
assert!(inject_into_usage_json(body.as_bytes(), &stats()).is_none());
}
#[test]
fn inject_sse_preserves_crlf_on_edited_frame() {
let sse = "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":2000}}\r\n\r\ndata: [DONE]\r\n\r\n";
let out = inject_into_sse_body(sse.as_bytes(), &stats()).unwrap();
let s = std::str::from_utf8(&out).unwrap();
assert!(s.contains("\"cache_read_input_tokens\":1024"), "got: {s}");
assert!(s.contains("}\r\n\r\n"), "edited frame must keep CRLF framing, got: {s}");
assert!(!s.contains("}\n\r"), "must not produce a malformed \\n\\r, got: {s}");
}
#[test]
fn inject_sse_edits_only_terminal_usage_frame() {
let sse = "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: {\"choices\":[],\"usage\":{\"prompt_tokens\":2000}}\n\ndata: [DONE]\n\n";
let out = inject_into_sse_body(sse.as_bytes(), &stats()).unwrap();
let s = std::str::from_utf8(&out).unwrap();
assert!(s.contains("\"cached_tokens\":1024"));
assert!(s.contains("\"cache_read_input_tokens\":1024"));
assert_eq!(s.matches("cached_tokens").count(), 1);
assert!(s.contains("data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}"));
assert!(s.contains("data: [DONE]"));
}
#[test]
fn inject_sse_none_when_no_usage_frame() {
let sse = "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n\n";
assert!(inject_into_sse_body(sse.as_bytes(), &stats()).is_none());
}
#[test]
fn inject_sse_handles_data_prefix_without_space() {
let sse = "data:{\"choices\":[],\"usage\":{\"prompt_tokens\":2000}}\n\ndata:[DONE]\n\n";
let out = inject_into_sse_body(sse.as_bytes(), &stats()).expect("no-space data: frame is injected");
let s = std::str::from_utf8(&out).unwrap();
assert!(s.contains("\"cache_read_input_tokens\":1024"), "got: {s}");
}
#[tokio::test]
async fn inject_response_streaming_edits_split_usage_frame() {
use axum::body::Body;
let chunks: Vec<Result<Bytes, std::io::Error>> = vec![
Ok(Bytes::from_static(b"data: {\"choices\":[],\"usage\":{\"prompt_")),
Ok(Bytes::from_static(b"tokens\":2000}}\n\ndata: [DONE]\n\n")),
];
let resp = Response::builder()
.header("content-type", "text/event-stream")
.body(Body::from_stream(futures::stream::iter(chunks)))
.unwrap();
let (out, gate) = inject_cache_stats_into_response(resp, &stats()).await;
let collected = axum::body::to_bytes(out.into_body(), usize::MAX).await.unwrap();
let s = std::str::from_utf8(&collected).unwrap();
assert!(s.contains("\"cached_tokens\":1024"), "got: {s}");
match gate {
CommitGate::Deferred(rx) => assert!(rx.await.unwrap(), "clean stream → commit"),
CommitGate::Ready(_) => panic!("streaming response must yield a deferred gate"),
}
}
#[tokio::test]
async fn inject_response_streaming_error_frame_vetoes_commit() {
use axum::body::Body;
let chunks: Vec<Result<Bytes, std::io::Error>> = vec![
Ok(Bytes::from_static(b"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n")),
Ok(Bytes::from_static(b"data: {\"error\":{\"message\":\"upstream exploded\"}}\n\n")),
];
let resp = Response::builder()
.header("content-type", "text/event-stream")
.body(Body::from_stream(futures::stream::iter(chunks)))
.unwrap();
let (out, gate) = inject_cache_stats_into_response(resp, &stats()).await;
let _ = axum::body::to_bytes(out.into_body(), usize::MAX).await.unwrap();
match gate {
CommitGate::Deferred(rx) => assert!(!rx.await.unwrap(), "error frame → veto"),
CommitGate::Ready(_) => panic!("streaming response must yield a deferred gate"),
}
}
#[tokio::test]
async fn inject_response_streaming_disconnect_vetoes_commit() {
use axum::body::Body;
let chunks: Vec<Result<Bytes, std::io::Error>> = vec![
Ok(Bytes::from_static(b"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n")),
Ok(Bytes::from_static(
b"data: {\"choices\":[],\"usage\":{\"prompt_tokens\":2000}}\n\ndata: [DONE]\n\n",
)),
];
let resp = Response::builder()
.header("content-type", "text/event-stream")
.body(Body::from_stream(futures::stream::iter(chunks)))
.unwrap();
let (out, gate) = inject_cache_stats_into_response(resp, &stats()).await;
drop(out);
match gate {
CommitGate::Deferred(rx) => assert!(!rx.await.unwrap(), "disconnect before terminal frame → veto"),
CommitGate::Ready(_) => panic!("streaming response must yield a deferred gate"),
}
}
#[tokio::test]
async fn inject_non_streaming_error_body_vetoes_commit() {
use axum::body::Body;
let resp = Response::builder()
.status(400)
.header("content-type", "application/json")
.body(Body::from(serde_json::json!({"error":{"message":"bad request"}}).to_string()))
.unwrap();
let (_out, gate) = inject_cache_stats_into_response(resp, &stats()).await;
match gate {
CommitGate::Ready(ok) => assert!(!ok, "error body → no commit"),
CommitGate::Deferred(_) => panic!("non-streaming response must yield a ready gate"),
}
}
}