Skip to main content

a3s_code_core/context/
assembler.rs

1//! Budgeted assembly for context items.
2
3use super::{ContextItem, ContextResult};
4use std::collections::HashMap;
5
6/// Budget limits for prompt-bound context.
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub struct ContextBudget {
9    pub max_items: usize,
10    pub max_tokens: usize,
11}
12
13impl Default for ContextBudget {
14    fn default() -> Self {
15        Self {
16            max_items: 12,
17            max_tokens: 4_000,
18        }
19    }
20}
21
22/// Source-aware limits that prevent one context source from dominating the prompt.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct ContextSourcePolicy {
25    pub max_items_per_source: Option<usize>,
26    pub max_tokens_per_source: Option<usize>,
27}
28
29impl Default for ContextSourcePolicy {
30    fn default() -> Self {
31        Self {
32            max_items_per_source: Some(6),
33            max_tokens_per_source: Some(2_500),
34        }
35    }
36}
37
38/// Named budget policy for context assembly.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub struct ContextAssemblyPolicy {
41    pub budget: ContextBudget,
42    pub source_policy: ContextSourcePolicy,
43}
44
45impl ContextAssemblyPolicy {
46    pub fn balanced() -> Self {
47        Self {
48            budget: ContextBudget {
49                max_items: 12,
50                max_tokens: 4_000,
51            },
52            source_policy: ContextSourcePolicy {
53                max_items_per_source: Some(6),
54                max_tokens_per_source: Some(2_500),
55            },
56        }
57    }
58
59    pub fn compact() -> Self {
60        Self {
61            budget: ContextBudget {
62                max_items: 8,
63                max_tokens: 2_500,
64            },
65            source_policy: ContextSourcePolicy {
66                max_items_per_source: Some(4),
67                max_tokens_per_source: Some(1_200),
68            },
69        }
70    }
71
72    pub fn expansive() -> Self {
73        Self {
74            budget: ContextBudget {
75                max_items: 20,
76                max_tokens: 8_000,
77            },
78            source_policy: ContextSourcePolicy {
79                max_items_per_source: Some(8),
80                max_tokens_per_source: Some(3_500),
81            },
82        }
83    }
84}
85
86impl Default for ContextAssemblyPolicy {
87    fn default() -> Self {
88        Self::balanced()
89    }
90}
91
92/// Final context selected for prompt rendering.
93#[derive(Debug, Clone, Default)]
94pub struct ContextAssembly {
95    pub items: Vec<ContextItem>,
96    pub total_tokens: usize,
97    pub truncated: bool,
98}
99
100impl ContextAssembly {
101    pub fn to_xml(&self) -> String {
102        self.items
103            .iter()
104            .map(ContextItem::to_xml)
105            .collect::<Vec<_>>()
106            .join("\n\n")
107    }
108
109    pub fn is_empty(&self) -> bool {
110        self.items.is_empty()
111    }
112}
113
114/// Assembles raw provider results into a ranked, deduplicated, budgeted set.
115#[derive(Debug, Clone)]
116pub struct ContextAssembler {
117    budget: ContextBudget,
118    source_policy: ContextSourcePolicy,
119}
120
121impl ContextAssembler {
122    pub fn new(budget: ContextBudget) -> Self {
123        Self::from_policy(ContextAssemblyPolicy {
124            budget,
125            source_policy: ContextSourcePolicy::default(),
126        })
127    }
128
129    pub fn from_policy(policy: ContextAssemblyPolicy) -> Self {
130        Self {
131            budget: policy.budget,
132            source_policy: policy.source_policy,
133        }
134    }
135
136    pub fn with_source_policy(mut self, policy: ContextSourcePolicy) -> Self {
137        self.source_policy = policy;
138        self
139    }
140
141    pub fn with_default_budget() -> Self {
142        Self::from_policy(ContextAssemblyPolicy::balanced())
143    }
144
145    pub fn assemble(&self, results: &[ContextResult]) -> ContextAssembly {
146        let mut deduped: HashMap<String, ContextItem> = HashMap::new();
147        let mut source_count = 0usize;
148
149        for result in results {
150            for item in &result.items {
151                source_count += 1;
152                let key = dedupe_key(item);
153                match deduped.get(&key) {
154                    Some(existing)
155                        if ranking_score(existing)
156                            .total_cmp(&ranking_score(item))
157                            .then_with(|| existing.relevance.total_cmp(&item.relevance))
158                            .is_ge() => {}
159                    _ => {
160                        deduped.insert(key, item.clone());
161                    }
162                }
163            }
164        }
165
166        let mut items = deduped.into_values().collect::<Vec<_>>();
167        items.sort_by(|a, b| {
168            ranking_score(b)
169                .total_cmp(&ranking_score(a))
170                .then_with(|| b.relevance.total_cmp(&a.relevance))
171                .then_with(|| estimated_tokens(a).cmp(&estimated_tokens(b)))
172                .then_with(|| a.id.cmp(&b.id))
173        });
174
175        let mut selected = Vec::new();
176        let mut total_tokens = 0usize;
177        let mut truncated = source_count > items.len();
178        let mut source_item_counts: HashMap<String, usize> = HashMap::new();
179        let mut source_token_counts: HashMap<String, usize> = HashMap::new();
180
181        for item in items {
182            if selected.len() >= self.budget.max_items {
183                truncated = true;
184                break;
185            }
186
187            let item_tokens = estimated_tokens(&item);
188            if total_tokens + item_tokens > self.budget.max_tokens {
189                truncated = true;
190                continue;
191            }
192
193            let source_key = source_policy_key(&item);
194            if let Some(max_items) = self.source_policy.max_items_per_source {
195                let count = source_item_counts.get(&source_key).copied().unwrap_or(0);
196                if count >= max_items {
197                    truncated = true;
198                    continue;
199                }
200            }
201            if let Some(max_tokens) = self.source_policy.max_tokens_per_source {
202                let source_tokens = source_token_counts.get(&source_key).copied().unwrap_or(0);
203                if source_tokens + item_tokens > max_tokens {
204                    truncated = true;
205                    continue;
206                }
207            }
208
209            total_tokens += item_tokens;
210            *source_item_counts.entry(source_key.clone()).or_insert(0) += 1;
211            *source_token_counts.entry(source_key).or_insert(0) += item_tokens;
212            selected.push(item);
213        }
214
215        ContextAssembly {
216            items: selected,
217            total_tokens,
218            truncated,
219        }
220    }
221}
222
223fn dedupe_key(item: &ContextItem) -> String {
224    item.source.clone().unwrap_or_else(|| item.id.clone())
225}
226
227fn source_policy_key(item: &ContextItem) -> String {
228    if let Some(provenance) = item.provenance() {
229        return format!("provenance:{provenance}");
230    }
231
232    if let Some(source) = &item.source {
233        let family = source
234            .split_once(':')
235            .map(|(family, _)| family)
236            .unwrap_or(source);
237        return format!("source:{family}");
238    }
239
240    format!("type:{:?}", item.context_type)
241}
242
243fn estimated_tokens(item: &ContextItem) -> usize {
244    if item.token_count > 0 {
245        item.token_count
246    } else {
247        item.content.split_whitespace().count().max(1)
248    }
249}
250
251fn ranking_score(item: &ContextItem) -> f32 {
252    item.relevance + item.priority() * 0.25 + item.trust() * 0.15 + item.freshness() * 0.10
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::context::{ContextItem, ContextResult, ContextType};
259
260    fn result(provider: &str, items: Vec<ContextItem>) -> ContextResult {
261        let mut result = ContextResult::new(provider);
262        for item in items {
263            result.add_item(item);
264        }
265        result
266    }
267
268    #[test]
269    fn balanced_policy_matches_default_budget_and_source_caps() {
270        let policy = ContextAssemblyPolicy::balanced();
271
272        assert_eq!(policy.budget, ContextBudget::default());
273        assert_eq!(policy.source_policy, ContextSourcePolicy::default());
274    }
275
276    #[test]
277    fn compact_policy_applies_tighter_caps() {
278        let assembler = ContextAssembler::from_policy(ContextAssemblyPolicy::compact());
279        let assembly = assembler.assemble(&[result(
280            "test",
281            (0..10)
282                .map(|index| {
283                    ContextItem::new(
284                        format!("file-{index}"),
285                        ContextType::Resource,
286                        format!("file {index}"),
287                    )
288                    .with_source(format!("file://{index}"))
289                    .with_relevance(1.0 - index as f32 * 0.01)
290                    .with_token_count(1)
291                })
292                .collect(),
293        )]);
294
295        assert_eq!(assembly.items.len(), 4);
296        assert!(assembly.truncated);
297    }
298
299    #[test]
300    fn expansive_policy_allows_broader_context() {
301        let assembler = ContextAssembler::from_policy(ContextAssemblyPolicy::expansive());
302        let assembly = assembler.assemble(&[result(
303            "test",
304            (0..8)
305                .map(|index| {
306                    ContextItem::new(
307                        format!("file-{index}"),
308                        ContextType::Resource,
309                        format!("file {index}"),
310                    )
311                    .with_source(format!("file://{index}"))
312                    .with_relevance(1.0 - index as f32 * 0.01)
313                    .with_token_count(1)
314                })
315                .collect(),
316        )]);
317
318        assert_eq!(assembly.items.len(), 8);
319        assert!(!assembly.truncated);
320    }
321
322    #[test]
323    fn assemble_ranks_by_relevance() {
324        let assembler = ContextAssembler::new(ContextBudget {
325            max_items: 10,
326            max_tokens: 100,
327        });
328        let assembly = assembler.assemble(&[result(
329            "test",
330            vec![
331                ContextItem::new("low", ContextType::Resource, "low")
332                    .with_relevance(0.1)
333                    .with_token_count(1),
334                ContextItem::new("high", ContextType::Resource, "high")
335                    .with_relevance(0.9)
336                    .with_token_count(1),
337            ],
338        )]);
339
340        assert_eq!(assembly.items[0].id, "high");
341        assert_eq!(assembly.items[1].id, "low");
342        assert!(!assembly.truncated);
343    }
344
345    #[test]
346    fn assemble_uses_priority_trust_and_freshness_as_ranking_signals() {
347        let assembler = ContextAssembler::new(ContextBudget {
348            max_items: 10,
349            max_tokens: 100,
350        });
351        let assembly = assembler.assemble(&[result(
352            "test",
353            vec![
354                ContextItem::new("plain", ContextType::Resource, "plain")
355                    .with_relevance(0.7)
356                    .with_token_count(1),
357                ContextItem::new("boosted", ContextType::Resource, "boosted")
358                    .with_relevance(0.6)
359                    .with_priority(1.0)
360                    .with_trust(1.0)
361                    .with_freshness(1.0)
362                    .with_token_count(1),
363            ],
364        )]);
365
366        assert_eq!(assembly.items[0].id, "boosted");
367        assert_eq!(assembly.items[1].id, "plain");
368    }
369
370    #[test]
371    fn assemble_dedupes_by_source_and_keeps_more_relevant_item() {
372        let assembler = ContextAssembler::with_default_budget();
373        let assembly = assembler.assemble(&[result(
374            "test",
375            vec![
376                ContextItem::new("old", ContextType::Resource, "old")
377                    .with_source("file://auth.rs")
378                    .with_relevance(0.2),
379                ContextItem::new("new", ContextType::Resource, "new")
380                    .with_source("file://auth.rs")
381                    .with_relevance(0.8),
382            ],
383        )]);
384
385        assert_eq!(assembly.items.len(), 1);
386        assert_eq!(assembly.items[0].id, "new");
387        assert!(assembly.truncated);
388    }
389
390    #[test]
391    fn assemble_dedupes_by_ranking_score() {
392        let assembler = ContextAssembler::with_default_budget();
393        let assembly = assembler.assemble(&[result(
394            "test",
395            vec![
396                ContextItem::new("plain", ContextType::Resource, "plain")
397                    .with_source("file://auth.rs")
398                    .with_relevance(0.7),
399                ContextItem::new("boosted", ContextType::Resource, "boosted")
400                    .with_source("file://auth.rs")
401                    .with_relevance(0.6)
402                    .with_priority(1.0),
403            ],
404        )]);
405
406        assert_eq!(assembly.items.len(), 1);
407        assert_eq!(assembly.items[0].id, "boosted");
408        assert!(assembly.truncated);
409    }
410
411    #[test]
412    fn assemble_respects_item_and_token_budget() {
413        let assembler = ContextAssembler::new(ContextBudget {
414            max_items: 1,
415            max_tokens: 5,
416        });
417        let assembly = assembler.assemble(&[result(
418            "test",
419            vec![
420                ContextItem::new("a", ContextType::Resource, "one two")
421                    .with_relevance(0.9)
422                    .with_token_count(2),
423                ContextItem::new("b", ContextType::Resource, "three four")
424                    .with_relevance(0.8)
425                    .with_token_count(2),
426            ],
427        )]);
428
429        assert_eq!(assembly.items.len(), 1);
430        assert_eq!(assembly.total_tokens, 2);
431        assert!(assembly.truncated);
432    }
433
434    #[test]
435    fn assemble_caps_items_per_source() {
436        let assembler = ContextAssembler::new(ContextBudget {
437            max_items: 10,
438            max_tokens: 100,
439        })
440        .with_source_policy(ContextSourcePolicy {
441            max_items_per_source: Some(2),
442            max_tokens_per_source: None,
443        });
444        let assembly = assembler.assemble(&[result(
445            "test",
446            vec![
447                ContextItem::new("a", ContextType::Resource, "a")
448                    .with_source("file://a")
449                    .with_relevance(0.9)
450                    .with_token_count(1),
451                ContextItem::new("b", ContextType::Resource, "b")
452                    .with_source("file://b")
453                    .with_relevance(0.8)
454                    .with_token_count(1),
455                ContextItem::new("c", ContextType::Resource, "c")
456                    .with_source("file://c")
457                    .with_relevance(0.7)
458                    .with_token_count(1),
459            ],
460        )]);
461
462        assert_eq!(assembly.items.len(), 2);
463        assert_eq!(assembly.items[0].id, "a");
464        assert_eq!(assembly.items[1].id, "b");
465        assert!(assembly.truncated);
466    }
467
468    #[test]
469    fn assemble_caps_tokens_per_source_but_keeps_other_sources() {
470        let assembler = ContextAssembler::new(ContextBudget {
471            max_items: 10,
472            max_tokens: 100,
473        })
474        .with_source_policy(ContextSourcePolicy {
475            max_items_per_source: None,
476            max_tokens_per_source: Some(3),
477        });
478        let assembly = assembler.assemble(&[result(
479            "test",
480            vec![
481                ContextItem::new("file-a", ContextType::Resource, "file a")
482                    .with_source("file://a")
483                    .with_relevance(0.9)
484                    .with_token_count(2),
485                ContextItem::new("file-b", ContextType::Resource, "file b")
486                    .with_source("file://b")
487                    .with_relevance(0.8)
488                    .with_token_count(2),
489                ContextItem::new("memory", ContextType::Memory, "memory")
490                    .with_source("memory://a")
491                    .with_relevance(0.7)
492                    .with_token_count(2),
493            ],
494        )]);
495
496        assert_eq!(
497            assembly
498                .items
499                .iter()
500                .map(|item| item.id.as_str())
501                .collect::<Vec<_>>(),
502            vec!["file-a", "memory"]
503        );
504        assert_eq!(assembly.total_tokens, 4);
505        assert!(assembly.truncated);
506    }
507}