use lazy_static::lazy_static;
use regex::Regex;
#[derive(Debug, Clone, PartialEq)]
pub struct Chunk {
pub text: String,
pub start: usize,
pub end: usize,
pub tokens: usize,
}
const CHARS_PER_TOKEN: usize = 4;
lazy_static! {
static ref PARAGRAPH_PATTERN: Regex = Regex::new(r"\n\n+").unwrap();
static ref SENTENCE_PATTERN: Regex = Regex::new(r"[.!?]\s+").unwrap();
}
fn estimate_tokens(text: &str) -> usize {
(text.len() + CHARS_PER_TOKEN - 1) / CHARS_PER_TOKEN
}
fn split_by_sentence_endings(text: &str) -> Vec<&str> {
let mut sentences = Vec::new();
let mut start = 0;
for (i, c) in text.char_indices() {
if c == '.' || c == '!' || c == '?' {
if i + 1 < text.len() {
let next_idx = i + c.len_utf8();
let sentence = text[start..next_idx].trim();
if !sentence.is_empty() {
sentences.push(sentence);
}
start = next_idx;
}
}
}
if start < text.len() {
let remaining = text[start..].trim();
if !remaining.is_empty() {
sentences.push(remaining);
}
}
if sentences.is_empty() && !text.trim().is_empty() {
sentences.push(text.trim());
}
sentences
}
pub fn chunk_text(text: &str, target_tokens: Option<usize>, overlap_ratio: Option<f64>) -> Vec<Chunk> {
let target = target_tokens.unwrap_or(768);
let overlap = overlap_ratio.unwrap_or(0.1);
let total_tokens = estimate_tokens(text);
if total_tokens <= target {
return vec![Chunk {
text: text.to_string(),
start: 0,
end: text.len(),
tokens: total_tokens,
}];
}
let target_chars = target * CHARS_PER_TOKEN;
let overlap_chars = (target_chars as f64 * overlap) as usize;
let paragraphs: Vec<&str> = PARAGRAPH_PATTERN.split(text).collect();
let mut chunks = Vec::new();
let mut current = String::new();
let mut current_start = 0;
for paragraph in paragraphs {
let sentences = split_by_sentence_endings(paragraph);
for sentence in sentences {
let potential = if current.is_empty() {
sentence.to_string()
} else {
format!("{} {}", current, sentence)
};
if potential.len() > target_chars && !current.is_empty() {
chunks.push(Chunk {
text: current.clone(),
start: current_start,
end: current_start + current.len(),
tokens: estimate_tokens(¤t),
});
let overlap_text = if current.len() > overlap_chars {
current[current.len() - overlap_chars..].to_string()
} else {
current.clone()
};
let overlap_len = overlap_text.len();
current = format!("{} {}", overlap_text, sentence);
current_start = current_start + current.len() - overlap_len - 1;
} else {
current = potential;
}
}
}
if !current.is_empty() {
chunks.push(Chunk {
text: current.clone(),
start: current_start,
end: current_start + current.len(),
tokens: estimate_tokens(¤t),
});
}
chunks
}
pub fn aggregate_vectors(vectors: &[Vec<f32>]) -> Vec<f32> {
if vectors.is_empty() {
panic!("Cannot aggregate empty vector list");
}
if vectors.len() == 1 {
return vectors[0].clone();
}
let dim = vectors[0].len();
let mut result = vec![0.0f32; dim];
let scale = 1.0 / vectors.len() as f32;
for vec in vectors {
for (i, &val) in vec.iter().enumerate() {
result[i] += val;
}
}
for val in &mut result {
*val *= scale;
}
result
}
pub fn join_chunks(chunks: &[Chunk]) -> String {
if chunks.is_empty() {
return String::new();
}
chunks
.iter()
.map(|c| c.text.as_str())
.collect::<Vec<_>>()
.join(" ")
}
pub fn split_sentences(text: &str) -> Vec<String> {
let mut sentences = Vec::new();
let mut current = String::new();
for c in text.chars() {
current.push(c);
if c == '.' || c == '!' || c == '?' {
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
sentences.push(trimmed);
}
current = String::new();
}
}
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
sentences.push(trimmed);
}
sentences
}
pub fn split_paragraphs(text: &str) -> Vec<String> {
PARAGRAPH_PATTERN
.split(text)
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_tokens() {
assert_eq!(estimate_tokens(""), 0);
assert_eq!(estimate_tokens("test"), 1);
assert_eq!(estimate_tokens("hello world"), 3); }
#[test]
fn test_chunk_small_text() {
let chunks = chunk_text("Small text.", None, None);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].text, "Small text.");
assert_eq!(chunks[0].start, 0);
}
#[test]
fn test_chunk_large_text() {
let long_text = "This is a sentence. ".repeat(500);
let chunks = chunk_text(&long_text, Some(256), Some(0.1));
assert!(chunks.len() > 1);
for chunk in &chunks {
assert!(chunk.tokens <= 256 + 50); }
}
#[test]
fn test_aggregate_vectors_single() {
let vecs = vec![vec![1.0, 2.0, 3.0]];
let result = aggregate_vectors(&vecs);
assert_eq!(result, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_aggregate_vectors_multiple() {
let vecs = vec![vec![1.0, 2.0, 3.0], vec![3.0, 4.0, 5.0]];
let result = aggregate_vectors(&vecs);
assert_eq!(result, vec![2.0, 3.0, 4.0]);
}
#[test]
#[should_panic(expected = "Cannot aggregate empty")]
fn test_aggregate_vectors_empty() {
let vecs: Vec<Vec<f32>> = vec![];
aggregate_vectors(&vecs);
}
#[test]
fn test_join_chunks() {
let chunks = vec![
Chunk {
text: "Hello".to_string(),
start: 0,
end: 5,
tokens: 2,
},
Chunk {
text: "World".to_string(),
start: 5,
end: 10,
tokens: 2,
},
];
assert_eq!(join_chunks(&chunks), "Hello World");
}
#[test]
fn test_join_chunks_empty() {
let chunks: Vec<Chunk> = vec![];
assert_eq!(join_chunks(&chunks), "");
}
#[test]
fn test_split_sentences() {
let text = "First sentence. Second sentence! Third one?";
let sentences = split_sentences(text);
assert_eq!(sentences.len(), 3);
assert_eq!(sentences[0], "First sentence.");
}
#[test]
fn test_split_paragraphs() {
let text = "Paragraph one.\n\nParagraph two.\n\n\nParagraph three.";
let paragraphs = split_paragraphs(text);
assert_eq!(paragraphs.len(), 3);
}
}