Skip to main content

code_chunker/
late.rs

1//! Late Chunking: Embed first, then chunk.
2//!
3//! ## The Problem with Traditional Chunking
4//!
5//! Traditional chunking embeds chunks independently:
6//!
7//! ```text
8//! Document: "Einstein developed relativity. He became famous."
9//! Chunks:   ["Einstein developed relativity.", "He became famous."]
10//! Embeddings: [embed(chunk1), embed(chunk2)]
11//!                              ↑
12//!                              "He" loses context!
13//! ```
14//!
15//! The second chunk embeds "He" without knowing it refers to Einstein.
16//!
17//! ## Late Chunking Solution
18//!
19//! Late chunking (Günther et al. 2024) embeds the full document first,
20//! then pools token embeddings for each chunk:
21//!
22//! ```text
23//! Document: "Einstein developed relativity. He became famous."
24//!
25//! Step 1: Embed full document → Token embeddings [t1, t2, ..., tn]
26//!         Each token "sees" the full document via attention.
27//!
28//! Step 2: Pool chunks from token embeddings:
29//!         Chunk 1: mean_pool([t1, ..., t4])  ← "Einstein developed relativity."
30//!         Chunk 2: mean_pool([t5, ..., t7])  ← "He became famous."
31//!                                               "He" now has Einstein context!
32//! ```
33//!
34//! ## The Math
35//!
36//! Given token embeddings H = [h1, h2, ..., hn] from full document,
37//! and chunk boundaries [(s1, e1), (s2, e2), ...]:
38//!
39//! ```text
40//! chunk_embedding_i = (1 / |ei - si|) * Σ_{t=si}^{ei} ht
41//! ```
42//!
43//! Mean pooling preserves the contextual information each token gained
44//! from attending to the full document.
45//!
46//! ## When to Use
47//!
48//! - **Use Late Chunking**: When chunks reference each other (pronouns,
49//!   acronym definitions, temporal references). Long coherent documents.
50//!
51//! - **Use Traditional**: Independent chunks, real-time embedding needed,
52//!   memory-constrained (late chunking needs full doc in memory).
53//!
54//! ## Trade-offs
55//!
56//! | Aspect | Traditional | Late Chunking |
57//! |--------|-------------|---------------|
58//! | Memory | O(chunk_size) | O(doc_length × dim) |
59//! | Context | Local only | Full document |
60//! | Speed | Parallel chunks | Sequential doc first |
61//! | Quality | Baseline | +5-15% recall typically |
62//!
63//! ## References
64//!
65//! Günther, Billerbeck, et al. (2024). "Late Chunking: Contextual Chunk
66//! Embeddings Using Long-Context Embedding Models." arXiv:2409.04701.
67
68use crate::Slab;
69
70/// Late chunking pooler: pools token embeddings into chunk embeddings.
71///
72/// This is the core operation of late chunking. Given token-level embeddings
73/// from a full document, it pools the tokens within each chunk boundary
74/// to create contextualized chunk embeddings.
75#[derive(Debug, Clone)]
76pub struct LateChunkingPooler {
77    /// Embedding dimension (for validation).
78    dim: usize,
79}
80
81impl LateChunkingPooler {
82    /// Create a new late chunking pooler.
83    ///
84    /// # Arguments
85    ///
86    /// * `dim` - Embedding dimension (e.g., 384 for all-MiniLM-L6-v2)
87    pub fn new(dim: usize) -> Self {
88        Self { dim }
89    }
90
91    /// Pool token embeddings into chunk embeddings.
92    ///
93    /// # Arguments
94    ///
95    /// * `token_embeddings` - Token-level embeddings from full document.
96    ///   Shape: [n_tokens, dim]. Each token has "seen" the full document.
97    /// * `chunks` - Chunk boundaries from any chunker.
98    /// * `doc_len` - Total document length in bytes (for mapping).
99    ///
100    /// # Returns
101    ///
102    /// Contextualized chunk embeddings. Each chunk embedding is the mean
103    /// of its constituent token embeddings.
104    ///
105    /// # Panics
106    ///
107    /// Panics if token embeddings have inconsistent dimensions.
108    pub fn pool(
109        &self,
110        token_embeddings: &[Vec<f32>],
111        chunks: &[Slab],
112        doc_len: usize,
113    ) -> Vec<Vec<f32>> {
114        if token_embeddings.is_empty() || chunks.is_empty() || doc_len == 0 {
115            return vec![vec![0.0; self.dim]; chunks.len()];
116        }
117
118        let n_tokens = token_embeddings.len();
119
120        chunks
121            .iter()
122            .map(|chunk| {
123                // Map byte offsets to token indices (linear approximation)
124                let token_start = (chunk.start as f64 / doc_len as f64 * n_tokens as f64) as usize;
125                let token_end =
126                    ((chunk.end as f64 / doc_len as f64 * n_tokens as f64) as usize).min(n_tokens);
127
128                if token_end <= token_start {
129                    // Fallback: use full document average
130                    return self.mean_pool(token_embeddings);
131                }
132
133                self.mean_pool(&token_embeddings[token_start..token_end])
134            })
135            .collect()
136    }
137
138    /// Pool with exact token-to-character mappings.
139    ///
140    /// Use this when you have exact token offsets from the tokenizer,
141    /// rather than relying on linear approximation.
142    ///
143    /// # Arguments
144    ///
145    /// * `token_embeddings` - Token-level embeddings [n_tokens, dim].
146    /// * `token_offsets` - Character offset for each token [(start, end), ...].
147    /// * `chunks` - Chunk boundaries.
148    pub fn pool_with_offsets(
149        &self,
150        token_embeddings: &[Vec<f32>],
151        token_offsets: &[(usize, usize)],
152        chunks: &[Slab],
153    ) -> Vec<Vec<f32>> {
154        if token_embeddings.is_empty() || chunks.is_empty() {
155            return vec![vec![0.0; self.dim]; chunks.len()];
156        }
157
158        chunks
159            .iter()
160            .map(|chunk| {
161                // Find tokens that overlap with this chunk
162                let token_indices: Vec<usize> = token_offsets
163                    .iter()
164                    .enumerate()
165                    .filter(|(_, (start, end))| {
166                        // Token overlaps with chunk
167                        *start < chunk.end && *end > chunk.start
168                    })
169                    .map(|(i, _)| i)
170                    .collect();
171
172                if token_indices.is_empty() {
173                    return self.mean_pool(token_embeddings);
174                }
175
176                let selected: Vec<&[f32]> = token_indices
177                    .iter()
178                    .filter_map(|&i| token_embeddings.get(i).map(Vec::as_slice))
179                    .collect();
180
181                self.mean_pool_refs(&selected)
182            })
183            .collect()
184    }
185
186    /// Mean pool a slice of token embeddings.
187    fn mean_pool(&self, embeddings: &[Vec<f32>]) -> Vec<f32> {
188        if embeddings.is_empty() {
189            return vec![0.0; self.dim];
190        }
191
192        let dim = embeddings[0].len();
193        let mut result = vec![0.0; dim];
194        let count = embeddings.len() as f32;
195
196        for emb in embeddings {
197            for (i, &v) in emb.iter().enumerate() {
198                result[i] += v;
199            }
200        }
201
202        for v in &mut result {
203            *v /= count;
204        }
205
206        // L2 normalize
207        let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
208        if norm > 1e-9 {
209            for v in &mut result {
210                *v /= norm;
211            }
212        }
213
214        result
215    }
216
217    /// Mean pool from references.
218    fn mean_pool_refs(&self, embeddings: &[&[f32]]) -> Vec<f32> {
219        if embeddings.is_empty() {
220            return vec![0.0; self.dim];
221        }
222
223        let dim = embeddings[0].len();
224        let mut result = vec![0.0; dim];
225        let count = embeddings.len() as f32;
226
227        for emb in embeddings {
228            for (i, &v) in emb.iter().enumerate() {
229                result[i] += v;
230            }
231        }
232
233        for v in &mut result {
234            *v /= count;
235        }
236
237        // L2 normalize
238        let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
239        if norm > 1e-9 {
240            for v in &mut result {
241                *v /= norm;
242            }
243        }
244
245        result
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_late_chunking_pooler_basic() {
255        let pooler = LateChunkingPooler::new(4);
256
257        // Simulate 6 tokens, 4-dim embeddings
258        let token_embeddings = vec![
259            vec![1.0, 0.0, 0.0, 0.0],
260            vec![0.0, 1.0, 0.0, 0.0],
261            vec![0.0, 0.0, 1.0, 0.0],
262            vec![0.0, 0.0, 0.0, 1.0],
263            vec![1.0, 1.0, 0.0, 0.0],
264            vec![0.0, 0.0, 1.0, 1.0],
265        ];
266
267        let chunks = vec![
268            Slab::new("first chunk", 0, 10, 0),
269            Slab::new("second chunk", 10, 20, 1),
270        ];
271
272        let chunk_embeddings = pooler.pool(&token_embeddings, &chunks, 20);
273
274        assert_eq!(chunk_embeddings.len(), 2);
275        assert_eq!(chunk_embeddings[0].len(), 4);
276        assert_eq!(chunk_embeddings[1].len(), 4);
277
278        // Embeddings should be normalized
279        let norm0: f32 = chunk_embeddings[0]
280            .iter()
281            .map(|x| x * x)
282            .sum::<f32>()
283            .sqrt();
284        assert!((norm0 - 1.0).abs() < 0.01);
285    }
286
287    #[test]
288    fn test_pool_with_exact_offsets() {
289        let pooler = LateChunkingPooler::new(3);
290
291        // 5 tokens with known character offsets
292        let token_embeddings = vec![
293            vec![1.0, 0.0, 0.0], // "Hello"
294            vec![0.0, 1.0, 0.0], // " "
295            vec![0.0, 0.0, 1.0], // "world"
296            vec![1.0, 1.0, 0.0], // "."
297            vec![0.0, 1.0, 1.0], // " Bye"
298        ];
299
300        let token_offsets = vec![
301            (0, 5),   // "Hello"
302            (5, 6),   // " "
303            (6, 11),  // "world"
304            (11, 12), // "."
305            (12, 16), // " Bye"
306        ];
307
308        let chunks = vec![
309            Slab::new("Hello world.", 0, 12, 0),
310            Slab::new(" Bye", 12, 16, 1),
311        ];
312
313        let embeddings = pooler.pool_with_offsets(&token_embeddings, &token_offsets, &chunks);
314
315        assert_eq!(embeddings.len(), 2);
316        // First chunk should average tokens 0-3
317        // Second chunk should be token 4
318    }
319
320    #[test]
321    fn test_empty_inputs() {
322        let pooler = LateChunkingPooler::new(4);
323
324        let result = pooler.pool(&[], &[], 0);
325        assert!(result.is_empty());
326
327        let chunks = vec![Slab::new("test", 0, 4, 0)];
328
329        let result = pooler.pool(&[], &chunks, 4);
330        assert_eq!(result.len(), 1);
331        assert_eq!(result[0].len(), 4);
332    }
333}