use crate::Chunk;
use crate::recursive::split_recursive;
use tiktoken::CoreBpe;
pub(crate) async fn split_semantic(
text: &str,
max_tokens: usize,
overlap_tokens: usize,
encoder: &CoreBpe,
client: &embedrs::Client,
threshold: f64,
) -> Result<Vec<Chunk>, crate::Error> {
if text.is_empty() {
return Ok(Vec::new());
}
let sentences = split_sentences(text);
if sentences.len() <= 1 {
let chunks = split_recursive(text, 0, max_tokens, overlap_tokens, encoder, &None);
return Ok(chunks);
}
let sentence_texts: Vec<String> = sentences.iter().map(|s| s.text.to_string()).collect();
let embed_result = client
.embed(sentence_texts)
.await
.map_err(crate::Error::Embed)?;
let breakpoints = find_breakpoints(&embed_result.embeddings, threshold);
let groups = merge_at_breakpoints(&sentences, &breakpoints);
let mut all_chunks = Vec::new();
for group in &groups {
let content: String = group.iter().map(|s| s.text).collect::<Vec<_>>().join("");
let group_offset = group.first().map(|s| s.byte_offset).unwrap_or(0);
let sub_chunks = split_recursive(
&content,
group_offset,
max_tokens,
overlap_tokens,
encoder,
&None,
);
all_chunks.extend(sub_chunks);
}
for (i, chunk) in all_chunks.iter_mut().enumerate() {
chunk.index = i;
}
Ok(all_chunks)
}
struct Sentence<'a> {
text: &'a str,
byte_offset: usize,
}
fn split_sentences(text: &str) -> Vec<Sentence<'_>> {
let mut sentences = Vec::new();
let mut start = 0;
let terminators = [". ", "! ", "? ", ".\n", "!\n", "?\n"];
let mut i = 0;
while i < text.len() {
let mut found = false;
for term in &terminators {
if text[i..].starts_with(term) {
let end = i + term.len();
let s = &text[start..end];
if !s.trim().is_empty() {
sentences.push(Sentence {
text: s,
byte_offset: start,
});
}
start = end;
i = end;
found = true;
break;
}
}
if !found {
i += text[i..].chars().next().map_or(1, |c| c.len_utf8());
}
}
if start < text.len() {
let s = &text[start..];
if !s.trim().is_empty() {
sentences.push(Sentence {
text: s,
byte_offset: start,
});
}
}
sentences
}
fn find_breakpoints(embeddings: &[Vec<f32>], threshold: f64) -> Vec<usize> {
let mut breakpoints = Vec::new();
for i in 1..embeddings.len() {
let sim = cosine_similarity(&embeddings[i - 1], &embeddings[i]);
if sim < threshold {
breakpoints.push(i);
}
}
breakpoints
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
let mut dot = 0.0f64;
let mut norm_a = 0.0f64;
let mut norm_b = 0.0f64;
for (x, y) in a.iter().zip(b.iter()) {
let x = *x as f64;
let y = *y as f64;
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom == 0.0 { 0.0 } else { dot / denom }
}
fn merge_at_breakpoints<'a>(
sentences: &'a [Sentence<'a>],
breakpoints: &[usize],
) -> Vec<Vec<&'a Sentence<'a>>> {
let mut groups: Vec<Vec<&Sentence<'a>>> = Vec::new();
let mut current_group: Vec<&Sentence<'a>> = Vec::new();
let break_set: std::collections::HashSet<usize> = breakpoints.iter().copied().collect();
for (i, sentence) in sentences.iter().enumerate() {
if break_set.contains(&i) && !current_group.is_empty() {
groups.push(current_group);
current_group = Vec::new();
}
current_group.push(sentence);
}
if !current_group.is_empty() {
groups.push(current_group);
}
groups
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn split_sentences_basic() {
let sentences = split_sentences("First sentence. Second sentence. Third.");
assert_eq!(sentences.len(), 3);
assert_eq!(sentences[0].text, "First sentence. ");
assert_eq!(sentences[1].text, "Second sentence. ");
assert_eq!(sentences[2].text, "Third.");
}
#[test]
fn split_sentences_empty() {
let sentences = split_sentences("");
assert!(sentences.is_empty());
}
#[test]
fn split_sentences_no_terminators() {
let sentences = split_sentences("no terminators here");
assert_eq!(sentences.len(), 1);
assert_eq!(sentences[0].text, "no terminators here");
}
#[test]
fn split_sentences_with_exclamation() {
let sentences = split_sentences("Hello! How are you? Fine.");
assert_eq!(sentences.len(), 3);
}
#[test]
fn split_sentences_preserves_offset() {
let sentences = split_sentences("First. Second.");
assert_eq!(sentences[0].byte_offset, 0);
assert_eq!(sentences[1].byte_offset, 7);
}
#[test]
fn find_breakpoints_high_threshold() {
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let bp = find_breakpoints(&embeddings, 0.5);
assert_eq!(bp, vec![1, 2]);
}
#[test]
fn find_breakpoints_low_threshold() {
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.8, 0.2, 0.0],
];
let bp = find_breakpoints(&embeddings, 0.1);
assert!(bp.is_empty());
}
#[test]
fn cosine_similarity_identical() {
let sim = cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_orthogonal() {
let sim = cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]);
assert!(sim.abs() < 1e-6);
}
#[test]
fn merge_at_breakpoints_basic() {
let s1 = Sentence {
text: "a",
byte_offset: 0,
};
let s2 = Sentence {
text: "b",
byte_offset: 1,
};
let s3 = Sentence {
text: "c",
byte_offset: 2,
};
let sentences = vec![s1, s2, s3];
let groups = merge_at_breakpoints(&sentences, &[2]);
assert_eq!(groups.len(), 2);
assert_eq!(groups[0].len(), 2); assert_eq!(groups[1].len(), 1); }
#[test]
fn split_sentences_cjk_no_panic() {
let sentences = split_sentences("这是第一句。第二句在这里。最后一句。");
assert!(!sentences.is_empty());
for s in &sentences {
assert!(!s.text.contains('\u{FFFD}'));
}
}
#[test]
fn split_sentences_cjk_with_ascii_terminators() {
let sentences = split_sentences("这是第一句. 第二句在这里. 最后一句.");
assert_eq!(sentences.len(), 3);
}
#[test]
fn split_sentences_japanese_no_panic() {
let sentences = split_sentences("最初の文です。次の文です。最後。");
assert!(!sentences.is_empty());
}
#[test]
fn merge_at_breakpoints_no_breaks() {
let s1 = Sentence {
text: "a",
byte_offset: 0,
};
let s2 = Sentence {
text: "b",
byte_offset: 1,
};
let sentences = vec![s1, s2];
let groups = merge_at_breakpoints(&sentences, &[]);
assert_eq!(groups.len(), 1);
assert_eq!(groups[0].len(), 2);
}
}