Skip to main content

mimir_proxy/
optimize.rs

1//! Request-body transforms applied before forwarding to the Anthropic API.
2//!
3//! Three passes, all pure and independently testable:
4//!  - **cache** (default on, safe): if the client set NO `cache_control`, add
5//!    ephemeral breakpoints on the system prompt and the last message — the
6//!    stable prefix that repeats across turns. Actual savings are measured from
7//!    the *response* (`usage.cache_read_input_tokens`), not estimated here.
8//!  - **dedup** (default on, safe): replace later identical large content
9//!    blocks with a short placeholder. The model still sees the content once
10//!    (the first occurrence). Lossless in practice; big in long sessions.
11//!  - **prune** (default off, lossy): replace large `tool_result` blocks in
12//!    older turns with a placeholder. Off by default — it changes what the
13//!    model sees.
14
15use std::collections::HashSet;
16use std::hash::{Hash, Hasher};
17
18use mimir_core::tokens;
19use serde_json::{json, Value};
20
21#[derive(Clone, Copy)]
22pub struct OptimizeOpts {
23    pub cache: bool,
24    pub dedup: bool,
25    pub prune: bool,
26}
27
28/// Tokens removed by the request-side passes, for the savings ledger. Cache
29/// savings are NOT here — they're measured from the response.
30#[derive(Debug, Default, PartialEq, Eq)]
31pub struct Optimization {
32    pub deduped: usize,
33    pub pruned: usize,
34}
35
36/// Keep tool results from at least this many trailing messages intact.
37const KEEP_RECENT_MESSAGES: usize = 6;
38/// Only prune a tool_result whose text is larger than this.
39const PRUNE_MIN_TOKENS: usize = 200;
40/// Only dedup a block whose text is larger than this.
41const DEDUP_MIN_TOKENS: usize = 100;
42const PRUNE_PLACEHOLDER: &str = "[older tool result elided by mimir proxy]";
43const DEDUP_PLACEHOLDER: &str = "[identical to an earlier block — elided by mimir proxy]";
44
45pub fn optimize_request(mut req: Value, opts: OptimizeOpts) -> (Value, Optimization) {
46    let mut opt = Optimization::default();
47    if opts.cache {
48        add_cache_breakpoints(&mut req);
49    }
50    if opts.dedup {
51        dedup_blocks(&mut req, &mut opt);
52    }
53    if opts.prune {
54        prune_old_tool_results(&mut req, &mut opt);
55    }
56    (req, opt)
57}
58
59fn blocks_have_cache_control(v: &Value) -> bool {
60    matches!(v, Value::Array(a) if a.iter().any(|b| b.get("cache_control").is_some()))
61}
62
63/// True if the request already carries any cache_control breakpoint.
64pub(crate) fn has_cache_control(req: &Value) -> bool {
65    if blocks_have_cache_control(req.get("system").unwrap_or(&Value::Null)) {
66        return true;
67    }
68    req.get("messages")
69        .and_then(|m| m.as_array())
70        .map(|msgs| {
71            msgs.iter()
72                .any(|m| blocks_have_cache_control(m.get("content").unwrap_or(&Value::Null)))
73        })
74        .unwrap_or(false)
75}
76
77/// Mark the stable prefix for caching: the system prompt and the last message's
78/// final block. No-op if the client already manages caching.
79fn add_cache_breakpoints(req: &mut Value) {
80    if has_cache_control(req) {
81        return;
82    }
83    if let Some(system) = req.get_mut("system") {
84        mark_cache_control(system);
85    }
86    if let Some(content) = req
87        .get_mut("messages")
88        .and_then(|m| m.as_array_mut())
89        .and_then(|m| m.last_mut())
90        .and_then(|m| m.get_mut("content"))
91    {
92        mark_cache_control(content);
93    }
94}
95
96/// Put an ephemeral cache breakpoint on a `system`/`content` value, converting
97/// a bare string into a one-block array so the marker has somewhere to live.
98fn mark_cache_control(v: &mut Value) {
99    match v {
100        Value::String(s) => {
101            let text = std::mem::take(s);
102            *v = json!([{
103                "type": "text", "text": text, "cache_control": { "type": "ephemeral" }
104            }]);
105        }
106        Value::Array(arr) => {
107            if let Some(obj) = arr.last_mut().and_then(|b| b.as_object_mut()) {
108                obj.insert("cache_control".into(), json!({ "type": "ephemeral" }));
109            }
110        }
111        _ => {}
112    }
113}
114
115/// Comparable text of a content block (text blocks and tool_result blocks).
116fn block_text(block: &Value) -> Option<String> {
117    match block.get("type").and_then(|t| t.as_str()) {
118        Some("text") => block
119            .get("text")
120            .and_then(|t| t.as_str())
121            .map(str::to_owned),
122        Some("tool_result") => block.get("content").map(|c| match c {
123            Value::String(s) => s.clone(),
124            other => other.to_string(),
125        }),
126        _ => None,
127    }
128}
129
130fn set_block_text(block: &mut Value, placeholder: &str) {
131    if let Some(obj) = block.as_object_mut() {
132        if obj.contains_key("text") {
133            obj.insert("text".into(), json!(placeholder));
134        } else {
135            obj.insert("content".into(), json!(placeholder));
136        }
137    }
138}
139
140/// Replace later identical large blocks with a placeholder (keep the first).
141fn dedup_blocks(req: &mut Value, opt: &mut Optimization) {
142    let Some(msgs) = req.get_mut("messages").and_then(|m| m.as_array_mut()) else {
143        return;
144    };
145    // Store digests, not the full block texts — those can be KB–MB each.
146    let mut seen: HashSet<u64> = HashSet::new();
147    for msg in msgs.iter_mut() {
148        let Some(blocks) = msg.get_mut("content").and_then(|c| c.as_array_mut()) else {
149            continue;
150        };
151        for b in blocks.iter_mut() {
152            let Some(text) = block_text(b) else { continue };
153            let toks = tokens::count(&text);
154            if toks < DEDUP_MIN_TOKENS {
155                continue;
156            }
157            if !seen.insert(digest(&text)) {
158                opt.deduped += toks.saturating_sub(tokens::count(DEDUP_PLACEHOLDER));
159                set_block_text(b, DEDUP_PLACEHOLDER);
160            }
161        }
162    }
163}
164
165fn digest(s: &str) -> u64 {
166    let mut h = std::collections::hash_map::DefaultHasher::new();
167    s.hash(&mut h);
168    h.finish()
169}
170
171fn prune_old_tool_results(req: &mut Value, opt: &mut Optimization) {
172    let Some(msgs) = req.get_mut("messages").and_then(|m| m.as_array_mut()) else {
173        return;
174    };
175    let keep_from = msgs.len().saturating_sub(KEEP_RECENT_MESSAGES);
176    for msg in msgs.iter_mut().take(keep_from) {
177        let Some(blocks) = msg.get_mut("content").and_then(|c| c.as_array_mut()) else {
178            continue;
179        };
180        for b in blocks.iter_mut() {
181            if b.get("type").and_then(|t| t.as_str()) != Some("tool_result") {
182                continue;
183            }
184            let Some(text) = block_text(b) else { continue };
185            let toks = tokens::count(&text);
186            if toks > PRUNE_MIN_TOKENS && text != DEDUP_PLACEHOLDER {
187                opt.pruned += toks.saturating_sub(tokens::count(PRUNE_PLACEHOLDER));
188                set_block_text(b, PRUNE_PLACEHOLDER);
189            }
190        }
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    fn opts(cache: bool, dedup: bool, prune: bool) -> OptimizeOpts {
199        OptimizeOpts {
200            cache,
201            dedup,
202            prune,
203        }
204    }
205
206    #[test]
207    fn caches_system_and_last_message() {
208        let req = json!({
209            "system": "You are careful. ".repeat(20),
210            "messages": [
211                {"role":"user","content":"hi"},
212                {"role":"user","content":"do the thing"}
213            ]
214        });
215        let (out, _) = optimize_request(req, opts(true, false, false));
216        assert_eq!(out["system"][0]["cache_control"]["type"], "ephemeral");
217        let msgs = out["messages"].as_array().unwrap();
218        let last = msgs.last().unwrap();
219        assert_eq!(last["content"][0]["cache_control"]["type"], "ephemeral");
220    }
221
222    #[test]
223    fn respects_existing_cache_control() {
224        let req = json!({
225            "system": [{"type":"text","text":"x","cache_control":{"type":"ephemeral"}}],
226            "messages": []
227        });
228        let (out, opt) = optimize_request(req.clone(), opts(true, true, false));
229        assert_eq!(opt, Optimization::default());
230        assert_eq!(out, req);
231    }
232
233    #[test]
234    fn dedup_elides_later_identical_blocks() {
235        let big = "x ".repeat(200); // > 100 tokens
236        let req = json!({
237            "messages": [
238                {"role":"user","content":[{"type":"text","text": big}]},
239                {"role":"user","content":"middle"},
240                {"role":"user","content":[{"type":"text","text": big}]}
241            ]
242        });
243        let (out, opt) = optimize_request(req, opts(false, true, false));
244        assert!(opt.deduped > 0);
245        // first kept, last elided
246        assert_ne!(out["messages"][0]["content"][0]["text"], DEDUP_PLACEHOLDER);
247        assert_eq!(out["messages"][2]["content"][0]["text"], DEDUP_PLACEHOLDER);
248    }
249
250    #[test]
251    fn dedup_keeps_small_repeats() {
252        let req = json!({
253            "messages": [
254                {"role":"user","content":[{"type":"text","text":"short"}]},
255                {"role":"user","content":[{"type":"text","text":"short"}]}
256            ]
257        });
258        let (out, opt) = optimize_request(req, opts(false, true, false));
259        assert_eq!(opt.deduped, 0);
260        assert_eq!(out["messages"][1]["content"][0]["text"], "short");
261    }
262
263    #[test]
264    fn prune_elides_old_large_tool_results() {
265        let big = "x ".repeat(400);
266        let mut messages = vec![json!({
267            "role":"user",
268            "content":[{"type":"tool_result","tool_use_id":"a","content": big}]
269        })];
270        for _ in 0..KEEP_RECENT_MESSAGES {
271            messages.push(json!({"role":"user","content":"recent"}));
272        }
273        let req = json!({ "messages": messages });
274        let (out, opt) = optimize_request(req, opts(false, false, true));
275        assert!(opt.pruned > 0);
276        assert_eq!(
277            out["messages"][0]["content"][0]["content"],
278            PRUNE_PLACEHOLDER
279        );
280    }
281}