1use axum::{
2 body::Body,
3 extract::State,
4 http::{Request, StatusCode},
5 response::Response,
6};
7use serde_json::Value;
8
9use super::compress::compress_tool_result;
10use super::forward;
11use super::tool_kind::{self, should_protect, ToolResultKind};
12use super::ProxyState;
13
14pub async fn handler(
26 State(state): State<ProxyState>,
27 req: Request<Body>,
28) -> Result<Response, StatusCode> {
29 let upstream = state.openai_upstream.clone();
30 forward::forward_request(
31 State(state),
32 req,
33 &upstream,
34 "/v1/responses",
35 compress_request_body,
36 "OpenAI",
37 &[],
38 )
39 .await
40}
41
42fn compress_request_body(body: &[u8]) -> (Vec<u8>, usize, usize) {
43 let original_size = body.len();
44
45 let parsed: Value = match serde_json::from_slice(body) {
46 Ok(v) => v,
47 Err(_) => return (body.to_vec(), original_size, original_size),
48 };
49
50 let mut doc = parsed;
51 let mut modified = false;
52
53 if let Some(input) = doc.get_mut("input").and_then(|i| i.as_array_mut()) {
63 let tool_names = tool_kind::responses_tool_names(input);
64 for item in input.iter_mut() {
65 if item.get("type").and_then(|t| t.as_str()) != Some("function_call_output") {
66 continue;
67 }
68 let name = item
69 .get("call_id")
70 .and_then(|v| v.as_str())
71 .and_then(|id| tool_names.get(id))
72 .map(String::as_str);
73 let kind = name.map_or(ToolResultKind::Other, tool_kind::classify_tool_name);
74 if let Some(output) = item.get_mut("output") {
75 modified |= compress_output_field(output, name, kind);
76 }
77 }
78 }
79
80 if !modified {
81 return (body.to_vec(), original_size, original_size);
82 }
83
84 match serde_json::to_vec(&doc) {
85 Ok(compressed) => {
86 let compressed_size = compressed.len();
87 (compressed, original_size, compressed_size)
88 }
89 Err(_) => (body.to_vec(), original_size, original_size),
90 }
91}
92
93fn compress_output_field(
100 output: &mut Value,
101 tool_name: Option<&str>,
102 kind: ToolResultKind,
103) -> bool {
104 match output {
105 Value::String(s) => {
106 if should_protect(kind, s) {
107 return false;
108 }
109 let compressed = compress_tool_result(s, tool_name);
110 if compressed.len() < s.len() {
111 *s = compressed;
112 return true;
113 }
114 false
115 }
116 Value::Array(parts) => {
117 let mut changed = false;
118 for part in parts.iter_mut() {
119 if let Some(Value::String(text)) = part.get_mut("text") {
120 if should_protect(kind, text) {
121 continue;
122 }
123 let compressed = compress_tool_result(text, tool_name);
124 if compressed.len() < text.len() {
125 *text = compressed;
126 changed = true;
127 }
128 }
129 }
130 changed
131 }
132 _ => false,
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 fn long_git_status() -> String {
143 let mut s = String::from(
144 "$ git status\nOn branch main\nYour branch is up to date with 'origin/main'.\n\nChanges not staged for commit:\n (use \"git add <file>...\" to update what will be committed)\n",
145 );
146 for i in 0..80 {
147 s.push_str(&format!("\tmodified: src/module_{i}/file_{i}.rs\n"));
148 }
149 s.push_str("\nno changes added to commit (use \"git add\" and/or \"git commit -a\")\n");
150 s
151 }
152
153 #[test]
154 fn string_output_mirrors_engine_and_shrinks() {
155 let raw = long_git_status();
156 let expected = compress_tool_result(&raw, None);
157 assert!(
158 expected.len() < raw.len(),
159 "fixture must be compressible by the shared engine"
160 );
161
162 let body = serde_json::json!({
163 "model": "gpt-5",
164 "input": [
165 {"type": "function_call_output", "call_id": "call_1", "output": raw}
166 ]
167 });
168 let bytes = serde_json::to_vec(&body).unwrap();
169 let (out, orig, comp) = compress_request_body(&bytes);
170
171 assert!(comp < orig, "compressed body must be smaller");
172 let parsed: Value = serde_json::from_slice(&out).unwrap();
173 assert_eq!(
174 parsed["input"][0]["output"].as_str().unwrap(),
175 expected,
176 "output must be exactly what the shared compressor produces"
177 );
178 }
179
180 #[test]
181 fn array_output_text_is_compressed() {
182 let raw = long_git_status();
183 let expected = compress_tool_result(&raw, None);
184
185 let body = serde_json::json!({
186 "input": [
187 {
188 "type": "function_call_output",
189 "call_id": "call_1",
190 "output": [{"type": "input_text", "text": raw}]
191 }
192 ]
193 });
194 let bytes = serde_json::to_vec(&body).unwrap();
195 let (out, orig, comp) = compress_request_body(&bytes);
196
197 assert!(comp < orig);
198 let parsed: Value = serde_json::from_slice(&out).unwrap();
199 assert_eq!(
200 parsed["input"][0]["output"][0]["text"].as_str().unwrap(),
201 expected
202 );
203 }
204
205 #[test]
206 fn non_tool_output_items_are_untouched() {
207 let body = serde_json::json!({
210 "input": [
211 {"type": "message", "role": "user", "content": long_git_status()},
212 {"type": "function_call", "call_id": "c", "name": "x", "arguments": "{}"}
213 ]
214 });
215 let bytes = serde_json::to_vec(&body).unwrap();
216 let (out, orig, comp) = compress_request_body(&bytes);
217
218 assert_eq!(comp, orig, "no function_call_output → passthrough");
219 assert_eq!(out, bytes, "body must be byte-identical");
220 }
221
222 #[test]
223 fn plain_string_input_passthrough() {
224 let body = serde_json::json!({"model": "gpt-5", "input": "hello world"});
225 let bytes = serde_json::to_vec(&body).unwrap();
226 let (out, orig, comp) = compress_request_body(&bytes);
227 assert_eq!(comp, orig);
228 assert_eq!(out, bytes);
229 }
230
231 #[test]
232 fn no_input_field_passthrough() {
233 let body = serde_json::json!({"model": "gpt-5", "previous_response_id": "resp_abc"});
236 let bytes = serde_json::to_vec(&body).unwrap();
237 let (out, orig, comp) = compress_request_body(&bytes);
238 assert_eq!(comp, orig);
239 assert_eq!(out, bytes);
240 }
241
242 #[test]
243 fn invalid_json_passthrough() {
244 let bytes = b"this is not json".to_vec();
245 let (out, orig, comp) = compress_request_body(&bytes);
246 assert_eq!(comp, orig);
247 assert_eq!(out, bytes);
248 }
249
250 #[test]
251 fn short_output_unchanged() {
252 let body = serde_json::json!({
253 "input": [
254 {"type": "function_call_output", "call_id": "c", "output": "ok"}
255 ]
256 });
257 let bytes = serde_json::to_vec(&body).unwrap();
258 let (out, orig, comp) = compress_request_body(&bytes);
259 assert_eq!(comp, orig);
260 assert_eq!(out, bytes);
261 }
262}