use crate::pipeline::graph::ElementGraph;
use crate::pipeline::{Element, ElementData, ElementMetadata};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MergePolicy {
SameTypeOnly,
AnyInlineContent,
}
#[derive(Debug, Clone)]
pub struct HybridChunkConfig {
pub max_tokens: usize,
pub overlap_tokens: usize,
pub merge_adjacent: bool,
pub propagate_headings: bool,
pub merge_policy: MergePolicy,
}
impl Default for HybridChunkConfig {
fn default() -> Self {
Self {
max_tokens: 512,
overlap_tokens: 50,
merge_adjacent: true,
propagate_headings: true,
merge_policy: MergePolicy::AnyInlineContent,
}
}
}
#[derive(Debug, Clone)]
pub struct HybridChunk {
elements: Vec<Element>,
pub heading_context: Option<String>,
oversized: bool,
}
impl HybridChunk {
pub fn elements(&self) -> &[Element] {
&self.elements
}
pub fn text(&self) -> String {
self.elements
.iter()
.map(|e| e.display_text())
.collect::<Vec<_>>()
.join("\n")
}
pub fn full_text(&self) -> String {
match &self.heading_context {
Some(heading) => format!("{}\n\n{}", heading, self.text()),
None => self.text(),
}
}
pub fn token_estimate(&self) -> usize {
estimate_tokens(&self.text())
}
pub fn is_oversized(&self) -> bool {
self.oversized
}
}
pub struct HybridChunker {
config: HybridChunkConfig,
}
impl Default for HybridChunker {
fn default() -> Self {
Self {
config: HybridChunkConfig::default(),
}
}
}
impl HybridChunker {
pub fn new(config: HybridChunkConfig) -> Self {
Self { config }
}
pub fn chunk(&self, elements: &[Element]) -> Vec<HybridChunk> {
if elements.is_empty() {
return Vec::new();
}
let mut chunks = Vec::new();
let mut buffer: Vec<Element> = Vec::new();
let mut buffer_tokens = 0usize;
let mut buffer_heading: Option<String> = None;
for element in elements {
let elem_tokens = estimate_tokens(&element.display_text());
let elem_heading = if self.config.propagate_headings {
element.metadata().parent_heading.clone()
} else {
None
};
let can_merge = self.config.merge_adjacent
&& !buffer.is_empty()
&& can_merge_elements(buffer.last().unwrap(), element, &self.config.merge_policy)
&& buffer_tokens + elem_tokens <= self.config.max_tokens;
if can_merge {
buffer.push(element.clone());
buffer_tokens += elem_tokens;
continue;
}
if !buffer.is_empty() {
if buffer_tokens + elem_tokens > self.config.max_tokens
|| !can_merge_elements(
buffer.last().unwrap(),
element,
&self.config.merge_policy,
)
|| !self.config.merge_adjacent
{
self.flush_buffer(
&mut chunks,
&mut buffer,
&mut buffer_tokens,
&mut buffer_heading,
);
}
}
if elem_tokens > self.config.max_tokens && buffer.is_empty() {
if is_splittable_element(element) {
let text = element.display_text();
let fragments = split_by_sentences(&text, self.config.max_tokens);
for fragment in fragments {
let fragment_element = make_text_fragment_element(element, fragment.trim());
chunks.push(HybridChunk {
elements: vec![fragment_element],
heading_context: elem_heading.clone(),
oversized: false,
});
}
} else {
chunks.push(HybridChunk {
elements: vec![element.clone()],
heading_context: elem_heading,
oversized: true,
});
}
continue;
}
if buffer.is_empty() {
buffer_heading = elem_heading;
}
buffer.push(element.clone());
buffer_tokens += elem_tokens;
}
if !buffer.is_empty() {
chunks.push(HybridChunk {
elements: std::mem::take(&mut buffer),
heading_context: buffer_heading,
oversized: false,
});
}
chunks
}
pub fn chunk_with_graph(&self, elements: &[Element], graph: &ElementGraph) -> Vec<HybridChunk> {
if elements.is_empty() {
return Vec::new();
}
let mut chunks: Vec<HybridChunk> = Vec::new();
let top_sections = graph.top_level_sections();
let first_title_idx = top_sections.first().copied().unwrap_or(elements.len());
if first_title_idx > 0 {
let preamble: Vec<Element> = elements[..first_title_idx].to_vec();
chunks.extend(self.chunk(&preamble));
}
for &title_idx in &top_sections {
let title_heading = elements[title_idx]
.metadata()
.parent_heading
.clone()
.or_else(|| Some(elements[title_idx].text().to_string()));
let child_indices = graph.elements_in_section(title_idx);
let mut section_elements: Vec<Element> = Vec::with_capacity(1 + child_indices.len());
section_elements.push(elements[title_idx].clone());
for &ci in &child_indices {
section_elements.push(elements[ci].clone());
}
let section_tokens: usize = section_elements
.iter()
.map(|e| estimate_tokens(&e.display_text()))
.sum();
if section_tokens <= self.config.max_tokens {
chunks.push(HybridChunk {
elements: section_elements,
heading_context: title_heading,
oversized: false,
});
} else {
let mut sub_chunks = self.chunk(§ion_elements);
for sub in &mut sub_chunks {
sub.heading_context = title_heading.clone();
}
chunks.extend(sub_chunks);
}
}
chunks
}
fn flush_buffer(
&self,
chunks: &mut Vec<HybridChunk>,
buffer: &mut Vec<Element>,
buffer_tokens: &mut usize,
buffer_heading: &mut Option<String>,
) {
let flushed = std::mem::take(buffer);
let heading = buffer_heading.take();
if self.config.overlap_tokens > 0 {
let mut overlap_tokens = 0usize;
let mut overlap_elements = Vec::new();
for elem in flushed.iter().rev() {
let t = estimate_tokens(&elem.display_text());
if overlap_tokens + t > self.config.overlap_tokens && !overlap_elements.is_empty() {
break;
}
overlap_elements.push(elem.clone());
overlap_tokens += t;
}
overlap_elements.reverse();
*buffer = overlap_elements;
*buffer_tokens = overlap_tokens;
if let Some(first) = buffer.first() {
*buffer_heading = first.metadata().parent_heading.clone();
}
} else {
*buffer_tokens = 0;
}
chunks.push(HybridChunk {
elements: flushed,
heading_context: heading,
oversized: false,
});
}
}
fn estimate_tokens(text: &str) -> usize {
text.split_whitespace().count()
}
fn can_merge_elements(a: &Element, b: &Element, policy: &MergePolicy) -> bool {
match policy {
MergePolicy::SameTypeOnly => matches!(
(a, b),
(Element::Paragraph(_), Element::Paragraph(_))
| (Element::ListItem(_), Element::ListItem(_))
),
MergePolicy::AnyInlineContent => is_inline_element(a) && is_inline_element(b),
}
}
fn is_inline_element(e: &Element) -> bool {
matches!(
e,
Element::Paragraph(_) | Element::ListItem(_) | Element::KeyValue(_)
)
}
fn is_splittable_element(e: &Element) -> bool {
matches!(e, Element::Paragraph(_) | Element::ListItem(_))
}
fn split_by_sentences(text: &str, max_tokens: usize) -> Vec<String> {
let sentences = split_into_sentences(text);
let mut fragments: Vec<String> = Vec::new();
let mut current = String::new();
let mut current_tokens = 0usize;
for sentence in sentences {
let sentence = sentence.trim();
if sentence.is_empty() {
continue;
}
let sentence_tokens = estimate_tokens(sentence);
if current.is_empty() {
current.push_str(sentence);
current_tokens = sentence_tokens;
} else if current_tokens + 1 + sentence_tokens <= max_tokens {
current.push(' ');
current.push_str(sentence);
current_tokens += 1 + sentence_tokens;
} else {
fragments.push(current.clone());
current = sentence.to_string();
current_tokens = sentence_tokens;
}
}
if !current.is_empty() {
fragments.push(current);
}
if fragments.is_empty() {
fragments.push(text.to_string());
}
fragments
}
fn split_into_sentences(text: &str) -> Vec<String> {
let mut sentences = Vec::new();
let mut current = String::new();
let mut iter = text.chars().peekable();
while let Some(ch) = iter.next() {
current.push(ch);
if matches!(ch, '.' | '!' | '?') {
if iter.peek() == Some(&' ') {
iter.next(); sentences.push(current.trim().to_string());
current = String::new();
continue;
}
} else if ch == '\n' {
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
sentences.push(trimmed);
}
current = String::new();
}
}
let remaining = current.trim().to_string();
if !remaining.is_empty() {
sentences.push(remaining);
}
sentences
}
fn make_text_fragment_element(source: &Element, fragment_text: &str) -> Element {
let metadata = source.metadata().clone();
Element::Paragraph(ElementData {
text: fragment_text.to_string(),
metadata: ElementMetadata {
page: metadata.page,
bbox: metadata.bbox,
parent_heading: metadata.parent_heading,
..Default::default()
},
})
}