1use std::collections::HashSet;
16use std::hash::{Hash, Hasher};
17
18use mimir_core::tokens;
19use serde_json::{json, Value};
20
21#[derive(Clone, Copy)]
22pub struct OptimizeOpts {
23 pub cache: bool,
24 pub dedup: bool,
25 pub prune: bool,
26}
27
28#[derive(Debug, Default, PartialEq, Eq)]
31pub struct Optimization {
32 pub deduped: usize,
33 pub pruned: usize,
34}
35
36const KEEP_RECENT_MESSAGES: usize = 6;
38const PRUNE_MIN_TOKENS: usize = 200;
40const DEDUP_MIN_TOKENS: usize = 100;
42const PRUNE_PLACEHOLDER: &str = "[older tool result elided by mimir proxy]";
43const DEDUP_PLACEHOLDER: &str = "[identical to an earlier block — elided by mimir proxy]";
44
45pub fn optimize_request(mut req: Value, opts: OptimizeOpts) -> (Value, Optimization) {
46 let mut opt = Optimization::default();
47 if opts.cache {
48 add_cache_breakpoints(&mut req);
49 }
50 if opts.dedup {
51 dedup_blocks(&mut req, &mut opt);
52 }
53 if opts.prune {
54 prune_old_tool_results(&mut req, &mut opt);
55 }
56 (req, opt)
57}
58
59fn blocks_have_cache_control(v: &Value) -> bool {
60 matches!(v, Value::Array(a) if a.iter().any(|b| b.get("cache_control").is_some()))
61}
62
63pub(crate) fn has_cache_control(req: &Value) -> bool {
65 if blocks_have_cache_control(req.get("system").unwrap_or(&Value::Null)) {
66 return true;
67 }
68 req.get("messages")
69 .and_then(|m| m.as_array())
70 .map(|msgs| {
71 msgs.iter()
72 .any(|m| blocks_have_cache_control(m.get("content").unwrap_or(&Value::Null)))
73 })
74 .unwrap_or(false)
75}
76
77fn add_cache_breakpoints(req: &mut Value) {
80 if has_cache_control(req) {
81 return;
82 }
83 if let Some(system) = req.get_mut("system") {
84 mark_cache_control(system);
85 }
86 if let Some(content) = req
87 .get_mut("messages")
88 .and_then(|m| m.as_array_mut())
89 .and_then(|m| m.last_mut())
90 .and_then(|m| m.get_mut("content"))
91 {
92 mark_cache_control(content);
93 }
94}
95
96fn mark_cache_control(v: &mut Value) {
99 match v {
100 Value::String(s) => {
101 let text = std::mem::take(s);
102 *v = json!([{
103 "type": "text", "text": text, "cache_control": { "type": "ephemeral" }
104 }]);
105 }
106 Value::Array(arr) => {
107 if let Some(obj) = arr.last_mut().and_then(|b| b.as_object_mut()) {
108 obj.insert("cache_control".into(), json!({ "type": "ephemeral" }));
109 }
110 }
111 _ => {}
112 }
113}
114
115fn block_text(block: &Value) -> Option<String> {
117 match block.get("type").and_then(|t| t.as_str()) {
118 Some("text") => block
119 .get("text")
120 .and_then(|t| t.as_str())
121 .map(str::to_owned),
122 Some("tool_result") => block.get("content").map(|c| match c {
123 Value::String(s) => s.clone(),
124 other => other.to_string(),
125 }),
126 _ => None,
127 }
128}
129
130fn set_block_text(block: &mut Value, placeholder: &str) {
131 if let Some(obj) = block.as_object_mut() {
132 if obj.contains_key("text") {
133 obj.insert("text".into(), json!(placeholder));
134 } else {
135 obj.insert("content".into(), json!(placeholder));
136 }
137 }
138}
139
140fn dedup_blocks(req: &mut Value, opt: &mut Optimization) {
142 let Some(msgs) = req.get_mut("messages").and_then(|m| m.as_array_mut()) else {
143 return;
144 };
145 let mut seen: HashSet<u64> = HashSet::new();
147 for msg in msgs.iter_mut() {
148 let Some(blocks) = msg.get_mut("content").and_then(|c| c.as_array_mut()) else {
149 continue;
150 };
151 for b in blocks.iter_mut() {
152 let Some(text) = block_text(b) else { continue };
153 let toks = tokens::count(&text);
154 if toks < DEDUP_MIN_TOKENS {
155 continue;
156 }
157 if !seen.insert(digest(&text)) {
158 opt.deduped += toks.saturating_sub(tokens::count(DEDUP_PLACEHOLDER));
159 set_block_text(b, DEDUP_PLACEHOLDER);
160 }
161 }
162 }
163}
164
165fn digest(s: &str) -> u64 {
166 let mut h = std::collections::hash_map::DefaultHasher::new();
167 s.hash(&mut h);
168 h.finish()
169}
170
171fn prune_old_tool_results(req: &mut Value, opt: &mut Optimization) {
172 let Some(msgs) = req.get_mut("messages").and_then(|m| m.as_array_mut()) else {
173 return;
174 };
175 let keep_from = msgs.len().saturating_sub(KEEP_RECENT_MESSAGES);
176 for msg in msgs.iter_mut().take(keep_from) {
177 let Some(blocks) = msg.get_mut("content").and_then(|c| c.as_array_mut()) else {
178 continue;
179 };
180 for b in blocks.iter_mut() {
181 if b.get("type").and_then(|t| t.as_str()) != Some("tool_result") {
182 continue;
183 }
184 let Some(text) = block_text(b) else { continue };
185 let toks = tokens::count(&text);
186 if toks > PRUNE_MIN_TOKENS && text != DEDUP_PLACEHOLDER {
187 opt.pruned += toks.saturating_sub(tokens::count(PRUNE_PLACEHOLDER));
188 set_block_text(b, PRUNE_PLACEHOLDER);
189 }
190 }
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 fn opts(cache: bool, dedup: bool, prune: bool) -> OptimizeOpts {
199 OptimizeOpts {
200 cache,
201 dedup,
202 prune,
203 }
204 }
205
206 #[test]
207 fn caches_system_and_last_message() {
208 let req = json!({
209 "system": "You are careful. ".repeat(20),
210 "messages": [
211 {"role":"user","content":"hi"},
212 {"role":"user","content":"do the thing"}
213 ]
214 });
215 let (out, _) = optimize_request(req, opts(true, false, false));
216 assert_eq!(out["system"][0]["cache_control"]["type"], "ephemeral");
217 let msgs = out["messages"].as_array().unwrap();
218 let last = msgs.last().unwrap();
219 assert_eq!(last["content"][0]["cache_control"]["type"], "ephemeral");
220 }
221
222 #[test]
223 fn respects_existing_cache_control() {
224 let req = json!({
225 "system": [{"type":"text","text":"x","cache_control":{"type":"ephemeral"}}],
226 "messages": []
227 });
228 let (out, opt) = optimize_request(req.clone(), opts(true, true, false));
229 assert_eq!(opt, Optimization::default());
230 assert_eq!(out, req);
231 }
232
233 #[test]
234 fn dedup_elides_later_identical_blocks() {
235 let big = "x ".repeat(200); let req = json!({
237 "messages": [
238 {"role":"user","content":[{"type":"text","text": big}]},
239 {"role":"user","content":"middle"},
240 {"role":"user","content":[{"type":"text","text": big}]}
241 ]
242 });
243 let (out, opt) = optimize_request(req, opts(false, true, false));
244 assert!(opt.deduped > 0);
245 assert_ne!(out["messages"][0]["content"][0]["text"], DEDUP_PLACEHOLDER);
247 assert_eq!(out["messages"][2]["content"][0]["text"], DEDUP_PLACEHOLDER);
248 }
249
250 #[test]
251 fn dedup_keeps_small_repeats() {
252 let req = json!({
253 "messages": [
254 {"role":"user","content":[{"type":"text","text":"short"}]},
255 {"role":"user","content":[{"type":"text","text":"short"}]}
256 ]
257 });
258 let (out, opt) = optimize_request(req, opts(false, true, false));
259 assert_eq!(opt.deduped, 0);
260 assert_eq!(out["messages"][1]["content"][0]["text"], "short");
261 }
262
263 #[test]
264 fn prune_elides_old_large_tool_results() {
265 let big = "x ".repeat(400);
266 let mut messages = vec![json!({
267 "role":"user",
268 "content":[{"type":"tool_result","tool_use_id":"a","content": big}]
269 })];
270 for _ in 0..KEEP_RECENT_MESSAGES {
271 messages.push(json!({"role":"user","content":"recent"}));
272 }
273 let req = json!({ "messages": messages });
274 let (out, opt) = optimize_request(req, opts(false, false, true));
275 assert!(opt.pruned > 0);
276 assert_eq!(
277 out["messages"][0]["content"][0]["content"],
278 PRUNE_PLACEHOLDER
279 );
280 }
281}