Skip to main content

ailake_query/
context_assembler.rs

1use std::collections::HashMap;
2
3use ailake_vec::cosine_distance;
4
5#[derive(Debug, Clone)]
6pub struct Chunk {
7    pub document_id: String,
8    pub chunk_index: u32,
9    pub chunk_text: String,
10    pub document_title: Option<String>,
11    pub section_path: Option<String>,
12    pub source_uri: Option<String>,
13    /// Distance from query (lower = more relevant)
14    pub distance: f32,
15    /// Optional embedding used for similarity-based dedup
16    pub embedding: Option<Vec<f32>>,
17}
18
19#[derive(Debug, Clone)]
20pub struct ContextAssemblerConfig {
21    /// Approximate token budget (4 chars ≈ 1 token)
22    pub max_tokens: usize,
23    /// Cosine distance below which two chunks are considered duplicates
24    pub dedup_threshold: f32,
25    pub group_by_document: bool,
26    pub max_chunks_per_document: usize,
27}
28
29impl Default for ContextAssemblerConfig {
30    fn default() -> Self {
31        Self {
32            max_tokens: 4096,
33            dedup_threshold: 0.05,
34            group_by_document: true,
35            max_chunks_per_document: 10,
36        }
37    }
38}
39
40pub struct AssembledContext {
41    /// XML-structured context ready for LLM input
42    pub text: String,
43    pub chunk_count: usize,
44    pub token_estimate: usize,
45}
46
47pub struct ContextAssembler {
48    config: ContextAssemblerConfig,
49}
50
51impl ContextAssembler {
52    pub fn new(config: ContextAssemblerConfig) -> Self {
53        Self { config }
54    }
55
56    /// Assemble chunks into structured XML context:
57    /// 1. Sort by relevance (distance ascending)
58    /// 2. Deduplicate similar chunks via embedding cosine distance
59    /// 3. Group by document, sort each group by chunk_index
60    /// 4. Apply token budget
61    /// 5. Render XML ready for LLM consumption
62    pub fn assemble_chunks(&self, mut chunks: Vec<Chunk>) -> AssembledContext {
63        chunks.sort_by(|a, b| {
64            a.distance
65                .partial_cmp(&b.distance)
66                .unwrap_or(std::cmp::Ordering::Equal)
67        });
68
69        let selected = self.dedup(chunks);
70        let groups = self.group(selected);
71        self.render(groups)
72    }
73
74    /// Assemble from plain text strings — no dedup, no XML grouping.
75    /// Kept for simpler callers that don't have document metadata.
76    pub fn assemble_texts(&self, chunks: &[String]) -> AssembledContext {
77        let char_budget = self.config.max_tokens * 4;
78        let mut text = String::new();
79        let mut count = 0;
80        for chunk in chunks {
81            if text.len() + chunk.len() + 2 > char_budget {
82                break;
83            }
84            if !text.is_empty() {
85                text.push_str("\n\n");
86            }
87            text.push_str(chunk);
88            count += 1;
89        }
90        AssembledContext {
91            token_estimate: text.len() / 4,
92            chunk_count: count,
93            text,
94        }
95    }
96
97    fn dedup(&self, chunks: Vec<Chunk>) -> Vec<Chunk> {
98        let mut selected: Vec<Chunk> = Vec::new();
99        'next: for chunk in chunks {
100            if let Some(emb) = &chunk.embedding {
101                for sel in &selected {
102                    if let Some(sel_emb) = &sel.embedding {
103                        if cosine_distance(emb, sel_emb) < self.config.dedup_threshold {
104                            continue 'next;
105                        }
106                    }
107                }
108            }
109            selected.push(chunk);
110        }
111        selected
112    }
113
114    fn group(&self, chunks: Vec<Chunk>) -> Vec<(String, Vec<Chunk>)> {
115        let mut map: HashMap<String, Vec<Chunk>> = HashMap::new();
116        let mut doc_order: Vec<String> = Vec::new();
117        for chunk in chunks {
118            if !map.contains_key(&chunk.document_id) {
119                doc_order.push(chunk.document_id.clone());
120            }
121            map.entry(chunk.document_id.clone())
122                .or_default()
123                .push(chunk);
124        }
125        if self.config.group_by_document {
126            for group in map.values_mut() {
127                group.sort_by_key(|c| c.chunk_index);
128            }
129        }
130        doc_order
131            .into_iter()
132            .map(|id| (id.clone(), map.remove(&id).unwrap_or_default()))
133            .collect()
134    }
135
136    fn render(&self, groups: Vec<(String, Vec<Chunk>)>) -> AssembledContext {
137        let char_budget = self.config.max_tokens * 4;
138        let mut xml = String::from("<context>\n");
139        let mut chunk_count = 0usize;
140
141        'outer: for (doc_id, doc_chunks) in &groups {
142            let title = doc_chunks
143                .first()
144                .and_then(|c| c.document_title.as_deref())
145                .unwrap_or("");
146            let source = doc_chunks
147                .first()
148                .and_then(|c| c.source_uri.as_deref())
149                .unwrap_or("");
150            xml.push_str(&format!(
151                "  <document id=\"{}\" title=\"{}\" source=\"{}\">\n",
152                escape_xml(doc_id),
153                escape_xml(title),
154                escape_xml(source)
155            ));
156
157            for chunk in doc_chunks.iter().take(self.config.max_chunks_per_document) {
158                if xml.len() >= char_budget {
159                    break 'outer;
160                }
161                let section = chunk.section_path.as_deref().unwrap_or("");
162                xml.push_str(&format!(
163                    "    <chunk index=\"{}\" section=\"{}\">\n      <text>{}</text>\n    </chunk>\n",
164                    chunk.chunk_index,
165                    escape_xml(section),
166                    escape_xml(&chunk.chunk_text)
167                ));
168                chunk_count += 1;
169            }
170
171            xml.push_str("  </document>\n");
172        }
173
174        xml.push_str("</context>");
175        AssembledContext {
176            token_estimate: xml.len() / 4,
177            chunk_count,
178            text: xml,
179        }
180    }
181}
182
183fn escape_xml(s: &str) -> String {
184    s.replace('&', "&amp;")
185        .replace('<', "&lt;")
186        .replace('>', "&gt;")
187        .replace('"', "&quot;")
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    fn make_chunk(doc: &str, idx: u32, text: &str, dist: f32) -> Chunk {
195        Chunk {
196            document_id: doc.to_string(),
197            chunk_index: idx,
198            chunk_text: text.to_string(),
199            document_title: Some(format!("Doc {doc}")),
200            section_path: Some("Introduction".into()),
201            source_uri: Some(format!("s3://lake/{doc}.parquet")),
202            distance: dist,
203            embedding: None,
204        }
205    }
206
207    #[test]
208    fn produces_valid_xml() {
209        let ca = ContextAssembler::new(ContextAssemblerConfig::default());
210        let chunks = vec![
211            make_chunk("doc-1", 0, "First chunk.", 0.1),
212            make_chunk("doc-1", 1, "Second chunk.", 0.15),
213            make_chunk("doc-2", 0, "Doc 2 chunk.", 0.2),
214        ];
215        let ctx = ca.assemble_chunks(chunks);
216        assert!(ctx.text.starts_with("<context>"));
217        assert!(ctx.text.ends_with("</context>"));
218        assert!(ctx.text.contains("doc-1"));
219        assert!(ctx.text.contains("doc-2"));
220        assert_eq!(ctx.chunk_count, 3);
221    }
222
223    #[test]
224    fn dedup_removes_near_identical_embeddings() {
225        let cfg = ContextAssemblerConfig {
226            dedup_threshold: 0.01,
227            ..Default::default()
228        };
229        let ca = ContextAssembler::new(cfg);
230        let emb = vec![1.0f32, 0.0, 0.0];
231        let mut c1 = make_chunk("doc-1", 0, "Text A.", 0.1);
232        c1.embedding = Some(emb.clone());
233        let mut c2 = make_chunk("doc-1", 1, "Text B.", 0.2);
234        c2.embedding = Some(emb.clone());
235        let ctx = ca.assemble_chunks(vec![c1, c2]);
236        assert_eq!(ctx.chunk_count, 1, "duplicate chunk should be deduplicated");
237    }
238
239    #[test]
240    fn grouping_restores_chunk_order() {
241        let ca = ContextAssembler::new(ContextAssemblerConfig::default());
242        // Chunks arrive out-of-order (by distance), but XML should group by doc + sort by index
243        let chunks = vec![
244            make_chunk("doc-1", 2, "Third chunk.", 0.3),
245            make_chunk("doc-1", 0, "First chunk.", 0.1),
246            make_chunk("doc-1", 1, "Second chunk.", 0.2),
247        ];
248        let ctx = ca.assemble_chunks(chunks);
249        let first_pos = ctx.text.find("First chunk.").unwrap();
250        let second_pos = ctx.text.find("Second chunk.").unwrap();
251        let third_pos = ctx.text.find("Third chunk.").unwrap();
252        assert!(first_pos < second_pos, "chunk 0 before chunk 1");
253        assert!(second_pos < third_pos, "chunk 1 before chunk 2");
254    }
255
256    #[test]
257    fn token_budget_limits_output() {
258        let cfg = ContextAssemblerConfig {
259            max_tokens: 10, // ~40 chars
260            ..Default::default()
261        };
262        let ca = ContextAssembler::new(cfg);
263        let chunks: Vec<Chunk> = (0..20)
264            .map(|i| make_chunk("doc-1", i, &"word ".repeat(20), i as f32 * 0.01))
265            .collect();
266        let ctx = ca.assemble_chunks(chunks);
267        assert!(ctx.token_estimate <= 100, "should respect token budget");
268    }
269
270    #[test]
271    fn xml_escaping_applied() {
272        let ca = ContextAssembler::new(ContextAssemblerConfig::default());
273        let mut chunk = make_chunk("doc-1", 0, "Text with <b>bold</b> & \"quotes\".", 0.1);
274        chunk.document_id = "doc<1>".into();
275        let ctx = ca.assemble_chunks(vec![chunk]);
276        assert!(ctx.text.contains("&lt;b&gt;"), "< should be escaped");
277        assert!(ctx.text.contains("&amp;"), "& should be escaped");
278    }
279
280    #[test]
281    fn assemble_texts_joins_with_budget() {
282        let ca = ContextAssembler::new(ContextAssemblerConfig::default());
283        let texts = vec!["Alpha".into(), "Beta".into(), "Gamma".into()];
284        let ctx = ca.assemble_texts(&texts);
285        assert!(ctx.text.contains("Alpha"));
286        assert_eq!(ctx.chunk_count, 3);
287    }
288}