use regex::Regex;
use serde_json::json;
use crate::config::{
ChunkerConfig, FixedOverlapChunkerConfig, GroupingConfig, HierarchyChunkerConfig,
SentenceAwareChunkerConfig,
};
#[cfg(feature = "embedder-hub")]
use crate::config::FastembedEmbedderConfig;
use crate::config::SemanticChunkerConfig;
#[cfg(feature = "embedder-hub")]
use crate::embedder::FastembedEmbedder;
use crate::sentence_split::naive_sentences;
use crate::sources::Document;
use crate::summarizer::build_summarizer;
pub trait BoundaryEmbedder: Send + Sync {
fn embed_batch(&mut self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>>;
}
pub mod oversize {
use std::sync::atomic::{AtomicBool, Ordering};
use tracing::warn;
use crate::chunker::Chunk;
pub const MAX_RECURSION_DEPTH: usize = 5;
#[derive(Debug, thiserror::Error)]
pub enum OversizeError {
#[error(
"if_oversize chain exceeded depth {} (chunker={chunker})",
MAX_RECURSION_DEPTH
)]
Recursion { chunker: String },
}
pub struct DedupedWarner {
pub chunker_name: String,
pub ceiling: usize,
warned: AtomicBool,
}
impl DedupedWarner {
pub fn new(chunker_name: impl Into<String>, ceiling: usize) -> Self {
Self {
chunker_name: chunker_name.into(),
ceiling,
warned: AtomicBool::new(false),
}
}
pub fn warn_once(&self, oversize_len: usize) {
if self.warned.swap(true, Ordering::Relaxed) {
return;
}
warn!(
target: "chunkshop::oversize",
chunker = %self.chunker_name,
ceiling = self.ceiling,
oversize_len = oversize_len,
"{} emitted oversize chunk(s) (>{} chars), no if_oversize fallback set; \
first oversize chunk has {} chars. \
To fix: add `if_oversize: {{ type: fixed_overlap, window_words: 200, \
step_words: 160, max_chars: {} }}` to the chunker config.",
self.chunker_name,
self.ceiling,
oversize_len,
self.ceiling,
);
}
}
pub fn is_oversize(c: &Chunk, ceiling: usize) -> bool {
c.embedded_content.chars().count() > ceiling || c.original_content.chars().count() > ceiling
}
}
pub fn apply_if_oversize(
chunks: Vec<Chunk>,
ceiling: Option<usize>,
if_oversize_cfg: Option<&ChunkerConfig>,
chunker_name: &str,
document: &Document,
depth: usize,
skip_check: Option<&dyn Fn(&Chunk) -> bool>,
warner: Option<&oversize::DedupedWarner>,
) -> Result<Vec<Chunk>, oversize::OversizeError> {
if depth > oversize::MAX_RECURSION_DEPTH {
return Err(oversize::OversizeError::Recursion {
chunker: chunker_name.to_string(),
});
}
let Some(ceiling) = ceiling else {
return Ok(chunks);
};
let mut out: Vec<Chunk> = Vec::new();
let mut seq = 0usize;
for c in chunks {
if let Some(f) = skip_check {
if f(&c) {
out.push(Chunk { seq_num: seq, ..c });
seq += 1;
continue;
}
}
if !oversize::is_oversize(&c, ceiling) {
out.push(Chunk { seq_num: seq, ..c });
seq += 1;
continue;
}
let Some(if_cfg) = if_oversize_cfg else {
if let Some(w) = warner {
let len_chars = c
.embedded_content
.chars()
.count()
.max(c.original_content.chars().count());
w.warn_once(len_chars);
}
out.push(Chunk { seq_num: seq, ..c });
seq += 1;
continue;
};
let synth_doc = Document {
id: c.doc_id.clone(),
content: c.original_content.clone(),
title: document.title.clone(),
metadata: document.metadata.clone(),
fingerprint: None,
};
let fallback =
build_chunker(if_cfg.clone()).map_err(|_| oversize::OversizeError::Recursion {
chunker: chunker_name.to_string(),
})?;
let sub_raw = fallback.chunk(&synth_doc);
let nested_ceiling = if_cfg.effective_max_chars();
let nested_cfg = if_cfg.if_oversize();
let sub = apply_if_oversize(
sub_raw,
nested_ceiling,
nested_cfg,
"fallback",
&synth_doc,
depth + 1,
None,
warner,
)?;
for sc in sub {
let mut merged = c.metadata.as_object().cloned().unwrap_or_default();
if let Some(sub_obj) = sc.metadata.as_object() {
for (k, v) in sub_obj.iter() {
merged.insert(k.clone(), v.clone());
}
}
out.push(Chunk {
doc_id: c.doc_id.clone(),
seq_num: seq,
original_content: sc.original_content.clone(),
embedded_content: sc.original_content.clone(),
metadata: serde_json::Value::Object(merged),
});
seq += 1;
}
}
Ok(out)
}
#[derive(Debug, Clone)]
pub struct Chunk {
pub doc_id: String,
pub seq_num: usize,
pub original_content: String,
pub embedded_content: String,
pub metadata: serde_json::Value,
}
pub struct SentenceAwareChunker {
cfg: SentenceAwareChunkerConfig,
warner: Option<oversize::DedupedWarner>,
}
impl SentenceAwareChunker {
pub fn new(cfg: SentenceAwareChunkerConfig) -> Self {
let warner = Some(oversize::DedupedWarner::new(
"sentence_aware",
cfg.max_chars,
));
Self { cfg, warner }
}
pub fn chunk(&self, doc: &Document) -> Vec<Chunk> {
let splits = if self.cfg.doc_type == "code" {
split_plain(&doc.content, self.cfg.max_chars)
} else {
split_prose(&doc.content, self.cfg.max_chars, self.cfg.min_chars)
};
let chunks: Vec<Chunk> = splits
.into_iter()
.enumerate()
.map(|(i, text)| Chunk {
doc_id: doc.id.clone(),
seq_num: i,
original_content: text.clone(),
embedded_content: text,
metadata: json!({ "strategy": "sentence_aware" }),
})
.collect();
apply_if_oversize(
chunks,
Some(self.cfg.max_chars),
self.cfg.if_oversize.as_deref(),
"sentence_aware",
doc,
0,
None,
self.warner.as_ref(),
)
.expect("if_oversize recursion")
}
}
fn md_heading_re() -> Regex {
Regex::new(r"(?m)^#{1,6}\s+.+$").unwrap()
}
fn heading_with_text_re() -> Regex {
Regex::new(r"(?m)^(#{1,6})\s+(.+?)$").unwrap()
}
fn para_break_re() -> Regex {
Regex::new(r"\n\s*\n").unwrap()
}
fn sent_boundary_re() -> Regex {
Regex::new(r"([.!?]\s+)").unwrap()
}
fn split_plain(text: &str, max_chars: usize) -> Vec<String> {
let paragraphs: Vec<String> = para_break_re()
.split(text)
.map(|p| p.trim().to_string())
.filter(|p| !p.is_empty())
.collect();
let mut result: Vec<String> = Vec::new();
let mut buffer = String::new();
for para in paragraphs {
if para.chars().count() > max_chars {
if !buffer.is_empty() {
result.push(buffer.trim().to_string());
buffer.clear();
}
result.extend(split_to_max_chars(¶, max_chars));
} else if !buffer.is_empty()
&& buffer.chars().count() + para.chars().count() + 2 > max_chars
{
result.push(buffer.trim().to_string());
buffer = para;
} else if buffer.is_empty() {
buffer = para;
} else {
buffer = format!("{buffer}\n\n{para}");
}
}
if !buffer.is_empty() {
result.push(buffer.trim().to_string());
}
result
}
fn split_prose(text: &str, max_chars: usize, min_chars: usize) -> Vec<String> {
let re = md_heading_re();
let headings: Vec<(usize, usize)> = re.find_iter(text).map(|m| (m.start(), m.end())).collect();
if headings.is_empty() {
return split_plain(text, max_chars);
}
let mut result: Vec<String> = Vec::new();
for i in 0..headings.len() {
let start = headings[i].0;
let end = if i + 1 < headings.len() {
headings[i + 1].0
} else {
text.len()
};
let section = text[start..end].trim();
if !section.is_empty() {
result.extend(split_to_max_chars(section, max_chars));
}
}
if headings[0].0 > 0 {
let prefix = text[..headings[0].0].trim();
if !prefix.is_empty() {
let mut pre = split_to_max_chars(prefix, max_chars);
pre.extend(result);
result = pre;
}
}
if text.chars().count() <= max_chars {
return result.into_iter().filter(|s| !s.is_empty()).collect();
}
result
.into_iter()
.filter(|s| s.chars().count() >= min_chars)
.collect()
}
pub fn split_to_max_chars(text: &str, max_chars: usize) -> Vec<String> {
assert!(max_chars > 0, "max_chars must be positive");
if text.chars().count() <= max_chars {
return vec![text.to_string()];
}
let paragraphs: Vec<String> = para_break_re()
.split(text)
.map(|p| p.trim().to_string())
.filter(|p| !p.is_empty())
.collect();
let mut out: Vec<String> = Vec::new();
let mut buf = String::new();
for para in paragraphs {
if para.chars().count() > max_chars {
if !buf.is_empty() {
out.push(format!("{buf}\n"));
buf.clear();
}
out.extend(split_sentences(¶, max_chars));
continue;
}
let budget = max_chars - 1;
let candidate = if buf.is_empty() {
para.clone()
} else {
format!("{buf}\n\n{para}")
};
if candidate.chars().count() > budget {
if !buf.is_empty() {
out.push(format!("{buf}\n"));
}
buf = para;
} else {
buf = candidate;
}
}
if !buf.is_empty() {
out.push(buf);
}
out
}
fn sentence_tokens(text: &str) -> Vec<String> {
let re = sent_boundary_re();
let mut parts: Vec<String> = Vec::new();
let mut last = 0usize;
for m in re.find_iter(text) {
parts.push(text[last..m.start()].to_string());
parts.push(text[m.start()..m.end()].to_string());
last = m.end();
}
parts.push(text[last..].to_string());
let mut tokens: Vec<String> = Vec::new();
let mut i = 0;
while i < parts.len() {
let body = parts[i].clone();
let tail = if i + 1 < parts.len() {
parts[i + 1].clone()
} else {
String::new()
};
let token = format!("{body}{tail}");
if !token.is_empty() {
tokens.push(token);
}
i += 2;
}
tokens
}
fn char_slice(text: &str, max_chars: usize) -> Vec<String> {
let chars: Vec<char> = text.chars().collect();
chars
.chunks(max_chars)
.map(|c| c.iter().collect::<String>())
.collect()
}
fn split_sentences(text: &str, max_chars: usize) -> Vec<String> {
let tokens = sentence_tokens(text);
let mut out: Vec<String> = Vec::new();
let mut buf = String::new();
for s in tokens {
if s.is_empty() {
continue;
}
if s.chars().count() > max_chars {
if !buf.is_empty() {
out.push(buf.clone());
buf.clear();
}
out.extend(char_slice(&s, max_chars));
continue;
}
let candidate = if buf.is_empty() {
s.clone()
} else {
format!("{buf}{s}")
};
if candidate.chars().count() > max_chars {
if !buf.is_empty() {
out.push(buf.clone());
}
buf = s;
} else {
buf = candidate;
}
}
if !buf.is_empty() {
out.push(buf);
}
out
}
pub struct HierarchyChunker {
cfg: HierarchyChunkerConfig,
heading_re: Regex,
custom_heading_pattern: bool,
warner: Option<oversize::DedupedWarner>,
}
impl HierarchyChunker {
pub fn new(cfg: HierarchyChunkerConfig) -> anyhow::Result<Self> {
let (heading_re, custom_heading_pattern) = match &cfg.heading_pattern {
Some(pat) => {
let re = Regex::new(pat).map_err(|e| {
anyhow::anyhow!("invalid hierarchy heading_pattern {pat:?}: {e}")
})?;
(re, true)
}
None => (heading_with_text_re(), false),
};
let warner = Some(oversize::DedupedWarner::new("hierarchy", cfg.max_chars));
Ok(Self {
cfg,
heading_re,
custom_heading_pattern,
warner,
})
}
pub fn chunk(&self, doc: &Document) -> Vec<Chunk> {
let text = &doc.content;
let headings: Vec<(usize, usize, String)> = self
.heading_re
.captures_iter(text)
.map(|c| {
let m0 = c.get(0).unwrap();
let h_text = if self.custom_heading_pattern {
c.get(1)
.or_else(|| c.get(0))
.map(|m| m.as_str().trim().to_string())
.unwrap_or_default()
} else {
c.get(2)
.map(|m| m.as_str().trim().to_string())
.unwrap_or_default()
};
(m0.start(), m0.end(), h_text)
})
.collect();
let chunks = if headings.is_empty() {
let body = text.trim();
if body.is_empty() {
Vec::new()
} else {
let title = doc.title.clone().unwrap_or_default();
self.emit_section_chunks(body, &title, &doc.id, 0)
}
} else {
let mut chunks: Vec<Chunk> = Vec::new();
if headings[0].0 > 0 {
let body = text[..headings[0].0].trim();
if body.chars().count() >= self.cfg.min_section_chars {
let title = doc.title.clone().unwrap_or_default();
let start_seq = chunks.len();
chunks.extend(self.emit_section_chunks(body, &title, &doc.id, start_seq));
}
}
for (i, (_h_start, h_end, h_text)) in headings.iter().enumerate() {
let start = *h_end;
let end = if i + 1 < headings.len() {
headings[i + 1].0
} else {
text.len()
};
let body = text[start..end].trim();
if body.chars().count() < self.cfg.min_section_chars {
continue;
}
let start_seq = chunks.len();
chunks.extend(self.emit_section_chunks(body, h_text, &doc.id, start_seq));
}
chunks
};
apply_if_oversize(
chunks,
Some(self.cfg.max_chars),
self.cfg.if_oversize.as_deref(),
"hierarchy",
doc,
0,
None,
self.warner.as_ref(),
)
.expect("if_oversize recursion")
}
fn emit_section_chunks(
&self,
body: &str,
heading_text: &str,
doc_id: &str,
start_seq: usize,
) -> Vec<Chunk> {
let parts: Vec<String> = if body.chars().count() > self.cfg.max_chars {
split_to_max_chars(body, self.cfg.max_chars)
} else {
vec![body.to_string()]
};
parts
.into_iter()
.enumerate()
.map(|(i, part)| {
let embedded = if !heading_text.is_empty() && self.cfg.prefix_heading {
format!("{heading_text}\n\n{part}")
} else {
part.clone()
};
Chunk {
doc_id: doc_id.to_string(),
seq_num: start_seq + i,
original_content: part,
embedded_content: embedded,
metadata: json!({
"strategy": "hierarchy",
"heading": heading_text,
"section_part": i,
}),
}
})
.collect()
}
}
pub trait ChunkerImpl {
fn chunk(&self, doc: &Document) -> Vec<Chunk>;
}
impl ChunkerImpl for SentenceAwareChunker {
fn chunk(&self, doc: &Document) -> Vec<Chunk> {
Self::chunk(self, doc)
}
}
impl ChunkerImpl for HierarchyChunker {
fn chunk(&self, doc: &Document) -> Vec<Chunk> {
Self::chunk(self, doc)
}
}
pub struct FixedOverlapChunker {
cfg: FixedOverlapChunkerConfig,
warner: Option<oversize::DedupedWarner>,
}
impl FixedOverlapChunker {
pub fn new(cfg: FixedOverlapChunkerConfig) -> anyhow::Result<Self> {
if cfg.window_words == 0 || cfg.step_words == 0 {
return Err(anyhow::anyhow!(
"window_words and step_words must be positive"
));
}
let warner = cfg
.max_chars
.map(|m| oversize::DedupedWarner::new("fixed_overlap", m));
Ok(Self { cfg, warner })
}
pub fn chunk(&self, doc: &Document) -> Vec<Chunk> {
let words: Vec<&str> = doc.content.split_whitespace().collect();
let window = self.cfg.window_words;
let step = self.cfg.step_words;
let mut chunks: Vec<Chunk> = Vec::new();
let mut seq = 0usize;
let mut i = 0usize;
while i < words.len() {
let end = (i + window).min(words.len());
let slice = &words[i..end];
let text = slice.join(" ");
chunks.push(Chunk {
doc_id: doc.id.clone(),
seq_num: seq,
original_content: text.clone(),
embedded_content: text,
metadata: json!({
"strategy": "fixed_overlap",
"start_word": i,
"n_words": slice.len(),
}),
});
seq += 1;
if i + window >= words.len() {
break;
}
i += step;
}
apply_if_oversize(
chunks,
self.cfg.max_chars,
self.cfg.if_oversize.as_deref(),
"fixed_overlap",
doc,
0,
None,
self.warner.as_ref(),
)
.expect("if_oversize recursion")
}
}
impl ChunkerImpl for FixedOverlapChunker {
fn chunk(&self, doc: &Document) -> Vec<Chunk> {
Self::chunk(self, doc)
}
}
pub struct NeighborExpandChunker {
window: usize,
base: Box<dyn ChunkerImpl + Send + Sync>,
effective_ceiling: Option<usize>,
if_oversize_cfg: Option<ChunkerConfig>,
warner: Option<oversize::DedupedWarner>,
}
impl NeighborExpandChunker {
pub fn new(
window: usize,
base: Box<dyn ChunkerImpl + Send + Sync>,
effective_ceiling: Option<usize>,
if_oversize_cfg: Option<ChunkerConfig>,
) -> Self {
let warner = effective_ceiling.map(|m| oversize::DedupedWarner::new("neighbor_expand", m));
Self {
window,
base,
effective_ceiling,
if_oversize_cfg,
warner,
}
}
pub fn chunk(&self, doc: &Document) -> Vec<Chunk> {
let base_chunks = self.base.chunk(doc);
let n = base_chunks.len();
if n == 0 {
return Vec::new();
}
let w = self.window;
let mut out = Vec::with_capacity(n);
for (i, bc) in base_chunks.iter().enumerate() {
let lo = i.saturating_sub(w);
let hi = (i + w).min(n - 1);
let parts: Vec<&str> = (lo..=hi)
.map(|j| base_chunks[j].embedded_content.as_str())
.collect();
let joined = parts.join("\n\n");
let mut merged = bc.metadata.as_object().cloned().unwrap_or_default();
merged.insert(
"neighbor_expand_window".to_string(),
serde_json::Value::from(w as u64),
);
out.push(Chunk {
doc_id: bc.doc_id.clone(),
seq_num: bc.seq_num,
original_content: bc.original_content.clone(),
embedded_content: joined,
metadata: serde_json::Value::Object(merged),
});
}
apply_if_oversize(
out,
self.effective_ceiling,
self.if_oversize_cfg.as_ref(),
"neighbor_expand",
doc,
0,
None,
self.warner.as_ref(),
)
.expect("if_oversize recursion")
}
}
impl ChunkerImpl for NeighborExpandChunker {
fn chunk(&self, doc: &Document) -> Vec<Chunk> {
Self::chunk(self, doc)
}
}
pub struct SemanticChunker {
cfg: SemanticChunkerConfig,
boundary: std::sync::Mutex<Box<dyn BoundaryEmbedder>>,
warner: Option<oversize::DedupedWarner>,
}
fn validate_semantic_cfg(cfg: &SemanticChunkerConfig) -> anyhow::Result<()> {
if cfg.sentence_splitter != "naive" {
return Err(anyhow::anyhow!(
"sentence_splitter {:?} not supported in chunkshop-rs (only 'naive'); \
nltk requires Python",
cfg.sentence_splitter
));
}
if cfg.breakpoint_percentile == 0 || cfg.breakpoint_percentile >= 100 {
return Err(anyhow::anyhow!(
"breakpoint_percentile must be in [1, 99], got {}",
cfg.breakpoint_percentile
));
}
if cfg.min_sentences_per_chunk < 1 {
return Err(anyhow::anyhow!("min_sentences_per_chunk must be >= 1"));
}
if cfg.max_chunk_chars < 100 {
return Err(anyhow::anyhow!(
"max_chunk_chars must be >= 100, got {}",
cfg.max_chunk_chars
));
}
Ok(())
}
impl SemanticChunker {
pub fn with_embedder(
cfg: SemanticChunkerConfig,
embedder: Box<dyn BoundaryEmbedder>,
) -> anyhow::Result<Self> {
validate_semantic_cfg(&cfg)?;
let warner = Some(oversize::DedupedWarner::new(
"semantic",
cfg.max_chunk_chars,
));
Ok(Self {
cfg,
boundary: std::sync::Mutex::new(embedder),
warner,
})
}
#[cfg(feature = "embedder-hub")]
pub fn new(cfg: SemanticChunkerConfig) -> anyhow::Result<Self> {
validate_semantic_cfg(&cfg)?;
let boundary_cfg = FastembedEmbedderConfig {
model_name: cfg.boundary_model.clone(),
dim: 384,
batch_size: 16,
threads: Some(2),
hf_repo: None,
onnx_path: None,
pooling: "cls".to_string(),
additional_files: vec![],
};
let boundary = FastembedEmbedder::new(boundary_cfg)?;
Self::with_embedder(cfg, Box::new(boundary))
}
pub fn chunk(&self, doc: &Document) -> Vec<Chunk> {
let chunks = self.chunk_inner(doc);
apply_if_oversize(
chunks,
Some(self.cfg.max_chunk_chars),
self.cfg.if_oversize.as_deref(),
"semantic",
doc,
0,
None,
self.warner.as_ref(),
)
.expect("if_oversize recursion")
}
fn chunk_inner(&self, doc: &Document) -> Vec<Chunk> {
if doc.content.is_empty() || doc.content.trim().is_empty() {
return Vec::new();
}
let sentences = naive_sentences(&doc.content);
if sentences.is_empty() {
return Vec::new();
}
if sentences.len() == 1 {
let mut chunks = Vec::new();
for sub in self.split_if_too_large(&sentences[0], (0, 1)) {
let seq = chunks.len();
chunks.push(self.mk_chunk(&doc.id, seq, &sub));
}
return chunks;
}
let refs: Vec<&str> = sentences.iter().map(String::as_str).collect();
let embeddings = match self
.boundary
.lock()
.expect("poisoned mutex")
.embed_batch(&refs)
{
Ok(v) => v,
Err(e) => {
tracing::error!("semantic chunker boundary embed failed: {e:#}");
return vec![self.mk_chunk(&doc.id, 0, &doc.content)];
}
};
let normed: Vec<Vec<f32>> = embeddings
.iter()
.map(|v| {
let n: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
let denom = if n == 0.0 { 1.0 } else { n };
v.iter().map(|x| x / denom).collect()
})
.collect();
let mut distances: Vec<f32> = Vec::with_capacity(normed.len() - 1);
for i in 0..normed.len() - 1 {
let dot: f32 = normed[i]
.iter()
.zip(&normed[i + 1])
.map(|(a, b)| a * b)
.sum();
distances.push(1.0 - dot);
}
let threshold = percentile_linear(&distances, self.cfg.breakpoint_percentile);
let breakpoints: Vec<usize> = distances
.iter()
.enumerate()
.filter_map(|(i, &d)| if d >= threshold { Some(i) } else { None })
.collect();
let mut starts: Vec<usize> = vec![0];
for &bp in &breakpoints {
starts.push(bp + 1);
}
let mut spans: Vec<(usize, usize)> = Vec::with_capacity(starts.len());
for i in 0..starts.len() {
let s = starts[i];
let e = if i + 1 < starts.len() {
starts[i + 1]
} else {
sentences.len()
};
spans.push((s, e));
}
let spans = merge_small_spans(spans, self.cfg.min_sentences_per_chunk);
let mut chunks: Vec<Chunk> = Vec::new();
for (s, e) in spans {
let body = sentences[s..e].join(" ").trim().to_string();
if body.is_empty() {
continue;
}
for sub in self.split_if_too_large(&body, (s, e)) {
let seq = chunks.len();
chunks.push(self.mk_chunk(&doc.id, seq, &sub));
}
}
chunks
}
fn mk_chunk(&self, doc_id: &str, seq: usize, text: &str) -> Chunk {
Chunk {
doc_id: doc_id.to_string(),
seq_num: seq,
original_content: text.to_string(),
embedded_content: text.to_string(),
metadata: json!({ "strategy": "semantic" }),
}
}
fn split_if_too_large(&self, body: &str, span: (usize, usize)) -> Vec<String> {
if body.chars().count() <= self.cfg.max_chunk_chars {
return vec![body.to_string()];
}
let sents = naive_sentences(body);
let sub_chunks: Vec<String> = if sents.is_empty() {
let chars: Vec<char> = body.chars().collect();
chars
.chunks(self.cfg.max_chunk_chars)
.map(|c| c.iter().collect())
.collect()
} else {
let mut out: Vec<String> = Vec::new();
let mut cur = String::new();
for s in sents {
let candidate = if cur.is_empty() {
s.clone()
} else {
let joined = format!("{cur} {s}");
joined.trim().to_string()
};
if candidate.chars().count() > self.cfg.max_chunk_chars && !cur.is_empty() {
out.push(cur.trim().to_string());
cur = s;
} else {
cur = candidate;
}
}
if !cur.is_empty() {
if cur.chars().count() > self.cfg.max_chunk_chars {
let chars: Vec<char> = cur.chars().collect();
for window in chars.chunks(self.cfg.max_chunk_chars) {
out.push(window.iter().collect());
}
} else {
out.push(cur.trim().to_string());
}
}
out
};
if sub_chunks.len() > 1 {
tracing::warn!(
target: "chunkshop::semantic",
max_chunk_chars = self.cfg.max_chunk_chars,
span_start = span.0,
span_end = span.1,
body_len = body.chars().count(),
sub_chunks = sub_chunks.len(),
"semantic chunk exceeded max_chunk_chars={}; hard-split into {} sub-chunks",
self.cfg.max_chunk_chars,
sub_chunks.len(),
);
}
sub_chunks
}
}
impl ChunkerImpl for SemanticChunker {
fn chunk(&self, doc: &Document) -> Vec<Chunk> {
Self::chunk(self, doc)
}
}
pub struct SummaryEmbedChunker {
base: Box<dyn ChunkerImpl + Send + Sync>,
summarizer: Box<dyn crate::summarizer::SummarizerImpl>,
mode: &'static str,
effective_ceiling: Option<usize>,
if_oversize_cfg: Option<ChunkerConfig>,
warner: Option<oversize::DedupedWarner>,
}
impl SummaryEmbedChunker {
pub fn new(
base: Box<dyn ChunkerImpl + Send + Sync>,
summarizer: Box<dyn crate::summarizer::SummarizerImpl>,
mode: &'static str,
effective_ceiling: Option<usize>,
if_oversize_cfg: Option<ChunkerConfig>,
) -> Self {
let warner = effective_ceiling.map(|m| oversize::DedupedWarner::new("summary_embed", m));
Self {
base,
summarizer,
mode,
effective_ceiling,
if_oversize_cfg,
warner,
}
}
pub fn chunk(&self, doc: &Document) -> Vec<Chunk> {
let base_chunks = self.base.chunk(doc);
let mut out = Vec::with_capacity(base_chunks.len());
for bc in base_chunks {
let summary = match self
.summarizer
.summarize(&bc.original_content, &doc.metadata)
{
Ok(s) => s,
Err(e) => {
tracing::error!(
"summary_embed: summarizer failed on doc={} seq={}: {e:#}",
bc.doc_id,
bc.seq_num
);
return Vec::new();
}
};
let mut meta = bc.metadata.as_object().cloned().unwrap_or_default();
meta.insert(
"summarizer".to_string(),
serde_json::Value::String(self.mode.to_string()),
);
out.push(Chunk {
doc_id: bc.doc_id,
seq_num: bc.seq_num,
original_content: bc.original_content,
embedded_content: summary,
metadata: serde_json::Value::Object(meta),
});
}
apply_if_oversize(
out,
self.effective_ceiling,
self.if_oversize_cfg.as_ref(),
"summary_embed",
doc,
0,
None,
self.warner.as_ref(),
)
.expect("if_oversize recursion")
}
}
impl ChunkerImpl for SummaryEmbedChunker {
fn chunk(&self, doc: &Document) -> Vec<Chunk> {
Self::chunk(self, doc)
}
}
pub struct HierarchicalSummaryChunker {
base: Box<dyn ChunkerImpl + Send + Sync>,
summarizer: Box<dyn crate::summarizer::SummarizerImpl>,
mode: &'static str,
grouping: HierarchicalGrouping,
effective_ceiling: Option<usize>,
if_oversize_cfg: Option<ChunkerConfig>,
warner: Option<oversize::DedupedWarner>,
}
pub enum HierarchicalGrouping {
FixedN(usize),
WordBudget(usize),
SectionAware,
}
impl HierarchicalSummaryChunker {
pub fn new(
base: Box<dyn ChunkerImpl + Send + Sync>,
summarizer: Box<dyn crate::summarizer::SummarizerImpl>,
mode: &'static str,
grouping: HierarchicalGrouping,
effective_ceiling: Option<usize>,
if_oversize_cfg: Option<ChunkerConfig>,
) -> Self {
let warner =
effective_ceiling.map(|m| oversize::DedupedWarner::new("hierarchical_summary", m));
Self {
base,
summarizer,
mode,
grouping,
effective_ceiling,
if_oversize_cfg,
warner,
}
}
pub fn chunk(&self, doc: &Document) -> Vec<Chunk> {
let base_chunks = self.base.chunk(doc);
if base_chunks.is_empty() {
return Vec::new();
}
let groups = self.group(base_chunks);
let mut out: Vec<Chunk> = Vec::new();
let mut seq: usize = 0;
for (group_idx, group_chunks) in groups.into_iter().enumerate() {
let group_id = format!("{}::g{}", doc.id, group_idx);
for bc in &group_chunks {
let mut meta = bc.metadata.as_object().cloned().unwrap_or_default();
meta.insert(
"granularity".to_string(),
serde_json::Value::String("fine".to_string()),
);
meta.insert(
"group_id".to_string(),
serde_json::Value::String(group_id.clone()),
);
meta.insert(
"summarizer".to_string(),
serde_json::Value::String(self.mode.to_string()),
);
out.push(Chunk {
doc_id: bc.doc_id.clone(),
seq_num: seq,
original_content: bc.original_content.clone(),
embedded_content: bc.embedded_content.clone(),
metadata: serde_json::Value::Object(meta),
});
seq += 1;
}
let joined = group_chunks
.iter()
.map(|c| c.original_content.as_str())
.collect::<Vec<_>>()
.join("\n\n");
let summary = match self.summarizer.summarize(&joined, &doc.metadata) {
Ok(s) => s,
Err(e) => {
tracing::error!(
"hierarchical_summary: summarizer failed on doc={} group={group_idx}: {e:#}",
doc.id
);
return Vec::new();
}
};
let mut coarse_meta = serde_json::Map::new();
coarse_meta.insert(
"granularity".to_string(),
serde_json::Value::String("coarse".to_string()),
);
coarse_meta.insert("group_id".to_string(), serde_json::Value::String(group_id));
coarse_meta.insert(
"summarizer".to_string(),
serde_json::Value::String(self.mode.to_string()),
);
coarse_meta.insert(
"strategy".to_string(),
serde_json::Value::String("hierarchical_summary".to_string()),
);
out.push(Chunk {
doc_id: doc.id.clone(),
seq_num: seq,
original_content: joined,
embedded_content: summary,
metadata: serde_json::Value::Object(coarse_meta),
});
seq += 1;
}
let skip =
|c: &Chunk| c.metadata.get("granularity").and_then(|v| v.as_str()) == Some("coarse");
apply_if_oversize(
out,
self.effective_ceiling,
self.if_oversize_cfg.as_ref(),
"hierarchical_summary",
doc,
0,
Some(&skip),
self.warner.as_ref(),
)
.expect("if_oversize recursion")
}
fn group(&self, chunks: Vec<Chunk>) -> Vec<Vec<Chunk>> {
match self.grouping {
HierarchicalGrouping::FixedN(n) => {
if n == 0 {
return vec![chunks];
}
let mut groups: Vec<Vec<Chunk>> = Vec::new();
let mut cur: Vec<Chunk> = Vec::new();
for c in chunks {
if cur.len() == n {
groups.push(std::mem::take(&mut cur));
}
cur.push(c);
}
if !cur.is_empty() {
groups.push(cur);
}
groups
}
HierarchicalGrouping::WordBudget(max_words) => {
let mut groups: Vec<Vec<Chunk>> = Vec::new();
let mut cur: Vec<Chunk> = Vec::new();
let mut cur_words = 0usize;
for c in chunks {
let w = c.original_content.split_whitespace().count();
if !cur.is_empty() && cur_words + w > max_words {
groups.push(std::mem::take(&mut cur));
cur_words = 0;
}
cur.push(c);
cur_words += w;
}
if !cur.is_empty() {
groups.push(cur);
}
groups
}
HierarchicalGrouping::SectionAware => {
let mut groups: Vec<Vec<Chunk>> = Vec::new();
let mut cur: Vec<Chunk> = Vec::new();
let mut cur_heading: Option<String> = None;
let mut have_heading = false;
for c in chunks {
let h = c
.metadata
.get("heading")
.and_then(|v| v.as_str())
.map(String::from);
if have_heading && h != cur_heading && !cur.is_empty() {
groups.push(std::mem::take(&mut cur));
}
cur_heading = h;
have_heading = true;
cur.push(c);
}
if !cur.is_empty() {
groups.push(cur);
}
groups
}
}
}
}
impl ChunkerImpl for HierarchicalSummaryChunker {
fn chunk(&self, doc: &Document) -> Vec<Chunk> {
Self::chunk(self, doc)
}
}
fn percentile_linear(values: &[f32], p: u32) -> f32 {
if values.is_empty() {
return 0.0;
}
let mut sorted: Vec<f32> = values.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = sorted.len();
if n == 1 {
return sorted[0];
}
let idx = (n as f64 - 1.0) * (p as f64) / 100.0;
let lo = idx.floor() as usize;
let hi = idx.ceil() as usize;
if lo == hi {
return sorted[lo];
}
let frac = (idx - lo as f64) as f32;
sorted[lo] * (1.0 - frac) + sorted[hi] * frac
}
fn merge_small_spans(spans: Vec<(usize, usize)>, min: usize) -> Vec<(usize, usize)> {
if spans.is_empty() {
return spans;
}
let mut merged: Vec<(usize, usize)> = Vec::new();
for (s, e) in spans {
if !merged.is_empty() && (e - s) < min {
let last = merged.len() - 1;
let (ps, _) = merged[last];
merged[last] = (ps, e);
} else {
merged.push((s, e));
}
}
if merged.len() > 1 && (merged[0].1 - merged[0].0) < min {
let new_first = (merged[0].0, merged[1].1);
merged[0] = new_first;
merged.remove(1);
}
if merged.len() > 1 && (merged[merged.len() - 1].1 - merged[merged.len() - 1].0) < min {
let last = merged.len() - 1;
let (ps, _) = merged[last - 1];
let (_, pe) = merged[last];
merged[last - 1] = (ps, pe);
merged.pop();
}
merged
}
pub fn build_chunker(cfg: ChunkerConfig) -> anyhow::Result<Box<dyn ChunkerImpl + Send + Sync>> {
Ok(match cfg {
ChunkerConfig::SentenceAware(c) => Box::new(SentenceAwareChunker::new(c)),
ChunkerConfig::Hierarchy(c) => Box::new(HierarchyChunker::new(c)?),
ChunkerConfig::FixedOverlap(c) => Box::new(FixedOverlapChunker::new(c)?),
ChunkerConfig::NeighborExpand(c) => {
let window = c.window;
let effective_ceiling = c.max_chars.or_else(|| c.base.effective_max_chars());
let if_oversize_cfg = c.if_oversize.as_deref().cloned();
let base = build_chunker(*c.base)?;
Box::new(NeighborExpandChunker::new(
window,
base,
effective_ceiling,
if_oversize_cfg,
))
}
#[cfg(feature = "embedder-hub")]
ChunkerConfig::Semantic(c) => Box::new(SemanticChunker::new(c)?),
#[cfg(not(feature = "embedder-hub"))]
ChunkerConfig::Semantic(_) => {
return Err(anyhow::anyhow!(
"semantic chunker requires the `embedder-hub` Cargo feature \
(rebuild with --features embedder-hub, --features embedder, or --features full)"
));
}
ChunkerConfig::SummaryEmbed(c) => {
let mode = c.summarizer.mode_str();
let effective_ceiling = c.max_chars.or_else(|| c.base.effective_max_chars());
let if_oversize_cfg = c.if_oversize.as_deref().cloned();
let base = build_chunker(*c.base)?;
let summarizer = build_summarizer(&c.summarizer)?;
Box::new(SummaryEmbedChunker::new(
base,
summarizer,
mode,
effective_ceiling,
if_oversize_cfg,
))
}
ChunkerConfig::HierarchicalSummary(c) => {
let mode = c.summarizer.mode_str();
let effective_ceiling = c.max_chars.or_else(|| c.base.effective_max_chars());
let if_oversize_cfg = c.if_oversize.as_deref().cloned();
let base = build_chunker(*c.base)?;
let summarizer = build_summarizer(&c.summarizer)?;
let grouping = match c.grouping {
GroupingConfig::FixedN(g) => HierarchicalGrouping::FixedN(g.n),
GroupingConfig::WordBudget(g) => HierarchicalGrouping::WordBudget(g.max_words),
GroupingConfig::SectionAware(_) => HierarchicalGrouping::SectionAware,
};
Box::new(HierarchicalSummaryChunker::new(
base,
summarizer,
mode,
grouping,
effective_ceiling,
if_oversize_cfg,
))
}
ChunkerConfig::Consolidation(c) => {
let base = build_chunker(*c.base)?;
let consolidator = crate::consolidators::build_consolidator(&c.consolidator);
let fact_max_chars = c.fact_max_chars;
Box::new(ConsolidationChunker::new(
base,
consolidator,
fact_max_chars,
))
}
#[cfg(feature = "code-aware")]
ChunkerConfig::SymbolAware(c) => Box::new(
crate::chunkers::symbol_aware::SymbolAwareChunker::new(c),
),
})
}
pub struct ConsolidationChunker {
base: Box<dyn ChunkerImpl + Send + Sync>,
consolidator: Box<dyn crate::consolidators::Consolidator>,
fact_max_chars: usize,
}
impl ConsolidationChunker {
pub fn new(
base: Box<dyn ChunkerImpl + Send + Sync>,
consolidator: Box<dyn crate::consolidators::Consolidator>,
fact_max_chars: usize,
) -> Self {
Self {
base,
consolidator,
fact_max_chars,
}
}
}
impl ChunkerImpl for ConsolidationChunker {
fn chunk(&self, doc: &Document) -> Vec<Chunk> {
use serde_json::{json, Value};
let extractor_mode = self.consolidator.mode();
let session_id_value = doc
.metadata
.get("session_id")
.cloned()
.unwrap_or(Value::Null);
let start_ts = doc
.metadata
.get("episode_start_ts")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let end_ts = doc
.metadata
.get("episode_end_ts")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let frame_seq = doc
.metadata
.get("frame_seq")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let session_id_str = session_id_value.as_str().unwrap_or(&doc.id);
let episode_input = crate::consolidators::EpisodeInput {
text: &doc.content,
frame_seq,
session_id: session_id_str,
episode_start_ts: start_ts,
episode_end_ts: end_ts,
};
let mut out = self.base.chunk(doc);
for c in out.iter_mut() {
if let Some(obj) = c.metadata.as_object_mut() {
obj.insert("kind".into(), Value::String("episode".into()));
obj.insert("extractor".into(), Value::String(extractor_mode.into()));
if !obj.contains_key("session_id") && !session_id_value.is_null() {
obj.insert("session_id".into(), session_id_value.clone());
}
obj.insert("episode_end_ts".into(), json!(end_ts));
}
}
let next_seq_base = out.len();
let cons_out = match self.consolidator.consolidate(&episode_input) {
Ok(o) => o,
Err(e) => {
let msg = format!("{:#}", e);
for c in out.iter_mut() {
if let Some(obj) = c.metadata.as_object_mut() {
obj.insert("consolidation_error".into(), Value::String(msg.clone()));
}
}
return out;
}
};
for (i, fact) in cons_out.facts.iter().enumerate() {
let text_full = format!(
"{} {} {}",
fact.subject.trim(),
fact.predicate.trim(),
fact.object.trim()
);
let truncated = take_chars(&text_full, self.fact_max_chars);
let support = fact
.support_span
.as_ref()
.map(|s| take_chars(s, self.fact_max_chars))
.unwrap_or_default();
let mut meta = serde_json::Map::new();
meta.insert("kind".into(), Value::String("fact".into()));
meta.insert("extractor".into(), Value::String(extractor_mode.into()));
if !session_id_value.is_null() {
meta.insert("session_id".into(), session_id_value.clone());
}
meta.insert("subject".into(), Value::String(fact.subject.clone()));
meta.insert("predicate".into(), Value::String(fact.predicate.clone()));
meta.insert("object".into(), Value::String(fact.object.clone()));
if !support.is_empty() {
meta.insert("support_span".into(), Value::String(support));
}
if let Some(conf) = fact.confidence {
meta.insert("confidence".into(), json!(conf));
}
meta.insert("source_chunk_seq".into(), json!(i as u64));
meta.insert("episode_end_ts".into(), json!(end_ts));
out.push(Chunk {
doc_id: doc.id.clone(),
seq_num: next_seq_base + i,
original_content: truncated.clone(),
embedded_content: truncated,
metadata: Value::Object(meta),
});
}
out
}
}
fn take_chars(s: &str, n: usize) -> String {
s.chars().take(n).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn short_text_single_chunk() {
let doc = Document {
id: "t".into(),
content: "Just a short sentence.".into(),
title: None,
metadata: json!({}),
fingerprint: None,
};
let chunker = SentenceAwareChunker::new(SentenceAwareChunkerConfig {
doc_type: "prose".into(),
max_chars: 2000,
min_chars: 200,
if_oversize: None,
});
let chunks = chunker.chunk(&doc);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].original_content, "Just a short sentence.");
assert_eq!(chunks[0].embedded_content, chunks[0].original_content);
}
#[test]
fn markdown_headings_split_sections() {
let content = "# Top\n\nIntro para.\n\n## Section A\n\nBody A.\n\n## Section B\n\nBody B.";
let doc = Document {
id: "h".into(),
content: content.into(),
title: None,
metadata: json!({}),
fingerprint: None,
};
let chunker = SentenceAwareChunker::new(SentenceAwareChunkerConfig {
doc_type: "prose".into(),
max_chars: 2000,
min_chars: 0,
if_oversize: None,
});
let chunks = chunker.chunk(&doc);
assert_eq!(chunks.len(), 3);
assert!(chunks[0].original_content.starts_with("# Top"));
assert!(chunks[1].original_content.starts_with("## Section A"));
assert!(chunks[2].original_content.starts_with("## Section B"));
}
#[test]
fn hierarchy_default_uses_markdown_pattern() {
let content = "# Top\n\n\
Intro paragraph that is long enough to clear the minimum section threshold so the first section is not skipped by the chunker. We add more words here for safety.\n\n\
## Section A\n\n\
Body A. Body A continued so the section is long enough to clear min_section_chars without effort. Padding padding padding padding padding padding padding padding.\n\n\
## Section B\n\n\
Body B. Body B continued so the section is long enough to clear min_section_chars without effort. Padding padding padding padding padding padding padding padding.";
let doc = Document {
id: "doc-default".into(),
content: content.into(),
title: Some("Doc".into()),
metadata: json!({}),
fingerprint: None,
};
let chunker = HierarchyChunker::new(HierarchyChunkerConfig {
prefix_heading: true,
min_section_chars: 50,
max_chars: 2000,
if_oversize: None,
heading_pattern: None,
})
.expect("default config must compile");
let chunks = chunker.chunk(&doc);
assert_eq!(chunks.len(), 3, "expected 3 chunks, got {}", chunks.len());
assert_eq!(chunks[0].metadata["heading"], "Top");
assert_eq!(chunks[1].metadata["heading"], "Section A");
assert_eq!(chunks[2].metadata["heading"], "Section B");
}
#[test]
fn hierarchy_custom_pattern_honored() {
let content = ">>> Alpha\n\n\
Body alpha — long enough body so the section easily clears the minimum-section threshold and the chunker emits it as a real chunk. Padding padding padding padding padding padding.\n\n\
>>> Beta\n\n\
Body beta — long enough body so the section easily clears the minimum-section threshold and the chunker emits it as a real chunk. Padding padding padding padding padding padding.";
let doc = Document {
id: "doc-custom".into(),
content: content.into(),
title: None,
metadata: json!({}),
fingerprint: None,
};
let chunker = HierarchyChunker::new(HierarchyChunkerConfig {
prefix_heading: false,
min_section_chars: 50,
max_chars: 2000,
if_oversize: None,
heading_pattern: Some(r"(?m)^>>>\s+(.+)$".to_string()),
})
.expect("valid custom pattern must compile");
let chunks = chunker.chunk(&doc);
assert_eq!(
chunks.len(),
2,
"expected 2 chunks for two `>>>`-delimited sections, got {}",
chunks.len()
);
assert_eq!(chunks[0].metadata["heading"], "Alpha");
assert_eq!(chunks[1].metadata["heading"], "Beta");
assert!(chunks[0].original_content.contains("Body alpha"));
assert!(chunks[1].original_content.contains("Body beta"));
}
#[test]
fn hierarchy_invalid_pattern_returns_err() {
let result = HierarchyChunker::new(HierarchyChunkerConfig {
prefix_heading: true,
min_section_chars: 100,
max_chars: 2000,
if_oversize: None,
heading_pattern: Some("[invalid".to_string()),
});
assert!(result.is_err(), "expected Err for invalid regex, got Ok");
}
#[test]
fn percentile_linear_matches_numpy() {
let p = percentile_linear(&[1.0_f32, 2.0, 3.0, 4.0, 5.0], 95);
assert!((p - 4.8).abs() < 1e-5, "got {p}");
let p = percentile_linear(&[1.0, 2.0, 3.0, 4.0], 50);
assert!((p - 2.5).abs() < 1e-5, "got {p}");
assert_eq!(percentile_linear(&[7.0], 95), 7.0);
assert_eq!(percentile_linear(&[], 95), 0.0);
}
#[test]
fn merge_small_spans_forward() {
let m = merge_small_spans(vec![(0, 5), (5, 6), (6, 10)], 3);
assert_eq!(m, vec![(0, 6), (6, 10)]);
}
#[test]
fn merge_small_spans_backward_last() {
let m = merge_small_spans(vec![(0, 5), (5, 10), (10, 11)], 3);
assert_eq!(m, vec![(0, 5), (5, 11)]);
}
#[test]
fn merge_small_spans_first_too_small_pulls_next() {
let m = merge_small_spans(vec![(0, 1), (1, 5), (5, 10)], 3);
assert_eq!(m, vec![(0, 5), (5, 10)]);
}
#[test]
fn merge_small_spans_empty_returns_empty() {
let m: Vec<(usize, usize)> = merge_small_spans(Vec::new(), 3);
assert!(m.is_empty());
}
struct StubEmbedder {
canned: Vec<Vec<f32>>,
calls: std::sync::atomic::AtomicUsize,
}
impl BoundaryEmbedder for StubEmbedder {
fn embed_batch(&mut self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let mut out = Vec::with_capacity(texts.len());
for i in 0..texts.len() {
out.push(self.canned[i % self.canned.len()].clone());
}
Ok(out)
}
}
fn semantic_cfg() -> SemanticChunkerConfig {
SemanticChunkerConfig {
sentence_splitter: "naive".into(),
boundary_model: "ignored-by-with_embedder".into(),
breakpoint_percentile: 50,
min_sentences_per_chunk: 1,
max_chunk_chars: 2000,
if_oversize: None,
}
}
#[test]
fn semantic_chunker_with_embedder_uses_supplied_impl() {
let stub = StubEmbedder {
canned: vec![vec![1.0, 0.0], vec![0.0, 1.0]],
calls: std::sync::atomic::AtomicUsize::new(0),
};
let chunker = SemanticChunker::with_embedder(semantic_cfg(), Box::new(stub))
.expect("with_embedder ok");
let doc = Document {
id: "d".into(),
content: "First sentence here. Second sentence here. Third sentence here.".into(),
title: None,
metadata: json!({}),
fingerprint: None,
};
let chunks = chunker.chunk(&doc);
assert!(
!chunks.is_empty(),
"stub embedder should produce at least one chunk"
);
assert!(chunks.iter().any(|c| !c.original_content.is_empty()));
}
#[test]
fn boundary_embedder_trait_is_object_safe() {
let stub = StubEmbedder {
canned: vec![vec![0.0]],
calls: std::sync::atomic::AtomicUsize::new(0),
};
let _boxed: Box<dyn BoundaryEmbedder> = Box::new(stub);
}
#[test]
fn semantic_chunker_with_embedder_validates_cfg() {
let bad_cfg = SemanticChunkerConfig {
sentence_splitter: "nltk".into(),
..semantic_cfg()
};
let stub = StubEmbedder {
canned: vec![vec![0.0]],
calls: std::sync::atomic::AtomicUsize::new(0),
};
let r = SemanticChunker::with_embedder(bad_cfg, Box::new(stub));
assert!(r.is_err(), "non-naive splitter should error");
}
fn mk(seq: usize, content: &str, heading: Option<&str>) -> Chunk {
let meta = match heading {
Some(h) => json!({"heading": h}),
None => json!({}),
};
Chunk {
doc_id: "d".into(),
seq_num: seq,
original_content: content.to_string(),
embedded_content: content.to_string(),
metadata: meta,
}
}
fn group_sizes(grouping: HierarchicalGrouping, chunks: Vec<Chunk>) -> Vec<usize> {
struct NoopBase;
impl ChunkerImpl for NoopBase {
fn chunk(&self, _: &Document) -> Vec<Chunk> {
Vec::new()
}
}
struct NoopSum;
impl crate::summarizer::SummarizerImpl for NoopSum {
fn summarize(&self, _: &str, _: &serde_json::Value) -> anyhow::Result<String> {
Ok(String::new())
}
}
let chunker = HierarchicalSummaryChunker::new(
Box::new(NoopBase),
Box::new(NoopSum),
"passthrough",
grouping,
None,
None,
);
chunker.group(chunks).into_iter().map(|g| g.len()).collect()
}
#[test]
fn fixed_n_groups_with_leftover() {
let chunks: Vec<Chunk> = (0..7).map(|i| mk(i, "x", None)).collect();
let sizes = group_sizes(HierarchicalGrouping::FixedN(3), chunks);
assert_eq!(sizes, vec![3, 3, 1]);
}
#[test]
fn word_budget_starts_new_group_when_over() {
let body = "one two three four five six seven eight nine ten";
let chunks: Vec<Chunk> = (0..5).map(|i| mk(i, body, None)).collect();
let sizes = group_sizes(HierarchicalGrouping::WordBudget(25), chunks);
assert_eq!(sizes, vec![2, 2, 1]);
}
#[test]
fn section_aware_groups_by_heading_change() {
let chunks = vec![
mk(0, "a1", Some("A")),
mk(1, "a2", Some("A")),
mk(2, "b1", Some("B")),
mk(3, "b2", Some("B")),
mk(4, "b3", Some("B")),
mk(5, "c1", Some("C")),
];
let sizes = group_sizes(HierarchicalGrouping::SectionAware, chunks);
assert_eq!(sizes, vec![2, 3, 1]);
}
use crate::consolidators::{ConsolidationOutput, Consolidator, EpisodeInput, FactTriple};
struct FakeCons {
facts: Vec<FactTriple>,
mode: &'static str,
}
impl Consolidator for FakeCons {
fn consolidate(&self, _e: &EpisodeInput<'_>) -> anyhow::Result<ConsolidationOutput> {
Ok(ConsolidationOutput {
summary: "fake summary".into(),
facts: self.facts.clone(),
})
}
fn mode(&self) -> &'static str {
self.mode
}
}
struct ErrCons;
impl Consolidator for ErrCons {
fn consolidate(&self, _e: &EpisodeInput<'_>) -> anyhow::Result<ConsolidationOutput> {
Err(anyhow::anyhow!("simulated consolidator failure"))
}
fn mode(&self) -> &'static str {
"broken"
}
}
fn episode_doc(text: &str) -> Document {
Document {
id: "s1".into(),
content: text.into(),
title: None,
metadata: serde_json::json!({
"session_id": "s1",
"frame_seq": 0u64,
"episode_start_ts": 100.0_f64,
"episode_end_ts": 200.0_f64,
}),
fingerprint: None,
}
}
fn fact(subj: &str, pred: &str, obj: &str) -> FactTriple {
FactTriple {
subject: subj.into(),
predicate: pred.into(),
object: obj.into(),
support_span: Some(format!("because: {} {}", subj, obj)),
confidence: Some(0.8),
}
}
fn base_chunker() -> Box<dyn ChunkerImpl + Send + Sync> {
Box::new(SentenceAwareChunker::new(SentenceAwareChunkerConfig {
doc_type: "prose".into(),
max_chars: 2000,
min_chars: 50,
if_oversize: None,
}))
}
#[test]
fn consolidation_emits_episode_plus_one_fact_per_triple() {
let cons = Box::new(FakeCons {
facts: vec![
fact("queue", "uses", "postgres"),
fact("api", "calls", "search"),
],
mode: "fake",
});
let chunker = ConsolidationChunker::new(base_chunker(), cons, 1200);
let doc = episode_doc("We migrated the queue to postgres. The api calls search.");
let chunks = chunker.chunk(&doc);
let kinds: Vec<&str> = chunks
.iter()
.map(|c| c.metadata["kind"].as_str().unwrap_or(""))
.collect();
assert!(
kinds.iter().any(|k| *k == "episode"),
"missing episode kind: {kinds:?}"
);
let fact_count = kinds.iter().filter(|k| **k == "fact").count();
assert_eq!(fact_count, 2, "expected 2 fact chunks; got {kinds:?}");
let fact_chunk = chunks
.iter()
.find(|c| c.metadata["kind"] == "fact")
.unwrap();
assert_eq!(fact_chunk.metadata["subject"], "queue");
assert_eq!(fact_chunk.metadata["predicate"], "uses");
assert_eq!(fact_chunk.metadata["object"], "postgres");
assert_eq!(fact_chunk.metadata["extractor"], "fake");
assert!(fact_chunk.metadata.get("support_span").is_some());
}
#[test]
fn consolidation_fact_chunk_truncated_to_fact_max_chars() {
let long_obj = "x".repeat(2000);
let cons = Box::new(FakeCons {
facts: vec![fact("s", "p", &long_obj)],
mode: "fake",
});
let chunker = ConsolidationChunker::new(base_chunker(), cons, 200);
let chunks = chunker.chunk(&episode_doc("episode body."));
let fact_chunk = chunks
.iter()
.find(|c| c.metadata["kind"] == "fact")
.unwrap();
assert!(fact_chunk.original_content.chars().count() <= 200);
}
#[test]
fn consolidation_o4_resilience_on_consolidator_error() {
let chunker = ConsolidationChunker::new(base_chunker(), Box::new(ErrCons), 1200);
let chunks = chunker.chunk(&episode_doc("an episode."));
let kinds: Vec<&str> = chunks
.iter()
.map(|c| c.metadata["kind"].as_str().unwrap_or(""))
.collect();
assert!(
kinds.iter().all(|k| *k == "episode"),
"no facts allowed: {kinds:?}"
);
assert!(!chunks.is_empty(), "should still emit episode chunks");
for c in &chunks {
assert!(
c.metadata.get("consolidation_error").is_some(),
"missing consolidation_error stamp: {:?}",
c.metadata
);
}
}
#[test]
fn consolidation_episode_carries_session_id_and_extractor() {
let cons = Box::new(FakeCons {
facts: vec![],
mode: "extractive",
});
let chunker = ConsolidationChunker::new(base_chunker(), cons, 1200);
let chunks = chunker.chunk(&episode_doc("hello world. another sentence here."));
for c in chunks.iter().filter(|c| c.metadata["kind"] == "episode") {
assert_eq!(c.metadata["session_id"], "s1");
assert_eq!(c.metadata["extractor"], "extractive");
assert_eq!(c.metadata["episode_end_ts"], 200.0);
}
}
}