use cognee_utils::NAMESPACE_OID;
use uuid::Uuid;
use crate::chunk_by_sentence::chunk_by_sentence;
use crate::cut_type::CutType;
use crate::token_counter::TokenCounter;
#[derive(Debug, Clone)]
pub struct ParagraphChunk<'a> {
pub text: &'a str,
pub chunk_size: usize,
pub chunk_id: Uuid,
pub paragraph_ids: Vec<Uuid>,
pub chunk_index: usize,
pub cut_type: CutType,
}
#[allow(
clippy::expect_used,
reason = "chunk_start invariants are upheld by the accumulation logic above each emit branch"
)]
pub fn chunk_by_paragraph<'a, C: TokenCounter>(
data: &'a str,
max_chunk_size: usize,
batch_paragraphs: bool,
counter: &C,
) -> Vec<ParagraphChunk<'a>> {
let sentences = chunk_by_sentence(data, Some(max_chunk_size), counter);
let mut result = Vec::new();
let mut chunk_index: usize = 0;
let mut paragraph_ids: Vec<Uuid> = Vec::new();
let mut last_cut_type = CutType::SentenceCut;
let mut current_chunk_size: usize = 0;
let mut chunk_start: Option<usize> = None;
let mut chunk_end: usize = 0;
for sentence in &sentences {
let sent_start = sentence.text.as_ptr() as usize - data.as_ptr() as usize;
let sent_end = sent_start + sentence.text.len();
if current_chunk_size > 0 && (current_chunk_size + sentence.size > max_chunk_size) {
let text = &data[chunk_start.expect("chunk_start is Some because current_chunk_size > 0 only after a sentence was accumulated")..chunk_end];
result.push(ParagraphChunk {
text,
chunk_size: current_chunk_size,
chunk_id: Uuid::new_v5(&NAMESPACE_OID, text.as_bytes()),
paragraph_ids: std::mem::take(&mut paragraph_ids),
chunk_index,
cut_type: last_cut_type.clone(),
});
current_chunk_size = 0;
chunk_start = None;
chunk_index += 1;
}
paragraph_ids.push(sentence.paragraph_id);
if chunk_start.is_none() {
chunk_start = Some(sent_start);
}
chunk_end = sent_end;
current_chunk_size += sentence.size;
if !batch_paragraphs
&& matches!(
sentence.cut_type,
CutType::ParagraphEnd | CutType::SentenceCut
)
{
let text = &data[chunk_start.expect(
"chunk_start is Some because it was set above before this emit branch is reached",
)..chunk_end];
result.push(ParagraphChunk {
text,
chunk_size: current_chunk_size,
chunk_id: Uuid::new_v5(&NAMESPACE_OID, text.as_bytes()),
paragraph_ids: std::mem::take(&mut paragraph_ids),
chunk_index,
cut_type: sentence.cut_type.clone(),
});
current_chunk_size = 0;
chunk_start = None;
chunk_index += 1;
}
last_cut_type = sentence.cut_type.clone();
}
if let Some(start) = chunk_start {
let final_cut_type = if last_cut_type == CutType::Word {
CutType::SentenceCut
} else {
last_cut_type
};
let text = &data[start..chunk_end];
result.push(ParagraphChunk {
chunk_id: Uuid::new_v5(&NAMESPACE_OID, text.as_bytes()),
text,
chunk_size: current_chunk_size,
paragraph_ids,
chunk_index,
cut_type: final_cut_type,
});
}
result
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
mod tests {
use super::*;
use crate::token_counter::WordCounter;
#[test]
fn empty_input() {
let chunks = chunk_by_paragraph("", 10, true, &WordCounter);
assert!(chunks.is_empty());
}
#[test]
fn single_short_paragraph() {
let chunks = chunk_by_paragraph("Hello world.", 100, true, &WordCounter);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].text, "Hello world.");
assert_eq!(chunks[0].chunk_size, 2);
assert_eq!(chunks[0].chunk_index, 0);
}
#[test]
fn batch_mode_accumulates() {
let text = "First sentence. Second sentence. Third sentence.";
let chunks = chunk_by_paragraph(text, 100, true, &WordCounter);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].chunk_size, 6);
}
#[test]
fn batch_mode_overflow() {
let text = "One two. Three four. Five six.";
let chunks = chunk_by_paragraph(text, 3, true, &WordCounter);
assert!(chunks.len() >= 2);
assert_eq!(chunks[0].chunk_index, 0);
assert_eq!(chunks[1].chunk_index, 1);
}
#[test]
fn non_batch_mode_yields_at_paragraph() {
let text = "First paragraph.\nSecond paragraph.";
let chunks = chunk_by_paragraph(text, 100, false, &WordCounter);
assert!(chunks.len() >= 2);
}
#[test]
fn sequential_chunk_indices() {
let text = "A. B. C. D. E.";
let chunks = chunk_by_paragraph(text, 2, true, &WordCounter);
for (i, chunk) in chunks.iter().enumerate() {
assert_eq!(chunk.chunk_index, i);
}
}
#[test]
fn deterministic_ids() {
let text = "Hello world. Foo bar.";
let chunks1 = chunk_by_paragraph(text, 100, true, &WordCounter);
let chunks2 = chunk_by_paragraph(text, 100, true, &WordCounter);
assert_eq!(chunks1[0].chunk_id, chunks2[0].chunk_id);
}
#[test]
fn ground_truth_whole_text() {
use crate::cut_type::CutType;
let input = "The quick brown fox jumps over the lazy dog. It was a sunny day.\n\
The rain in Spain falls mainly on the plain. A stitch in time saves nine. An apple a day keeps the doctor away.\n\
To be or not to be that is the question. All that glitters is not gold. Actions speak louder than words. The pen is mightier than the sword. Knowledge is power above all else.";
let counter = WordCounter;
let chunks = chunk_by_paragraph(input, 12, true, &counter);
assert!(
chunks.len() >= 2,
"expected at least 2 chunks, got {}",
chunks.len()
);
for (i, chunk) in chunks.iter().enumerate() {
assert!(
chunk.chunk_size <= 12,
"chunk {i} has size {} > 12",
chunk.chunk_size
);
}
let last = chunks.last().unwrap();
assert_eq!(last.cut_type, CutType::SentenceEnd);
for (i, chunk) in chunks.iter().enumerate() {
assert_eq!(chunk.chunk_index, i, "chunk_index mismatch at {i}");
}
}
#[test]
fn ground_truth_cut_text() {
use crate::cut_type::CutType;
let input = "The quick brown fox jumps over the lazy dog. It was a sunny day.\n\
The rain in Spain falls mainly on the plain. A stitch in time saves nine. An apple a day keeps the doctor away.\n\
To be or not to be that is the question. All that glitters is not gold. Actions speak louder than words. The pen is mightier than the sword. Knowledge is power above all else";
let counter = WordCounter;
let chunks = chunk_by_paragraph(input, 12, true, &counter);
assert!(chunks.len() >= 2, "expected at least 2 chunks");
let last = chunks.last().unwrap();
assert_eq!(
last.cut_type,
CutType::SentenceCut,
"last chunk should be SentenceCut when text doesn't end with punctuation"
);
}
}