graphrag_core/text/
late_chunking.rs1#![cfg_attr(not(feature = "async"), allow(unused_imports))]
2
3use crate::{
37 core::{ChunkId, ChunkingStrategy, DocumentId, GraphRAGError, TextChunk},
38 text::chunking::HierarchicalChunker,
39};
40use std::sync::atomic::{AtomicU64, Ordering};
41
42static LATE_CHUNK_COUNTER: AtomicU64 = AtomicU64::new(0);
44
45#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
47pub struct LateChunkingConfig {
48 pub chunk_size: usize,
50
51 pub chunk_overlap: usize,
53
54 pub max_doc_tokens: u32,
59
60 pub annotate_positions: bool,
64}
65
66impl Default for LateChunkingConfig {
67 fn default() -> Self {
68 Self {
69 chunk_size: 512,
70 chunk_overlap: 64,
71 max_doc_tokens: 8192, annotate_positions: true,
73 }
74 }
75}
76
77pub struct LateChunkingStrategy {
99 config: LateChunkingConfig,
100 document_id: DocumentId,
101 inner: HierarchicalChunker,
102}
103
104impl LateChunkingStrategy {
105 pub fn new(config: LateChunkingConfig, document_id: DocumentId) -> Self {
107 Self {
108 inner: HierarchicalChunker::new().with_min_size(50),
109 config,
110 document_id,
111 }
112 }
113
114 pub fn with_defaults(document_id: DocumentId) -> Self {
116 Self::new(LateChunkingConfig::default(), document_id)
117 }
118
119 pub fn with_max_doc_tokens(mut self, max_tokens: u32) -> Self {
124 self.config.max_doc_tokens = max_tokens;
125 self
126 }
127
128 pub fn estimate_tokens(text: &str) -> u32 {
130 (text.len() / 4) as u32
131 }
132
133 pub fn fits_in_context(&self, text: &str) -> bool {
135 Self::estimate_tokens(text) <= self.config.max_doc_tokens
136 }
137
138 pub fn split_into_sections(&self, text: &str) -> Vec<String> {
144 if self.fits_in_context(text) {
145 return vec![text.to_string()];
146 }
147
148 let max_chars = (self.config.max_doc_tokens * 4) as usize;
149 let mut sections: Vec<String> = Vec::new();
150 let mut current = String::new();
151
152 for paragraph in text.split("\n\n") {
153 let needed = current.len() + if current.is_empty() { 0 } else { 2 } + paragraph.len();
154 if needed > max_chars && !current.is_empty() {
155 sections.push(current.trim().to_string());
156 current = String::new();
157 }
158 if !current.is_empty() {
159 current.push_str("\n\n");
160 }
161 current.push_str(paragraph);
162 }
163
164 if !current.trim().is_empty() {
165 sections.push(current.trim().to_string());
166 }
167
168 sections
169 }
170}
171
172impl ChunkingStrategy for LateChunkingStrategy {
173 fn chunk(&self, text: &str) -> Vec<TextChunk> {
174 let raw_chunks =
175 self.inner
176 .chunk_text(text, self.config.chunk_size, self.config.chunk_overlap);
177 let doc_len = text.len().max(1);
178 let mut chunks = Vec::with_capacity(raw_chunks.len());
179 let mut current_pos: usize = 0;
180
181 for chunk_content in raw_chunks {
182 if chunk_content.trim().is_empty() {
183 current_pos += chunk_content.len();
184 continue;
185 }
186
187 let chunk_id = ChunkId::new(format!(
188 "{}_lc_{}",
189 self.document_id,
190 LATE_CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst),
191 ));
192
193 let start = current_pos;
194 let end = start + chunk_content.len();
195 let mut chunk = TextChunk::new(
196 chunk_id,
197 self.document_id.clone(),
198 chunk_content.clone(),
199 start,
200 end,
201 );
202
203 if self.config.annotate_positions {
205 chunk.metadata.position_in_document = Some(start as f32 / doc_len as f32);
206 }
207
208 chunks.push(chunk);
209 current_pos = end;
210 }
211
212 chunks
213 }
214}
215
216#[derive(Debug, Clone)]
239pub struct JinaLateChunkingClient {
240 #[cfg_attr(not(feature = "async"), allow(dead_code))]
241 api_key: String,
242 model: String,
244}
245
246impl JinaLateChunkingClient {
247 #[cfg(feature = "async")]
248 const ENDPOINT: &'static str = "https://api.jina.ai/v1/embeddings";
249
250 pub fn new(api_key: impl Into<String>) -> Self {
252 Self {
253 api_key: api_key.into(),
254 model: "jina-embeddings-v3".to_string(),
255 }
256 }
257
258 pub fn with_model(mut self, model: impl Into<String>) -> Self {
260 self.model = model.into();
261 self
262 }
263
264 #[cfg(feature = "ureq")]
272 pub async fn embed_with_late_chunking(
273 &self,
274 chunks: &[TextChunk],
275 ) -> crate::Result<Vec<Vec<f32>>> {
276 let inputs: Vec<&str> = chunks.iter().map(|c| c.content.as_str()).collect();
277
278 let body = serde_json::json!({
279 "model": self.model,
280 "input": inputs,
281 "late_chunking": true,
282 });
283
284 let agent = ureq::AgentBuilder::new().build();
285 let response = agent
286 .post(Self::ENDPOINT)
287 .set("Authorization", &format!("Bearer {}", self.api_key))
288 .set("Content-Type", "application/json")
289 .send_json(&body)
290 .map_err(|e| GraphRAGError::Generation {
291 message: format!("Jina API request failed: {e}"),
292 })?;
293
294 let json: serde_json::Value =
295 response
296 .into_json()
297 .map_err(|e| GraphRAGError::Generation {
298 message: format!("Failed to parse Jina API response: {e}"),
299 })?;
300
301 let response_data = json["data"]
302 .as_array()
303 .ok_or_else(|| GraphRAGError::Generation {
304 message: "Invalid Jina API response: missing 'data' array".to_string(),
305 })?;
306
307 let embeddings = response_data
308 .iter()
309 .map(|item| {
310 item["embedding"]
311 .as_array()
312 .unwrap_or(&vec![])
313 .iter()
314 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
315 .collect::<Vec<f32>>()
316 })
317 .collect::<Vec<_>>();
318
319 Ok(embeddings)
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use crate::core::DocumentId;
327
328 #[test]
329 fn test_late_chunking_produces_chunks_with_position() {
330 let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("test-doc".to_string()));
331
332 let text = "First paragraph about machine learning.\n\n\
333 Second paragraph about deep learning.\n\n\
334 Third paragraph about neural networks.";
335
336 let chunks = strategy.chunk(text);
337 assert!(!chunks.is_empty());
338
339 for chunk in &chunks {
341 assert!(
342 chunk.metadata.position_in_document.is_some(),
343 "chunk {} missing position metadata",
344 chunk.id
345 );
346 }
347 }
348
349 #[test]
350 fn test_chunk_ids_have_lc_suffix() {
351 let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("doc".to_string()));
352 let chunks = strategy.chunk("Some text to chunk into pieces here.");
353 for chunk in &chunks {
354 assert!(
355 chunk.id.0.contains("_lc_"),
356 "Expected '_lc_' in ID: {}",
357 chunk.id
358 );
359 }
360 }
361
362 #[test]
363 fn test_fits_in_context() {
364 let config = LateChunkingConfig {
365 max_doc_tokens: 10,
366 ..Default::default()
367 };
368 let strategy = LateChunkingStrategy::new(config, DocumentId::new("d".to_string()));
369
370 assert!(strategy.fits_in_context("tiny")); assert!(!strategy.fits_in_context(&"x".repeat(100))); }
373
374 #[test]
375 fn test_split_into_sections_short_doc() {
376 let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("d".to_string()));
377 let text = "Short document.";
378 let sections = strategy.split_into_sections(text);
379 assert_eq!(sections.len(), 1);
380 assert_eq!(sections[0], text);
381 }
382
383 #[test]
384 fn test_split_into_sections_long_doc() {
385 let config = LateChunkingConfig {
386 max_doc_tokens: 5, ..Default::default()
388 };
389 let strategy = LateChunkingStrategy::new(config, DocumentId::new("d".to_string()));
390
391 let text = "Paragraph one.\n\nParagraph two.\n\nParagraph three.";
393 let sections = strategy.split_into_sections(text);
394 assert!(
396 sections.len() > 1,
397 "Expected multiple sections, got {}",
398 sections.len()
399 );
400 let combined = sections.join(" ");
402 assert!(combined.contains("Paragraph one"));
403 assert!(combined.contains("Paragraph two"));
404 assert!(combined.contains("Paragraph three"));
405 }
406
407 #[test]
408 fn test_estimate_tokens() {
409 assert_eq!(LateChunkingStrategy::estimate_tokens(&"a".repeat(400)), 100);
410 assert_eq!(LateChunkingStrategy::estimate_tokens(""), 0);
411 }
412}