1use serde::{Deserialize, Serialize};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Chunk {
23 pub text: String,
24 pub start_char: usize,
25 pub end_char: usize,
26 pub token_estimate: usize,
27 pub chunk_type: ChunkType,
28 pub metadata: HashMap<String, String>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32pub enum ChunkType {
33 Text,
34 Code,
35 Heading,
36 List,
37 Table,
38 Paragraph,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct SplitConfig {
44 pub max_tokens: usize, pub overlap_tokens: usize, pub min_chunk_size: usize, pub respect_sentences: bool, pub respect_paragraphs: bool,
49 pub respect_headings: bool,
50 pub code_aware: bool, pub chars_per_token: f64, }
53
54impl Default for SplitConfig {
55 fn default() -> Self {
56 Self { max_tokens: 256, overlap_tokens: 32, min_chunk_size: 50,
57 respect_sentences: true, respect_paragraphs: true,
58 respect_headings: true, code_aware: true, chars_per_token: 4.0 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SplitStats {
65 pub input_chars: usize,
66 pub input_tokens_est: usize,
67 pub chunks: usize,
68 pub avg_chunk_tokens: f64,
69 pub avg_chunk_chars: f64,
70 pub max_chunk_chars: usize,
71 pub min_chunk_chars: usize,
72}
73
74pub struct TileSplit {
76 config: SplitConfig,
77}
78
79impl TileSplit {
80 pub fn new(config: SplitConfig) -> Self {
81 Self { config }
82 }
83
84 pub fn split(&self, text: &str) -> Vec<Chunk> {
86 if text.is_empty() { return Vec::new(); }
87
88 let max_chars = (self.config.max_tokens as f64 * self.config.chars_per_token) as usize;
89 let overlap_chars = (self.config.overlap_tokens as f64 * self.config.chars_per_token) as usize;
90
91 if text.len() <= max_chars {
92 return vec![Chunk {
93 text: text.to_string(), start_char: 0, end_char: text.len(),
94 token_estimate: Self::estimate_tokens(text, self.config.chars_per_token),
95 chunk_type: self.detect_type(text), metadata: HashMap::new()
96 }];
97 }
98
99 let segments = self.find_segments(text);
101 let mut chunks = Vec::new();
102 let mut buffer = String::new();
103 let mut buffer_start = 0;
104 let mut buffer_tokens = 0;
105
106 for segment in &segments {
107 let seg_tokens = Self::estimate_tokens(segment, self.config.chars_per_token);
108
109 if buffer_tokens + seg_tokens > self.config.max_tokens && !buffer.is_empty() {
110 let chunk_text = buffer.trim().to_string();
112 if chunk_text.len() >= self.config.min_chunk_size {
113 chunks.push(Chunk {
114 text: chunk_text.clone(), start_char: buffer_start,
115 end_char: buffer_start + buffer.len(),
116 token_estimate: buffer_tokens,
117 chunk_type: self.detect_type(&chunk_text),
118 metadata: HashMap::new()
119 });
120 }
121 let overlap_start = if buffer.len() > overlap_chars {
123 buffer.len() - overlap_chars
124 } else { 0 };
125 buffer = buffer[overlap_start..].to_string();
126 buffer_start = buffer_start + overlap_start;
127 buffer_tokens = Self::estimate_tokens(&buffer, self.config.chars_per_token);
128 }
129
130 buffer.push_str(segment);
131 buffer.push('\n');
132 buffer_tokens += seg_tokens;
133 }
134
135 let chunk_text = buffer.trim().to_string();
137 if chunk_text.len() >= self.config.min_chunk_size {
138 chunks.push(Chunk {
139 text: chunk_text.clone(), start_char: buffer_start,
140 end_char: buffer_start + buffer.len(),
141 token_estimate: buffer_tokens,
142 chunk_type: self.detect_type(&chunk_text),
143 metadata: HashMap::new()
144 });
145 }
146
147 chunks
148 }
149
150 pub fn split_n(&self, text: &str, n: usize) -> Vec<Chunk> {
152 if n <= 1 { return self.split(text); }
153 let chunk_size = text.len() / n;
154 let mut chunks = Vec::new();
155 let mut pos = 0;
156 for i in 0..n {
157 let end = if i == n - 1 { text.len() } else {
158 let mut boundary = pos + chunk_size;
159 if self.config.respect_sentences {
161 if let Some(idx) = text[pos..].find(". ") {
162 let candidate = pos + idx + 2;
163 if candidate <= pos + chunk_size + 50 {
164 boundary = candidate;
165 }
166 }
167 }
168 boundary.min(text.len())
169 };
170 let chunk_text = text[pos..end].trim().to_string();
171 chunks.push(Chunk {
172 text: chunk_text.clone(), start_char: pos, end_char: end,
173 token_estimate: Self::estimate_tokens(&chunk_text, self.config.chars_per_token),
174 chunk_type: self.detect_type(&chunk_text), metadata: HashMap::new()
175 });
176 pos = end;
177 }
178 chunks
179 }
180
181 pub fn split_by(&self, text: &str, delimiter: &str) -> Vec<Chunk> {
183 let mut chunks = Vec::new();
184 let mut pos = 0;
185 for part in text.split(delimiter) {
186 let trimmed = part.trim();
187 if trimmed.len() >= self.config.min_chunk_size {
188 let end = pos + part.len();
189 chunks.push(Chunk {
190 text: trimmed.to_string(), start_char: pos, end_char: end,
191 token_estimate: Self::estimate_tokens(trimmed, self.config.chars_per_token),
192 chunk_type: self.detect_type(trimmed), metadata: HashMap::from([("delimiter".into(), delimiter.to_string())])
193 });
194 }
195 pos += part.len() + delimiter.len();
196 }
197 chunks
198 }
199
200 pub fn split_code(&self, code: &str) -> Vec<Chunk> {
202 let mut chunks = Vec::new();
203 let mut pos = 0;
204 let mut brace_depth: usize = 0;
205 let mut block_start = 0;
206 let mut in_string = false;
207 let mut string_char = ' ';
208
209 for (i, c) in code.char_indices() {
210 match c {
211 '"' | '\'' if !in_string => { in_string = true; string_char = c; }
212 c if in_string && c == string_char => { in_string = false; }
213 '{' if !in_string => {
214 if brace_depth == 0 { block_start = pos; }
215 brace_depth += 1;
216 }
217 '}' if !in_string => {
218 brace_depth = brace_depth.saturating_sub(1);
219 if brace_depth == 0 && i > block_start {
220 let block = code[block_start..=i].trim().to_string();
221 if block.len() >= self.config.min_chunk_size {
222 chunks.push(Chunk {
223 text: block.clone(), start_char: block_start, end_char: i + 1,
224 token_estimate: Self::estimate_tokens(&block, self.config.chars_per_token),
225 chunk_type: ChunkType::Code, metadata: HashMap::new()
226 });
227 }
228 pos = i + 1;
229 }
230 }
231 '\n' if !in_string && brace_depth == 0 => {
232 let line = code[pos..i].trim();
233 if line.len() >= self.config.min_chunk_size {
234 chunks.push(Chunk {
235 text: line.to_string(), start_char: pos, end_char: i,
236 token_estimate: Self::estimate_tokens(line, self.config.chars_per_token),
237 chunk_type: ChunkType::Code, metadata: HashMap::new()
238 });
239 }
240 pos = i + 1;
241 }
242 _ => {}
243 }
244 }
245 chunks
246 }
247
248 fn find_segments(&self, text: &str) -> Vec<String> {
249 let mut segments = Vec::new();
250
251 if self.config.code_aware {
252 let mut in_code = false;
254 let mut code_buf = String::new();
255 let mut text_buf = String::new();
256
257 for line in text.lines() {
258 if line.trim().starts_with("```") {
259 if in_code {
260 code_buf.push_str(line);
261 code_buf.push('\n');
262 segments.push(code_buf.clone());
263 code_buf.clear();
264 in_code = false;
265 } else {
266 if !text_buf.is_empty() {
267 segments.extend(self.split_text_segments(&text_buf));
269 text_buf.clear();
270 }
271 code_buf.push_str(line);
272 code_buf.push('\n');
273 in_code = true;
274 }
275 } else if in_code {
276 code_buf.push_str(line);
277 code_buf.push('\n');
278 } else {
279 text_buf.push_str(line);
280 text_buf.push('\n');
281 }
282 }
283 if !text_buf.is_empty() {
284 segments.extend(self.split_text_segments(&text_buf));
285 }
286 if !code_buf.is_empty() {
287 segments.push(code_buf);
288 }
289 } else {
290 segments.extend(self.split_text_segments(text));
291 }
292
293 if segments.is_empty() {
294 segments.push(text.to_string());
295 }
296 segments
297 }
298
299 fn split_text_segments(&self, text: &str) -> Vec<String> {
300 if self.config.respect_headings {
301 let mut segments = Vec::new();
302 let mut current = String::new();
303 for line in text.lines() {
304 if line.starts_with('#') && !current.is_empty() {
305 segments.push(current.trim().to_string());
306 current.clear();
307 }
308 current.push_str(line);
309 current.push('\n');
310 }
311 if !current.is_empty() {
312 segments.push(current.trim().to_string());
313 }
314 return segments;
315 }
316
317 if self.config.respect_paragraphs {
318 return text.split("\n\n")
319 .filter(|s| !s.trim().is_empty())
320 .map(|s| s.to_string())
321 .collect();
322 }
323
324 if self.config.respect_sentences {
325 return text.split_inclusive(". ")
326 .filter(|s| s.trim().len() >= 10)
327 .map(|s| s.to_string())
328 .collect();
329 }
330
331 vec![text.to_string()]
332 }
333
334 fn detect_type(&self, text: &str) -> ChunkType {
335 let trimmed = text.trim();
336 if trimmed.starts_with("```") || trimmed.contains("fn ") || trimmed.contains("def ")
337 || trimmed.contains("function ") || trimmed.contains("class ") {
338 return ChunkType::Code;
339 }
340 if trimmed.starts_with('#') { return ChunkType::Heading; }
341 if trimmed.lines().all(|l| l.trim().starts_with("- ") || l.trim().starts_with("* ")
342 || l.trim().starts_with("• ")) { return ChunkType::List; }
343 if trimmed.contains('|') && trimmed.lines().filter(|l| l.contains('|')).count() >= 2 {
344 return ChunkType::Table;
345 }
346 if trimmed.lines().count() <= 2 { return ChunkType::Paragraph; }
347 ChunkType::Text
348 }
349
350 fn estimate_tokens(text: &str, chars_per_token: f64) -> usize {
351 (text.len() as f64 / chars_per_token).ceil() as usize
352 }
353
354 pub fn stats(&self, text: &str, chunks: &[Chunk]) -> SplitStats {
356 let input_tokens = Self::estimate_tokens(text, self.config.chars_per_token);
357 let chunk_tokens: Vec<usize> = chunks.iter().map(|c| c.token_estimate).collect();
358 let chunk_chars: Vec<usize> = chunks.iter().map(|c| c.text.len()).collect();
359 SplitStats {
360 input_chars: text.len(), input_tokens_est: input_tokens,
361 chunks: chunks.len(),
362 avg_chunk_tokens: if chunk_tokens.is_empty() { 0.0 } else { chunk_tokens.iter().sum::<usize>() as f64 / chunks.len() as f64 },
363 avg_chunk_chars: if chunk_chars.is_empty() { 0.0 } else { chunk_chars.iter().sum::<usize>() as f64 / chunks.len() as f64 },
364 max_chunk_chars: chunk_chars.iter().cloned().max().unwrap_or(0),
365 min_chunk_chars: chunk_chars.iter().cloned().min().unwrap_or(0),
366 }
367 }
368}
369
370use std::collections::HashMap;
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn test_basic_split() {
378 let splitter = TileSplit::new(SplitConfig::default());
379 let text = "Hello world. This is a test. Another sentence here.";
380 let chunks = splitter.split(text);
381 assert!(!chunks.is_empty());
382 }
383
384 #[test]
385 fn test_code_aware() {
386 let mut config = SplitConfig::default();
387 config.max_tokens = 50;
388 config.code_aware = true;
389 config.chars_per_token = 1.0;
390 config.min_chunk_size = 5;
391 let splitter = TileSplit::new(config);
392 let text = "Some text.\n\n```python\ndef foo():\n return 42\n```\n\nMore text.";
393 let chunks = splitter.split(text);
394 assert!(chunks.len() >= 2);
395 }
396
397 #[test]
398 fn test_split_n() {
399 let splitter = TileSplit::new(SplitConfig::default());
400 let text = "One. Two. Three. Four. Five. Six. Seven. Eight.";
401 let chunks = splitter.split_n(text, 3);
402 assert_eq!(chunks.len(), 3);
403 }
404
405 #[test]
406 fn test_split_code() {
407 let mut config = SplitConfig::default();
408 config.min_chunk_size = 10;
409 let splitter = TileSplit::new(config);
410 let code = "fn add(a: i32, b: i32) -> i32 {\n a + b\n}\n\nfn mul(a: i32, b: i32) -> i32 {\n a * b\n}";
411 let chunks = splitter.split_code(code);
412 assert!(chunks.len() >= 2);
413 assert!(chunks.iter().all(|c| c.chunk_type == ChunkType::Code));
414 }
415
416 #[test]
417 fn test_stats() {
418 let splitter = TileSplit::new(SplitConfig::default());
419 let text = "Hello world. ".repeat(100);
420 let chunks = splitter.split(&text);
421 let stats = splitter.stats(&text, &chunks);
422 assert!(stats.chunks >= 1);
423 assert!(stats.avg_chunk_chars > 0.0);
424 }
425
426 #[test]
427 fn test_empty() {
428 let splitter = TileSplit::new(SplitConfig::default());
429 assert!(splitter.split("").is_empty());
430 }
431
432 #[test]
433 fn test_small_text_no_split() {
434 let splitter = TileSplit::new(SplitConfig::default());
435 let chunks = splitter.split("Short text.");
436 assert_eq!(chunks.len(), 1);
437 }
438}