Skip to main content

graphrag_core/text/
late_chunking.rs

1#![cfg_attr(not(feature = "async"), allow(unused_imports))]
2
3//! Late Chunking — context-preserving embeddings for RAG
4//!
5//! Standard RAG embeds each chunk in isolation, losing cross-chunk context.
6//! Late Chunking (Jina AI, 2024) fixes this by encoding the **whole document**
7//! first and then extracting per-chunk embeddings via span pooling:
8//!
9//! ```text
10//! Standard:  chunk₁ → embed₁   chunk₂ → embed₂   (context-blind)
11//! Late:       [chunk₁ | chunk₂ | …] → model → pool spans → embed₁, embed₂
12//! ```
13//!
14//! Each chunk's embedding "sees" the entire document during the attention pass,
15//! giving it +5-10% retrieval accuracy over standard chunking.
16//!
17//! ## Two usage modes
18//!
19//! 1. **`LateChunkingStrategy`** — a [`ChunkingStrategy`] that splits text and
20//!    records precise byte spans. Use this when you will pass the chunks to a
21//!    late-chunking-aware embedding provider separately.
22//!
23//! 2. **`JinaLateChunkingClient`** — calls the Jina embeddings API with
24//!    `late_chunking=true` to get document-context-aware embeddings directly.
25//!
26//! ## Model context limits
27//!
28//! | Model                  | Max tokens | Notes                          |
29//! |------------------------|------------|-------------------------------|
30//! | Jina v3 (default)      | 8 192      | Good for most documents        |
31//! | gte-Qwen2-7B-instruct  | 32 768     | Better quality, needs more GPU |
32//!
33//! For documents exceeding the limit use [`LateChunkingStrategy::split_into_sections`]
34//! to pre-divide the document and apply late chunking section-by-section.
35
36use crate::{
37    core::{ChunkId, ChunkingStrategy, DocumentId, GraphRAGError, TextChunk},
38    text::chunking::HierarchicalChunker,
39};
40use std::sync::atomic::{AtomicU64, Ordering};
41
42/// Global counter for generating unique late-chunking chunk IDs
43static LATE_CHUNK_COUNTER: AtomicU64 = AtomicU64::new(0);
44
45/// Configuration for the late chunking strategy
46#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
47pub struct LateChunkingConfig {
48    /// Target chunk size in characters
49    pub chunk_size: usize,
50
51    /// Chunk overlap in characters
52    pub chunk_overlap: usize,
53
54    /// Maximum document size in tokens before splitting into sections.
55    ///
56    /// - `8192` for Jina v3 (default)
57    /// - `32768` for gte-Qwen2-7B-instruct
58    pub max_doc_tokens: u32,
59
60    /// Annotate each chunk's `position_in_document` metadata field.
61    ///
62    /// Embedding providers can use this to apply position-aware pooling.
63    pub annotate_positions: bool,
64}
65
66impl Default for LateChunkingConfig {
67    fn default() -> Self {
68        Self {
69            chunk_size: 512,
70            chunk_overlap: 64,
71            max_doc_tokens: 8192, // Jina v3 default
72            annotate_positions: true,
73        }
74    }
75}
76
77/// Context-aware chunking strategy for use with late-chunking embedding models
78///
79/// Splits text using [`HierarchicalChunker`] and records precise byte-offset
80/// spans in each chunk's metadata. A late-chunking embedding provider
81/// (Jina API or a local `candle` model) can then use these spans to extract
82/// per-chunk representations from a single full-document forward pass.
83///
84/// # Examples
85///
86/// ```rust
87/// use graphrag_core::text::late_chunking::{LateChunkingStrategy, LateChunkingConfig};
88/// use graphrag_core::core::{ChunkingStrategy, DocumentId};
89///
90/// let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("doc-1".to_string()));
91/// let chunks = strategy.chunk("First paragraph.\n\nSecond paragraph.");
92///
93/// for chunk in &chunks {
94///     // position_in_document ∈ [0.0, 1.0] — used by embedding provider for pooling
95///     assert!(chunk.metadata.position_in_document.is_some());
96/// }
97/// ```
98pub struct LateChunkingStrategy {
99    config: LateChunkingConfig,
100    document_id: DocumentId,
101    inner: HierarchicalChunker,
102}
103
104impl LateChunkingStrategy {
105    /// Create a new late chunking strategy with explicit config
106    pub fn new(config: LateChunkingConfig, document_id: DocumentId) -> Self {
107        Self {
108            inner: HierarchicalChunker::new().with_min_size(50),
109            config,
110            document_id,
111        }
112    }
113
114    /// Create with default config (8192 token limit, 512-char chunks)
115    pub fn with_defaults(document_id: DocumentId) -> Self {
116        Self::new(LateChunkingConfig::default(), document_id)
117    }
118
119    /// Set the maximum document token limit (choose based on embedding model)
120    ///
121    /// - `8192`  → Jina v3
122    /// - `32768` → gte-Qwen2-7B-instruct
123    pub fn with_max_doc_tokens(mut self, max_tokens: u32) -> Self {
124        self.config.max_doc_tokens = max_tokens;
125        self
126    }
127
128    /// Estimate token count from character count (1 token ≈ 4 chars)
129    pub fn estimate_tokens(text: &str) -> u32 {
130        (text.len() / 4) as u32
131    }
132
133    /// Returns `true` if the document fits within the model's context window
134    pub fn fits_in_context(&self, text: &str) -> bool {
135        Self::estimate_tokens(text) <= self.config.max_doc_tokens
136    }
137
138    /// Split an oversized document into sections that fit within the context window
139    ///
140    /// Sections are formed by grouping paragraphs (double-newline boundaries)
141    /// until the next paragraph would exceed the limit. Each section can be
142    /// embedded independently with late chunking applied within it.
143    pub fn split_into_sections(&self, text: &str) -> Vec<String> {
144        if self.fits_in_context(text) {
145            return vec![text.to_string()];
146        }
147
148        let max_chars = (self.config.max_doc_tokens * 4) as usize;
149        let mut sections: Vec<String> = Vec::new();
150        let mut current = String::new();
151
152        for paragraph in text.split("\n\n") {
153            let needed = current.len() + if current.is_empty() { 0 } else { 2 } + paragraph.len();
154            if needed > max_chars && !current.is_empty() {
155                sections.push(current.trim().to_string());
156                current = String::new();
157            }
158            if !current.is_empty() {
159                current.push_str("\n\n");
160            }
161            current.push_str(paragraph);
162        }
163
164        if !current.trim().is_empty() {
165            sections.push(current.trim().to_string());
166        }
167
168        sections
169    }
170}
171
172impl ChunkingStrategy for LateChunkingStrategy {
173    fn chunk(&self, text: &str) -> Vec<TextChunk> {
174        let raw_chunks =
175            self.inner
176                .chunk_text(text, self.config.chunk_size, self.config.chunk_overlap);
177        let doc_len = text.len().max(1);
178        let mut chunks = Vec::with_capacity(raw_chunks.len());
179        let mut current_pos: usize = 0;
180
181        for chunk_content in raw_chunks {
182            if chunk_content.trim().is_empty() {
183                current_pos += chunk_content.len();
184                continue;
185            }
186
187            let chunk_id = ChunkId::new(format!(
188                "{}_lc_{}",
189                self.document_id,
190                LATE_CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst),
191            ));
192
193            let start = current_pos;
194            let end = start + chunk_content.len();
195            let mut chunk = TextChunk::new(
196                chunk_id,
197                self.document_id.clone(),
198                chunk_content.clone(),
199                start,
200                end,
201            );
202
203            // Record relative position so the embedding layer knows the span
204            if self.config.annotate_positions {
205                chunk.metadata.position_in_document = Some(start as f32 / doc_len as f32);
206            }
207
208            chunks.push(chunk);
209            current_pos = end;
210        }
211
212        chunks
213    }
214}
215
216/// Jina AI embeddings client with native late chunking support
217///
218/// Calls the Jina embeddings API with `late_chunking=true`. The API encodes
219/// the concatenated inputs as a single sequence and returns per-input embeddings
220/// where each embedding reflects the full-document context.
221///
222/// For fully **local** operation (no API key), configure Ollama with
223/// `rjmalagon/gte-qwen2-7b-instruct` — it provides excellent 32k-context
224/// embeddings without native late chunking but with a far larger context window.
225///
226/// # Examples
227///
228/// ```rust,no_run
229/// use graphrag_core::text::late_chunking::JinaLateChunkingClient;
230///
231/// # async fn example(chunks: &[graphrag_core::TextChunk]) -> graphrag_core::Result<()> {
232/// let client = JinaLateChunkingClient::new("jina_xxxx".to_string());
233/// let embeddings = client.embed_with_late_chunking(chunks).await?;
234/// assert_eq!(embeddings.len(), chunks.len());
235/// # Ok(())
236/// # }
237/// ```
238#[derive(Debug, Clone)]
239pub struct JinaLateChunkingClient {
240    #[cfg_attr(not(feature = "async"), allow(dead_code))]
241    api_key: String,
242    /// Model name (default: `"jina-embeddings-v3"`)
243    model: String,
244}
245
246impl JinaLateChunkingClient {
247    #[cfg(feature = "async")]
248    const ENDPOINT: &'static str = "https://api.jina.ai/v1/embeddings";
249
250    /// Create a new client with a Jina API key
251    pub fn new(api_key: impl Into<String>) -> Self {
252        Self {
253            api_key: api_key.into(),
254            model: "jina-embeddings-v3".to_string(),
255        }
256    }
257
258    /// Override the embedding model
259    pub fn with_model(mut self, model: impl Into<String>) -> Self {
260        self.model = model.into();
261        self
262    }
263
264    /// Embed chunks using Jina's late chunking API
265    ///
266    /// Sends all chunk contents in a single request with `late_chunking: true`.
267    /// The Jina API encodes them as one sequence so each chunk's embedding
268    /// incorporates the full document context.
269    ///
270    /// Returns one embedding vector per chunk, in the same order as the input.
271    #[cfg(feature = "ureq")]
272    pub async fn embed_with_late_chunking(
273        &self,
274        chunks: &[TextChunk],
275    ) -> crate::Result<Vec<Vec<f32>>> {
276        let inputs: Vec<&str> = chunks.iter().map(|c| c.content.as_str()).collect();
277
278        let body = serde_json::json!({
279            "model": self.model,
280            "input": inputs,
281            "late_chunking": true,
282        });
283
284        let agent = ureq::AgentBuilder::new().build();
285        let response = agent
286            .post(Self::ENDPOINT)
287            .set("Authorization", &format!("Bearer {}", self.api_key))
288            .set("Content-Type", "application/json")
289            .send_json(&body)
290            .map_err(|e| GraphRAGError::Generation {
291                message: format!("Jina API request failed: {e}"),
292            })?;
293
294        let json: serde_json::Value =
295            response
296                .into_json()
297                .map_err(|e| GraphRAGError::Generation {
298                    message: format!("Failed to parse Jina API response: {e}"),
299                })?;
300
301        let response_data = json["data"]
302            .as_array()
303            .ok_or_else(|| GraphRAGError::Generation {
304                message: "Invalid Jina API response: missing 'data' array".to_string(),
305            })?;
306
307        let embeddings = response_data
308            .iter()
309            .map(|item| {
310                item["embedding"]
311                    .as_array()
312                    .unwrap_or(&vec![])
313                    .iter()
314                    .map(|v| v.as_f64().unwrap_or(0.0) as f32)
315                    .collect::<Vec<f32>>()
316            })
317            .collect::<Vec<_>>();
318
319        Ok(embeddings)
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use crate::core::DocumentId;
327
328    #[test]
329    fn test_late_chunking_produces_chunks_with_position() {
330        let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("test-doc".to_string()));
331
332        let text = "First paragraph about machine learning.\n\n\
333             Second paragraph about deep learning.\n\n\
334             Third paragraph about neural networks.";
335
336        let chunks = strategy.chunk(text);
337        assert!(!chunks.is_empty());
338
339        // Every chunk should have a position annotation
340        for chunk in &chunks {
341            assert!(
342                chunk.metadata.position_in_document.is_some(),
343                "chunk {} missing position metadata",
344                chunk.id
345            );
346        }
347    }
348
349    #[test]
350    fn test_chunk_ids_have_lc_suffix() {
351        let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("doc".to_string()));
352        let chunks = strategy.chunk("Some text to chunk into pieces here.");
353        for chunk in &chunks {
354            assert!(
355                chunk.id.0.contains("_lc_"),
356                "Expected '_lc_' in ID: {}",
357                chunk.id
358            );
359        }
360    }
361
362    #[test]
363    fn test_fits_in_context() {
364        let config = LateChunkingConfig {
365            max_doc_tokens: 10,
366            ..Default::default()
367        };
368        let strategy = LateChunkingStrategy::new(config, DocumentId::new("d".to_string()));
369
370        assert!(strategy.fits_in_context("tiny")); // 4 chars → 1 token
371        assert!(!strategy.fits_in_context(&"x".repeat(100))); // 100 chars → 25 tokens
372    }
373
374    #[test]
375    fn test_split_into_sections_short_doc() {
376        let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("d".to_string()));
377        let text = "Short document.";
378        let sections = strategy.split_into_sections(text);
379        assert_eq!(sections.len(), 1);
380        assert_eq!(sections[0], text);
381    }
382
383    #[test]
384    fn test_split_into_sections_long_doc() {
385        let config = LateChunkingConfig {
386            max_doc_tokens: 5, // 20 chars max
387            ..Default::default()
388        };
389        let strategy = LateChunkingStrategy::new(config, DocumentId::new("d".to_string()));
390
391        // Each paragraph is ~15 chars, exceeding the 20-char section limit when combined
392        let text = "Paragraph one.\n\nParagraph two.\n\nParagraph three.";
393        let sections = strategy.split_into_sections(text);
394        // Should be split into multiple sections
395        assert!(
396            sections.len() > 1,
397            "Expected multiple sections, got {}",
398            sections.len()
399        );
400        // All content should be present
401        let combined = sections.join(" ");
402        assert!(combined.contains("Paragraph one"));
403        assert!(combined.contains("Paragraph two"));
404        assert!(combined.contains("Paragraph three"));
405    }
406
407    #[test]
408    fn test_estimate_tokens() {
409        assert_eq!(LateChunkingStrategy::estimate_tokens(&"a".repeat(400)), 100);
410        assert_eq!(LateChunkingStrategy::estimate_tokens(""), 0);
411    }
412}