Skip to main content

lean_ctx/proxy/
openai.rs

1use axum::{
2    body::Body,
3    extract::State,
4    http::{Request, StatusCode},
5    response::Response,
6};
7use serde_json::Value;
8
9use super::compress::compress_tool_result;
10use super::forward;
11use super::tool_kind::{self, should_protect, ToolResultKind};
12use super::ProxyState;
13
14const KEEP_RECENT: usize = 6;
15
16pub async fn handler(
17    State(state): State<ProxyState>,
18    req: Request<Body>,
19) -> Result<Response, StatusCode> {
20    let upstream = state.openai_upstream.clone();
21    forward::forward_request(
22        State(state),
23        req,
24        &upstream,
25        "/v1/chat/completions",
26        compress_request_body,
27        "OpenAI",
28        &[],
29    )
30    .await
31}
32
33fn compress_request_body(parsed: Value, original_size: usize) -> (Vec<u8>, usize, usize) {
34    let mut doc = parsed;
35    let mut modified = false;
36
37    if let Some(messages) = doc.get_mut("messages").and_then(|m| m.as_array_mut()) {
38        let tool_names = tool_kind::openai_tool_names(messages);
39
40        super::history_prune::prune_history(messages, KEEP_RECENT, &tool_names);
41        modified = true;
42
43        for msg in messages.iter_mut() {
44            let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or("");
45            if role != "tool" {
46                continue;
47            }
48
49            let name = msg
50                .get("tool_call_id")
51                .and_then(|v| v.as_str())
52                .and_then(|id| tool_names.get(id))
53                .map(String::as_str);
54            let kind = name.map_or(ToolResultKind::Other, tool_kind::classify_tool_name);
55
56            if let Some(content) = msg
57                .get_mut("content")
58                .and_then(|c| c.as_str().map(String::from))
59            {
60                if should_protect(kind, &content) {
61                    continue;
62                }
63                let compressed = compress_tool_result(&content, name);
64                if compressed.len() < content.len() {
65                    msg["content"] = Value::String(compressed);
66                    modified = true;
67                }
68            }
69        }
70    }
71
72    let out = serde_json::to_vec(&doc).unwrap_or_default();
73    let compressed_size = if modified { out.len() } else { original_size };
74    (out, original_size, compressed_size)
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[test]
82    fn read_file_tool_result_protected() {
83        let code = (0..60)
84            .map(|i| format!("    const value{i} = computeValue{i}(ctx, opts);"))
85            .collect::<Vec<_>>()
86            .join("\n");
87        let body = serde_json::json!({
88            "model": "gpt-5",
89            "messages": [
90                {"role": "assistant", "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "read_file"}}]},
91                {"role": "tool", "tool_call_id": "call_1", "content": code}
92            ]
93        });
94        let bytes = serde_json::to_vec(&body).unwrap();
95        let (out, _orig, _comp) = compress_request_body(body, bytes.len());
96        let parsed: Value = serde_json::from_slice(&out).unwrap();
97        assert!(parsed["messages"][1]["content"]
98            .as_str()
99            .unwrap()
100            .contains("value59"));
101    }
102}