1use super::{ContextItem, ContextResult};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub struct ContextBudget {
9 pub max_items: usize,
10 pub max_tokens: usize,
11}
12
13impl Default for ContextBudget {
14 fn default() -> Self {
15 Self {
16 max_items: 12,
17 max_tokens: 4_000,
18 }
19 }
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct ContextSourcePolicy {
25 pub max_items_per_source: Option<usize>,
26 pub max_tokens_per_source: Option<usize>,
27}
28
29impl Default for ContextSourcePolicy {
30 fn default() -> Self {
31 Self {
32 max_items_per_source: Some(6),
33 max_tokens_per_source: Some(2_500),
34 }
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub struct ContextAssemblyPolicy {
41 pub budget: ContextBudget,
42 pub source_policy: ContextSourcePolicy,
43}
44
45impl ContextAssemblyPolicy {
46 pub fn balanced() -> Self {
47 Self {
48 budget: ContextBudget {
49 max_items: 12,
50 max_tokens: 4_000,
51 },
52 source_policy: ContextSourcePolicy {
53 max_items_per_source: Some(6),
54 max_tokens_per_source: Some(2_500),
55 },
56 }
57 }
58
59 pub fn compact() -> Self {
60 Self {
61 budget: ContextBudget {
62 max_items: 8,
63 max_tokens: 2_500,
64 },
65 source_policy: ContextSourcePolicy {
66 max_items_per_source: Some(4),
67 max_tokens_per_source: Some(1_200),
68 },
69 }
70 }
71
72 pub fn expansive() -> Self {
73 Self {
74 budget: ContextBudget {
75 max_items: 20,
76 max_tokens: 8_000,
77 },
78 source_policy: ContextSourcePolicy {
79 max_items_per_source: Some(8),
80 max_tokens_per_source: Some(3_500),
81 },
82 }
83 }
84}
85
86impl Default for ContextAssemblyPolicy {
87 fn default() -> Self {
88 Self::balanced()
89 }
90}
91
92#[derive(Debug, Clone, Default)]
94pub struct ContextAssembly {
95 pub items: Vec<ContextItem>,
96 pub total_tokens: usize,
97 pub truncated: bool,
98}
99
100impl ContextAssembly {
101 pub fn to_xml(&self) -> String {
102 self.items
103 .iter()
104 .map(ContextItem::to_xml)
105 .collect::<Vec<_>>()
106 .join("\n\n")
107 }
108
109 pub fn is_empty(&self) -> bool {
110 self.items.is_empty()
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct ContextAssembler {
117 budget: ContextBudget,
118 source_policy: ContextSourcePolicy,
119}
120
121impl ContextAssembler {
122 pub fn new(budget: ContextBudget) -> Self {
123 Self::from_policy(ContextAssemblyPolicy {
124 budget,
125 source_policy: ContextSourcePolicy::default(),
126 })
127 }
128
129 pub fn from_policy(policy: ContextAssemblyPolicy) -> Self {
130 Self {
131 budget: policy.budget,
132 source_policy: policy.source_policy,
133 }
134 }
135
136 pub fn with_source_policy(mut self, policy: ContextSourcePolicy) -> Self {
137 self.source_policy = policy;
138 self
139 }
140
141 pub fn with_default_budget() -> Self {
142 Self::from_policy(ContextAssemblyPolicy::balanced())
143 }
144
145 pub fn assemble(&self, results: &[ContextResult]) -> ContextAssembly {
146 let mut deduped: HashMap<String, ContextItem> = HashMap::new();
147 let mut source_count = 0usize;
148
149 for result in results {
150 for item in &result.items {
151 source_count += 1;
152 let key = dedupe_key(item);
153 match deduped.get(&key) {
154 Some(existing)
155 if ranking_score(existing)
156 .total_cmp(&ranking_score(item))
157 .then_with(|| existing.relevance.total_cmp(&item.relevance))
158 .is_ge() => {}
159 _ => {
160 deduped.insert(key, item.clone());
161 }
162 }
163 }
164 }
165
166 let mut items = deduped.into_values().collect::<Vec<_>>();
167 items.sort_by(|a, b| {
168 ranking_score(b)
169 .total_cmp(&ranking_score(a))
170 .then_with(|| b.relevance.total_cmp(&a.relevance))
171 .then_with(|| estimated_tokens(a).cmp(&estimated_tokens(b)))
172 .then_with(|| a.id.cmp(&b.id))
173 });
174
175 let mut selected = Vec::new();
176 let mut total_tokens = 0usize;
177 let mut truncated = source_count > items.len();
178 let mut source_item_counts: HashMap<String, usize> = HashMap::new();
179 let mut source_token_counts: HashMap<String, usize> = HashMap::new();
180
181 for item in items {
182 if selected.len() >= self.budget.max_items {
183 truncated = true;
184 break;
185 }
186
187 let item_tokens = estimated_tokens(&item);
188 if total_tokens + item_tokens > self.budget.max_tokens {
189 truncated = true;
190 continue;
191 }
192
193 let source_key = source_policy_key(&item);
194 if let Some(max_items) = self.source_policy.max_items_per_source {
195 let count = source_item_counts.get(&source_key).copied().unwrap_or(0);
196 if count >= max_items {
197 truncated = true;
198 continue;
199 }
200 }
201 if let Some(max_tokens) = self.source_policy.max_tokens_per_source {
202 let source_tokens = source_token_counts.get(&source_key).copied().unwrap_or(0);
203 if source_tokens + item_tokens > max_tokens {
204 truncated = true;
205 continue;
206 }
207 }
208
209 total_tokens += item_tokens;
210 *source_item_counts.entry(source_key.clone()).or_insert(0) += 1;
211 *source_token_counts.entry(source_key).or_insert(0) += item_tokens;
212 selected.push(item);
213 }
214
215 ContextAssembly {
216 items: selected,
217 total_tokens,
218 truncated,
219 }
220 }
221}
222
223fn dedupe_key(item: &ContextItem) -> String {
224 item.source.clone().unwrap_or_else(|| item.id.clone())
225}
226
227fn source_policy_key(item: &ContextItem) -> String {
228 if let Some(provenance) = item.provenance() {
229 return format!("provenance:{provenance}");
230 }
231
232 if let Some(source) = &item.source {
233 let family = source
234 .split_once(':')
235 .map(|(family, _)| family)
236 .unwrap_or(source);
237 return format!("source:{family}");
238 }
239
240 format!("type:{:?}", item.context_type)
241}
242
243fn estimated_tokens(item: &ContextItem) -> usize {
244 if item.token_count > 0 {
245 item.token_count
246 } else {
247 item.content.split_whitespace().count().max(1)
248 }
249}
250
251fn ranking_score(item: &ContextItem) -> f32 {
252 item.relevance + item.priority() * 0.25 + item.trust() * 0.15 + item.freshness() * 0.10
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use crate::context::{ContextItem, ContextResult, ContextType};
259
260 fn result(provider: &str, items: Vec<ContextItem>) -> ContextResult {
261 let mut result = ContextResult::new(provider);
262 for item in items {
263 result.add_item(item);
264 }
265 result
266 }
267
268 #[test]
269 fn balanced_policy_matches_default_budget_and_source_caps() {
270 let policy = ContextAssemblyPolicy::balanced();
271
272 assert_eq!(policy.budget, ContextBudget::default());
273 assert_eq!(policy.source_policy, ContextSourcePolicy::default());
274 }
275
276 #[test]
277 fn compact_policy_applies_tighter_caps() {
278 let assembler = ContextAssembler::from_policy(ContextAssemblyPolicy::compact());
279 let assembly = assembler.assemble(&[result(
280 "test",
281 (0..10)
282 .map(|index| {
283 ContextItem::new(
284 format!("file-{index}"),
285 ContextType::Resource,
286 format!("file {index}"),
287 )
288 .with_source(format!("file://{index}"))
289 .with_relevance(1.0 - index as f32 * 0.01)
290 .with_token_count(1)
291 })
292 .collect(),
293 )]);
294
295 assert_eq!(assembly.items.len(), 4);
296 assert!(assembly.truncated);
297 }
298
299 #[test]
300 fn expansive_policy_allows_broader_context() {
301 let assembler = ContextAssembler::from_policy(ContextAssemblyPolicy::expansive());
302 let assembly = assembler.assemble(&[result(
303 "test",
304 (0..8)
305 .map(|index| {
306 ContextItem::new(
307 format!("file-{index}"),
308 ContextType::Resource,
309 format!("file {index}"),
310 )
311 .with_source(format!("file://{index}"))
312 .with_relevance(1.0 - index as f32 * 0.01)
313 .with_token_count(1)
314 })
315 .collect(),
316 )]);
317
318 assert_eq!(assembly.items.len(), 8);
319 assert!(!assembly.truncated);
320 }
321
322 #[test]
323 fn assemble_ranks_by_relevance() {
324 let assembler = ContextAssembler::new(ContextBudget {
325 max_items: 10,
326 max_tokens: 100,
327 });
328 let assembly = assembler.assemble(&[result(
329 "test",
330 vec![
331 ContextItem::new("low", ContextType::Resource, "low")
332 .with_relevance(0.1)
333 .with_token_count(1),
334 ContextItem::new("high", ContextType::Resource, "high")
335 .with_relevance(0.9)
336 .with_token_count(1),
337 ],
338 )]);
339
340 assert_eq!(assembly.items[0].id, "high");
341 assert_eq!(assembly.items[1].id, "low");
342 assert!(!assembly.truncated);
343 }
344
345 #[test]
346 fn assemble_uses_priority_trust_and_freshness_as_ranking_signals() {
347 let assembler = ContextAssembler::new(ContextBudget {
348 max_items: 10,
349 max_tokens: 100,
350 });
351 let assembly = assembler.assemble(&[result(
352 "test",
353 vec![
354 ContextItem::new("plain", ContextType::Resource, "plain")
355 .with_relevance(0.7)
356 .with_token_count(1),
357 ContextItem::new("boosted", ContextType::Resource, "boosted")
358 .with_relevance(0.6)
359 .with_priority(1.0)
360 .with_trust(1.0)
361 .with_freshness(1.0)
362 .with_token_count(1),
363 ],
364 )]);
365
366 assert_eq!(assembly.items[0].id, "boosted");
367 assert_eq!(assembly.items[1].id, "plain");
368 }
369
370 #[test]
371 fn assemble_dedupes_by_source_and_keeps_more_relevant_item() {
372 let assembler = ContextAssembler::with_default_budget();
373 let assembly = assembler.assemble(&[result(
374 "test",
375 vec![
376 ContextItem::new("old", ContextType::Resource, "old")
377 .with_source("file://auth.rs")
378 .with_relevance(0.2),
379 ContextItem::new("new", ContextType::Resource, "new")
380 .with_source("file://auth.rs")
381 .with_relevance(0.8),
382 ],
383 )]);
384
385 assert_eq!(assembly.items.len(), 1);
386 assert_eq!(assembly.items[0].id, "new");
387 assert!(assembly.truncated);
388 }
389
390 #[test]
391 fn assemble_dedupes_by_ranking_score() {
392 let assembler = ContextAssembler::with_default_budget();
393 let assembly = assembler.assemble(&[result(
394 "test",
395 vec![
396 ContextItem::new("plain", ContextType::Resource, "plain")
397 .with_source("file://auth.rs")
398 .with_relevance(0.7),
399 ContextItem::new("boosted", ContextType::Resource, "boosted")
400 .with_source("file://auth.rs")
401 .with_relevance(0.6)
402 .with_priority(1.0),
403 ],
404 )]);
405
406 assert_eq!(assembly.items.len(), 1);
407 assert_eq!(assembly.items[0].id, "boosted");
408 assert!(assembly.truncated);
409 }
410
411 #[test]
412 fn assemble_respects_item_and_token_budget() {
413 let assembler = ContextAssembler::new(ContextBudget {
414 max_items: 1,
415 max_tokens: 5,
416 });
417 let assembly = assembler.assemble(&[result(
418 "test",
419 vec![
420 ContextItem::new("a", ContextType::Resource, "one two")
421 .with_relevance(0.9)
422 .with_token_count(2),
423 ContextItem::new("b", ContextType::Resource, "three four")
424 .with_relevance(0.8)
425 .with_token_count(2),
426 ],
427 )]);
428
429 assert_eq!(assembly.items.len(), 1);
430 assert_eq!(assembly.total_tokens, 2);
431 assert!(assembly.truncated);
432 }
433
434 #[test]
435 fn assemble_caps_items_per_source() {
436 let assembler = ContextAssembler::new(ContextBudget {
437 max_items: 10,
438 max_tokens: 100,
439 })
440 .with_source_policy(ContextSourcePolicy {
441 max_items_per_source: Some(2),
442 max_tokens_per_source: None,
443 });
444 let assembly = assembler.assemble(&[result(
445 "test",
446 vec![
447 ContextItem::new("a", ContextType::Resource, "a")
448 .with_source("file://a")
449 .with_relevance(0.9)
450 .with_token_count(1),
451 ContextItem::new("b", ContextType::Resource, "b")
452 .with_source("file://b")
453 .with_relevance(0.8)
454 .with_token_count(1),
455 ContextItem::new("c", ContextType::Resource, "c")
456 .with_source("file://c")
457 .with_relevance(0.7)
458 .with_token_count(1),
459 ],
460 )]);
461
462 assert_eq!(assembly.items.len(), 2);
463 assert_eq!(assembly.items[0].id, "a");
464 assert_eq!(assembly.items[1].id, "b");
465 assert!(assembly.truncated);
466 }
467
468 #[test]
469 fn assemble_caps_tokens_per_source_but_keeps_other_sources() {
470 let assembler = ContextAssembler::new(ContextBudget {
471 max_items: 10,
472 max_tokens: 100,
473 })
474 .with_source_policy(ContextSourcePolicy {
475 max_items_per_source: None,
476 max_tokens_per_source: Some(3),
477 });
478 let assembly = assembler.assemble(&[result(
479 "test",
480 vec![
481 ContextItem::new("file-a", ContextType::Resource, "file a")
482 .with_source("file://a")
483 .with_relevance(0.9)
484 .with_token_count(2),
485 ContextItem::new("file-b", ContextType::Resource, "file b")
486 .with_source("file://b")
487 .with_relevance(0.8)
488 .with_token_count(2),
489 ContextItem::new("memory", ContextType::Memory, "memory")
490 .with_source("memory://a")
491 .with_relevance(0.7)
492 .with_token_count(2),
493 ],
494 )]);
495
496 assert_eq!(
497 assembly
498 .items
499 .iter()
500 .map(|item| item.id.as_str())
501 .collect::<Vec<_>>(),
502 vec!["file-a", "memory"]
503 );
504 assert_eq!(assembly.total_tokens, 4);
505 assert!(assembly.truncated);
506 }
507}