use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::metadata_filter::MetadataValue;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ChunkConfig {
pub chunk_size: usize,
pub overlap: usize,
pub min_chunk_size: usize,
}
impl Default for ChunkConfig {
fn default() -> Self {
Self {
chunk_size: 512,
overlap: 64,
min_chunk_size: 32,
}
}
}
impl ChunkConfig {
pub fn validate(&self) -> Result<(), String> {
if self.chunk_size == 0 {
return Err("chunk_size must be > 0".into());
}
if self.overlap >= self.chunk_size {
return Err(format!(
"overlap ({}) must be < chunk_size ({})",
self.overlap, self.chunk_size
));
}
Ok(())
}
#[must_use]
pub fn with_chunk_size(mut self, chunk_size: usize) -> Self {
self.chunk_size = chunk_size;
self
}
#[must_use]
pub fn with_overlap(mut self, overlap: usize) -> Self {
self.overlap = overlap;
self
}
#[must_use]
pub fn with_min_chunk_size(mut self, min_chunk_size: usize) -> Self {
self.min_chunk_size = min_chunk_size;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Chunk {
pub text: String,
pub doc_id: usize,
pub chunk_idx: usize,
pub char_offset: usize,
#[serde(default)]
pub metadata: HashMap<String, MetadataValue>,
}
impl Chunk {
pub fn new(text: String, doc_id: usize, chunk_idx: usize, char_offset: usize) -> Self {
Self {
text,
doc_id,
chunk_idx,
char_offset,
metadata: HashMap::new(),
}
}
#[must_use]
pub fn with_metadata(
mut self,
key: impl Into<String>,
value: impl Into<MetadataValue>,
) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
pub fn chunk_document(text: &str, doc_id: usize, config: &ChunkConfig) -> Vec<Chunk> {
if text.is_empty() {
return Vec::new();
}
let chars: Vec<char> = text.chars().collect();
let total = chars.len();
if total < config.min_chunk_size {
return Vec::new();
}
let step = config.chunk_size.saturating_sub(config.overlap).max(1);
let mut chunks = Vec::new();
let mut start = 0usize;
while start < total {
let end = (start + config.chunk_size).min(total);
let chunk_chars = &chars[start..end];
let chunk_text: String = chunk_chars.iter().collect();
if chunk_text.chars().count() >= config.min_chunk_size {
let byte_offset = byte_offset_of_char(text, start);
let chunk_idx = chunks.len();
chunks.push(Chunk::new(chunk_text, doc_id, chunk_idx, byte_offset));
}
if end == total {
break;
}
start += step;
}
chunks
}
pub fn chunk_by_sentences(text: &str, doc_id: usize, max_sentences: usize) -> Vec<Chunk> {
if text.is_empty() || max_sentences == 0 {
return Vec::new();
}
let sentences = split_sentences(text);
if sentences.is_empty() {
return Vec::new();
}
let mut chunks = Vec::new();
let mut sentence_start_byte = 0usize;
let mut i = 0usize;
while i < sentences.len() {
let batch_start = i;
let batch_end = (i + max_sentences).min(sentences.len());
let batch: Vec<&str> = sentences[batch_start..batch_end].to_vec();
let chunk_text = batch.join(" ");
let chunk_idx = chunks.len();
chunks.push(Chunk::new(
chunk_text,
doc_id,
chunk_idx,
sentence_start_byte,
));
for s in &batch {
sentence_start_byte += s.len();
sentence_start_byte += 1; }
i = batch_end;
}
chunks
}
fn split_sentences(text: &str) -> Vec<&str> {
let mut sentences = Vec::new();
let bytes = text.as_bytes();
let mut start = 0usize;
let mut i = 0usize;
while i < bytes.len() {
let b = bytes[i];
if b == b'.' || b == b'!' || b == b'?' {
let mut j = i + 1;
while j < bytes.len() && (bytes[j] == b'.' || bytes[j] == b'!' || bytes[j] == b'?') {
j += 1;
}
while j < bytes.len()
&& (bytes[j] == b' ' || bytes[j] == b'\t' || bytes[j] == b'\n' || bytes[j] == b'\r')
{
j += 1;
}
let sentence = text[start..j].trim();
if !sentence.is_empty() {
sentences.push(sentence);
}
start = j;
i = j;
} else {
i += 1;
}
}
let tail = text[start..].trim();
if !tail.is_empty() {
sentences.push(tail);
}
sentences
}
pub fn chunk_by_paragraphs(text: &str, doc_id: usize) -> Vec<Chunk> {
if text.is_empty() {
return Vec::new();
}
let mut chunks = Vec::new();
let mut byte_cursor = 0usize;
let mut para_start = 0usize;
let mut prev_line_empty = false;
let mut line_start = 0usize;
let text_bytes = text.as_bytes();
let mut i = 0usize;
while i <= text_bytes.len() {
let is_eot = i == text_bytes.len();
let is_newline = !is_eot && (text_bytes[i] == b'\n');
if is_newline || is_eot {
let line = text[line_start..i].trim();
let is_blank = line.is_empty();
if is_blank && !prev_line_empty {
let para = text[para_start..line_start].trim();
if !para.is_empty() {
let chunk_idx = chunks.len();
chunks.push(Chunk::new(para.to_string(), doc_id, chunk_idx, byte_cursor));
byte_cursor = i;
}
para_start = i + 1;
} else if !is_blank {
if prev_line_empty {
para_start = line_start;
}
}
prev_line_empty = is_blank;
line_start = i + 1;
}
if is_eot {
let para = text[para_start..].trim();
if !para.is_empty() {
let chunk_idx = chunks.len();
chunks.push(Chunk::new(para.to_string(), doc_id, chunk_idx, byte_cursor));
}
break;
}
i += 1;
}
chunks
}
fn byte_offset_of_char(s: &str, n: usize) -> usize {
s.char_indices().nth(n).map(|(b, _)| b).unwrap_or(s.len())
}