use axum::{
body::Body,
extract::State,
http::{Request, StatusCode},
response::Response,
};
use serde_json::Value;
use super::compress::compress_tool_result;
use super::forward;
use super::tool_kind::{self, should_protect, ToolResultKind};
use super::ProxyState;
pub async fn handler(
State(state): State<ProxyState>,
req: Request<Body>,
) -> Result<Response, StatusCode> {
let upstream = state.openai_upstream.clone();
forward::forward_request(
State(state),
req,
&upstream,
"/v1/responses",
compress_request_body,
"OpenAI",
&[],
)
.await
}
fn compress_request_body(parsed: Value, original_size: usize) -> (Vec<u8>, usize, usize) {
let mut doc = parsed;
let mut modified = false;
if let Some(input) = doc.get_mut("input").and_then(|i| i.as_array_mut()) {
let tool_names = tool_kind::responses_tool_names(input);
for item in input.iter_mut() {
if item.get("type").and_then(|t| t.as_str()) != Some("function_call_output") {
continue;
}
let name = item
.get("call_id")
.and_then(|v| v.as_str())
.and_then(|id| tool_names.get(id))
.map(String::as_str);
let kind = name.map_or(ToolResultKind::Other, tool_kind::classify_tool_name);
if let Some(output) = item.get_mut("output") {
modified |= compress_output_field(output, name, kind);
}
}
}
let out = serde_json::to_vec(&doc).unwrap_or_default();
let compressed_size = if modified { out.len() } else { original_size };
(out, original_size, compressed_size)
}
fn compress_output_field(
output: &mut Value,
tool_name: Option<&str>,
kind: ToolResultKind,
) -> bool {
match output {
Value::String(s) => {
if should_protect(kind, s) {
return false;
}
let compressed = compress_tool_result(s, tool_name);
if compressed.len() < s.len() {
*s = compressed;
return true;
}
false
}
Value::Array(parts) => {
let mut changed = false;
for part in parts.iter_mut() {
if let Some(Value::String(text)) = part.get_mut("text") {
if should_protect(kind, text) {
continue;
}
let compressed = compress_tool_result(text, tool_name);
if compressed.len() < text.len() {
*text = compressed;
changed = true;
}
}
}
changed
}
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn long_git_status() -> String {
let mut s = String::from(
"$ git status\nOn branch main\nYour branch is up to date with 'origin/main'.\n\nChanges not staged for commit:\n (use \"git add <file>...\" to update what will be committed)\n",
);
for i in 0..80 {
s.push_str(&format!("\tmodified: src/module_{i}/file_{i}.rs\n"));
}
s.push_str("\nno changes added to commit (use \"git add\" and/or \"git commit -a\")\n");
s
}
#[test]
fn string_output_mirrors_engine_and_shrinks() {
let raw = long_git_status();
let expected = compress_tool_result(&raw, None);
assert!(
expected.len() < raw.len(),
"fixture must be compressible by the shared engine"
);
let body = serde_json::json!({
"model": "gpt-5",
"input": [
{"type": "function_call_output", "call_id": "call_1", "output": raw}
]
});
let bytes = serde_json::to_vec(&body).unwrap();
let (out, orig, comp) = compress_request_body(body, bytes.len());
assert!(comp < orig, "compressed body must be smaller");
let parsed: Value = serde_json::from_slice(&out).unwrap();
assert_eq!(
parsed["input"][0]["output"].as_str().unwrap(),
expected,
"output must be exactly what the shared compressor produces"
);
}
#[test]
fn array_output_text_is_compressed() {
let raw = long_git_status();
let expected = compress_tool_result(&raw, None);
let body = serde_json::json!({
"input": [
{
"type": "function_call_output",
"call_id": "call_1",
"output": [{"type": "input_text", "text": raw}]
}
]
});
let bytes = serde_json::to_vec(&body).unwrap();
let (out, orig, comp) = compress_request_body(body, bytes.len());
assert!(comp < orig);
let parsed: Value = serde_json::from_slice(&out).unwrap();
assert_eq!(
parsed["input"][0]["output"][0]["text"].as_str().unwrap(),
expected
);
}
#[test]
fn non_tool_output_items_are_untouched() {
let body = serde_json::json!({
"input": [
{"type": "message", "role": "user", "content": long_git_status()},
{"type": "function_call", "call_id": "c", "name": "x", "arguments": "{}"}
]
});
let bytes = serde_json::to_vec(&body).unwrap();
let (out, orig, comp) = compress_request_body(body.clone(), bytes.len());
assert_eq!(comp, orig, "no function_call_output → passthrough");
let reparsed: Value = serde_json::from_slice(&out).unwrap();
assert_eq!(reparsed, body);
}
#[test]
fn plain_string_input_passthrough() {
let body = serde_json::json!({"model": "gpt-5", "input": "hello world"});
let bytes = serde_json::to_vec(&body).unwrap();
let (out, orig, comp) = compress_request_body(body.clone(), bytes.len());
assert_eq!(comp, orig);
let reparsed: Value = serde_json::from_slice(&out).unwrap();
assert_eq!(reparsed, body);
}
#[test]
fn no_input_field_passthrough() {
let body = serde_json::json!({"model": "gpt-5", "previous_response_id": "resp_abc"});
let bytes = serde_json::to_vec(&body).unwrap();
let (out, orig, comp) = compress_request_body(body.clone(), bytes.len());
assert_eq!(comp, orig);
let reparsed: Value = serde_json::from_slice(&out).unwrap();
assert_eq!(reparsed, body);
}
#[test]
fn short_output_unchanged() {
let body = serde_json::json!({
"input": [
{"type": "function_call_output", "call_id": "c", "output": "ok"}
]
});
let bytes = serde_json::to_vec(&body).unwrap();
let (out, orig, comp) = compress_request_body(body.clone(), bytes.len());
assert_eq!(comp, orig);
let reparsed: Value = serde_json::from_slice(&out).unwrap();
assert_eq!(reparsed, body);
}
}