1use axum::{
2 body::Body,
3 extract::State,
4 http::{Request, StatusCode},
5 response::Response,
6};
7use serde_json::Value;
8
9use super::compress::compress_tool_result;
10use super::forward;
11use super::ProxyState;
12
13const UPSTREAM: &str = "https://api.openai.com";
14
15pub async fn handler(state: State<ProxyState>, req: Request<Body>) -> Result<Response, StatusCode> {
16 forward::forward_request(
17 state,
18 req,
19 UPSTREAM,
20 "/v1/chat/completions",
21 compress_request_body,
22 "OpenAI",
23 &[],
24 )
25 .await
26}
27
28fn compress_request_body(body: &[u8]) -> (Vec<u8>, usize, usize) {
29 let original_size = body.len();
30
31 let parsed: Value = match serde_json::from_slice(body) {
32 Ok(v) => v,
33 Err(_) => return (body.to_vec(), original_size, original_size),
34 };
35
36 let mut doc = parsed;
37 let mut modified = false;
38
39 if let Some(messages) = doc.get_mut("messages").and_then(|m| m.as_array_mut()) {
40 for msg in messages.iter_mut() {
41 let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or("");
42 if role != "tool" {
43 continue;
44 }
45
46 if let Some(content) = msg
47 .get_mut("content")
48 .and_then(|c| c.as_str().map(String::from))
49 {
50 let compressed = compress_tool_result(&content, None);
51 if compressed.len() < content.len() {
52 msg["content"] = Value::String(compressed);
53 modified = true;
54 }
55 }
56 }
57 }
58
59 if !modified {
60 return (body.to_vec(), original_size, original_size);
61 }
62
63 match serde_json::to_vec(&doc) {
64 Ok(compressed) => {
65 let compressed_size = compressed.len();
66 (compressed, original_size, compressed_size)
67 }
68 Err(_) => (body.to_vec(), original_size, original_size),
69 }
70}