Skip to main content

lean_ctx/proxy/
openai.rs

1use axum::{
2    body::Body,
3    extract::State,
4    http::{Request, StatusCode},
5    response::Response,
6};
7use serde_json::Value;
8
9use super::compress::compress_tool_result;
10use super::forward;
11use super::tool_kind::{self, should_protect, ToolResultKind};
12use super::ProxyState;
13
14const KEEP_RECENT: usize = 6;
15
16pub async fn handler(
17    State(state): State<ProxyState>,
18    req: Request<Body>,
19) -> Result<Response, StatusCode> {
20    let upstream = state.openai_upstream.clone();
21    forward::forward_request(
22        State(state),
23        req,
24        &upstream,
25        "/v1/chat/completions",
26        compress_request_body,
27        "OpenAI",
28        &[],
29    )
30    .await
31}
32
33fn compress_request_body(body: &[u8]) -> (Vec<u8>, usize, usize) {
34    let original_size = body.len();
35
36    let parsed: Value = match serde_json::from_slice(body) {
37        Ok(v) => v,
38        Err(_) => return (body.to_vec(), original_size, original_size),
39    };
40
41    let mut doc = parsed;
42    let mut modified = false;
43
44    if let Some(messages) = doc.get_mut("messages").and_then(|m| m.as_array_mut()) {
45        let tool_names = tool_kind::openai_tool_names(messages);
46
47        super::history_prune::prune_history(messages, KEEP_RECENT, &tool_names);
48        modified = true;
49
50        for msg in messages.iter_mut() {
51            let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or("");
52            if role != "tool" {
53                continue;
54            }
55
56            let name = msg
57                .get("tool_call_id")
58                .and_then(|v| v.as_str())
59                .and_then(|id| tool_names.get(id))
60                .map(String::as_str);
61            let kind = name.map_or(ToolResultKind::Other, tool_kind::classify_tool_name);
62
63            if let Some(content) = msg
64                .get_mut("content")
65                .and_then(|c| c.as_str().map(String::from))
66            {
67                if should_protect(kind, &content) {
68                    continue;
69                }
70                let compressed = compress_tool_result(&content, name);
71                if compressed.len() < content.len() {
72                    msg["content"] = Value::String(compressed);
73                    modified = true;
74                }
75            }
76        }
77    }
78
79    if !modified {
80        return (body.to_vec(), original_size, original_size);
81    }
82
83    match serde_json::to_vec(&doc) {
84        Ok(compressed) => {
85            let compressed_size = compressed.len();
86            (compressed, original_size, compressed_size)
87        }
88        Err(_) => (body.to_vec(), original_size, original_size),
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn read_file_tool_result_protected() {
98        let code = (0..60)
99            .map(|i| format!("    const value{i} = computeValue{i}(ctx, opts);"))
100            .collect::<Vec<_>>()
101            .join("\n");
102        let body = serde_json::json!({
103            "model": "gpt-5",
104            "messages": [
105                {"role": "assistant", "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "read_file"}}]},
106                {"role": "tool", "tool_call_id": "call_1", "content": code}
107            ]
108        });
109        let bytes = serde_json::to_vec(&body).unwrap();
110        let (out, _orig, _comp) = compress_request_body(&bytes);
111        let parsed: Value = serde_json::from_slice(&out).unwrap();
112        assert!(parsed["messages"][1]["content"]
113            .as_str()
114            .unwrap()
115            .contains("value59"));
116    }
117}