Skip to main content

graphrag_core/text/
late_chunking.rs

1//! Late Chunking — context-preserving embeddings for RAG
2//!
3//! Standard RAG embeds each chunk in isolation, losing cross-chunk context.
4//! Late Chunking (Jina AI, 2024) fixes this by encoding the **whole document**
5//! first and then extracting per-chunk embeddings via span pooling:
6//!
7//! ```text
8//! Standard:  chunk₁ → embed₁   chunk₂ → embed₂   (context-blind)
9//! Late:       [chunk₁ | chunk₂ | …] → model → pool spans → embed₁, embed₂
10//! ```
11//!
12//! Each chunk's embedding "sees" the entire document during the attention pass,
13//! giving it +5-10% retrieval accuracy over standard chunking.
14//!
15//! ## Two usage modes
16//!
17//! 1. **`LateChunkingStrategy`** — a [`ChunkingStrategy`] that splits text and
18//!    records precise byte spans. Use this when you will pass the chunks to a
19//!    late-chunking-aware embedding provider separately.
20//!
21//! 2. **`JinaLateChunkingClient`** — calls the Jina embeddings API with
22//!    `late_chunking=true` to get document-context-aware embeddings directly.
23//!
24//! ## Model context limits
25//!
26//! | Model                  | Max tokens | Notes                          |
27//! |------------------------|------------|-------------------------------|
28//! | Jina v3 (default)      | 8 192      | Good for most documents        |
29//! | gte-Qwen2-7B-instruct  | 32 768     | Better quality, needs more GPU |
30//!
31//! For documents exceeding the limit use [`LateChunkingStrategy::split_into_sections`]
32//! to pre-divide the document and apply late chunking section-by-section.
33
34use crate::{
35    core::{ChunkId, ChunkingStrategy, DocumentId, GraphRAGError, TextChunk},
36    text::chunking::HierarchicalChunker,
37};
38use std::sync::atomic::{AtomicU64, Ordering};
39
40/// Global counter for generating unique late-chunking chunk IDs
41static LATE_CHUNK_COUNTER: AtomicU64 = AtomicU64::new(0);
42
43/// Configuration for the late chunking strategy
44#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
45pub struct LateChunkingConfig {
46    /// Target chunk size in characters
47    pub chunk_size: usize,
48
49    /// Chunk overlap in characters
50    pub chunk_overlap: usize,
51
52    /// Maximum document size in tokens before splitting into sections.
53    ///
54    /// - `8192` for Jina v3 (default)
55    /// - `32768` for gte-Qwen2-7B-instruct
56    pub max_doc_tokens: u32,
57
58    /// Annotate each chunk's `position_in_document` metadata field.
59    ///
60    /// Embedding providers can use this to apply position-aware pooling.
61    pub annotate_positions: bool,
62}
63
64impl Default for LateChunkingConfig {
65    fn default() -> Self {
66        Self {
67            chunk_size: 512,
68            chunk_overlap: 64,
69            max_doc_tokens: 8192, // Jina v3 default
70            annotate_positions: true,
71        }
72    }
73}
74
75/// Context-aware chunking strategy for use with late-chunking embedding models
76///
77/// Splits text using [`HierarchicalChunker`] and records precise byte-offset
78/// spans in each chunk's metadata. A late-chunking embedding provider
79/// (Jina API or a local `candle` model) can then use these spans to extract
80/// per-chunk representations from a single full-document forward pass.
81///
82/// # Examples
83///
84/// ```rust
85/// use graphrag_core::text::late_chunking::{LateChunkingStrategy, LateChunkingConfig};
86/// use graphrag_core::core::{ChunkingStrategy, DocumentId};
87///
88/// let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("doc-1".to_string()));
89/// let chunks = strategy.chunk("First paragraph.\n\nSecond paragraph.");
90///
91/// for chunk in &chunks {
92///     // position_in_document ∈ [0.0, 1.0] — used by embedding provider for pooling
93///     assert!(chunk.metadata.position_in_document.is_some());
94/// }
95/// ```
96pub struct LateChunkingStrategy {
97    config: LateChunkingConfig,
98    document_id: DocumentId,
99    inner: HierarchicalChunker,
100}
101
102impl LateChunkingStrategy {
103    /// Create a new late chunking strategy with explicit config
104    pub fn new(config: LateChunkingConfig, document_id: DocumentId) -> Self {
105        Self {
106            inner: HierarchicalChunker::new().with_min_size(50),
107            config,
108            document_id,
109        }
110    }
111
112    /// Create with default config (8192 token limit, 512-char chunks)
113    pub fn with_defaults(document_id: DocumentId) -> Self {
114        Self::new(LateChunkingConfig::default(), document_id)
115    }
116
117    /// Set the maximum document token limit (choose based on embedding model)
118    ///
119    /// - `8192`  → Jina v3
120    /// - `32768` → gte-Qwen2-7B-instruct
121    pub fn with_max_doc_tokens(mut self, max_tokens: u32) -> Self {
122        self.config.max_doc_tokens = max_tokens;
123        self
124    }
125
126    /// Estimate token count from character count (1 token ≈ 4 chars)
127    pub fn estimate_tokens(text: &str) -> u32 {
128        (text.len() / 4) as u32
129    }
130
131    /// Returns `true` if the document fits within the model's context window
132    pub fn fits_in_context(&self, text: &str) -> bool {
133        Self::estimate_tokens(text) <= self.config.max_doc_tokens
134    }
135
136    /// Split an oversized document into sections that fit within the context window
137    ///
138    /// Sections are formed by grouping paragraphs (double-newline boundaries)
139    /// until the next paragraph would exceed the limit. Each section can be
140    /// embedded independently with late chunking applied within it.
141    pub fn split_into_sections(&self, text: &str) -> Vec<String> {
142        if self.fits_in_context(text) {
143            return vec![text.to_string()];
144        }
145
146        let max_chars = (self.config.max_doc_tokens * 4) as usize;
147        let mut sections: Vec<String> = Vec::new();
148        let mut current = String::new();
149
150        for paragraph in text.split("\n\n") {
151            let needed = current.len() + if current.is_empty() { 0 } else { 2 } + paragraph.len();
152            if needed > max_chars && !current.is_empty() {
153                sections.push(current.trim().to_string());
154                current = String::new();
155            }
156            if !current.is_empty() {
157                current.push_str("\n\n");
158            }
159            current.push_str(paragraph);
160        }
161
162        if !current.trim().is_empty() {
163            sections.push(current.trim().to_string());
164        }
165
166        sections
167    }
168}
169
170impl ChunkingStrategy for LateChunkingStrategy {
171    fn chunk(&self, text: &str) -> Vec<TextChunk> {
172        let raw_chunks =
173            self.inner
174                .chunk_text(text, self.config.chunk_size, self.config.chunk_overlap);
175        let doc_len = text.len().max(1);
176        let mut chunks = Vec::with_capacity(raw_chunks.len());
177        let mut current_pos: usize = 0;
178
179        for chunk_content in raw_chunks {
180            if chunk_content.trim().is_empty() {
181                current_pos += chunk_content.len();
182                continue;
183            }
184
185            let chunk_id = ChunkId::new(format!(
186                "{}_lc_{}",
187                self.document_id,
188                LATE_CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst),
189            ));
190
191            let start = current_pos;
192            let end = start + chunk_content.len();
193            let mut chunk = TextChunk::new(
194                chunk_id,
195                self.document_id.clone(),
196                chunk_content.clone(),
197                start,
198                end,
199            );
200
201            // Record relative position so the embedding layer knows the span
202            if self.config.annotate_positions {
203                chunk.metadata.position_in_document = Some(start as f32 / doc_len as f32);
204            }
205
206            chunks.push(chunk);
207            current_pos = end;
208        }
209
210        chunks
211    }
212}
213
214/// Jina AI embeddings client with native late chunking support
215///
216/// Calls the Jina embeddings API with `late_chunking=true`. The API encodes
217/// the concatenated inputs as a single sequence and returns per-input embeddings
218/// where each embedding reflects the full-document context.
219///
220/// For fully **local** operation (no API key), configure Ollama with
221/// `rjmalagon/gte-qwen2-7b-instruct` — it provides excellent 32k-context
222/// embeddings without native late chunking but with a far larger context window.
223///
224/// # Examples
225///
226/// ```rust,no_run
227/// use graphrag_core::text::late_chunking::JinaLateChunkingClient;
228///
229/// # async fn example(chunks: &[graphrag_core::TextChunk]) -> graphrag_core::Result<()> {
230/// let client = JinaLateChunkingClient::new("jina_xxxx".to_string());
231/// let embeddings = client.embed_with_late_chunking(chunks).await?;
232/// assert_eq!(embeddings.len(), chunks.len());
233/// # Ok(())
234/// # }
235/// ```
236#[derive(Debug, Clone)]
237pub struct JinaLateChunkingClient {
238    api_key: String,
239    /// Model name (default: `"jina-embeddings-v3"`)
240    model: String,
241}
242
243impl JinaLateChunkingClient {
244    const ENDPOINT: &'static str = "https://api.jina.ai/v1/embeddings";
245
246    /// Create a new client with a Jina API key
247    pub fn new(api_key: impl Into<String>) -> Self {
248        Self {
249            api_key: api_key.into(),
250            model: "jina-embeddings-v3".to_string(),
251        }
252    }
253
254    /// Override the embedding model
255    pub fn with_model(mut self, model: impl Into<String>) -> Self {
256        self.model = model.into();
257        self
258    }
259
260    /// Embed chunks using Jina's late chunking API
261    ///
262    /// Sends all chunk contents in a single request with `late_chunking: true`.
263    /// The Jina API encodes them as one sequence so each chunk's embedding
264    /// incorporates the full document context.
265    ///
266    /// Returns one embedding vector per chunk, in the same order as the input.
267    #[cfg(feature = "ureq")]
268    pub async fn embed_with_late_chunking(
269        &self,
270        chunks: &[TextChunk],
271    ) -> crate::Result<Vec<Vec<f32>>> {
272        let inputs: Vec<&str> = chunks.iter().map(|c| c.content.as_str()).collect();
273
274        let body = serde_json::json!({
275            "model": self.model,
276            "input": inputs,
277            "late_chunking": true,
278        });
279
280        let agent = ureq::AgentBuilder::new().build();
281        let response = agent
282            .post(Self::ENDPOINT)
283            .set("Authorization", &format!("Bearer {}", self.api_key))
284            .set("Content-Type", "application/json")
285            .send_json(&body)
286            .map_err(|e| GraphRAGError::Generation {
287                message: format!("Jina API request failed: {e}"),
288            })?;
289
290        let json: serde_json::Value =
291            response
292                .into_json()
293                .map_err(|e| GraphRAGError::Generation {
294                    message: format!("Failed to parse Jina API response: {e}"),
295                })?;
296
297        let data = json["data"]
298            .as_array()
299            .ok_or_else(|| GraphRAGError::Generation {
300                message: "Invalid Jina API response: missing 'data' array".to_string(),
301            })?;
302
303        let embeddings = data
304            .iter()
305            .map(|item| {
306                item["embedding"]
307                    .as_array()
308                    .unwrap_or(&vec![])
309                    .iter()
310                    .map(|v| v.as_f64().unwrap_or(0.0) as f32)
311                    .collect::<Vec<f32>>()
312            })
313            .collect::<Vec<_>>();
314
315        Ok(embeddings)
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate::core::DocumentId;
323
324    #[test]
325    fn test_late_chunking_produces_chunks_with_position() {
326        let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("test-doc".to_string()));
327
328        let text = "First paragraph about machine learning.\n\n\
329             Second paragraph about deep learning.\n\n\
330             Third paragraph about neural networks.";
331
332        let chunks = strategy.chunk(text);
333        assert!(!chunks.is_empty());
334
335        // Every chunk should have a position annotation
336        for chunk in &chunks {
337            assert!(
338                chunk.metadata.position_in_document.is_some(),
339                "chunk {} missing position metadata",
340                chunk.id
341            );
342        }
343    }
344
345    #[test]
346    fn test_chunk_ids_have_lc_suffix() {
347        let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("doc".to_string()));
348        let chunks = strategy.chunk("Some text to chunk into pieces here.");
349        for chunk in &chunks {
350            assert!(
351                chunk.id.0.contains("_lc_"),
352                "Expected '_lc_' in ID: {}",
353                chunk.id
354            );
355        }
356    }
357
358    #[test]
359    fn test_fits_in_context() {
360        let config = LateChunkingConfig {
361            max_doc_tokens: 10,
362            ..Default::default()
363        };
364        let strategy = LateChunkingStrategy::new(config, DocumentId::new("d".to_string()));
365
366        assert!(strategy.fits_in_context("tiny")); // 4 chars → 1 token
367        assert!(!strategy.fits_in_context(&"x".repeat(100))); // 100 chars → 25 tokens
368    }
369
370    #[test]
371    fn test_split_into_sections_short_doc() {
372        let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("d".to_string()));
373        let text = "Short document.";
374        let sections = strategy.split_into_sections(text);
375        assert_eq!(sections.len(), 1);
376        assert_eq!(sections[0], text);
377    }
378
379    #[test]
380    fn test_split_into_sections_long_doc() {
381        let config = LateChunkingConfig {
382            max_doc_tokens: 5, // 20 chars max
383            ..Default::default()
384        };
385        let strategy = LateChunkingStrategy::new(config, DocumentId::new("d".to_string()));
386
387        // Each paragraph is ~15 chars, exceeding the 20-char section limit when combined
388        let text = "Paragraph one.\n\nParagraph two.\n\nParagraph three.";
389        let sections = strategy.split_into_sections(text);
390        // Should be split into multiple sections
391        assert!(
392            sections.len() > 1,
393            "Expected multiple sections, got {}",
394            sections.len()
395        );
396        // All content should be present
397        let combined = sections.join(" ");
398        assert!(combined.contains("Paragraph one"));
399        assert!(combined.contains("Paragraph two"));
400        assert!(combined.contains("Paragraph three"));
401    }
402
403    #[test]
404    fn test_estimate_tokens() {
405        assert_eq!(LateChunkingStrategy::estimate_tokens(&"a".repeat(400)), 100);
406        assert_eq!(LateChunkingStrategy::estimate_tokens(""), 0);
407    }
408}