agent_code_lib/services/
context_collapse.rs1use crate::llm::message::Message;
18use crate::services::tokens;
19
20pub struct CollapseResult {
22 pub api_messages: Vec<Message>,
24 pub snipped_count: usize,
26 pub tokens_freed: u64,
28}
29
30pub fn collapse_to_budget(messages: &[Message], max_tokens: u64) -> Option<CollapseResult> {
37 let current = tokens::estimate_context_tokens(messages);
38 if current <= max_tokens {
39 return None; }
41
42 let overshoot = current - max_tokens;
43
44 let groups = group_by_round(messages);
46 if groups.len() <= 2 {
47 return None; }
49
50 let mut freed = 0u64;
53 let mut snip_end = 1; for (group_idx, group) in groups[1..groups.len().saturating_sub(1)].iter().enumerate() {
56 let group_tokens: u64 = group.iter().map(tokens::estimate_message_tokens).sum();
57 freed += group_tokens;
58 snip_end = group_idx + 2; if freed >= overshoot {
61 break;
62 }
63 }
64
65 if freed == 0 {
66 return None;
67 }
68
69 let mut api_messages = Vec::new();
71
72 api_messages.extend(groups[0].iter().cloned());
74
75 api_messages.push(crate::llm::message::user_message(
77 "[Earlier messages collapsed to fit context window]",
78 ));
79
80 for group in &groups[snip_end..] {
82 api_messages.extend(group.iter().cloned());
83 }
84
85 let snipped_count: usize = groups[1..snip_end].iter().map(|g| g.len()).sum();
86
87 Some(CollapseResult {
88 api_messages,
89 snipped_count,
90 tokens_freed: freed,
91 })
92}
93
94pub fn recover_from_overflow(
96 messages: &[Message],
97 token_gap: Option<u64>,
98) -> Option<CollapseResult> {
99 let target = token_gap.map(|gap| gap + gap / 10).unwrap_or(20_000);
101
102 let current = tokens::estimate_context_tokens(messages);
103 let budget = current.saturating_sub(target);
104
105 collapse_to_budget(messages, budget)
106}
107
108fn group_by_round(messages: &[Message]) -> Vec<Vec<Message>> {
110 let mut groups: Vec<Vec<Message>> = Vec::new();
111 let mut current_group: Vec<Message> = Vec::new();
112
113 for msg in messages {
114 match msg {
115 Message::User(u) if !u.is_meta => {
116 if !current_group.is_empty() {
118 groups.push(current_group);
119 current_group = Vec::new();
120 }
121 current_group.push(msg.clone());
122 }
123 _ => {
124 current_group.push(msg.clone());
125 }
126 }
127 }
128
129 if !current_group.is_empty() {
130 groups.push(current_group);
131 }
132
133 groups
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::llm::message::user_message;
140
141 #[test]
142 fn test_no_collapse_within_budget() {
143 let messages = vec![user_message("short")];
144 assert!(collapse_to_budget(&messages, 1_000_000).is_none());
145 }
146
147 #[test]
148 fn test_collapse_empty_messages() {
149 let messages: Vec<crate::llm::message::Message> = vec![];
150 assert!(collapse_to_budget(&messages, 100).is_none());
151 }
152
153 #[test]
154 fn test_collapse_preserves_first_and_last() {
155 use crate::llm::message::*;
156 let mut messages = Vec::new();
158 for i in 0..10 {
159 messages.push(user_message(format!(
160 "message {i} with some content padding"
161 )));
162 messages.push(Message::Assistant(AssistantMessage {
163 uuid: uuid::Uuid::new_v4(),
164 timestamp: String::new(),
165 content: vec![ContentBlock::Text {
166 text: format!("response {i} with content"),
167 }],
168 model: None,
169 usage: None,
170 stop_reason: None,
171 request_id: None,
172 }));
173 }
174 if let Some(result) = collapse_to_budget(&messages, 50) {
176 assert!(result.snipped_count > 0);
177 assert!(result.tokens_freed > 0);
178 assert!(result.api_messages.len() < messages.len());
179 }
180 assert!(collapse_to_budget(&messages, 1_000_000).is_none());
182 }
183
184 #[test]
185 fn test_recover_from_overflow() {
186 use crate::llm::message::*;
187 let mut messages = Vec::new();
188 for i in 0..20 {
189 messages.push(user_message(format!("msg {i} {}", "x".repeat(200))));
190 messages.push(Message::Assistant(AssistantMessage {
191 uuid: uuid::Uuid::new_v4(),
192 timestamp: String::new(),
193 content: vec![ContentBlock::Text {
194 text: format!("resp {i} {}", "y".repeat(200)),
195 }],
196 model: None,
197 usage: None,
198 stop_reason: None,
199 request_id: None,
200 }));
201 }
202 let result = recover_from_overflow(&messages, Some(5000));
203 assert!(result.is_some());
204 let r = result.unwrap();
205 assert!(r.snipped_count > 0);
206 }
207}