Skip to main content

ailake_query/
context_assembler.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use std::collections::HashMap;
3
4use ailake_vec::cosine_distance;
5
6#[derive(Debug, Clone)]
7pub struct Chunk {
8    pub document_id: String,
9    pub chunk_index: u32,
10    pub chunk_text: String,
11    pub document_title: Option<String>,
12    pub section_path: Option<String>,
13    pub source_uri: Option<String>,
14    /// Distance from query (lower = more relevant)
15    pub distance: f32,
16    /// Optional embedding used for similarity-based dedup
17    pub embedding: Option<Vec<f32>>,
18}
19
20#[derive(Debug, Clone)]
21pub struct ContextAssemblerConfig {
22    /// Approximate token budget (4 chars ≈ 1 token)
23    pub max_tokens: usize,
24    /// Cosine distance below which two chunks are considered duplicates
25    pub dedup_threshold: f32,
26    pub group_by_document: bool,
27    pub max_chunks_per_document: usize,
28}
29
30impl Default for ContextAssemblerConfig {
31    fn default() -> Self {
32        Self {
33            max_tokens: 4096,
34            dedup_threshold: 0.05,
35            group_by_document: true,
36            max_chunks_per_document: 10,
37        }
38    }
39}
40
41pub struct AssembledContext {
42    /// XML-structured context ready for LLM input
43    pub text: String,
44    pub chunk_count: usize,
45    pub token_estimate: usize,
46}
47
48pub struct ContextAssembler {
49    config: ContextAssemblerConfig,
50}
51
52impl ContextAssembler {
53    pub fn new(config: ContextAssemblerConfig) -> Self {
54        Self { config }
55    }
56
57    /// Assemble chunks into structured XML context:
58    /// 1. Sort by relevance (distance ascending)
59    /// 2. Deduplicate similar chunks via embedding cosine distance
60    /// 3. Group by document, sort each group by chunk_index
61    /// 4. Apply token budget
62    /// 5. Render XML ready for LLM consumption
63    pub fn assemble_chunks(&self, mut chunks: Vec<Chunk>) -> AssembledContext {
64        chunks.sort_by(|a, b| {
65            a.distance
66                .partial_cmp(&b.distance)
67                .unwrap_or(std::cmp::Ordering::Equal)
68        });
69
70        let selected = self.dedup(chunks);
71        let groups = self.group(selected);
72        self.render(groups)
73    }
74
75    /// Assemble from plain text strings — no dedup, no XML grouping.
76    /// Kept for simpler callers that don't have document metadata.
77    pub fn assemble_texts(&self, chunks: &[String]) -> AssembledContext {
78        let char_budget = self.config.max_tokens * 4;
79        let mut text = String::new();
80        let mut count = 0;
81        for chunk in chunks {
82            if text.len() + chunk.len() + 2 > char_budget {
83                break;
84            }
85            if !text.is_empty() {
86                text.push_str("\n\n");
87            }
88            text.push_str(chunk);
89            count += 1;
90        }
91        AssembledContext {
92            token_estimate: text.len() / 4,
93            chunk_count: count,
94            text,
95        }
96    }
97
98    fn dedup(&self, chunks: Vec<Chunk>) -> Vec<Chunk> {
99        let mut selected: Vec<Chunk> = Vec::new();
100        'next: for chunk in chunks {
101            if let Some(emb) = &chunk.embedding {
102                for sel in &selected {
103                    if let Some(sel_emb) = &sel.embedding {
104                        if cosine_distance(emb, sel_emb) < self.config.dedup_threshold {
105                            continue 'next;
106                        }
107                    }
108                }
109            }
110            selected.push(chunk);
111        }
112        selected
113    }
114
115    fn group(&self, chunks: Vec<Chunk>) -> Vec<(String, Vec<Chunk>)> {
116        let mut map: HashMap<String, Vec<Chunk>> = HashMap::new();
117        let mut doc_order: Vec<String> = Vec::new();
118        for chunk in chunks {
119            if !map.contains_key(&chunk.document_id) {
120                doc_order.push(chunk.document_id.clone());
121            }
122            map.entry(chunk.document_id.clone())
123                .or_default()
124                .push(chunk);
125        }
126        if self.config.group_by_document {
127            for group in map.values_mut() {
128                group.sort_by_key(|c| c.chunk_index);
129            }
130        }
131        doc_order
132            .into_iter()
133            .map(|id| (id.clone(), map.remove(&id).unwrap_or_default()))
134            .collect()
135    }
136
137    fn render(&self, groups: Vec<(String, Vec<Chunk>)>) -> AssembledContext {
138        let char_budget = self.config.max_tokens * 4;
139        let mut xml = String::from("<context>\n");
140        let mut chunk_count = 0usize;
141
142        'outer: for (doc_id, doc_chunks) in &groups {
143            let title = doc_chunks
144                .first()
145                .and_then(|c| c.document_title.as_deref())
146                .unwrap_or("");
147            let source = doc_chunks
148                .first()
149                .and_then(|c| c.source_uri.as_deref())
150                .unwrap_or("");
151            xml.push_str(&format!(
152                "  <document id=\"{}\" title=\"{}\" source=\"{}\">\n",
153                escape_xml(doc_id),
154                escape_xml(title),
155                escape_xml(source)
156            ));
157
158            for chunk in doc_chunks.iter().take(self.config.max_chunks_per_document) {
159                if xml.len() >= char_budget {
160                    break 'outer;
161                }
162                let section = chunk.section_path.as_deref().unwrap_or("");
163                xml.push_str(&format!(
164                    "    <chunk index=\"{}\" section=\"{}\">\n      <text>{}</text>\n    </chunk>\n",
165                    chunk.chunk_index,
166                    escape_xml(section),
167                    escape_xml(&chunk.chunk_text)
168                ));
169                chunk_count += 1;
170            }
171
172            xml.push_str("  </document>\n");
173        }
174
175        xml.push_str("</context>");
176        AssembledContext {
177            token_estimate: xml.len() / 4,
178            chunk_count,
179            text: xml,
180        }
181    }
182}
183
184fn escape_xml(s: &str) -> String {
185    s.replace('&', "&amp;")
186        .replace('<', "&lt;")
187        .replace('>', "&gt;")
188        .replace('"', "&quot;")
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    fn make_chunk(doc: &str, idx: u32, text: &str, dist: f32) -> Chunk {
196        Chunk {
197            document_id: doc.to_string(),
198            chunk_index: idx,
199            chunk_text: text.to_string(),
200            document_title: Some(format!("Doc {doc}")),
201            section_path: Some("Introduction".into()),
202            source_uri: Some(format!("s3://lake/{doc}.parquet")),
203            distance: dist,
204            embedding: None,
205        }
206    }
207
208    #[test]
209    fn produces_valid_xml() {
210        let ca = ContextAssembler::new(ContextAssemblerConfig::default());
211        let chunks = vec![
212            make_chunk("doc-1", 0, "First chunk.", 0.1),
213            make_chunk("doc-1", 1, "Second chunk.", 0.15),
214            make_chunk("doc-2", 0, "Doc 2 chunk.", 0.2),
215        ];
216        let ctx = ca.assemble_chunks(chunks);
217        assert!(ctx.text.starts_with("<context>"));
218        assert!(ctx.text.ends_with("</context>"));
219        assert!(ctx.text.contains("doc-1"));
220        assert!(ctx.text.contains("doc-2"));
221        assert_eq!(ctx.chunk_count, 3);
222    }
223
224    #[test]
225    fn dedup_removes_near_identical_embeddings() {
226        let cfg = ContextAssemblerConfig {
227            dedup_threshold: 0.01,
228            ..Default::default()
229        };
230        let ca = ContextAssembler::new(cfg);
231        let emb = vec![1.0f32, 0.0, 0.0];
232        let mut c1 = make_chunk("doc-1", 0, "Text A.", 0.1);
233        c1.embedding = Some(emb.clone());
234        let mut c2 = make_chunk("doc-1", 1, "Text B.", 0.2);
235        c2.embedding = Some(emb.clone());
236        let ctx = ca.assemble_chunks(vec![c1, c2]);
237        assert_eq!(ctx.chunk_count, 1, "duplicate chunk should be deduplicated");
238    }
239
240    #[test]
241    fn grouping_restores_chunk_order() {
242        let ca = ContextAssembler::new(ContextAssemblerConfig::default());
243        // Chunks arrive out-of-order (by distance), but XML should group by doc + sort by index
244        let chunks = vec![
245            make_chunk("doc-1", 2, "Third chunk.", 0.3),
246            make_chunk("doc-1", 0, "First chunk.", 0.1),
247            make_chunk("doc-1", 1, "Second chunk.", 0.2),
248        ];
249        let ctx = ca.assemble_chunks(chunks);
250        let first_pos = ctx.text.find("First chunk.").unwrap();
251        let second_pos = ctx.text.find("Second chunk.").unwrap();
252        let third_pos = ctx.text.find("Third chunk.").unwrap();
253        assert!(first_pos < second_pos, "chunk 0 before chunk 1");
254        assert!(second_pos < third_pos, "chunk 1 before chunk 2");
255    }
256
257    #[test]
258    fn token_budget_limits_output() {
259        let cfg = ContextAssemblerConfig {
260            max_tokens: 10, // ~40 chars
261            ..Default::default()
262        };
263        let ca = ContextAssembler::new(cfg);
264        let chunks: Vec<Chunk> = (0..20)
265            .map(|i| make_chunk("doc-1", i, &"word ".repeat(20), i as f32 * 0.01))
266            .collect();
267        let ctx = ca.assemble_chunks(chunks);
268        assert!(ctx.token_estimate <= 100, "should respect token budget");
269    }
270
271    #[test]
272    fn xml_escaping_applied() {
273        let ca = ContextAssembler::new(ContextAssemblerConfig::default());
274        let mut chunk = make_chunk("doc-1", 0, "Text with <b>bold</b> & \"quotes\".", 0.1);
275        chunk.document_id = "doc<1>".into();
276        let ctx = ca.assemble_chunks(vec![chunk]);
277        assert!(ctx.text.contains("&lt;b&gt;"), "< should be escaped");
278        assert!(ctx.text.contains("&amp;"), "& should be escaped");
279    }
280
281    #[test]
282    fn assemble_texts_joins_with_budget() {
283        let ca = ContextAssembler::new(ContextAssemblerConfig::default());
284        let texts = vec!["Alpha".into(), "Beta".into(), "Gamma".into()];
285        let ctx = ca.assemble_texts(&texts);
286        assert!(ctx.text.contains("Alpha"));
287        assert_eq!(ctx.chunk_count, 3);
288    }
289}