lean_ctx/proxy/
openai_responses.rs1use axum::{
2 body::Body,
3 extract::State,
4 http::{Request, StatusCode},
5 response::Response,
6};
7use serde_json::Value;
8
9use super::compress::compress_tool_result;
10use super::forward;
11use super::tool_kind::{self, should_protect, ToolResultKind};
12use super::ProxyState;
13
14pub async fn handler(
26 State(state): State<ProxyState>,
27 req: Request<Body>,
28) -> Result<Response, StatusCode> {
29 let upstream = state.openai_upstream.clone();
30 forward::forward_request(
31 State(state),
32 req,
33 &upstream,
34 "/v1/responses",
35 compress_request_body,
36 "OpenAI",
37 &[],
38 )
39 .await
40}
41
42fn compress_request_body(parsed: Value, original_size: usize) -> (Vec<u8>, usize, usize) {
43 let mut doc = parsed;
44 let mut modified = false;
45
46 if let Some(input) = doc.get_mut("input").and_then(|i| i.as_array_mut()) {
56 let tool_names = tool_kind::responses_tool_names(input);
57 for item in input.iter_mut() {
58 if item.get("type").and_then(|t| t.as_str()) != Some("function_call_output") {
59 continue;
60 }
61 let name = item
62 .get("call_id")
63 .and_then(|v| v.as_str())
64 .and_then(|id| tool_names.get(id))
65 .map(String::as_str);
66 let kind = name.map_or(ToolResultKind::Other, tool_kind::classify_tool_name);
67 if let Some(output) = item.get_mut("output") {
68 modified |= compress_output_field(output, name, kind);
69 }
70 }
71 }
72
73 let out = serde_json::to_vec(&doc).unwrap_or_default();
74 let compressed_size = if modified { out.len() } else { original_size };
75 (out, original_size, compressed_size)
76}
77
78fn compress_output_field(
85 output: &mut Value,
86 tool_name: Option<&str>,
87 kind: ToolResultKind,
88) -> bool {
89 match output {
90 Value::String(s) => {
91 if should_protect(kind, s) {
92 return false;
93 }
94 let compressed = compress_tool_result(s, tool_name);
95 if compressed.len() < s.len() {
96 *s = compressed;
97 return true;
98 }
99 false
100 }
101 Value::Array(parts) => {
102 let mut changed = false;
103 for part in parts.iter_mut() {
104 if let Some(Value::String(text)) = part.get_mut("text") {
105 if should_protect(kind, text) {
106 continue;
107 }
108 let compressed = compress_tool_result(text, tool_name);
109 if compressed.len() < text.len() {
110 *text = compressed;
111 changed = true;
112 }
113 }
114 }
115 changed
116 }
117 _ => false,
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 fn long_git_status() -> String {
128 let mut s = String::from(
129 "$ git status\nOn branch main\nYour branch is up to date with 'origin/main'.\n\nChanges not staged for commit:\n (use \"git add <file>...\" to update what will be committed)\n",
130 );
131 for i in 0..80 {
132 s.push_str(&format!("\tmodified: src/module_{i}/file_{i}.rs\n"));
133 }
134 s.push_str("\nno changes added to commit (use \"git add\" and/or \"git commit -a\")\n");
135 s
136 }
137
138 #[test]
139 fn string_output_mirrors_engine_and_shrinks() {
140 let raw = long_git_status();
141 let expected = compress_tool_result(&raw, None);
142 assert!(
143 expected.len() < raw.len(),
144 "fixture must be compressible by the shared engine"
145 );
146
147 let body = serde_json::json!({
148 "model": "gpt-5",
149 "input": [
150 {"type": "function_call_output", "call_id": "call_1", "output": raw}
151 ]
152 });
153 let bytes = serde_json::to_vec(&body).unwrap();
154 let (out, orig, comp) = compress_request_body(body, bytes.len());
155
156 assert!(comp < orig, "compressed body must be smaller");
157 let parsed: Value = serde_json::from_slice(&out).unwrap();
158 assert_eq!(
159 parsed["input"][0]["output"].as_str().unwrap(),
160 expected,
161 "output must be exactly what the shared compressor produces"
162 );
163 }
164
165 #[test]
166 fn array_output_text_is_compressed() {
167 let raw = long_git_status();
168 let expected = compress_tool_result(&raw, None);
169
170 let body = serde_json::json!({
171 "input": [
172 {
173 "type": "function_call_output",
174 "call_id": "call_1",
175 "output": [{"type": "input_text", "text": raw}]
176 }
177 ]
178 });
179 let bytes = serde_json::to_vec(&body).unwrap();
180 let (out, orig, comp) = compress_request_body(body, bytes.len());
181
182 assert!(comp < orig);
183 let parsed: Value = serde_json::from_slice(&out).unwrap();
184 assert_eq!(
185 parsed["input"][0]["output"][0]["text"].as_str().unwrap(),
186 expected
187 );
188 }
189
190 #[test]
191 fn non_tool_output_items_are_untouched() {
192 let body = serde_json::json!({
193 "input": [
194 {"type": "message", "role": "user", "content": long_git_status()},
195 {"type": "function_call", "call_id": "c", "name": "x", "arguments": "{}"}
196 ]
197 });
198 let bytes = serde_json::to_vec(&body).unwrap();
199 let (out, orig, comp) = compress_request_body(body.clone(), bytes.len());
200
201 assert_eq!(comp, orig, "no function_call_output → passthrough");
202 let reparsed: Value = serde_json::from_slice(&out).unwrap();
203 assert_eq!(reparsed, body);
204 }
205
206 #[test]
207 fn plain_string_input_passthrough() {
208 let body = serde_json::json!({"model": "gpt-5", "input": "hello world"});
209 let bytes = serde_json::to_vec(&body).unwrap();
210 let (out, orig, comp) = compress_request_body(body.clone(), bytes.len());
211 assert_eq!(comp, orig);
212 let reparsed: Value = serde_json::from_slice(&out).unwrap();
213 assert_eq!(reparsed, body);
214 }
215
216 #[test]
217 fn no_input_field_passthrough() {
218 let body = serde_json::json!({"model": "gpt-5", "previous_response_id": "resp_abc"});
219 let bytes = serde_json::to_vec(&body).unwrap();
220 let (out, orig, comp) = compress_request_body(body.clone(), bytes.len());
221 assert_eq!(comp, orig);
222 let reparsed: Value = serde_json::from_slice(&out).unwrap();
223 assert_eq!(reparsed, body);
224 }
225
226 #[test]
227 fn short_output_unchanged() {
228 let body = serde_json::json!({
229 "input": [
230 {"type": "function_call_output", "call_id": "c", "output": "ok"}
231 ]
232 });
233 let bytes = serde_json::to_vec(&body).unwrap();
234 let (out, orig, comp) = compress_request_body(body.clone(), bytes.len());
235 assert_eq!(comp, orig);
236 let reparsed: Value = serde_json::from_slice(&out).unwrap();
237 assert_eq!(reparsed, body);
238 }
239}