1use anyhow::Result;
2use ck_core::Span;
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct StrideInfo {
8 pub original_chunk_id: String,
10 pub stride_index: usize,
12 pub total_strides: usize,
14 pub overlap_start: usize,
16 pub overlap_end: usize,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Chunk {
22 pub span: Span,
23 pub text: String,
24 pub chunk_type: ChunkType,
25 pub stride_info: Option<StrideInfo>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30pub enum ChunkType {
31 Text,
32 Function,
33 Class,
34 Method,
35 Module,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum ParseableLanguage {
40 Python,
41 TypeScript,
42 JavaScript,
43 Haskell,
44 Rust,
45 Ruby,
46 Go,
47}
48
49impl std::fmt::Display for ParseableLanguage {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 let name = match self {
52 ParseableLanguage::Python => "python",
53 ParseableLanguage::TypeScript => "typescript",
54 ParseableLanguage::JavaScript => "javascript",
55 ParseableLanguage::Haskell => "haskell",
56 ParseableLanguage::Rust => "rust",
57 ParseableLanguage::Ruby => "ruby",
58 ParseableLanguage::Go => "go",
59 };
60 write!(f, "{}", name)
61 }
62}
63
64impl TryFrom<ck_core::Language> for ParseableLanguage {
65 type Error = anyhow::Error;
66
67 fn try_from(lang: ck_core::Language) -> Result<Self, Self::Error> {
68 match lang {
69 ck_core::Language::Python => Ok(ParseableLanguage::Python),
70 ck_core::Language::TypeScript => Ok(ParseableLanguage::TypeScript),
71 ck_core::Language::JavaScript => Ok(ParseableLanguage::JavaScript),
72 ck_core::Language::Haskell => Ok(ParseableLanguage::Haskell),
73 ck_core::Language::Rust => Ok(ParseableLanguage::Rust),
74 ck_core::Language::Ruby => Ok(ParseableLanguage::Ruby),
75 ck_core::Language::Go => Ok(ParseableLanguage::Go),
76 _ => Err(anyhow::anyhow!(
77 "Language {:?} is not supported for parsing",
78 lang
79 )),
80 }
81 }
82}
83
84pub fn chunk_text(text: &str, language: Option<ck_core::Language>) -> Result<Vec<Chunk>> {
85 chunk_text_with_config(text, language, &ChunkConfig::default())
86}
87
88#[derive(Debug, Clone)]
90pub struct ChunkConfig {
91 pub max_tokens: usize,
93 pub stride_overlap: usize,
95 pub enable_striding: bool,
97}
98
99impl Default for ChunkConfig {
100 fn default() -> Self {
101 Self {
102 max_tokens: 8192, stride_overlap: 1024, enable_striding: true,
105 }
106 }
107}
108
109pub fn chunk_text_with_config(
110 text: &str,
111 language: Option<ck_core::Language>,
112 config: &ChunkConfig,
113) -> Result<Vec<Chunk>> {
114 tracing::debug!(
115 "Chunking text with language: {:?}, length: {} chars, config: {:?}",
116 language,
117 text.len(),
118 config
119 );
120
121 let result = match language.map(ParseableLanguage::try_from) {
122 Some(Ok(lang)) => {
123 tracing::debug!("Using {} tree-sitter parser", lang);
124 chunk_language(text, lang)
125 }
126 Some(Err(_)) => {
127 tracing::debug!("Language not supported for parsing, using generic chunking strategy");
128 chunk_generic(text)
129 }
130 None => {
131 tracing::debug!("Using generic chunking strategy");
132 chunk_generic(text)
133 }
134 };
135
136 let mut chunks = result?;
137
138 if config.enable_striding {
140 chunks = apply_striding(chunks, config)?;
141 }
142
143 tracing::debug!("Successfully created {} final chunks", chunks.len());
144 Ok(chunks)
145}
146
147fn chunk_generic(text: &str) -> Result<Vec<Chunk>> {
148 let mut chunks = Vec::new();
149 let lines: Vec<&str> = text.lines().collect();
150 let chunk_size = 20;
151 let overlap = 5;
152
153 let mut line_byte_offsets = Vec::with_capacity(lines.len() + 1);
155 line_byte_offsets.push(0);
156 let mut cumulative_offset = 0;
157 for line in &lines {
158 cumulative_offset += line.len() + 1; line_byte_offsets.push(cumulative_offset);
160 }
161
162 let mut i = 0;
163 while i < lines.len() {
164 let end = (i + chunk_size).min(lines.len());
165 let chunk_lines = &lines[i..end];
166 let chunk_text = chunk_lines.join("\n");
167
168 let byte_start = line_byte_offsets[i];
169 let byte_end = byte_start + chunk_text.len();
170
171 chunks.push(Chunk {
172 span: Span {
173 byte_start,
174 byte_end,
175 line_start: i + 1,
176 line_end: end,
177 },
178 text: chunk_text,
179 chunk_type: ChunkType::Text,
180 stride_info: None,
181 });
182
183 i += chunk_size - overlap;
184 if i >= lines.len() {
185 break;
186 }
187 }
188
189 Ok(chunks)
190}
191
192fn chunk_language(text: &str, language: ParseableLanguage) -> Result<Vec<Chunk>> {
193 let mut parser = tree_sitter::Parser::new();
194
195 match language {
196 ParseableLanguage::Python => parser.set_language(&tree_sitter_python::language())?,
197 ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => {
198 parser.set_language(&tree_sitter_typescript::language_typescript())?
199 }
200 ParseableLanguage::Haskell => parser.set_language(&tree_sitter_haskell::language())?,
201 ParseableLanguage::Rust => parser.set_language(&tree_sitter_rust::language())?,
202 ParseableLanguage::Ruby => parser.set_language(&tree_sitter_ruby::language())?,
203 ParseableLanguage::Go => parser.set_language(&tree_sitter_go::language())?,
204 }
205
206 let tree = parser
207 .parse(text, None)
208 .ok_or_else(|| anyhow::anyhow!("Failed to parse {} code", language))?;
209
210 let mut chunks = Vec::new();
211 let mut cursor = tree.root_node().walk();
212
213 extract_code_chunks(&mut cursor, text, &mut chunks, language);
214
215 if chunks.is_empty() {
216 return chunk_generic(text);
217 }
218
219 Ok(chunks)
220}
221
222fn extract_code_chunks(
223 cursor: &mut tree_sitter::TreeCursor,
224 source: &str,
225 chunks: &mut Vec<Chunk>,
226 language: ParseableLanguage,
227) {
228 let node = cursor.node();
229 let node_kind = node.kind();
230
231 let is_chunk = match language {
232 ParseableLanguage::Python => {
233 matches!(node_kind, "function_definition" | "class_definition")
234 }
235 ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => matches!(
236 node_kind,
237 "function_declaration" | "class_declaration" | "method_definition" | "arrow_function"
238 ),
239 ParseableLanguage::Haskell => matches!(
240 node_kind,
241 "signature"
242 | "data_type"
243 | "newtype"
244 | "type_synomym"
245 | "type_family"
246 | "class"
247 | "instance"
248 ),
249 ParseableLanguage::Rust => matches!(
250 node_kind,
251 "function_item" | "impl_item" | "struct_item" | "enum_item" | "trait_item" | "mod_item"
252 ),
253 ParseableLanguage::Ruby => matches!(
254 node_kind,
255 "method" | "class" | "module" | "singleton_method"
256 ),
257 ParseableLanguage::Go => matches!(
258 node_kind,
259 "function_declaration"
260 | "method_declaration"
261 | "type_declaration"
262 | "var_declaration"
263 | "const_declaration"
264 ),
265 };
266
267 if is_chunk {
268 let start_byte = node.start_byte();
269 let end_byte = node.end_byte();
270 let start_pos = node.start_position();
271 let end_pos = node.end_position();
272
273 let text = &source[start_byte..end_byte];
274
275 let chunk_type = match node_kind {
276 "function_definition"
277 | "function_declaration"
278 | "arrow_function"
279 | "function"
280 | "signature"
281 | "function_item"
282 | "def"
283 | "defp"
284 | "method"
285 | "singleton_method"
286 | "defn"
287 | "defn-" => ChunkType::Function,
288 "class_definition"
289 | "class_declaration"
290 | "instance_declaration"
291 | "class"
292 | "instance"
293 | "struct_item"
294 | "enum_item"
295 | "defstruct"
296 | "defrecord"
297 | "deftype"
298 | "type_declaration" => ChunkType::Class,
299 "method_definition" | "method_declaration" | "defmacro" => ChunkType::Method,
300 "data_type" | "newtype" | "type_synomym" | "type_family" | "impl_item"
301 | "trait_item" | "mod_item" | "defmodule" | "module" | "defprotocol" | "ns"
302 | "var_declaration" | "const_declaration" => ChunkType::Module,
303 _ => ChunkType::Text,
304 };
305
306 chunks.push(Chunk {
307 span: Span {
308 byte_start: start_byte,
309 byte_end: end_byte,
310 line_start: start_pos.row + 1,
311 line_end: end_pos.row + 1,
312 },
313 text: text.to_string(),
314 chunk_type,
315 stride_info: None,
316 });
317 }
318
319 if cursor.goto_first_child() {
320 loop {
321 extract_code_chunks(cursor, source, chunks, language);
322 if !cursor.goto_next_sibling() {
323 break;
324 }
325 }
326 cursor.goto_parent();
327 }
328}
329
330fn apply_striding(chunks: Vec<Chunk>, config: &ChunkConfig) -> Result<Vec<Chunk>> {
332 let mut result = Vec::new();
333
334 for chunk in chunks {
335 let estimated_tokens = estimate_tokens(&chunk.text);
336
337 if estimated_tokens <= config.max_tokens {
338 result.push(chunk);
340 } else {
341 tracing::debug!(
343 "Chunk with {} tokens exceeds limit of {}, applying striding",
344 estimated_tokens,
345 config.max_tokens
346 );
347
348 let strided_chunks = stride_large_chunk(chunk, config)?;
349 result.extend(strided_chunks);
350 }
351 }
352
353 Ok(result)
354}
355
356fn stride_large_chunk(chunk: Chunk, config: &ChunkConfig) -> Result<Vec<Chunk>> {
358 let text = &chunk.text;
359 let text_len = text.len();
360
361 let chars_per_token = text_len as f32 / estimate_tokens(text) as f32;
364 let window_chars = ((config.max_tokens as f32 * 0.9) * chars_per_token) as usize; let overlap_chars = (config.stride_overlap as f32 * chars_per_token) as usize;
366 let stride_chars = window_chars.saturating_sub(overlap_chars);
367
368 if stride_chars == 0 {
369 return Err(anyhow::anyhow!("Stride size is too small"));
370 }
371
372 let mut strided_chunks = Vec::new();
373 let original_chunk_id = format!("{}:{}", chunk.span.byte_start, chunk.span.byte_end);
374 let mut start_pos = 0;
375 let mut stride_index = 0;
376
377 let total_strides = if text_len <= window_chars {
379 1
380 } else {
381 ((text_len - overlap_chars) as f32 / stride_chars as f32).ceil() as usize
382 };
383
384 while start_pos < text_len {
385 let end_pos = (start_pos + window_chars).min(text_len);
386 let stride_text = &text[start_pos..end_pos];
387
388 let overlap_start = if stride_index > 0 { overlap_chars } else { 0 };
390 let overlap_end = if end_pos < text_len { overlap_chars } else { 0 };
391
392 let byte_offset_start = chunk.span.byte_start + start_pos;
394 let byte_offset_end = chunk.span.byte_start + end_pos;
395
396 let text_before_start = &text[..start_pos];
398 let line_offset_start = text_before_start.lines().count().saturating_sub(1);
399 let stride_lines = stride_text.lines().count();
400
401 let stride_chunk = Chunk {
402 span: Span {
403 byte_start: byte_offset_start,
404 byte_end: byte_offset_end,
405 line_start: chunk.span.line_start + line_offset_start,
406 line_end: chunk.span.line_start + line_offset_start + stride_lines,
407 },
408 text: stride_text.to_string(),
409 chunk_type: chunk.chunk_type.clone(),
410 stride_info: Some(StrideInfo {
411 original_chunk_id: original_chunk_id.clone(),
412 stride_index,
413 total_strides,
414 overlap_start,
415 overlap_end,
416 }),
417 };
418
419 strided_chunks.push(stride_chunk);
420
421 if end_pos >= text_len {
423 break;
424 }
425
426 start_pos += stride_chars;
427 stride_index += 1;
428 }
429
430 tracing::debug!(
431 "Created {} strides from chunk of {} tokens",
432 strided_chunks.len(),
433 estimate_tokens(text)
434 );
435
436 Ok(strided_chunks)
437}
438
439fn estimate_tokens(text: &str) -> usize {
441 if text.is_empty() {
442 return 0;
443 }
444
445 let char_count = text.chars().count();
447 (char_count as f32 / 4.5).ceil() as usize
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn test_chunk_generic_byte_offsets() {
456 let text = "line 1\nline 2\nline 3\nline 4\nline 5";
458 let chunks = chunk_generic(text).unwrap();
459
460 assert!(!chunks.is_empty());
461
462 assert_eq!(chunks[0].span.byte_start, 0);
464
465 for chunk in &chunks {
467 let expected_len = chunk.text.len();
468 let actual_len = chunk.span.byte_end - chunk.span.byte_start;
469 assert_eq!(actual_len, expected_len);
470 }
471 }
472
473 #[test]
474 fn test_chunk_generic_large_file_performance() {
475 let lines: Vec<String> = (0..1000)
477 .map(|i| format!("Line {}: Some content here", i))
478 .collect();
479 let text = lines.join("\n");
480
481 let start = std::time::Instant::now();
482 let chunks = chunk_generic(&text).unwrap();
483 let duration = start.elapsed();
484
485 assert!(
487 duration.as_millis() < 100,
488 "Chunking took too long: {:?}",
489 duration
490 );
491 assert!(!chunks.is_empty());
492
493 for chunk in &chunks {
495 assert!(chunk.span.line_start > 0);
496 assert!(chunk.span.line_end >= chunk.span.line_start);
497 }
498 }
499
500 #[test]
501 fn test_chunk_rust() {
502 let rust_code = r#"
503pub struct Calculator {
504 memory: f64,
505}
506
507impl Calculator {
508 pub fn new() -> Self {
509 Calculator { memory: 0.0 }
510 }
511
512 pub fn add(&mut self, a: f64, b: f64) -> f64 {
513 a + b
514 }
515}
516
517fn main() {
518 let calc = Calculator::new();
519}
520
521pub mod utils {
522 pub fn helper() {}
523}
524"#;
525
526 let chunks = chunk_language(rust_code, ParseableLanguage::Rust).unwrap();
527 assert!(!chunks.is_empty());
528
529 let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
531 assert!(chunk_types.contains(&&ChunkType::Class)); assert!(chunk_types.contains(&&ChunkType::Module)); assert!(chunk_types.contains(&&ChunkType::Function)); }
535
536 #[test]
537 fn test_chunk_ruby() {
538 let ruby_code = r#"
539class Calculator
540 def initialize
541 @memory = 0.0
542 end
543
544 def add(a, b)
545 a + b
546 end
547
548 def self.class_method
549 "class method"
550 end
551
552 private
553
554 def private_method
555 "private"
556 end
557end
558
559module Utils
560 def self.helper
561 "helper"
562 end
563end
564
565def main
566 calc = Calculator.new
567end
568"#;
569
570 let chunks = chunk_language(ruby_code, ParseableLanguage::Ruby).unwrap();
571 assert!(!chunks.is_empty());
572
573 let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
575 assert!(chunk_types.contains(&&ChunkType::Class)); assert!(chunk_types.contains(&&ChunkType::Module)); assert!(chunk_types.contains(&&ChunkType::Function)); }
579
580 #[test]
581 fn test_language_detection_fallback() {
582 let generic_text = "Some text\nwith multiple lines\nto chunk generically";
584
585 let chunks_unknown = chunk_text(generic_text, None).unwrap();
586 let chunks_generic = chunk_generic(generic_text).unwrap();
587
588 assert_eq!(chunks_unknown.len(), chunks_generic.len());
590 assert_eq!(chunks_unknown[0].text, chunks_generic[0].text);
591 }
592
593 #[test]
594 fn test_chunk_go() {
595 let go_code = r#"
596package main
597
598import "fmt"
599
600const Pi = 3.14159
601
602var memory float64
603
604type Calculator struct {
605 memory float64
606}
607
608type Operation interface {
609 Calculate(a, b float64) float64
610}
611
612func NewCalculator() *Calculator {
613 return &Calculator{memory: 0.0}
614}
615
616func (c *Calculator) Add(a, b float64) float64 {
617 return a + b
618}
619
620func main() {
621 calc := NewCalculator()
622}
623"#;
624
625 let chunks = chunk_language(go_code, ParseableLanguage::Go).unwrap();
626 assert!(!chunks.is_empty());
627
628 let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
630 assert!(chunk_types.contains(&&ChunkType::Module)); assert!(chunk_types.contains(&&ChunkType::Class)); assert!(chunk_types.contains(&&ChunkType::Function)); assert!(chunk_types.contains(&&ChunkType::Method)); }
635}