code_chunker/late.rs
1//! Late Chunking: Embed first, then chunk.
2//!
3//! ## The Problem with Traditional Chunking
4//!
5//! Traditional chunking embeds chunks independently:
6//!
7//! ```text
8//! Document: "Einstein developed relativity. He became famous."
9//! Chunks: ["Einstein developed relativity.", "He became famous."]
10//! Embeddings: [embed(chunk1), embed(chunk2)]
11//! ↑
12//! "He" loses context!
13//! ```
14//!
15//! The second chunk embeds "He" without knowing it refers to Einstein.
16//!
17//! ## Late Chunking Solution
18//!
19//! Late chunking (Günther et al. 2024) embeds the full document first,
20//! then pools token embeddings for each chunk:
21//!
22//! ```text
23//! Document: "Einstein developed relativity. He became famous."
24//!
25//! Step 1: Embed full document → Token embeddings [t1, t2, ..., tn]
26//! Each token "sees" the full document via attention.
27//!
28//! Step 2: Pool chunks from token embeddings:
29//! Chunk 1: mean_pool([t1, ..., t4]) ← "Einstein developed relativity."
30//! Chunk 2: mean_pool([t5, ..., t7]) ← "He became famous."
31//! "He" now has Einstein context!
32//! ```
33//!
34//! ## The Math
35//!
36//! Given token embeddings H = [h1, h2, ..., hn] from full document,
37//! and chunk boundaries [(s1, e1), (s2, e2), ...]:
38//!
39//! ```text
40//! chunk_embedding_i = (1 / |ei - si|) * Σ_{t=si}^{ei} ht
41//! ```
42//!
43//! Mean pooling preserves the contextual information each token gained
44//! from attending to the full document.
45//!
46//! ## When to Use
47//!
48//! - **Use Late Chunking**: When chunks reference each other (pronouns,
49//! acronym definitions, temporal references). Long coherent documents.
50//!
51//! - **Use Traditional**: Independent chunks, real-time embedding needed,
52//! memory-constrained (late chunking needs full doc in memory).
53//!
54//! ## Trade-offs
55//!
56//! | Aspect | Traditional | Late Chunking |
57//! |--------|-------------|---------------|
58//! | Memory | O(chunk_size) | O(doc_length × dim) |
59//! | Context | Local only | Full document |
60//! | Speed | Parallel chunks | Sequential doc first |
61//! | Quality | Baseline | +5-15% recall typically |
62//!
63//! ## References
64//!
65//! Günther, Billerbeck, et al. (2024). "Late Chunking: Contextual Chunk
66//! Embeddings Using Long-Context Embedding Models." arXiv:2409.04701.
67
68use crate::Slab;
69
70/// Late chunking pooler: pools token embeddings into chunk embeddings.
71///
72/// This is the core operation of late chunking. Given token-level embeddings
73/// from a full document, it pools the tokens within each chunk boundary
74/// to create contextualized chunk embeddings.
75#[derive(Debug, Clone)]
76pub struct LateChunkingPooler {
77 /// Embedding dimension (for validation).
78 dim: usize,
79}
80
81impl LateChunkingPooler {
82 /// Create a new late chunking pooler.
83 ///
84 /// # Arguments
85 ///
86 /// * `dim` - Embedding dimension (e.g., 384 for all-MiniLM-L6-v2)
87 pub fn new(dim: usize) -> Self {
88 Self { dim }
89 }
90
91 /// Pool token embeddings into chunk embeddings.
92 ///
93 /// # Arguments
94 ///
95 /// * `token_embeddings` - Token-level embeddings from full document.
96 /// Shape: [n_tokens, dim]. Each token has "seen" the full document.
97 /// * `chunks` - Chunk boundaries from any chunker.
98 /// * `doc_len` - Total document length in bytes (for mapping).
99 ///
100 /// # Returns
101 ///
102 /// Contextualized chunk embeddings. Each chunk embedding is the mean
103 /// of its constituent token embeddings.
104 ///
105 /// # Panics
106 ///
107 /// Panics if token embeddings have inconsistent dimensions.
108 pub fn pool(
109 &self,
110 token_embeddings: &[Vec<f32>],
111 chunks: &[Slab],
112 doc_len: usize,
113 ) -> Vec<Vec<f32>> {
114 if token_embeddings.is_empty() || chunks.is_empty() || doc_len == 0 {
115 return vec![vec![0.0; self.dim]; chunks.len()];
116 }
117
118 let n_tokens = token_embeddings.len();
119
120 chunks
121 .iter()
122 .map(|chunk| {
123 // Map byte offsets to token indices (linear approximation)
124 let token_start = (chunk.start as f64 / doc_len as f64 * n_tokens as f64) as usize;
125 let token_end =
126 ((chunk.end as f64 / doc_len as f64 * n_tokens as f64) as usize).min(n_tokens);
127
128 if token_end <= token_start {
129 // Fallback: use full document average
130 return self.mean_pool(token_embeddings);
131 }
132
133 self.mean_pool(&token_embeddings[token_start..token_end])
134 })
135 .collect()
136 }
137
138 /// Pool with exact token-to-character mappings.
139 ///
140 /// Use this when you have exact token offsets from the tokenizer,
141 /// rather than relying on linear approximation.
142 ///
143 /// # Arguments
144 ///
145 /// * `token_embeddings` - Token-level embeddings [n_tokens, dim].
146 /// * `token_offsets` - Character offset for each token [(start, end), ...].
147 /// * `chunks` - Chunk boundaries.
148 pub fn pool_with_offsets(
149 &self,
150 token_embeddings: &[Vec<f32>],
151 token_offsets: &[(usize, usize)],
152 chunks: &[Slab],
153 ) -> Vec<Vec<f32>> {
154 if token_embeddings.is_empty() || chunks.is_empty() {
155 return vec![vec![0.0; self.dim]; chunks.len()];
156 }
157
158 chunks
159 .iter()
160 .map(|chunk| {
161 // Find tokens that overlap with this chunk
162 let token_indices: Vec<usize> = token_offsets
163 .iter()
164 .enumerate()
165 .filter(|(_, (start, end))| {
166 // Token overlaps with chunk
167 *start < chunk.end && *end > chunk.start
168 })
169 .map(|(i, _)| i)
170 .collect();
171
172 if token_indices.is_empty() {
173 return self.mean_pool(token_embeddings);
174 }
175
176 let selected: Vec<&[f32]> = token_indices
177 .iter()
178 .filter_map(|&i| token_embeddings.get(i).map(Vec::as_slice))
179 .collect();
180
181 self.mean_pool_refs(&selected)
182 })
183 .collect()
184 }
185
186 /// Mean pool a slice of token embeddings.
187 fn mean_pool(&self, embeddings: &[Vec<f32>]) -> Vec<f32> {
188 if embeddings.is_empty() {
189 return vec![0.0; self.dim];
190 }
191
192 let dim = embeddings[0].len();
193 let mut result = vec![0.0; dim];
194 let count = embeddings.len() as f32;
195
196 for emb in embeddings {
197 for (i, &v) in emb.iter().enumerate() {
198 result[i] += v;
199 }
200 }
201
202 for v in &mut result {
203 *v /= count;
204 }
205
206 // L2 normalize
207 let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
208 if norm > 1e-9 {
209 for v in &mut result {
210 *v /= norm;
211 }
212 }
213
214 result
215 }
216
217 /// Mean pool from references.
218 fn mean_pool_refs(&self, embeddings: &[&[f32]]) -> Vec<f32> {
219 if embeddings.is_empty() {
220 return vec![0.0; self.dim];
221 }
222
223 let dim = embeddings[0].len();
224 let mut result = vec![0.0; dim];
225 let count = embeddings.len() as f32;
226
227 for emb in embeddings {
228 for (i, &v) in emb.iter().enumerate() {
229 result[i] += v;
230 }
231 }
232
233 for v in &mut result {
234 *v /= count;
235 }
236
237 // L2 normalize
238 let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
239 if norm > 1e-9 {
240 for v in &mut result {
241 *v /= norm;
242 }
243 }
244
245 result
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_late_chunking_pooler_basic() {
255 let pooler = LateChunkingPooler::new(4);
256
257 // Simulate 6 tokens, 4-dim embeddings
258 let token_embeddings = vec![
259 vec![1.0, 0.0, 0.0, 0.0],
260 vec![0.0, 1.0, 0.0, 0.0],
261 vec![0.0, 0.0, 1.0, 0.0],
262 vec![0.0, 0.0, 0.0, 1.0],
263 vec![1.0, 1.0, 0.0, 0.0],
264 vec![0.0, 0.0, 1.0, 1.0],
265 ];
266
267 let chunks = vec![
268 Slab::new("first chunk", 0, 10, 0),
269 Slab::new("second chunk", 10, 20, 1),
270 ];
271
272 let chunk_embeddings = pooler.pool(&token_embeddings, &chunks, 20);
273
274 assert_eq!(chunk_embeddings.len(), 2);
275 assert_eq!(chunk_embeddings[0].len(), 4);
276 assert_eq!(chunk_embeddings[1].len(), 4);
277
278 // Embeddings should be normalized
279 let norm0: f32 = chunk_embeddings[0]
280 .iter()
281 .map(|x| x * x)
282 .sum::<f32>()
283 .sqrt();
284 assert!((norm0 - 1.0).abs() < 0.01);
285 }
286
287 #[test]
288 fn test_pool_with_exact_offsets() {
289 let pooler = LateChunkingPooler::new(3);
290
291 // 5 tokens with known character offsets
292 let token_embeddings = vec![
293 vec![1.0, 0.0, 0.0], // "Hello"
294 vec![0.0, 1.0, 0.0], // " "
295 vec![0.0, 0.0, 1.0], // "world"
296 vec![1.0, 1.0, 0.0], // "."
297 vec![0.0, 1.0, 1.0], // " Bye"
298 ];
299
300 let token_offsets = vec![
301 (0, 5), // "Hello"
302 (5, 6), // " "
303 (6, 11), // "world"
304 (11, 12), // "."
305 (12, 16), // " Bye"
306 ];
307
308 let chunks = vec![
309 Slab::new("Hello world.", 0, 12, 0),
310 Slab::new(" Bye", 12, 16, 1),
311 ];
312
313 let embeddings = pooler.pool_with_offsets(&token_embeddings, &token_offsets, &chunks);
314
315 assert_eq!(embeddings.len(), 2);
316 // First chunk should average tokens 0-3
317 // Second chunk should be token 4
318 }
319
320 #[test]
321 fn test_empty_inputs() {
322 let pooler = LateChunkingPooler::new(4);
323
324 let result = pooler.pool(&[], &[], 0);
325 assert!(result.is_empty());
326
327 let chunks = vec![Slab::new("test", 0, 4, 0)];
328
329 let result = pooler.pool(&[], &chunks, 4);
330 assert_eq!(result.len(), 1);
331 assert_eq!(result[0].len(), 4);
332 }
333}