1use crate::protocol::ChatMessage;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5pub enum ContextBlockKind {
6 Persona,
7 Rules,
8 Memory,
9 Retrieval,
10 Workspace,
11 Task,
12}
13
14impl ContextBlockKind {
15 fn order(self) -> u8 {
17 match self {
18 Self::Persona => 0,
19 Self::Rules => 1,
20 Self::Memory => 2,
21 Self::Retrieval => 3,
22 Self::Workspace => 4,
23 Self::Task => 5,
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct ContextBlock {
31 pub kind: ContextBlockKind,
32 pub content: String,
33 pub priority: u8,
36}
37
38#[derive(Debug, Clone)]
40pub struct ContextCompilerConfig {
41 pub total_budget: usize,
43 pub block_budgets: Vec<(ContextBlockKind, usize)>,
46}
47
48impl Default for ContextCompilerConfig {
49 fn default() -> Self {
50 Self {
51 total_budget: 30_000,
52 block_budgets: vec![
53 (ContextBlockKind::Persona, 2_000),
54 (ContextBlockKind::Rules, 5_000),
55 (ContextBlockKind::Memory, 8_000),
56 (ContextBlockKind::Retrieval, 6_000),
57 (ContextBlockKind::Workspace, 5_000),
58 (ContextBlockKind::Task, 4_000),
59 ],
60 }
61 }
62}
63
64impl ContextCompilerConfig {
65 fn budget_for(&self, kind: ContextBlockKind) -> Option<usize> {
66 self.block_budgets
67 .iter()
68 .find(|(k, _)| *k == kind)
69 .map(|(_, b)| *b)
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct CompiledContext {
76 pub system_messages: Vec<ChatMessage>,
78 pub total_tokens: usize,
80 pub dropped_blocks: Vec<ContextBlockKind>,
82}
83
84fn estimate_tokens(text: &str) -> usize {
86 text.len().div_ceil(4).max(1)
87}
88
89fn truncate_to_budget(text: &str, max_tokens: usize) -> &str {
91 let max_chars = max_tokens * 4;
92 if text.len() <= max_chars {
93 return text;
94 }
95 let truncated = &text[..max_chars];
97 match truncated.rfind(' ') {
98 Some(pos) if pos > max_chars / 2 => &text[..pos],
99 _ => truncated,
100 }
101}
102
103pub fn compile_context(blocks: &[ContextBlock], config: &ContextCompilerConfig) -> CompiledContext {
109 if blocks.is_empty() {
110 return CompiledContext {
111 system_messages: Vec::new(),
112 total_tokens: 0,
113 dropped_blocks: Vec::new(),
114 };
115 }
116
117 let mut sorted: Vec<&ContextBlock> = blocks.iter().filter(|b| !b.content.is_empty()).collect();
119 sorted.sort_by_key(|b| b.kind.order());
120
121 let truncated: Vec<(&ContextBlock, &str)> = sorted
123 .iter()
124 .map(|block| {
125 let content = if let Some(budget) = config.budget_for(block.kind) {
126 truncate_to_budget(&block.content, budget)
127 } else {
128 block.content.as_str()
129 };
130 (*block, content)
131 })
132 .collect();
133
134 let total: usize = truncated.iter().map(|(_, c)| estimate_tokens(c)).sum();
136
137 if total <= config.total_budget {
138 let system_messages = truncated
140 .iter()
141 .map(|(_, content)| ChatMessage::system(*content))
142 .collect();
143 return CompiledContext {
144 system_messages,
145 total_tokens: total,
146 dropped_blocks: Vec::new(),
147 };
148 }
149
150 let mut indexed: Vec<(usize, &ContextBlock, &str, usize)> = truncated
152 .iter()
153 .enumerate()
154 .map(|(i, (block, content))| (i, *block, *content, estimate_tokens(content)))
155 .collect();
156
157 indexed.sort_by(|a, b| a.1.priority.cmp(&b.1.priority));
160
161 let mut budget_remaining = config.total_budget;
162 let mut keep_indices: Vec<usize> = Vec::new();
163 let mut dropped_blocks: Vec<ContextBlockKind> = Vec::new();
164
165 for &(original_idx, block, _, tokens) in indexed.iter().rev() {
167 if tokens <= budget_remaining {
168 keep_indices.push(original_idx);
169 budget_remaining = budget_remaining.saturating_sub(tokens);
170 } else {
171 dropped_blocks.push(block.kind);
172 }
173 }
174
175 keep_indices.sort_unstable();
177
178 let system_messages: Vec<ChatMessage> = keep_indices
179 .iter()
180 .map(|&i| ChatMessage::system(truncated[i].1))
181 .collect();
182
183 let total_tokens: usize = keep_indices
184 .iter()
185 .map(|&i| estimate_tokens(truncated[i].1))
186 .sum();
187
188 CompiledContext {
189 system_messages,
190 total_tokens,
191 dropped_blocks,
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 fn make_block(kind: ContextBlockKind, content: &str, priority: u8) -> ContextBlock {
200 ContextBlock {
201 kind,
202 content: content.to_string(),
203 priority,
204 }
205 }
206
207 #[test]
208 fn empty_blocks_returns_empty() {
209 let result = compile_context(&[], &ContextCompilerConfig::default());
210 assert!(result.system_messages.is_empty());
211 assert_eq!(result.total_tokens, 0);
212 assert!(result.dropped_blocks.is_empty());
213 }
214
215 #[test]
216 fn single_block_compiles() {
217 let blocks = vec![make_block(
218 ContextBlockKind::Persona,
219 "You are a helpful assistant.",
220 255,
221 )];
222 let result = compile_context(&blocks, &ContextCompilerConfig::default());
223 assert_eq!(result.system_messages.len(), 1);
224 assert_eq!(
225 result.system_messages[0].content,
226 "You are a helpful assistant."
227 );
228 assert!(result.dropped_blocks.is_empty());
229 }
230
231 #[test]
232 fn all_six_blocks_in_order() {
233 let blocks = vec![
234 make_block(ContextBlockKind::Task, "Current task info", 50),
235 make_block(ContextBlockKind::Persona, "I am an AI", 255),
236 make_block(ContextBlockKind::Memory, "User prefers dark mode", 100),
237 make_block(ContextBlockKind::Rules, "Never lie", 200),
238 make_block(ContextBlockKind::Workspace, "cwd: /home", 80),
239 make_block(ContextBlockKind::Retrieval, "Relevant docs", 90),
240 ];
241 let config = ContextCompilerConfig {
242 total_budget: 100_000,
243 block_budgets: Vec::new(), };
245 let result = compile_context(&blocks, &config);
246 assert_eq!(result.system_messages.len(), 6);
247 assert!(result.system_messages[0].content.contains("I am an AI"));
249 assert!(result.system_messages[1].content.contains("Never lie"));
250 assert!(result.system_messages[2].content.contains("dark mode"));
251 assert!(result.system_messages[3].content.contains("Relevant docs"));
252 assert!(result.system_messages[4].content.contains("cwd:"));
253 assert!(result.system_messages[5].content.contains("Current task"));
254 assert!(result.dropped_blocks.is_empty());
255 }
256
257 #[test]
258 fn block_truncation_respects_budget() {
259 let long_content = "word ".repeat(10000); let blocks = vec![make_block(ContextBlockKind::Memory, &long_content, 100)];
261 let config = ContextCompilerConfig {
262 total_budget: 100_000,
263 block_budgets: vec![(ContextBlockKind::Memory, 100)], };
265 let result = compile_context(&blocks, &config);
266 assert_eq!(result.system_messages.len(), 1);
267 assert!(result.system_messages[0].content.len() <= 400);
269 }
270
271 #[test]
272 fn total_budget_overflow_drops_low_priority() {
273 let blocks = vec![
274 make_block(ContextBlockKind::Persona, &"a".repeat(400), 255), make_block(ContextBlockKind::Rules, &"b".repeat(400), 200), make_block(ContextBlockKind::Memory, &"c".repeat(400), 50), make_block(ContextBlockKind::Retrieval, &"d".repeat(400), 30), ];
279 let config = ContextCompilerConfig {
280 total_budget: 250, block_budgets: Vec::new(),
282 };
283 let result = compile_context(&blocks, &config);
284 assert!(result.system_messages.len() <= 3);
286 assert!(!result.dropped_blocks.is_empty());
287 assert!(
289 result
290 .system_messages
291 .iter()
292 .any(|m| m.content.contains('a'))
293 );
294 }
295
296 #[test]
297 fn persona_never_dropped() {
298 let blocks = vec![
299 make_block(ContextBlockKind::Persona, &"x".repeat(100), 255),
300 make_block(ContextBlockKind::Rules, &"y".repeat(4000), 200),
301 ];
302 let config = ContextCompilerConfig {
303 total_budget: 50, block_budgets: Vec::new(),
305 };
306 let result = compile_context(&blocks, &config);
307 assert!(
308 result
309 .system_messages
310 .iter()
311 .any(|m| m.content.contains('x'))
312 );
313 }
314
315 #[test]
316 fn empty_content_skipped() {
317 let blocks = vec![
318 make_block(ContextBlockKind::Persona, "hello", 255),
319 make_block(ContextBlockKind::Rules, "", 200),
320 make_block(ContextBlockKind::Memory, " ", 100),
321 ];
322 let config = ContextCompilerConfig {
323 total_budget: 100_000,
324 block_budgets: Vec::new(),
325 };
326 let result = compile_context(&blocks, &config);
327 assert_eq!(result.system_messages.len(), 2);
329 }
330
331 #[test]
332 fn default_config_reasonable() {
333 let config = ContextCompilerConfig::default();
334 assert_eq!(config.total_budget, 30_000);
335 assert_eq!(config.block_budgets.len(), 6);
336 }
337
338 #[test]
339 fn word_boundary_truncation() {
340 let content = "hello world this is a test of truncation at word boundaries";
341 let truncated = truncate_to_budget(content, 3); assert!(truncated.len() <= 12);
344 assert!(!truncated.ends_with(' '));
345 }
346
347 #[test]
348 fn token_count_accuracy() {
349 let tokens = estimate_tokens(&"a".repeat(100));
351 assert_eq!(tokens, 25);
352
353 assert_eq!(estimate_tokens("a"), 1);
355
356 assert_eq!(estimate_tokens(""), 1);
358 }
359
360 #[test]
361 fn dropped_blocks_reported() {
362 let blocks = vec![
363 make_block(ContextBlockKind::Persona, &"a".repeat(400), 255),
364 make_block(ContextBlockKind::Memory, &"c".repeat(400), 50),
365 ];
366 let config = ContextCompilerConfig {
367 total_budget: 110, block_budgets: Vec::new(),
369 };
370 let result = compile_context(&blocks, &config);
371 assert!(!result.dropped_blocks.is_empty());
372 assert!(result.dropped_blocks.contains(&ContextBlockKind::Memory));
373 }
374
375 #[test]
376 fn custom_budgets_applied() {
377 let blocks = vec![
378 make_block(ContextBlockKind::Persona, &"p".repeat(1000), 255),
379 make_block(ContextBlockKind::Rules, &"r".repeat(1000), 200),
380 ];
381 let config = ContextCompilerConfig {
382 total_budget: 100_000,
383 block_budgets: vec![
384 (ContextBlockKind::Persona, 50), (ContextBlockKind::Rules, 50),
386 ],
387 };
388 let result = compile_context(&blocks, &config);
389 assert_eq!(result.system_messages.len(), 2);
390 for msg in &result.system_messages {
392 assert!(msg.content.len() <= 200);
393 }
394 }
395
396 #[test]
397 fn deterministic_ordering() {
398 let blocks = vec![
399 make_block(ContextBlockKind::Workspace, "ws", 80),
400 make_block(ContextBlockKind::Persona, "persona", 255),
401 make_block(ContextBlockKind::Task, "task", 50),
402 ];
403 let config = ContextCompilerConfig {
404 total_budget: 100_000,
405 block_budgets: Vec::new(),
406 };
407 let r1 = compile_context(&blocks, &config);
409 let r2 = compile_context(&blocks, &config);
410 assert_eq!(r1.system_messages.len(), r2.system_messages.len());
411 for (a, b) in r1.system_messages.iter().zip(r2.system_messages.iter()) {
412 assert_eq!(a.content, b.content);
413 }
414 }
415
416 #[test]
417 fn compiles_alongside_compact_messages() {
418 let blocks = vec![
420 make_block(ContextBlockKind::Persona, "You are helpful.", 255),
421 make_block(ContextBlockKind::Rules, "Be concise.", 200),
422 ];
423 let config = ContextCompilerConfig::default();
424 let compiled = compile_context(&blocks, &config);
425
426 let mut messages = compiled.system_messages;
428 messages.push(ChatMessage::user("Hello"));
429 messages.push(ChatMessage::assistant("Hi!"));
430
431 assert!(messages.len() >= 4);
433 assert_eq!(messages[0].role, crate::protocol::Role::System);
434 assert_eq!(messages[1].role, crate::protocol::Role::System);
435 }
436}