1use anyhow::Result;
2use ck_core::Span;
3use serde::{Deserialize, Serialize};
4
5pub use ck_embed::TokenEstimator;
7
8fn estimate_tokens(text: &str) -> usize {
10 TokenEstimator::estimate_tokens(text)
11}
12
13fn get_model_chunk_config(model_name: Option<&str>) -> (usize, usize) {
16 let model = model_name.unwrap_or("nomic-embed-text-v1.5");
17
18 match model {
19 "BAAI/bge-small-en-v1.5" | "sentence-transformers/all-MiniLM-L6-v2" => {
21 (400, 80) }
23
24 "nomic-embed-text-v1" | "nomic-embed-text-v1.5" | "jina-embeddings-v2-base-code" => {
27 (1024, 200) }
29
30 "BAAI/bge-base-en-v1.5" | "BAAI/bge-large-en-v1.5" => {
32 (400, 80) }
34
35 _ => (1024, 200), }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct StrideInfo {
43 pub original_chunk_id: String,
45 pub stride_index: usize,
47 pub total_strides: usize,
49 pub overlap_start: usize,
51 pub overlap_end: usize,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct Chunk {
57 pub span: Span,
58 pub text: String,
59 pub chunk_type: ChunkType,
60 pub stride_info: Option<StrideInfo>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65pub enum ChunkType {
66 Text,
67 Function,
68 Class,
69 Method,
70 Module,
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum ParseableLanguage {
75 Python,
76 TypeScript,
77 JavaScript,
78 Haskell,
79 Rust,
80 Ruby,
81 Go,
82 CSharp,
83}
84
85impl std::fmt::Display for ParseableLanguage {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 let name = match self {
88 ParseableLanguage::Python => "python",
89 ParseableLanguage::TypeScript => "typescript",
90 ParseableLanguage::JavaScript => "javascript",
91 ParseableLanguage::Haskell => "haskell",
92 ParseableLanguage::Rust => "rust",
93 ParseableLanguage::Ruby => "ruby",
94 ParseableLanguage::Go => "go",
95 ParseableLanguage::CSharp => "csharp",
96 };
97 write!(f, "{}", name)
98 }
99}
100
101impl TryFrom<ck_core::Language> for ParseableLanguage {
102 type Error = anyhow::Error;
103
104 fn try_from(lang: ck_core::Language) -> Result<Self, Self::Error> {
105 match lang {
106 ck_core::Language::Python => Ok(ParseableLanguage::Python),
107 ck_core::Language::TypeScript => Ok(ParseableLanguage::TypeScript),
108 ck_core::Language::JavaScript => Ok(ParseableLanguage::JavaScript),
109 ck_core::Language::Haskell => Ok(ParseableLanguage::Haskell),
110 ck_core::Language::Rust => Ok(ParseableLanguage::Rust),
111 ck_core::Language::Ruby => Ok(ParseableLanguage::Ruby),
112 ck_core::Language::Go => Ok(ParseableLanguage::Go),
113 ck_core::Language::CSharp => Ok(ParseableLanguage::CSharp),
114 _ => Err(anyhow::anyhow!(
115 "Language {:?} is not supported for parsing",
116 lang
117 )),
118 }
119 }
120}
121
122pub fn chunk_text(text: &str, language: Option<ck_core::Language>) -> Result<Vec<Chunk>> {
123 chunk_text_with_config(text, language, &ChunkConfig::default())
124}
125
126#[derive(Debug, Clone)]
128pub struct ChunkConfig {
129 pub max_tokens: usize,
131 pub stride_overlap: usize,
133 pub enable_striding: bool,
135}
136
137impl Default for ChunkConfig {
138 fn default() -> Self {
139 Self {
140 max_tokens: 8192, stride_overlap: 1024, enable_striding: true,
143 }
144 }
145}
146
147pub fn chunk_text_with_model(
149 text: &str,
150 language: Option<ck_core::Language>,
151 model_name: Option<&str>,
152) -> Result<Vec<Chunk>> {
153 let (target_tokens, overlap_tokens) = get_model_chunk_config(model_name);
154
155 let config = ChunkConfig {
157 max_tokens: target_tokens,
158 stride_overlap: overlap_tokens,
159 enable_striding: true,
160 };
161
162 chunk_text_with_config_and_model(text, language, &config, model_name)
163}
164
165pub fn chunk_text_with_config(
166 text: &str,
167 language: Option<ck_core::Language>,
168 config: &ChunkConfig,
169) -> Result<Vec<Chunk>> {
170 chunk_text_with_config_and_model(text, language, config, None)
171}
172
173fn chunk_text_with_config_and_model(
174 text: &str,
175 language: Option<ck_core::Language>,
176 config: &ChunkConfig,
177 model_name: Option<&str>,
178) -> Result<Vec<Chunk>> {
179 tracing::debug!(
180 "Chunking text with language: {:?}, length: {} chars, config: {:?}",
181 language,
182 text.len(),
183 config
184 );
185
186 let result = match language.map(ParseableLanguage::try_from) {
187 Some(Ok(lang)) => {
188 tracing::debug!("Using {} tree-sitter parser", lang);
189 chunk_language_with_model(text, lang, model_name)
190 }
191 Some(Err(_)) => {
192 tracing::debug!("Language not supported for parsing, using generic chunking strategy");
193 chunk_generic_with_token_config(text, model_name)
194 }
195 None => {
196 tracing::debug!("Using generic chunking strategy");
197 chunk_generic_with_token_config(text, model_name)
198 }
199 };
200
201 let mut chunks = result?;
202
203 if config.enable_striding {
205 chunks = apply_striding(chunks, config)?;
206 }
207
208 tracing::debug!("Successfully created {} final chunks", chunks.len());
209 Ok(chunks)
210}
211
212fn chunk_generic(text: &str) -> Result<Vec<Chunk>> {
213 chunk_generic_with_token_config(text, None)
214}
215
216fn chunk_generic_with_token_config(text: &str, model_name: Option<&str>) -> Result<Vec<Chunk>> {
217 let mut chunks = Vec::new();
218 let lines: Vec<&str> = text.lines().collect();
219
220 let (target_tokens, overlap_tokens) = get_model_chunk_config(model_name);
222
223 let avg_tokens_per_line = 10.0; let target_lines = ((target_tokens as f32) / avg_tokens_per_line) as usize;
227 let overlap_lines = ((overlap_tokens as f32) / avg_tokens_per_line) as usize;
228
229 let chunk_size = target_lines.max(5); let overlap = overlap_lines.max(1); let mut line_byte_offsets = Vec::with_capacity(lines.len() + 1);
234 line_byte_offsets.push(0);
235 let mut cumulative_offset = 0;
236 for line in &lines {
237 cumulative_offset += line.len() + 1; line_byte_offsets.push(cumulative_offset);
239 }
240
241 let mut i = 0;
242 while i < lines.len() {
243 let end = (i + chunk_size).min(lines.len());
244 let chunk_lines = &lines[i..end];
245 let chunk_text = chunk_lines.join("\n");
246
247 let byte_start = line_byte_offsets[i];
248 let byte_end = byte_start + chunk_text.len();
249
250 chunks.push(Chunk {
251 span: Span {
252 byte_start,
253 byte_end,
254 line_start: i + 1,
255 line_end: end,
256 },
257 text: chunk_text,
258 chunk_type: ChunkType::Text,
259 stride_info: None,
260 });
261
262 i += chunk_size - overlap;
263 if i >= lines.len() {
264 break;
265 }
266 }
267
268 Ok(chunks)
269}
270
271fn chunk_language(text: &str, language: ParseableLanguage) -> Result<Vec<Chunk>> {
272 let mut parser = tree_sitter::Parser::new();
273
274 match language {
275 ParseableLanguage::Python => parser.set_language(&tree_sitter_python::language())?,
276 ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => {
277 parser.set_language(&tree_sitter_typescript::language_typescript())?
278 }
279 ParseableLanguage::Haskell => parser.set_language(&tree_sitter_haskell::language())?,
280 ParseableLanguage::Rust => parser.set_language(&tree_sitter_rust::language())?,
281 ParseableLanguage::Ruby => parser.set_language(&tree_sitter_ruby::language())?,
282 ParseableLanguage::Go => parser.set_language(&tree_sitter_go::language())?,
283 ParseableLanguage::CSharp => parser.set_language(&tree_sitter_c_sharp::language())?,
284 }
285
286 let tree = parser
287 .parse(text, None)
288 .ok_or_else(|| anyhow::anyhow!("Failed to parse {} code", language))?;
289
290 let mut chunks = Vec::new();
291 let mut cursor = tree.root_node().walk();
292
293 extract_code_chunks(&mut cursor, text, &mut chunks, language);
294
295 if chunks.is_empty() {
296 return chunk_generic(text);
297 }
298
299 Ok(chunks)
300}
301
302fn chunk_language_with_model(
303 text: &str,
304 language: ParseableLanguage,
305 _model_name: Option<&str>,
306) -> Result<Vec<Chunk>> {
307 chunk_language(text, language)
311}
312
313fn extract_code_chunks(
314 cursor: &mut tree_sitter::TreeCursor,
315 source: &str,
316 chunks: &mut Vec<Chunk>,
317 language: ParseableLanguage,
318) {
319 let node = cursor.node();
320 let node_kind = node.kind();
321
322 let is_chunk = match language {
323 ParseableLanguage::Python => {
324 matches!(node_kind, "function_definition" | "class_definition")
325 }
326 ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => matches!(
327 node_kind,
328 "function_declaration" | "class_declaration" | "method_definition" | "arrow_function"
329 ),
330 ParseableLanguage::Haskell => matches!(
331 node_kind,
332 "signature"
333 | "data_type"
334 | "newtype"
335 | "type_synomym"
336 | "type_family"
337 | "class"
338 | "instance"
339 ),
340 ParseableLanguage::Rust => matches!(
341 node_kind,
342 "function_item" | "impl_item" | "struct_item" | "enum_item" | "trait_item" | "mod_item"
343 ),
344 ParseableLanguage::Ruby => matches!(
345 node_kind,
346 "method" | "class" | "module" | "singleton_method"
347 ),
348 ParseableLanguage::Go => matches!(
349 node_kind,
350 "function_declaration"
351 | "method_declaration"
352 | "type_declaration"
353 | "var_declaration"
354 | "const_declaration"
355 ),
356 ParseableLanguage::CSharp => matches!(
357 node_kind,
358 "method_declaration"
359 | "class_declaration"
360 | "interface_declaration"
361 | "variable_declaration"
362 ),
363 };
364
365 if is_chunk {
366 let start_byte = node.start_byte();
367 let end_byte = node.end_byte();
368 let start_pos = node.start_position();
369 let end_pos = node.end_position();
370
371 let text = &source[start_byte..end_byte];
372
373 let chunk_type = match node_kind {
374 "function_definition"
375 | "function_declaration"
376 | "arrow_function"
377 | "function"
378 | "signature"
379 | "function_item"
380 | "def"
381 | "defp"
382 | "method"
383 | "singleton_method"
384 | "defn"
385 | "defn-" => ChunkType::Function,
386 "class_definition"
387 | "class_declaration"
388 | "instance_declaration"
389 | "class"
390 | "instance"
391 | "struct_item"
392 | "enum_item"
393 | "defstruct"
394 | "defrecord"
395 | "deftype"
396 | "type_declaration" => ChunkType::Class,
397 "method_definition" | "method_declaration" | "defmacro" => ChunkType::Method,
398 "data_type"
399 | "newtype"
400 | "type_synomym"
401 | "type_family"
402 | "impl_item"
403 | "trait_item"
404 | "mod_item"
405 | "defmodule"
406 | "module"
407 | "defprotocol"
408 | "interface_declaration"
409 | "ns"
410 | "var_declaration"
411 | "const_declaration"
412 | "variable_declaration" => ChunkType::Module,
413 _ => ChunkType::Text,
414 };
415
416 chunks.push(Chunk {
417 span: Span {
418 byte_start: start_byte,
419 byte_end: end_byte,
420 line_start: start_pos.row + 1,
421 line_end: end_pos.row + 1,
422 },
423 text: text.to_string(),
424 chunk_type,
425 stride_info: None,
426 });
427 }
428
429 if cursor.goto_first_child() {
430 loop {
431 extract_code_chunks(cursor, source, chunks, language);
432 if !cursor.goto_next_sibling() {
433 break;
434 }
435 }
436 cursor.goto_parent();
437 }
438}
439
440fn apply_striding(chunks: Vec<Chunk>, config: &ChunkConfig) -> Result<Vec<Chunk>> {
442 let mut result = Vec::new();
443
444 for chunk in chunks {
445 let estimated_tokens = estimate_tokens(&chunk.text);
446
447 if estimated_tokens <= config.max_tokens {
448 result.push(chunk);
450 } else {
451 tracing::debug!(
453 "Chunk with {} tokens exceeds limit of {}, applying striding",
454 estimated_tokens,
455 config.max_tokens
456 );
457
458 let strided_chunks = stride_large_chunk(chunk, config)?;
459 result.extend(strided_chunks);
460 }
461 }
462
463 Ok(result)
464}
465
466fn stride_large_chunk(chunk: Chunk, config: &ChunkConfig) -> Result<Vec<Chunk>> {
468 let text = &chunk.text;
469 let text_len = text.len();
470
471 let chars_per_token = text_len as f32 / estimate_tokens(text) as f32;
474 let window_chars = ((config.max_tokens as f32 * 0.9) * chars_per_token) as usize; let overlap_chars = (config.stride_overlap as f32 * chars_per_token) as usize;
476 let stride_chars = window_chars.saturating_sub(overlap_chars);
477
478 if stride_chars == 0 {
479 return Err(anyhow::anyhow!("Stride size is too small"));
480 }
481
482 let mut strided_chunks = Vec::new();
483 let original_chunk_id = format!("{}:{}", chunk.span.byte_start, chunk.span.byte_end);
484 let mut start_pos = 0;
485 let mut stride_index = 0;
486
487 let total_strides = if text_len <= window_chars {
489 1
490 } else {
491 ((text_len - overlap_chars) as f32 / stride_chars as f32).ceil() as usize
492 };
493
494 while start_pos < text_len {
495 let end_pos = (start_pos + window_chars).min(text_len);
496 let stride_text = &text[start_pos..end_pos];
497
498 let overlap_start = if stride_index > 0 { overlap_chars } else { 0 };
500 let overlap_end = if end_pos < text_len { overlap_chars } else { 0 };
501
502 let byte_offset_start = chunk.span.byte_start + start_pos;
504 let byte_offset_end = chunk.span.byte_start + end_pos;
505
506 let text_before_start = &text[..start_pos];
508 let line_offset_start = text_before_start.lines().count().saturating_sub(1);
509 let stride_lines = stride_text.lines().count();
510
511 let stride_chunk = Chunk {
512 span: Span {
513 byte_start: byte_offset_start,
514 byte_end: byte_offset_end,
515 line_start: chunk.span.line_start + line_offset_start,
516 line_end: chunk.span.line_start + line_offset_start + stride_lines,
517 },
518 text: stride_text.to_string(),
519 chunk_type: chunk.chunk_type.clone(),
520 stride_info: Some(StrideInfo {
521 original_chunk_id: original_chunk_id.clone(),
522 stride_index,
523 total_strides,
524 overlap_start,
525 overlap_end,
526 }),
527 };
528
529 strided_chunks.push(stride_chunk);
530
531 if end_pos >= text_len {
533 break;
534 }
535
536 start_pos += stride_chars;
537 stride_index += 1;
538 }
539
540 tracing::debug!(
541 "Created {} strides from chunk of {} tokens",
542 strided_chunks.len(),
543 estimate_tokens(text)
544 );
545
546 Ok(strided_chunks)
547}
548
549#[cfg(test)]
552mod tests {
553 use super::*;
554
555 #[test]
556 fn test_chunk_generic_byte_offsets() {
557 let text = "line 1\nline 2\nline 3\nline 4\nline 5";
559 let chunks = chunk_generic(text).unwrap();
560
561 assert!(!chunks.is_empty());
562
563 assert_eq!(chunks[0].span.byte_start, 0);
565
566 for chunk in &chunks {
568 let expected_len = chunk.text.len();
569 let actual_len = chunk.span.byte_end - chunk.span.byte_start;
570 assert_eq!(actual_len, expected_len);
571 }
572 }
573
574 #[test]
575 fn test_chunk_generic_large_file_performance() {
576 let lines: Vec<String> = (0..1000)
578 .map(|i| format!("Line {}: Some content here", i))
579 .collect();
580 let text = lines.join("\n");
581
582 let start = std::time::Instant::now();
583 let chunks = chunk_generic(&text).unwrap();
584 let duration = start.elapsed();
585
586 assert!(
588 duration.as_millis() < 100,
589 "Chunking took too long: {:?}",
590 duration
591 );
592 assert!(!chunks.is_empty());
593
594 for chunk in &chunks {
596 assert!(chunk.span.line_start > 0);
597 assert!(chunk.span.line_end >= chunk.span.line_start);
598 }
599 }
600
601 #[test]
602 fn test_chunk_rust() {
603 let rust_code = r#"
604pub struct Calculator {
605 memory: f64,
606}
607
608impl Calculator {
609 pub fn new() -> Self {
610 Calculator { memory: 0.0 }
611 }
612
613 pub fn add(&mut self, a: f64, b: f64) -> f64 {
614 a + b
615 }
616}
617
618fn main() {
619 let calc = Calculator::new();
620}
621
622pub mod utils {
623 pub fn helper() {}
624}
625"#;
626
627 let chunks = chunk_language(rust_code, ParseableLanguage::Rust).unwrap();
628 assert!(!chunks.is_empty());
629
630 let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
632 assert!(chunk_types.contains(&&ChunkType::Class)); assert!(chunk_types.contains(&&ChunkType::Module)); assert!(chunk_types.contains(&&ChunkType::Function)); }
636
637 #[test]
638 fn test_chunk_ruby() {
639 let ruby_code = r#"
640class Calculator
641 def initialize
642 @memory = 0.0
643 end
644
645 def add(a, b)
646 a + b
647 end
648
649 def self.class_method
650 "class method"
651 end
652
653 private
654
655 def private_method
656 "private"
657 end
658end
659
660module Utils
661 def self.helper
662 "helper"
663 end
664end
665
666def main
667 calc = Calculator.new
668end
669"#;
670
671 let chunks = chunk_language(ruby_code, ParseableLanguage::Ruby).unwrap();
672 assert!(!chunks.is_empty());
673
674 let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
676 assert!(chunk_types.contains(&&ChunkType::Class)); assert!(chunk_types.contains(&&ChunkType::Module)); assert!(chunk_types.contains(&&ChunkType::Function)); }
680
681 #[test]
682 fn test_language_detection_fallback() {
683 let generic_text = "Some text\nwith multiple lines\nto chunk generically";
685
686 let chunks_unknown = chunk_text(generic_text, None).unwrap();
687 let chunks_generic = chunk_generic(generic_text).unwrap();
688
689 assert_eq!(chunks_unknown.len(), chunks_generic.len());
691 assert_eq!(chunks_unknown[0].text, chunks_generic[0].text);
692 }
693
694 #[test]
695 fn test_chunk_go() {
696 let go_code = r#"
697package main
698
699import "fmt"
700
701const Pi = 3.14159
702
703var memory float64
704
705type Calculator struct {
706 memory float64
707}
708
709type Operation interface {
710 Calculate(a, b float64) float64
711}
712
713func NewCalculator() *Calculator {
714 return &Calculator{memory: 0.0}
715}
716
717func (c *Calculator) Add(a, b float64) float64 {
718 return a + b
719}
720
721func main() {
722 calc := NewCalculator()
723}
724"#;
725
726 let chunks = chunk_language(go_code, ParseableLanguage::Go).unwrap();
727 assert!(!chunks.is_empty());
728
729 let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
731 assert!(chunk_types.contains(&&ChunkType::Module)); assert!(chunk_types.contains(&&ChunkType::Class)); assert!(chunk_types.contains(&&ChunkType::Function)); assert!(chunk_types.contains(&&ChunkType::Method)); }
736
737 #[test]
738 fn test_chunk_csharp() {
739 let csharp_code = r#"
740namespace Calculator;
741
742public interface ICalculator
743{
744 double Add(double x, double y);
745}
746
747public class Calculator
748{
749 public static const double PI = 3.14159;
750 private double _memory;
751
752 public Calculator()
753 {
754 _memory = 0.0;
755 }
756
757 public double Add(double x, double y)
758 {
759 return x + y;
760 }
761
762 public static void Main(string[] args)
763 {
764 var calc = new Calculator();
765 }
766}
767"#;
768
769 let chunks = chunk_language(csharp_code, ParseableLanguage::CSharp).unwrap();
770 assert!(!chunks.is_empty());
771
772 let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
774 assert!(chunk_types.contains(&&ChunkType::Module)); assert!(chunk_types.contains(&&ChunkType::Class)); assert!(chunk_types.contains(&&ChunkType::Method)); }
778}