#![cfg_attr(not(feature = "async"), allow(unused_imports))]
use crate::{
core::{ChunkId, ChunkingStrategy, DocumentId, GraphRAGError, TextChunk},
text::chunking::HierarchicalChunker,
};
use std::sync::atomic::{AtomicU64, Ordering};
static LATE_CHUNK_COUNTER: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LateChunkingConfig {
pub chunk_size: usize,
pub chunk_overlap: usize,
pub max_doc_tokens: u32,
pub annotate_positions: bool,
}
impl Default for LateChunkingConfig {
fn default() -> Self {
Self {
chunk_size: 512,
chunk_overlap: 64,
max_doc_tokens: 8192, annotate_positions: true,
}
}
}
pub struct LateChunkingStrategy {
config: LateChunkingConfig,
document_id: DocumentId,
inner: HierarchicalChunker,
}
impl LateChunkingStrategy {
pub fn new(config: LateChunkingConfig, document_id: DocumentId) -> Self {
Self {
inner: HierarchicalChunker::new().with_min_size(50),
config,
document_id,
}
}
pub fn with_defaults(document_id: DocumentId) -> Self {
Self::new(LateChunkingConfig::default(), document_id)
}
pub fn with_max_doc_tokens(mut self, max_tokens: u32) -> Self {
self.config.max_doc_tokens = max_tokens;
self
}
pub fn estimate_tokens(text: &str) -> u32 {
(text.len() / 4) as u32
}
pub fn fits_in_context(&self, text: &str) -> bool {
Self::estimate_tokens(text) <= self.config.max_doc_tokens
}
pub fn split_into_sections(&self, text: &str) -> Vec<String> {
if self.fits_in_context(text) {
return vec![text.to_string()];
}
let max_chars = (self.config.max_doc_tokens * 4) as usize;
let mut sections: Vec<String> = Vec::new();
let mut current = String::new();
for paragraph in text.split("\n\n") {
let needed = current.len() + if current.is_empty() { 0 } else { 2 } + paragraph.len();
if needed > max_chars && !current.is_empty() {
sections.push(current.trim().to_string());
current = String::new();
}
if !current.is_empty() {
current.push_str("\n\n");
}
current.push_str(paragraph);
}
if !current.trim().is_empty() {
sections.push(current.trim().to_string());
}
sections
}
}
impl ChunkingStrategy for LateChunkingStrategy {
fn chunk(&self, text: &str) -> Vec<TextChunk> {
let raw_chunks =
self.inner
.chunk_text(text, self.config.chunk_size, self.config.chunk_overlap);
let doc_len = text.len().max(1);
let mut chunks = Vec::with_capacity(raw_chunks.len());
let mut current_pos: usize = 0;
for chunk_content in raw_chunks {
if chunk_content.trim().is_empty() {
current_pos += chunk_content.len();
continue;
}
let chunk_id = ChunkId::new(format!(
"{}_lc_{}",
self.document_id,
LATE_CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst),
));
let start = current_pos;
let end = start + chunk_content.len();
let mut chunk = TextChunk::new(
chunk_id,
self.document_id.clone(),
chunk_content.clone(),
start,
end,
);
if self.config.annotate_positions {
chunk.metadata.position_in_document = Some(start as f32 / doc_len as f32);
}
chunks.push(chunk);
current_pos = end;
}
chunks
}
}
#[derive(Debug, Clone)]
pub struct JinaLateChunkingClient {
#[cfg_attr(not(feature = "async"), allow(dead_code))]
api_key: String,
model: String,
}
impl JinaLateChunkingClient {
#[cfg(feature = "async")]
const ENDPOINT: &'static str = "https://api.jina.ai/v1/embeddings";
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
model: "jina-embeddings-v3".to_string(),
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
#[cfg(feature = "ureq")]
pub async fn embed_with_late_chunking(
&self,
chunks: &[TextChunk],
) -> crate::Result<Vec<Vec<f32>>> {
let inputs: Vec<&str> = chunks.iter().map(|c| c.content.as_str()).collect();
let body = serde_json::json!({
"model": self.model,
"input": inputs,
"late_chunking": true,
});
let agent = ureq::AgentBuilder::new().build();
let response = agent
.post(Self::ENDPOINT)
.set("Authorization", &format!("Bearer {}", self.api_key))
.set("Content-Type", "application/json")
.send_json(&body)
.map_err(|e| GraphRAGError::Generation {
message: format!("Jina API request failed: {e}"),
})?;
let json: serde_json::Value =
response
.into_json()
.map_err(|e| GraphRAGError::Generation {
message: format!("Failed to parse Jina API response: {e}"),
})?;
let response_data = json["data"]
.as_array()
.ok_or_else(|| GraphRAGError::Generation {
message: "Invalid Jina API response: missing 'data' array".to_string(),
})?;
let embeddings = response_data
.iter()
.map(|item| {
item["embedding"]
.as_array()
.unwrap_or(&vec![])
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect::<Vec<f32>>()
})
.collect::<Vec<_>>();
Ok(embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::DocumentId;
#[test]
fn test_late_chunking_produces_chunks_with_position() {
let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("test-doc".to_string()));
let text = "First paragraph about machine learning.\n\n\
Second paragraph about deep learning.\n\n\
Third paragraph about neural networks.";
let chunks = strategy.chunk(text);
assert!(!chunks.is_empty());
for chunk in &chunks {
assert!(
chunk.metadata.position_in_document.is_some(),
"chunk {} missing position metadata",
chunk.id
);
}
}
#[test]
fn test_chunk_ids_have_lc_suffix() {
let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("doc".to_string()));
let chunks = strategy.chunk("Some text to chunk into pieces here.");
for chunk in &chunks {
assert!(
chunk.id.0.contains("_lc_"),
"Expected '_lc_' in ID: {}",
chunk.id
);
}
}
#[test]
fn test_fits_in_context() {
let config = LateChunkingConfig {
max_doc_tokens: 10,
..Default::default()
};
let strategy = LateChunkingStrategy::new(config, DocumentId::new("d".to_string()));
assert!(strategy.fits_in_context("tiny")); assert!(!strategy.fits_in_context(&"x".repeat(100))); }
#[test]
fn test_split_into_sections_short_doc() {
let strategy = LateChunkingStrategy::with_defaults(DocumentId::new("d".to_string()));
let text = "Short document.";
let sections = strategy.split_into_sections(text);
assert_eq!(sections.len(), 1);
assert_eq!(sections[0], text);
}
#[test]
fn test_split_into_sections_long_doc() {
let config = LateChunkingConfig {
max_doc_tokens: 5, ..Default::default()
};
let strategy = LateChunkingStrategy::new(config, DocumentId::new("d".to_string()));
let text = "Paragraph one.\n\nParagraph two.\n\nParagraph three.";
let sections = strategy.split_into_sections(text);
assert!(
sections.len() > 1,
"Expected multiple sections, got {}",
sections.len()
);
let combined = sections.join(" ");
assert!(combined.contains("Paragraph one"));
assert!(combined.contains("Paragraph two"));
assert!(combined.contains("Paragraph three"));
}
#[test]
fn test_estimate_tokens() {
assert_eq!(LateChunkingStrategy::estimate_tokens(&"a".repeat(400)), 100);
assert_eq!(LateChunkingStrategy::estimate_tokens(""), 0);
}
}