graphrag_core/text/
late_chunking.rs1use crate::{
35 core::{ChunkId, ChunkingStrategy, DocumentId, GraphRAGError, TextChunk},
36 text::chunking::HierarchicalChunker,
37};
38use std::sync::atomic::{AtomicU64, Ordering};
39
40static LATE_CHUNK_COUNTER: AtomicU64 = AtomicU64::new(0);
42
43#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
45pub struct LateChunkingConfig {
46 pub chunk_size: usize,
48
49 pub chunk_overlap: usize,
51
52 pub max_doc_tokens: u32,
57
58 pub annotate_positions: bool,
62}
63
64impl Default for LateChunkingConfig {
65 fn default() -> Self {
66 Self {
67 chunk_size: 512,
68 chunk_overlap: 64,
69 max_doc_tokens: 8192, annotate_positions: true,
71 }
72 }
73}
74
75pub struct LateChunkingStrategy {
97 config: LateChunkingConfig,
98 document_id: DocumentId,
99 inner: HierarchicalChunker,
100}
101
102impl LateChunkingStrategy {
103 pub fn new(config: LateChunkingConfig, document_id: DocumentId) -> Self {
105 Self {
106 inner: HierarchicalChunker::new().with_min_size(50),
107 config,
108 document_id,
109 }
110 }
111
112 pub fn with_defaults(document_id: DocumentId) -> Self {
114 Self::new(LateChunkingConfig::default(), document_id)
115 }
116
117 pub fn with_max_doc_tokens(mut self, max_tokens: u32) -> Self {
122 self.config.max_doc_tokens = max_tokens;
123 self
124 }
125
126 pub fn estimate_tokens(text: &str) -> u32 {
128 (text.len() / 4) as u32
129 }
130
131 pub fn fits_in_context(&self, text: &str) -> bool {
133 Self::estimate_tokens(text) <= self.config.max_doc_tokens
134 }
135
136 pub fn split_into_sections(&self, text: &str) -> Vec<String> {
142 if self.fits_in_context(text) {
143 return vec![text.to_string()];
144 }
145
146 let max_chars = (self.config.max_doc_tokens * 4) as usize;
147 let mut sections: Vec<String> = Vec::new();
148 let mut current = String::new();
149
150 for paragraph in text.split("\n\n") {
151 let needed = current.len() + if current.is_empty() { 0 } else { 2 } + paragraph.len();
152 if needed > max_chars && !current.is_empty() {
153 sections.push(current.trim().to_string());
154 current = String::new();
155 }
156 if !current.is_empty() {
157 current.push_str("\n\n");
158 }
159 current.push_str(paragraph);
160 }
161
162 if !current.trim().is_empty() {
163 sections.push(current.trim().to_string());
164 }
165
166 sections
167 }
168}
169
170impl ChunkingStrategy for LateChunkingStrategy {
171 fn chunk(&self, text: &str) -> Vec<TextChunk> {
172 let raw_chunks =
173 self.inner
174 .chunk_text(text, self.config.chunk_size, self.config.chunk_overlap);
175 let doc_len = text.len().max(1);
176 let mut chunks = Vec::with_capacity(raw_chunks.len());
177 let mut current_pos: usize = 0;
178
179 for chunk_content in raw_chunks {
180 if chunk_content.trim().is_empty() {
181 current_pos += chunk_content.len();
182 continue;
183 }
184
185 let chunk_id = ChunkId::new(format!(
186 "{}_lc_{}",
187 self.document_id,
188 LATE_CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst),
189 ));
190
191 let start = current_pos;
192 let end = start + chunk_content.len();
193 let mut chunk = TextChunk::new(
194 chunk_id,
195 self.document_id.clone(),
196 chunk_content.clone(),
197 start,
198 end,
199 );
200
201 if self.config.annotate_positions {
203 chunk.metadata.position_in_document = Some(start as f32 / doc_len as f32);
204 }
205
206 chunks.push(chunk);
207 current_pos = end;
208 }
209
210 chunks
211 }
212}
213
214#[derive(Debug, Clone)]
237pub struct JinaLateChunkingClient {
238 api_key: String,
239 model: String,
241}
242
243impl JinaLateChunkingClient {
244 const ENDPOINT: &'static str = "https://api.jina.ai/v1/embeddings";
245
246 pub fn new(api_key: impl Into<String>) -> Self {
248 Self {
249 api_key: api_key.into(),
250 model: "jina-embeddings-v3".to_string(),
251 }
252 }
253
254 pub fn with_model(mut self, model: impl Into<String>) -> Self {
256 self.model = model.into();
257 self
258 }
259
260 #[cfg(feature = "ureq")]
268 pub async fn embed_with_late_chunking(
269 &self,
270 chunks: &[TextChunk],
271 ) -> crate::Result<Vec<Vec<f32>>> {
272 let inputs: Vec<&str> = chunks.iter().map(|c| c.content.as_str()).collect();
273
274 let body = serde_json::json!({
275 "model": self.model,
276 "input": inputs,
277 "late_chunking": true,
278 });
279
280 let agent = ureq::AgentBuilder::new().build();
281 let response = agent
282 .post(Self::ENDPOINT)
283 .set("Authorization", &format!("Bearer {}", self.api_key))
284 .set("Content-Type", "application/json")
285 .send_json(&body)
286 .map_err(|e| GraphRAGError::Generation {
287 message: format!("Jina API request failed: {e}"),
288 })?;
289
290 let json: serde_json::Value =
291 response
292 .into_json()
293 .map_err(|e| GraphRAGError::Generation {
294 message: format!("Failed to parse Jina API response: {e}"),
295 })?;
296
297 let data = json["data"]
298 .as_array()
299 .ok_or_else(|| GraphRAGError::Generation {
300 message: "Invalid Jina API response: missing 'data' array".to_string(),
301 })?;
302
303 let embeddings = data
304 .iter()
305 .map(|item| {
306 item["embedding"]
307 .as_array()
308 .unwrap_or(&vec![])
309 .iter()
310 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
311 .collect::<Vec<f32>>()
312 })
313 .collect::<Vec<_>>();
314
315 Ok(embeddings)
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate::core::DocumentId;
323
324 #[test]
325 fn test_late_chunking_produces_chunks_with_position() {
326 let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("test-doc".to_string()));
327
328 let text = "First paragraph about machine learning.\n\n\
329 Second paragraph about deep learning.\n\n\
330 Third paragraph about neural networks.";
331
332 let chunks = strategy.chunk(text);
333 assert!(!chunks.is_empty());
334
335 for chunk in &chunks {
337 assert!(
338 chunk.metadata.position_in_document.is_some(),
339 "chunk {} missing position metadata",
340 chunk.id
341 );
342 }
343 }
344
345 #[test]
346 fn test_chunk_ids_have_lc_suffix() {
347 let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("doc".to_string()));
348 let chunks = strategy.chunk("Some text to chunk into pieces here.");
349 for chunk in &chunks {
350 assert!(
351 chunk.id.0.contains("_lc_"),
352 "Expected '_lc_' in ID: {}",
353 chunk.id
354 );
355 }
356 }
357
358 #[test]
359 fn test_fits_in_context() {
360 let config = LateChunkingConfig {
361 max_doc_tokens: 10,
362 ..Default::default()
363 };
364 let strategy = LateChunkingStrategy::new(config, DocumentId::new("d".to_string()));
365
366 assert!(strategy.fits_in_context("tiny")); assert!(!strategy.fits_in_context(&"x".repeat(100))); }
369
370 #[test]
371 fn test_split_into_sections_short_doc() {
372 let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("d".to_string()));
373 let text = "Short document.";
374 let sections = strategy.split_into_sections(text);
375 assert_eq!(sections.len(), 1);
376 assert_eq!(sections[0], text);
377 }
378
379 #[test]
380 fn test_split_into_sections_long_doc() {
381 let config = LateChunkingConfig {
382 max_doc_tokens: 5, ..Default::default()
384 };
385 let strategy = LateChunkingStrategy::new(config, DocumentId::new("d".to_string()));
386
387 let text = "Paragraph one.\n\nParagraph two.\n\nParagraph three.";
389 let sections = strategy.split_into_sections(text);
390 assert!(
392 sections.len() > 1,
393 "Expected multiple sections, got {}",
394 sections.len()
395 );
396 let combined = sections.join(" ");
398 assert!(combined.contains("Paragraph one"));
399 assert!(combined.contains("Paragraph two"));
400 assert!(combined.contains("Paragraph three"));
401 }
402
403 #[test]
404 fn test_estimate_tokens() {
405 assert_eq!(LateChunkingStrategy::estimate_tokens(&"a".repeat(400)), 100);
406 assert_eq!(LateChunkingStrategy::estimate_tokens(""), 0);
407 }
408}