1use crate::Chunk;
7use std::collections::{HashMap, HashSet};
8use uuid::Uuid;
9
10#[derive(Debug, Clone)]
12pub struct ScoredChunk {
13 pub chunk: Chunk,
14 pub score: f32,
15 pub document_path: String,
16}
17
18#[derive(Debug, Clone)]
20pub struct TokenBudget {
21 pub total_tokens: usize,
23 pub reserved_system: usize,
25 pub reserved_query: usize,
27 pub reserved_response: usize,
29 pub per_doc_overhead: usize,
31}
32
33impl TokenBudget {
34 pub fn new(total_tokens: usize) -> Self {
35 let reserved_system = 200;
36 let reserved_query = 100;
37 let reserved_response = 500;
38 let per_doc_overhead = 10;
39
40 Self {
41 total_tokens,
42 reserved_system,
43 reserved_query,
44 reserved_response,
45 per_doc_overhead,
46 }
47 }
48
49 pub fn available(&self) -> usize {
51 self.total_tokens
52 .saturating_sub(self.reserved_system)
53 .saturating_sub(self.reserved_query)
54 .saturating_sub(self.reserved_response)
55 }
56}
57
58impl Default for TokenBudget {
59 fn default() -> Self {
60 Self::new(2000)
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct ContextChunk {
67 pub chunk_id: Uuid,
68 pub document_path: String,
69 pub text: String,
70 pub score: f32,
71 pub sequence: u32,
72}
73
74#[derive(Debug, Clone)]
76pub struct ContextMetadata {
77 pub query_hash: [u8; 32],
79 pub state_root: [u8; 32],
81}
82
83#[derive(Debug, Clone)]
85pub struct AssembledContext {
86 pub chunks: Vec<ContextChunk>,
87 pub total_tokens: usize,
88 pub truncated: bool,
89 pub metadata: ContextMetadata,
90}
91
92pub struct ContextAssembler {
94 budget: TokenBudget,
95}
96
97impl ContextAssembler {
98 pub fn new(budget: TokenBudget) -> Self {
99 Self { budget }
100 }
101
102 pub fn with_budget(total_tokens: usize) -> Self {
104 Self {
105 budget: TokenBudget::new(total_tokens),
106 }
107 }
108
109 pub fn assemble(
118 &self,
119 chunks: Vec<ScoredChunk>,
120 query: &str,
121 state_root: [u8; 32],
122 ) -> AssembledContext {
123 let deduped = Self::deduplicate(chunks);
125
126 let sorted = Self::deterministic_sort(deduped);
128
129 let groups = Self::group_by_document(&sorted);
131
132 let available = self.budget.available();
134 let (packed, total_tokens, truncated) =
135 Self::greedy_pack(groups, available, self.budget.per_doc_overhead);
136
137 let context_chunks: Vec<ContextChunk> = packed
139 .into_iter()
140 .map(|sc| ContextChunk {
141 chunk_id: sc.chunk.id,
142 document_path: sc.document_path.clone(),
143 text: sc.chunk.text.clone(),
144 score: sc.score,
145 sequence: sc.chunk.sequence,
146 })
147 .collect();
148
149 AssembledContext {
150 chunks: context_chunks,
151 total_tokens,
152 truncated,
153 metadata: ContextMetadata {
154 query_hash: *blake3::hash(query.as_bytes()).as_bytes(),
155 state_root,
156 },
157 }
158 }
159
160 pub fn format(context: &AssembledContext) -> String {
164 let mut formatted = String::new();
165 let mut current_doc: Option<&str> = None;
166
167 for chunk in &context.chunks {
168 if current_doc != Some(&chunk.document_path) {
169 if current_doc.is_some() {
170 formatted.push('\n');
171 }
172 use std::fmt::Write;
173 let _ = writeln!(formatted, "[DOC: {}]", chunk.document_path);
174 current_doc = Some(&chunk.document_path);
175 }
176 formatted.push_str(&chunk.text);
177 formatted.push('\n');
178 }
179
180 formatted
181 }
182
183 fn deduplicate(chunks: Vec<ScoredChunk>) -> Vec<ScoredChunk> {
185 let mut seen_ids: HashSet<Uuid> = HashSet::new();
186 let mut seen_text_hashes: HashSet<[u8; 32]> = HashSet::new();
187 let mut result = Vec::new();
188
189 for sc in chunks {
190 if !seen_ids.insert(sc.chunk.id) {
192 continue;
193 }
194
195 let text_hash = *blake3::hash(sc.chunk.text.as_bytes()).as_bytes();
197 if !seen_text_hashes.insert(text_hash) {
198 continue;
199 }
200
201 result.push(sc);
202 }
203
204 result
205 }
206
207 fn deterministic_sort(mut chunks: Vec<ScoredChunk>) -> Vec<ScoredChunk> {
211 chunks.sort_by(|a, b| {
212 b.score
214 .partial_cmp(&a.score)
215 .unwrap_or(std::cmp::Ordering::Equal)
216 .then_with(|| a.document_path.cmp(&b.document_path))
218 .then_with(|| a.chunk.sequence.cmp(&b.chunk.sequence))
220 .then_with(|| a.chunk.byte_offset.cmp(&b.chunk.byte_offset))
222 });
223 chunks
224 }
225
226 fn group_by_document(chunks: &[ScoredChunk]) -> Vec<Vec<&ScoredChunk>> {
228 let mut groups: HashMap<&str, Vec<&ScoredChunk>> = HashMap::new();
229 let mut max_scores: HashMap<&str, f32> = HashMap::new();
230
231 for sc in chunks {
232 let path = sc.document_path.as_str();
233 groups.entry(path).or_default().push(sc);
234 let entry = max_scores.entry(path).or_insert(0.0);
235 if sc.score > *entry {
236 *entry = sc.score;
237 }
238 }
239
240 let mut group_list: Vec<(&str, Vec<&ScoredChunk>)> = groups.into_iter().collect();
242 group_list.sort_by(|a, b| {
243 let score_a = max_scores.get(a.0).copied().unwrap_or(0.0);
244 let score_b = max_scores.get(b.0).copied().unwrap_or(0.0);
245 score_b
246 .partial_cmp(&score_a)
247 .unwrap_or(std::cmp::Ordering::Equal)
248 .then_with(|| a.0.cmp(b.0))
249 });
250
251 group_list
253 .into_iter()
254 .map(|(_, mut chunks)| {
255 chunks.sort_by_key(|sc| (sc.chunk.sequence, sc.chunk.byte_offset));
256 chunks
257 })
258 .collect()
259 }
260
261 fn greedy_pack(
265 groups: Vec<Vec<&ScoredChunk>>,
266 budget: usize,
267 per_doc_overhead: usize,
268 ) -> (Vec<ScoredChunk>, usize, bool) {
269 let mut tokens_used = 0usize;
270 let mut packed = Vec::new();
271 let mut truncated = false;
272 let mut seen_docs: HashSet<&str> = HashSet::new();
273
274 for group in groups {
275 if tokens_used >= budget {
276 truncated = true;
277 break;
278 }
279
280 for sc in group {
281 let doc_overhead = if seen_docs.contains(sc.document_path.as_str()) {
283 0
284 } else {
285 per_doc_overhead
286 };
287
288 let chunk_tokens = approx_tokens(&sc.chunk.text);
289 let total_needed = chunk_tokens + doc_overhead;
290
291 if tokens_used + total_needed <= budget {
292 seen_docs.insert(&sc.document_path);
293 packed.push(sc.clone());
294 tokens_used += total_needed;
295 } else {
296 truncated = true;
297 break;
298 }
299 }
300 }
301
302 (packed, tokens_used, truncated)
303 }
304}
305
306fn approx_tokens(text: &str) -> usize {
310 text.len().div_ceil(4)
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 fn make_chunk(doc_id: Uuid, text: &str, seq: u32) -> Chunk {
318 Chunk::new(doc_id, text, u64::from(seq) * 100, seq)
319 }
320
321 fn make_scored(chunk: Chunk, score: f32, path: &str) -> ScoredChunk {
322 ScoredChunk {
323 chunk,
324 score,
325 document_path: path.to_string(),
326 }
327 }
328
329 #[test]
330 fn test_deterministic_assembly() {
331 let doc_id = Uuid::from_bytes([1u8; 16]);
332 let assembler = ContextAssembler::with_budget(4000);
333
334 let chunks = vec![
335 make_scored(make_chunk(doc_id, "Chunk A", 0), 0.9, "doc1.md"),
336 make_scored(make_chunk(doc_id, "Chunk B", 1), 0.8, "doc1.md"),
337 ];
338
339 let ctx1 = assembler.assemble(chunks.clone(), "query", [0u8; 32]);
340 let ctx2 = assembler.assemble(chunks, "query", [0u8; 32]);
341
342 let fmt1 = ContextAssembler::format(&ctx1);
343 let fmt2 = ContextAssembler::format(&ctx2);
344
345 assert_eq!(
346 fmt1, fmt2,
347 "Identical inputs must produce byte-identical context"
348 );
349 }
350
351 #[test]
352 fn test_deduplication() {
353 let doc_id = Uuid::from_bytes([1u8; 16]);
354 let assembler = ContextAssembler::with_budget(4000);
355
356 let chunk = make_chunk(doc_id, "Same text", 0);
357 let chunks = vec![
358 make_scored(chunk.clone(), 0.9, "doc.md"),
359 make_scored(chunk.clone(), 0.8, "doc.md"), ];
361
362 let ctx = assembler.assemble(chunks, "query", [0u8; 32]);
363 assert_eq!(ctx.chunks.len(), 1);
364 }
365
366 #[test]
367 fn test_budget_compliance() {
368 let doc_id = Uuid::from_bytes([1u8; 16]);
369 let assembler = ContextAssembler::with_budget(900);
371 let available = assembler.budget.available();
372
373 let chunks: Vec<ScoredChunk> = (0..20)
374 .map(|i| {
375 make_scored(
376 make_chunk(doc_id, &"x".repeat(200), i),
377 1.0 - (i as f32 * 0.01),
378 "doc.md",
379 )
380 })
381 .collect();
382
383 let ctx = assembler.assemble(chunks, "query", [0u8; 32]);
384 assert!(
385 ctx.total_tokens <= available,
386 "Context must not exceed budget"
387 );
388 }
389
390 #[test]
391 fn test_document_grouping() {
392 let doc_a = Uuid::from_bytes([1u8; 16]);
393 let doc_b = Uuid::from_bytes([2u8; 16]);
394 let assembler = ContextAssembler::with_budget(4000);
395
396 let chunks = vec![
397 make_scored(make_chunk(doc_a, "A chunk 0", 0), 0.9, "a.md"),
398 make_scored(make_chunk(doc_b, "B chunk 0", 0), 0.85, "b.md"),
399 make_scored(make_chunk(doc_a, "A chunk 1", 1), 0.8, "a.md"),
400 ];
401
402 let ctx = assembler.assemble(chunks, "query", [0u8; 32]);
403 let formatted = ContextAssembler::format(&ctx);
404
405 assert!(formatted.contains("[DOC: a.md]"));
407 assert!(formatted.contains("[DOC: b.md]"));
408 }
409
410 #[test]
411 fn test_format_markers() {
412 let doc_id = Uuid::from_bytes([1u8; 16]);
413 let assembler = ContextAssembler::with_budget(4000);
414
415 let chunks = vec![make_scored(
416 make_chunk(doc_id, "Hello world", 0),
417 0.9,
418 "test.md",
419 )];
420
421 let ctx = assembler.assemble(chunks, "query", [0u8; 32]);
422 let formatted = ContextAssembler::format(&ctx);
423
424 assert!(formatted.starts_with("[DOC: test.md]\n"));
425 assert!(formatted.contains("Hello world"));
426 }
427
428 #[test]
429 fn test_sort_tiebreaker() {
430 let assembler = ContextAssembler::with_budget(4000);
431
432 let chunks = vec![
434 make_scored(
435 Chunk::new(Uuid::from_bytes([1u8; 16]), "Chunk Z", 0, 0),
436 0.9,
437 "z.md",
438 ),
439 make_scored(
440 Chunk::new(Uuid::from_bytes([2u8; 16]), "Chunk A", 0, 0),
441 0.9,
442 "a.md",
443 ),
444 ];
445
446 let ctx = assembler.assemble(chunks, "query", [0u8; 32]);
447 let formatted = ContextAssembler::format(&ctx);
448
449 let pos_a = formatted.find("a.md");
451 let pos_z = formatted.find("z.md");
452 assert!(
453 pos_a.is_some() && pos_z.is_some(),
454 "Paths not found in formatted"
455 );
456 assert!(pos_a.unwrap() < pos_z.unwrap());
457 }
458
459 #[test]
460 fn test_empty_chunks() {
461 let assembler = ContextAssembler::with_budget(4000);
462 let ctx = assembler.assemble(vec![], "query", [0u8; 32]);
463 assert!(ctx.chunks.is_empty());
464 assert_eq!(ctx.total_tokens, 0);
465 assert!(!ctx.truncated);
466 }
467}