1use std::collections::HashMap;
3
4use ailake_vec::cosine_distance;
5
6#[derive(Debug, Clone)]
7pub struct Chunk {
8 pub document_id: String,
9 pub chunk_index: u32,
10 pub chunk_text: String,
11 pub document_title: Option<String>,
12 pub section_path: Option<String>,
13 pub source_uri: Option<String>,
14 pub distance: f32,
16 pub embedding: Option<Vec<f32>>,
18}
19
20#[derive(Debug, Clone)]
21pub struct ContextAssemblerConfig {
22 pub max_tokens: usize,
24 pub dedup_threshold: f32,
26 pub group_by_document: bool,
27 pub max_chunks_per_document: usize,
28}
29
30impl Default for ContextAssemblerConfig {
31 fn default() -> Self {
32 Self {
33 max_tokens: 4096,
34 dedup_threshold: 0.05,
35 group_by_document: true,
36 max_chunks_per_document: 10,
37 }
38 }
39}
40
41pub struct AssembledContext {
42 pub text: String,
44 pub chunk_count: usize,
45 pub token_estimate: usize,
46}
47
48pub struct ContextAssembler {
49 config: ContextAssemblerConfig,
50}
51
52impl ContextAssembler {
53 pub fn new(config: ContextAssemblerConfig) -> Self {
54 Self { config }
55 }
56
57 pub fn assemble_chunks(&self, mut chunks: Vec<Chunk>) -> AssembledContext {
64 chunks.sort_by(|a, b| {
65 a.distance
66 .partial_cmp(&b.distance)
67 .unwrap_or(std::cmp::Ordering::Equal)
68 });
69
70 let selected = self.dedup(chunks);
71 let groups = self.group(selected);
72 self.render(groups)
73 }
74
75 pub fn assemble_texts(&self, chunks: &[String]) -> AssembledContext {
78 let char_budget = self.config.max_tokens * 4;
79 let mut text = String::new();
80 let mut count = 0;
81 for chunk in chunks {
82 if text.len() + chunk.len() + 2 > char_budget {
83 break;
84 }
85 if !text.is_empty() {
86 text.push_str("\n\n");
87 }
88 text.push_str(chunk);
89 count += 1;
90 }
91 AssembledContext {
92 token_estimate: text.len() / 4,
93 chunk_count: count,
94 text,
95 }
96 }
97
98 fn dedup(&self, chunks: Vec<Chunk>) -> Vec<Chunk> {
99 let mut selected: Vec<Chunk> = Vec::new();
100 'next: for chunk in chunks {
101 if let Some(emb) = &chunk.embedding {
102 for sel in &selected {
103 if let Some(sel_emb) = &sel.embedding {
104 if cosine_distance(emb, sel_emb) < self.config.dedup_threshold {
105 continue 'next;
106 }
107 }
108 }
109 }
110 selected.push(chunk);
111 }
112 selected
113 }
114
115 fn group(&self, chunks: Vec<Chunk>) -> Vec<(String, Vec<Chunk>)> {
116 let mut map: HashMap<String, Vec<Chunk>> = HashMap::new();
117 let mut doc_order: Vec<String> = Vec::new();
118 for chunk in chunks {
119 if !map.contains_key(&chunk.document_id) {
120 doc_order.push(chunk.document_id.clone());
121 }
122 map.entry(chunk.document_id.clone())
123 .or_default()
124 .push(chunk);
125 }
126 if self.config.group_by_document {
127 for group in map.values_mut() {
128 group.sort_by_key(|c| c.chunk_index);
129 }
130 }
131 doc_order
132 .into_iter()
133 .map(|id| (id.clone(), map.remove(&id).unwrap_or_default()))
134 .collect()
135 }
136
137 fn render(&self, groups: Vec<(String, Vec<Chunk>)>) -> AssembledContext {
138 let char_budget = self.config.max_tokens * 4;
139 let mut xml = String::from("<context>\n");
140 let mut chunk_count = 0usize;
141
142 'outer: for (doc_id, doc_chunks) in &groups {
143 let title = doc_chunks
144 .first()
145 .and_then(|c| c.document_title.as_deref())
146 .unwrap_or("");
147 let source = doc_chunks
148 .first()
149 .and_then(|c| c.source_uri.as_deref())
150 .unwrap_or("");
151 xml.push_str(&format!(
152 " <document id=\"{}\" title=\"{}\" source=\"{}\">\n",
153 escape_xml(doc_id),
154 escape_xml(title),
155 escape_xml(source)
156 ));
157
158 for chunk in doc_chunks.iter().take(self.config.max_chunks_per_document) {
159 if xml.len() >= char_budget {
160 break 'outer;
161 }
162 let section = chunk.section_path.as_deref().unwrap_or("");
163 xml.push_str(&format!(
164 " <chunk index=\"{}\" section=\"{}\">\n <text>{}</text>\n </chunk>\n",
165 chunk.chunk_index,
166 escape_xml(section),
167 escape_xml(&chunk.chunk_text)
168 ));
169 chunk_count += 1;
170 }
171
172 xml.push_str(" </document>\n");
173 }
174
175 xml.push_str("</context>");
176 AssembledContext {
177 token_estimate: xml.len() / 4,
178 chunk_count,
179 text: xml,
180 }
181 }
182}
183
184fn escape_xml(s: &str) -> String {
185 s.replace('&', "&")
186 .replace('<', "<")
187 .replace('>', ">")
188 .replace('"', """)
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 fn make_chunk(doc: &str, idx: u32, text: &str, dist: f32) -> Chunk {
196 Chunk {
197 document_id: doc.to_string(),
198 chunk_index: idx,
199 chunk_text: text.to_string(),
200 document_title: Some(format!("Doc {doc}")),
201 section_path: Some("Introduction".into()),
202 source_uri: Some(format!("s3://lake/{doc}.parquet")),
203 distance: dist,
204 embedding: None,
205 }
206 }
207
208 #[test]
209 fn produces_valid_xml() {
210 let ca = ContextAssembler::new(ContextAssemblerConfig::default());
211 let chunks = vec![
212 make_chunk("doc-1", 0, "First chunk.", 0.1),
213 make_chunk("doc-1", 1, "Second chunk.", 0.15),
214 make_chunk("doc-2", 0, "Doc 2 chunk.", 0.2),
215 ];
216 let ctx = ca.assemble_chunks(chunks);
217 assert!(ctx.text.starts_with("<context>"));
218 assert!(ctx.text.ends_with("</context>"));
219 assert!(ctx.text.contains("doc-1"));
220 assert!(ctx.text.contains("doc-2"));
221 assert_eq!(ctx.chunk_count, 3);
222 }
223
224 #[test]
225 fn dedup_removes_near_identical_embeddings() {
226 let cfg = ContextAssemblerConfig {
227 dedup_threshold: 0.01,
228 ..Default::default()
229 };
230 let ca = ContextAssembler::new(cfg);
231 let emb = vec![1.0f32, 0.0, 0.0];
232 let mut c1 = make_chunk("doc-1", 0, "Text A.", 0.1);
233 c1.embedding = Some(emb.clone());
234 let mut c2 = make_chunk("doc-1", 1, "Text B.", 0.2);
235 c2.embedding = Some(emb.clone());
236 let ctx = ca.assemble_chunks(vec![c1, c2]);
237 assert_eq!(ctx.chunk_count, 1, "duplicate chunk should be deduplicated");
238 }
239
240 #[test]
241 fn grouping_restores_chunk_order() {
242 let ca = ContextAssembler::new(ContextAssemblerConfig::default());
243 let chunks = vec![
245 make_chunk("doc-1", 2, "Third chunk.", 0.3),
246 make_chunk("doc-1", 0, "First chunk.", 0.1),
247 make_chunk("doc-1", 1, "Second chunk.", 0.2),
248 ];
249 let ctx = ca.assemble_chunks(chunks);
250 let first_pos = ctx.text.find("First chunk.").unwrap();
251 let second_pos = ctx.text.find("Second chunk.").unwrap();
252 let third_pos = ctx.text.find("Third chunk.").unwrap();
253 assert!(first_pos < second_pos, "chunk 0 before chunk 1");
254 assert!(second_pos < third_pos, "chunk 1 before chunk 2");
255 }
256
257 #[test]
258 fn token_budget_limits_output() {
259 let cfg = ContextAssemblerConfig {
260 max_tokens: 10, ..Default::default()
262 };
263 let ca = ContextAssembler::new(cfg);
264 let chunks: Vec<Chunk> = (0..20)
265 .map(|i| make_chunk("doc-1", i, &"word ".repeat(20), i as f32 * 0.01))
266 .collect();
267 let ctx = ca.assemble_chunks(chunks);
268 assert!(ctx.token_estimate <= 100, "should respect token budget");
269 }
270
271 #[test]
272 fn xml_escaping_applied() {
273 let ca = ContextAssembler::new(ContextAssemblerConfig::default());
274 let mut chunk = make_chunk("doc-1", 0, "Text with <b>bold</b> & \"quotes\".", 0.1);
275 chunk.document_id = "doc<1>".into();
276 let ctx = ca.assemble_chunks(vec![chunk]);
277 assert!(ctx.text.contains("<b>"), "< should be escaped");
278 assert!(ctx.text.contains("&"), "& should be escaped");
279 }
280
281 #[test]
282 fn assemble_texts_joins_with_budget() {
283 let ca = ContextAssembler::new(ContextAssemblerConfig::default());
284 let texts = vec!["Alpha".into(), "Beta".into(), "Gamma".into()];
285 let ctx = ca.assemble_texts(&texts);
286 assert!(ctx.text.contains("Alpha"));
287 assert_eq!(ctx.chunk_count, 3);
288 }
289}