1use std::collections::HashMap;
2
3use ailake_vec::cosine_distance;
4
5#[derive(Debug, Clone)]
6pub struct Chunk {
7 pub document_id: String,
8 pub chunk_index: u32,
9 pub chunk_text: String,
10 pub document_title: Option<String>,
11 pub section_path: Option<String>,
12 pub source_uri: Option<String>,
13 pub distance: f32,
15 pub embedding: Option<Vec<f32>>,
17}
18
19#[derive(Debug, Clone)]
20pub struct ContextAssemblerConfig {
21 pub max_tokens: usize,
23 pub dedup_threshold: f32,
25 pub group_by_document: bool,
26 pub max_chunks_per_document: usize,
27}
28
29impl Default for ContextAssemblerConfig {
30 fn default() -> Self {
31 Self {
32 max_tokens: 4096,
33 dedup_threshold: 0.05,
34 group_by_document: true,
35 max_chunks_per_document: 10,
36 }
37 }
38}
39
40pub struct AssembledContext {
41 pub text: String,
43 pub chunk_count: usize,
44 pub token_estimate: usize,
45}
46
47pub struct ContextAssembler {
48 config: ContextAssemblerConfig,
49}
50
51impl ContextAssembler {
52 pub fn new(config: ContextAssemblerConfig) -> Self {
53 Self { config }
54 }
55
56 pub fn assemble_chunks(&self, mut chunks: Vec<Chunk>) -> AssembledContext {
63 chunks.sort_by(|a, b| {
64 a.distance
65 .partial_cmp(&b.distance)
66 .unwrap_or(std::cmp::Ordering::Equal)
67 });
68
69 let selected = self.dedup(chunks);
70 let groups = self.group(selected);
71 self.render(groups)
72 }
73
74 pub fn assemble_texts(&self, chunks: &[String]) -> AssembledContext {
77 let char_budget = self.config.max_tokens * 4;
78 let mut text = String::new();
79 let mut count = 0;
80 for chunk in chunks {
81 if text.len() + chunk.len() + 2 > char_budget {
82 break;
83 }
84 if !text.is_empty() {
85 text.push_str("\n\n");
86 }
87 text.push_str(chunk);
88 count += 1;
89 }
90 AssembledContext {
91 token_estimate: text.len() / 4,
92 chunk_count: count,
93 text,
94 }
95 }
96
97 fn dedup(&self, chunks: Vec<Chunk>) -> Vec<Chunk> {
98 let mut selected: Vec<Chunk> = Vec::new();
99 'next: for chunk in chunks {
100 if let Some(emb) = &chunk.embedding {
101 for sel in &selected {
102 if let Some(sel_emb) = &sel.embedding {
103 if cosine_distance(emb, sel_emb) < self.config.dedup_threshold {
104 continue 'next;
105 }
106 }
107 }
108 }
109 selected.push(chunk);
110 }
111 selected
112 }
113
114 fn group(&self, chunks: Vec<Chunk>) -> Vec<(String, Vec<Chunk>)> {
115 let mut map: HashMap<String, Vec<Chunk>> = HashMap::new();
116 let mut doc_order: Vec<String> = Vec::new();
117 for chunk in chunks {
118 if !map.contains_key(&chunk.document_id) {
119 doc_order.push(chunk.document_id.clone());
120 }
121 map.entry(chunk.document_id.clone())
122 .or_default()
123 .push(chunk);
124 }
125 if self.config.group_by_document {
126 for group in map.values_mut() {
127 group.sort_by_key(|c| c.chunk_index);
128 }
129 }
130 doc_order
131 .into_iter()
132 .map(|id| (id.clone(), map.remove(&id).unwrap_or_default()))
133 .collect()
134 }
135
136 fn render(&self, groups: Vec<(String, Vec<Chunk>)>) -> AssembledContext {
137 let char_budget = self.config.max_tokens * 4;
138 let mut xml = String::from("<context>\n");
139 let mut chunk_count = 0usize;
140
141 'outer: for (doc_id, doc_chunks) in &groups {
142 let title = doc_chunks
143 .first()
144 .and_then(|c| c.document_title.as_deref())
145 .unwrap_or("");
146 let source = doc_chunks
147 .first()
148 .and_then(|c| c.source_uri.as_deref())
149 .unwrap_or("");
150 xml.push_str(&format!(
151 " <document id=\"{}\" title=\"{}\" source=\"{}\">\n",
152 escape_xml(doc_id),
153 escape_xml(title),
154 escape_xml(source)
155 ));
156
157 for chunk in doc_chunks.iter().take(self.config.max_chunks_per_document) {
158 if xml.len() >= char_budget {
159 break 'outer;
160 }
161 let section = chunk.section_path.as_deref().unwrap_or("");
162 xml.push_str(&format!(
163 " <chunk index=\"{}\" section=\"{}\">\n <text>{}</text>\n </chunk>\n",
164 chunk.chunk_index,
165 escape_xml(section),
166 escape_xml(&chunk.chunk_text)
167 ));
168 chunk_count += 1;
169 }
170
171 xml.push_str(" </document>\n");
172 }
173
174 xml.push_str("</context>");
175 AssembledContext {
176 token_estimate: xml.len() / 4,
177 chunk_count,
178 text: xml,
179 }
180 }
181}
182
183fn escape_xml(s: &str) -> String {
184 s.replace('&', "&")
185 .replace('<', "<")
186 .replace('>', ">")
187 .replace('"', """)
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 fn make_chunk(doc: &str, idx: u32, text: &str, dist: f32) -> Chunk {
195 Chunk {
196 document_id: doc.to_string(),
197 chunk_index: idx,
198 chunk_text: text.to_string(),
199 document_title: Some(format!("Doc {doc}")),
200 section_path: Some("Introduction".into()),
201 source_uri: Some(format!("s3://lake/{doc}.parquet")),
202 distance: dist,
203 embedding: None,
204 }
205 }
206
207 #[test]
208 fn produces_valid_xml() {
209 let ca = ContextAssembler::new(ContextAssemblerConfig::default());
210 let chunks = vec![
211 make_chunk("doc-1", 0, "First chunk.", 0.1),
212 make_chunk("doc-1", 1, "Second chunk.", 0.15),
213 make_chunk("doc-2", 0, "Doc 2 chunk.", 0.2),
214 ];
215 let ctx = ca.assemble_chunks(chunks);
216 assert!(ctx.text.starts_with("<context>"));
217 assert!(ctx.text.ends_with("</context>"));
218 assert!(ctx.text.contains("doc-1"));
219 assert!(ctx.text.contains("doc-2"));
220 assert_eq!(ctx.chunk_count, 3);
221 }
222
223 #[test]
224 fn dedup_removes_near_identical_embeddings() {
225 let cfg = ContextAssemblerConfig {
226 dedup_threshold: 0.01,
227 ..Default::default()
228 };
229 let ca = ContextAssembler::new(cfg);
230 let emb = vec![1.0f32, 0.0, 0.0];
231 let mut c1 = make_chunk("doc-1", 0, "Text A.", 0.1);
232 c1.embedding = Some(emb.clone());
233 let mut c2 = make_chunk("doc-1", 1, "Text B.", 0.2);
234 c2.embedding = Some(emb.clone());
235 let ctx = ca.assemble_chunks(vec![c1, c2]);
236 assert_eq!(ctx.chunk_count, 1, "duplicate chunk should be deduplicated");
237 }
238
239 #[test]
240 fn grouping_restores_chunk_order() {
241 let ca = ContextAssembler::new(ContextAssemblerConfig::default());
242 let chunks = vec![
244 make_chunk("doc-1", 2, "Third chunk.", 0.3),
245 make_chunk("doc-1", 0, "First chunk.", 0.1),
246 make_chunk("doc-1", 1, "Second chunk.", 0.2),
247 ];
248 let ctx = ca.assemble_chunks(chunks);
249 let first_pos = ctx.text.find("First chunk.").unwrap();
250 let second_pos = ctx.text.find("Second chunk.").unwrap();
251 let third_pos = ctx.text.find("Third chunk.").unwrap();
252 assert!(first_pos < second_pos, "chunk 0 before chunk 1");
253 assert!(second_pos < third_pos, "chunk 1 before chunk 2");
254 }
255
256 #[test]
257 fn token_budget_limits_output() {
258 let cfg = ContextAssemblerConfig {
259 max_tokens: 10, ..Default::default()
261 };
262 let ca = ContextAssembler::new(cfg);
263 let chunks: Vec<Chunk> = (0..20)
264 .map(|i| make_chunk("doc-1", i, &"word ".repeat(20), i as f32 * 0.01))
265 .collect();
266 let ctx = ca.assemble_chunks(chunks);
267 assert!(ctx.token_estimate <= 100, "should respect token budget");
268 }
269
270 #[test]
271 fn xml_escaping_applied() {
272 let ca = ContextAssembler::new(ContextAssemblerConfig::default());
273 let mut chunk = make_chunk("doc-1", 0, "Text with <b>bold</b> & \"quotes\".", 0.1);
274 chunk.document_id = "doc<1>".into();
275 let ctx = ca.assemble_chunks(vec![chunk]);
276 assert!(ctx.text.contains("<b>"), "< should be escaped");
277 assert!(ctx.text.contains("&"), "& should be escaped");
278 }
279
280 #[test]
281 fn assemble_texts_joins_with_budget() {
282 let ca = ContextAssembler::new(ContextAssemblerConfig::default());
283 let texts = vec!["Alpha".into(), "Beta".into(), "Gamma".into()];
284 let ctx = ca.assemble_texts(&texts);
285 assert!(ctx.text.contains("Alpha"));
286 assert_eq!(ctx.chunk_count, 3);
287 }
288}