use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Chunk {
pub text: String,
pub start_char: usize,
pub end_char: usize,
pub token_estimate: usize,
pub chunk_type: ChunkType,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ChunkType {
Text,
Code,
Heading,
List,
Table,
Paragraph,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SplitConfig {
pub max_tokens: usize, pub overlap_tokens: usize, pub min_chunk_size: usize, pub respect_sentences: bool, pub respect_paragraphs: bool,
pub respect_headings: bool,
pub code_aware: bool, pub chars_per_token: f64, }
impl Default for SplitConfig {
fn default() -> Self {
Self { max_tokens: 256, overlap_tokens: 32, min_chunk_size: 50,
respect_sentences: true, respect_paragraphs: true,
respect_headings: true, code_aware: true, chars_per_token: 4.0 }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SplitStats {
pub input_chars: usize,
pub input_tokens_est: usize,
pub chunks: usize,
pub avg_chunk_tokens: f64,
pub avg_chunk_chars: f64,
pub max_chunk_chars: usize,
pub min_chunk_chars: usize,
}
pub struct TileSplit {
config: SplitConfig,
}
impl TileSplit {
pub fn new(config: SplitConfig) -> Self {
Self { config }
}
pub fn split(&self, text: &str) -> Vec<Chunk> {
if text.is_empty() { return Vec::new(); }
let max_chars = (self.config.max_tokens as f64 * self.config.chars_per_token) as usize;
let overlap_chars = (self.config.overlap_tokens as f64 * self.config.chars_per_token) as usize;
if text.len() <= max_chars {
return vec![Chunk {
text: text.to_string(), start_char: 0, end_char: text.len(),
token_estimate: Self::estimate_tokens(text, self.config.chars_per_token),
chunk_type: self.detect_type(text), metadata: HashMap::new()
}];
}
let segments = self.find_segments(text);
let mut chunks = Vec::new();
let mut buffer = String::new();
let mut buffer_start = 0;
let mut buffer_tokens = 0;
for segment in &segments {
let seg_tokens = Self::estimate_tokens(segment, self.config.chars_per_token);
if buffer_tokens + seg_tokens > self.config.max_tokens && !buffer.is_empty() {
let chunk_text = buffer.trim().to_string();
if chunk_text.len() >= self.config.min_chunk_size {
chunks.push(Chunk {
text: chunk_text.clone(), start_char: buffer_start,
end_char: buffer_start + buffer.len(),
token_estimate: buffer_tokens,
chunk_type: self.detect_type(&chunk_text),
metadata: HashMap::new()
});
}
let overlap_start = if buffer.len() > overlap_chars {
buffer.len() - overlap_chars
} else { 0 };
buffer = buffer[overlap_start..].to_string();
buffer_start = buffer_start + overlap_start;
buffer_tokens = Self::estimate_tokens(&buffer, self.config.chars_per_token);
}
buffer.push_str(segment);
buffer.push('\n');
buffer_tokens += seg_tokens;
}
let chunk_text = buffer.trim().to_string();
if chunk_text.len() >= self.config.min_chunk_size {
chunks.push(Chunk {
text: chunk_text.clone(), start_char: buffer_start,
end_char: buffer_start + buffer.len(),
token_estimate: buffer_tokens,
chunk_type: self.detect_type(&chunk_text),
metadata: HashMap::new()
});
}
chunks
}
pub fn split_n(&self, text: &str, n: usize) -> Vec<Chunk> {
if n <= 1 { return self.split(text); }
let chunk_size = text.len() / n;
let mut chunks = Vec::new();
let mut pos = 0;
for i in 0..n {
let end = if i == n - 1 { text.len() } else {
let mut boundary = pos + chunk_size;
if self.config.respect_sentences {
if let Some(idx) = text[pos..].find(". ") {
let candidate = pos + idx + 2;
if candidate <= pos + chunk_size + 50 {
boundary = candidate;
}
}
}
boundary.min(text.len())
};
let chunk_text = text[pos..end].trim().to_string();
chunks.push(Chunk {
text: chunk_text.clone(), start_char: pos, end_char: end,
token_estimate: Self::estimate_tokens(&chunk_text, self.config.chars_per_token),
chunk_type: self.detect_type(&chunk_text), metadata: HashMap::new()
});
pos = end;
}
chunks
}
pub fn split_by(&self, text: &str, delimiter: &str) -> Vec<Chunk> {
let mut chunks = Vec::new();
let mut pos = 0;
for part in text.split(delimiter) {
let trimmed = part.trim();
if trimmed.len() >= self.config.min_chunk_size {
let end = pos + part.len();
chunks.push(Chunk {
text: trimmed.to_string(), start_char: pos, end_char: end,
token_estimate: Self::estimate_tokens(trimmed, self.config.chars_per_token),
chunk_type: self.detect_type(trimmed), metadata: HashMap::from([("delimiter".into(), delimiter.to_string())])
});
}
pos += part.len() + delimiter.len();
}
chunks
}
pub fn split_code(&self, code: &str) -> Vec<Chunk> {
let mut chunks = Vec::new();
let mut pos = 0;
let mut brace_depth: usize = 0;
let mut block_start = 0;
let mut in_string = false;
let mut string_char = ' ';
for (i, c) in code.char_indices() {
match c {
'"' | '\'' if !in_string => { in_string = true; string_char = c; }
c if in_string && c == string_char => { in_string = false; }
'{' if !in_string => {
if brace_depth == 0 { block_start = pos; }
brace_depth += 1;
}
'}' if !in_string => {
brace_depth = brace_depth.saturating_sub(1);
if brace_depth == 0 && i > block_start {
let block = code[block_start..=i].trim().to_string();
if block.len() >= self.config.min_chunk_size {
chunks.push(Chunk {
text: block.clone(), start_char: block_start, end_char: i + 1,
token_estimate: Self::estimate_tokens(&block, self.config.chars_per_token),
chunk_type: ChunkType::Code, metadata: HashMap::new()
});
}
pos = i + 1;
}
}
'\n' if !in_string && brace_depth == 0 => {
let line = code[pos..i].trim();
if line.len() >= self.config.min_chunk_size {
chunks.push(Chunk {
text: line.to_string(), start_char: pos, end_char: i,
token_estimate: Self::estimate_tokens(line, self.config.chars_per_token),
chunk_type: ChunkType::Code, metadata: HashMap::new()
});
}
pos = i + 1;
}
_ => {}
}
}
chunks
}
fn find_segments(&self, text: &str) -> Vec<String> {
let mut segments = Vec::new();
if self.config.code_aware {
let mut in_code = false;
let mut code_buf = String::new();
let mut text_buf = String::new();
for line in text.lines() {
if line.trim().starts_with("```") {
if in_code {
code_buf.push_str(line);
code_buf.push('\n');
segments.push(code_buf.clone());
code_buf.clear();
in_code = false;
} else {
if !text_buf.is_empty() {
segments.extend(self.split_text_segments(&text_buf));
text_buf.clear();
}
code_buf.push_str(line);
code_buf.push('\n');
in_code = true;
}
} else if in_code {
code_buf.push_str(line);
code_buf.push('\n');
} else {
text_buf.push_str(line);
text_buf.push('\n');
}
}
if !text_buf.is_empty() {
segments.extend(self.split_text_segments(&text_buf));
}
if !code_buf.is_empty() {
segments.push(code_buf);
}
} else {
segments.extend(self.split_text_segments(text));
}
if segments.is_empty() {
segments.push(text.to_string());
}
segments
}
fn split_text_segments(&self, text: &str) -> Vec<String> {
if self.config.respect_headings {
let mut segments = Vec::new();
let mut current = String::new();
for line in text.lines() {
if line.starts_with('#') && !current.is_empty() {
segments.push(current.trim().to_string());
current.clear();
}
current.push_str(line);
current.push('\n');
}
if !current.is_empty() {
segments.push(current.trim().to_string());
}
return segments;
}
if self.config.respect_paragraphs {
return text.split("\n\n")
.filter(|s| !s.trim().is_empty())
.map(|s| s.to_string())
.collect();
}
if self.config.respect_sentences {
return text.split_inclusive(". ")
.filter(|s| s.trim().len() >= 10)
.map(|s| s.to_string())
.collect();
}
vec![text.to_string()]
}
fn detect_type(&self, text: &str) -> ChunkType {
let trimmed = text.trim();
if trimmed.starts_with("```") || trimmed.contains("fn ") || trimmed.contains("def ")
|| trimmed.contains("function ") || trimmed.contains("class ") {
return ChunkType::Code;
}
if trimmed.starts_with('#') { return ChunkType::Heading; }
if trimmed.lines().all(|l| l.trim().starts_with("- ") || l.trim().starts_with("* ")
|| l.trim().starts_with("• ")) { return ChunkType::List; }
if trimmed.contains('|') && trimmed.lines().filter(|l| l.contains('|')).count() >= 2 {
return ChunkType::Table;
}
if trimmed.lines().count() <= 2 { return ChunkType::Paragraph; }
ChunkType::Text
}
fn estimate_tokens(text: &str, chars_per_token: f64) -> usize {
(text.len() as f64 / chars_per_token).ceil() as usize
}
pub fn stats(&self, text: &str, chunks: &[Chunk]) -> SplitStats {
let input_tokens = Self::estimate_tokens(text, self.config.chars_per_token);
let chunk_tokens: Vec<usize> = chunks.iter().map(|c| c.token_estimate).collect();
let chunk_chars: Vec<usize> = chunks.iter().map(|c| c.text.len()).collect();
SplitStats {
input_chars: text.len(), input_tokens_est: input_tokens,
chunks: chunks.len(),
avg_chunk_tokens: if chunk_tokens.is_empty() { 0.0 } else { chunk_tokens.iter().sum::<usize>() as f64 / chunks.len() as f64 },
avg_chunk_chars: if chunk_chars.is_empty() { 0.0 } else { chunk_chars.iter().sum::<usize>() as f64 / chunks.len() as f64 },
max_chunk_chars: chunk_chars.iter().cloned().max().unwrap_or(0),
min_chunk_chars: chunk_chars.iter().cloned().min().unwrap_or(0),
}
}
}
use std::collections::HashMap;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_split() {
let splitter = TileSplit::new(SplitConfig::default());
let text = "Hello world. This is a test. Another sentence here.";
let chunks = splitter.split(text);
assert!(!chunks.is_empty());
}
#[test]
fn test_code_aware() {
let mut config = SplitConfig::default();
config.max_tokens = 50;
config.code_aware = true;
config.chars_per_token = 1.0;
config.min_chunk_size = 5;
let splitter = TileSplit::new(config);
let text = "Some text.\n\n```python\ndef foo():\n return 42\n```\n\nMore text.";
let chunks = splitter.split(text);
assert!(chunks.len() >= 2);
}
#[test]
fn test_split_n() {
let splitter = TileSplit::new(SplitConfig::default());
let text = "One. Two. Three. Four. Five. Six. Seven. Eight.";
let chunks = splitter.split_n(text, 3);
assert_eq!(chunks.len(), 3);
}
#[test]
fn test_split_code() {
let mut config = SplitConfig::default();
config.min_chunk_size = 10;
let splitter = TileSplit::new(config);
let code = "fn add(a: i32, b: i32) -> i32 {\n a + b\n}\n\nfn mul(a: i32, b: i32) -> i32 {\n a * b\n}";
let chunks = splitter.split_code(code);
assert!(chunks.len() >= 2);
assert!(chunks.iter().all(|c| c.chunk_type == ChunkType::Code));
}
#[test]
fn test_stats() {
let splitter = TileSplit::new(SplitConfig::default());
let text = "Hello world. ".repeat(100);
let chunks = splitter.split(&text);
let stats = splitter.stats(&text, &chunks);
assert!(stats.chunks >= 1);
assert!(stats.avg_chunk_chars > 0.0);
}
#[test]
fn test_empty() {
let splitter = TileSplit::new(SplitConfig::default());
assert!(splitter.split("").is_empty());
}
#[test]
fn test_small_text_no_split() {
let splitter = TileSplit::new(SplitConfig::default());
let chunks = splitter.split("Short text.");
assert_eq!(chunks.len(), 1);
}
}