use std::collections::VecDeque;
use crate::chunk::{MeasuredSpan, TextSpan};
pub fn merge_splits(
splits: &[String],
separator: &str,
chunk_size: usize,
chunk_overlap: usize,
strip_whitespace: bool,
length_fn: &dyn Fn(&str) -> usize,
) -> Vec<String> {
let separator_len = length_fn(separator);
let mut docs: Vec<String> = Vec::new();
let mut current_doc: Vec<&str> = Vec::new();
let mut total: usize = 0;
for d in splits {
let len = length_fn(d);
let sep_cost = if current_doc.is_empty() {
0
} else {
separator_len
};
if total + len + sep_cost > chunk_size && !current_doc.is_empty() {
let doc = join_docs(¤t_doc, separator, strip_whitespace);
if let Some(doc) = doc {
docs.push(doc);
}
while total > chunk_overlap
|| (total
+ len
+ if current_doc.is_empty() {
0
} else {
separator_len
}
> chunk_size
&& total > 0)
{
let removed_len = length_fn(current_doc[0]);
let sep = if current_doc.len() > 1 {
separator_len
} else {
0
};
total = total.saturating_sub(removed_len + sep);
current_doc.remove(0);
}
}
current_doc.push(d);
total += len
+ if current_doc.len() > 1 {
separator_len
} else {
0
};
}
if let Some(doc) = join_docs(¤t_doc, separator, strip_whitespace) {
docs.push(doc);
}
docs
}
pub(crate) fn merge_spans(
input: &str,
splits: &[TextSpan],
chunk_size: usize,
chunk_overlap: usize,
strip_whitespace: bool,
length_fn: &dyn Fn(&str) -> usize,
) -> Vec<TextSpan> {
if splits.is_empty() {
return Vec::new();
}
let mut chunks = Vec::new();
let mut start_idx = 0usize;
let mut end_idx = 0usize;
while start_idx < splits.len() {
let mut last_fit = start_idx + 1;
end_idx = end_idx.max(start_idx + 1);
while end_idx <= splits.len() {
let candidate = TextSpan::new(splits[start_idx].start, splits[end_idx - 1].end);
let candidate = if strip_whitespace {
match candidate.trim(input) {
Some(span) => span,
None => break,
}
} else {
candidate
};
if length_fn(candidate.text(input)) > chunk_size && end_idx > start_idx + 1 {
break;
}
last_fit = end_idx;
if end_idx == splits.len() {
break;
}
end_idx += 1;
}
let chunk = TextSpan::new(splits[start_idx].start, splits[last_fit - 1].end);
if let Some(chunk) = if strip_whitespace {
chunk.trim(input)
} else {
Some(chunk)
} {
chunks.push(chunk);
}
if last_fit >= splits.len() {
break;
}
if chunk_overlap == 0 {
start_idx = last_fit;
end_idx = start_idx;
continue;
}
let mut overlap_start = last_fit;
while overlap_start > start_idx {
let candidate =
TextSpan::new(splits[overlap_start - 1].start, splits[last_fit - 1].end);
let candidate = if strip_whitespace {
match candidate.trim(input) {
Some(span) => span,
None => break,
}
} else {
candidate
};
if length_fn(candidate.text(input)) > chunk_overlap {
break;
}
overlap_start -= 1;
}
start_idx = overlap_start.max(start_idx + 1).min(last_fit);
end_idx = start_idx;
}
chunks
}
pub(crate) struct StreamingMerger<'a, I> {
input: &'a str,
splits: I,
chunk_size: usize,
chunk_overlap: usize,
strip_whitespace: bool,
length_fn: &'a dyn Fn(&str) -> usize,
current: VecDeque<TextSpan>,
pending: Option<TextSpan>,
exhausted: bool,
}
impl<'a, I> StreamingMerger<'a, I>
where
I: Iterator<Item = TextSpan>,
{
pub(crate) fn new(
input: &'a str,
splits: I,
chunk_size: usize,
chunk_overlap: usize,
strip_whitespace: bool,
length_fn: &'a dyn Fn(&str) -> usize,
) -> Self {
Self {
input,
splits,
chunk_size,
chunk_overlap,
strip_whitespace,
length_fn,
current: VecDeque::new(),
pending: None,
exhausted: false,
}
}
fn next_split(&mut self) -> Option<TextSpan> {
self.pending.take().or_else(|| self.splits.next())
}
fn current_span(&self) -> Option<TextSpan> {
let start = self.current.front()?.start;
let end = self.current.back()?.end;
Some(TextSpan::new(start, end))
}
fn measured_current(&self) -> Option<MeasuredSpan> {
let span = self.current_span()?;
let span = if self.strip_whitespace {
span.trim(self.input)?
} else {
span
};
Some(MeasuredSpan::new(self.input, span, self.length_fn))
}
fn trim_overlap(&mut self, next: Option<TextSpan>) {
if self.chunk_overlap == 0 {
self.current.clear();
return;
}
let next_len = next
.map(|span| span.len_with(self.input, self.length_fn))
.unwrap_or(0);
while !self.current.is_empty() {
let Some(span) = self.current_span() else {
break;
};
let span = if self.strip_whitespace {
match span.trim(self.input) {
Some(span) => span,
None => {
self.current.pop_front();
continue;
}
}
} else {
span
};
let current_len = span.len_with(self.input, self.length_fn);
let would_fit_next = next
.map(|next| {
let combined = TextSpan::new(span.start, next.end);
combined.len_with(self.input, self.length_fn) <= self.chunk_size
})
.unwrap_or(true);
if current_len <= self.chunk_overlap && (would_fit_next || next_len == 0) {
break;
}
self.current.pop_front();
}
}
}
impl<I> Iterator for StreamingMerger<'_, I>
where
I: Iterator<Item = TextSpan>,
{
type Item = MeasuredSpan;
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.exhausted {
if self.current.is_empty() {
return None;
}
let chunk = self.measured_current();
self.current.clear();
return chunk;
}
let split = self.next_split();
let Some(split) = split else {
self.exhausted = true;
continue;
};
if self.current.is_empty() {
self.current.push_back(split);
continue;
}
let candidate = TextSpan::new(self.current.front()?.start, split.end);
let candidate = if self.strip_whitespace {
candidate.trim(self.input).unwrap_or(candidate)
} else {
candidate
};
if candidate.len_with(self.input, self.length_fn) > self.chunk_size {
self.pending = Some(split);
let chunk = self.measured_current();
self.trim_overlap(self.pending);
if chunk.is_some() {
return chunk;
}
} else {
self.current.push_back(split);
}
}
}
}
fn join_docs(docs: &[&str], separator: &str, strip_whitespace: bool) -> Option<String> {
let text = docs.join(separator);
let text = if strip_whitespace {
text.trim().to_string()
} else {
text
};
if text.is_empty() {
None
} else {
Some(text)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::char_len;
#[test]
fn test_merge_basic() {
let splits: Vec<String> = vec!["a", "b", "c"].into_iter().map(String::from).collect();
let result = merge_splits(&splits, " ", 5, 0, true, &char_len);
assert_eq!(result, vec!["a b c"]);
}
#[test]
fn test_merge_exceeds_chunk_size() {
let splits: Vec<String> = vec!["foo", "bar", "baz", "123"]
.into_iter()
.map(String::from)
.collect();
let result = merge_splits(&splits, " ", 7, 3, true, &char_len);
assert_eq!(result, vec!["foo bar", "bar baz", "baz 123"]);
}
#[test]
fn test_merge_no_overlap() {
let splits: Vec<String> = vec!["aa", "bb", "cc", "dd"]
.into_iter()
.map(String::from)
.collect();
let result = merge_splits(&splits, " ", 5, 0, true, &char_len);
assert_eq!(result, vec!["aa bb", "cc dd"]);
}
#[test]
fn test_merge_empty_splits() {
let splits: Vec<String> = vec![];
let result = merge_splits(&splits, " ", 10, 0, true, &char_len);
assert!(result.is_empty());
}
#[test]
fn test_merge_strip_whitespace() {
let splits: Vec<String> = vec![" a ", " b "].into_iter().map(String::from).collect();
let result = merge_splits(&splits, " ", 100, 0, true, &char_len);
assert_eq!(result, vec!["a b"]);
}
#[test]
fn test_merge_custom_length_fn() {
let word_len = |s: &str| -> usize { s.split_whitespace().count().max(1) };
let splits: Vec<String> = vec!["hello world", "foo bar", "baz"]
.into_iter()
.map(String::from)
.collect();
let result = merge_splits(&splits, " ", 3, 0, true, &word_len);
assert_eq!(result, vec!["hello world", "foo bar", "baz"]);
}
}