ck_chunk/
lib.rs

1use anyhow::Result;
2use ck_core::Span;
3use serde::{Deserialize, Serialize};
4
5/// Information about chunk striding for large chunks that exceed token limits
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct StrideInfo {
8    /// Unique ID for the original chunk before striding
9    pub original_chunk_id: String,
10    /// Index of this stride (0-based)
11    pub stride_index: usize,
12    /// Total number of strides for the original chunk
13    pub total_strides: usize,
14    /// Byte offset where overlap with previous stride begins
15    pub overlap_start: usize,
16    /// Byte offset where overlap with next stride ends
17    pub overlap_end: usize,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Chunk {
22    pub span: Span,
23    pub text: String,
24    pub chunk_type: ChunkType,
25    /// Stride information if this chunk was created by striding a larger chunk
26    pub stride_info: Option<StrideInfo>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30pub enum ChunkType {
31    Text,
32    Function,
33    Class,
34    Method,
35    Module,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum ParseableLanguage {
40    Python,
41    TypeScript,
42    JavaScript,
43    Haskell,
44    Rust,
45    Ruby,
46    Go,
47}
48
49impl std::fmt::Display for ParseableLanguage {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        let name = match self {
52            ParseableLanguage::Python => "python",
53            ParseableLanguage::TypeScript => "typescript",
54            ParseableLanguage::JavaScript => "javascript",
55            ParseableLanguage::Haskell => "haskell",
56            ParseableLanguage::Rust => "rust",
57            ParseableLanguage::Ruby => "ruby",
58            ParseableLanguage::Go => "go",
59        };
60        write!(f, "{}", name)
61    }
62}
63
64impl TryFrom<ck_core::Language> for ParseableLanguage {
65    type Error = anyhow::Error;
66
67    fn try_from(lang: ck_core::Language) -> Result<Self, Self::Error> {
68        match lang {
69            ck_core::Language::Python => Ok(ParseableLanguage::Python),
70            ck_core::Language::TypeScript => Ok(ParseableLanguage::TypeScript),
71            ck_core::Language::JavaScript => Ok(ParseableLanguage::JavaScript),
72            ck_core::Language::Haskell => Ok(ParseableLanguage::Haskell),
73            ck_core::Language::Rust => Ok(ParseableLanguage::Rust),
74            ck_core::Language::Ruby => Ok(ParseableLanguage::Ruby),
75            ck_core::Language::Go => Ok(ParseableLanguage::Go),
76            _ => Err(anyhow::anyhow!(
77                "Language {:?} is not supported for parsing",
78                lang
79            )),
80        }
81    }
82}
83
84pub fn chunk_text(text: &str, language: Option<ck_core::Language>) -> Result<Vec<Chunk>> {
85    chunk_text_with_config(text, language, &ChunkConfig::default())
86}
87
88/// Configuration for chunking behavior
89#[derive(Debug, Clone)]
90pub struct ChunkConfig {
91    /// Maximum tokens per chunk (for striding)
92    pub max_tokens: usize,
93    /// Overlap size for striding (in tokens)
94    pub stride_overlap: usize,
95    /// Enable striding for chunks that exceed max_tokens
96    pub enable_striding: bool,
97}
98
99impl Default for ChunkConfig {
100    fn default() -> Self {
101        Self {
102            max_tokens: 8192,     // Default to Nomic model limit
103            stride_overlap: 1024, // 12.5% overlap
104            enable_striding: true,
105        }
106    }
107}
108
109pub fn chunk_text_with_config(
110    text: &str,
111    language: Option<ck_core::Language>,
112    config: &ChunkConfig,
113) -> Result<Vec<Chunk>> {
114    tracing::debug!(
115        "Chunking text with language: {:?}, length: {} chars, config: {:?}",
116        language,
117        text.len(),
118        config
119    );
120
121    let result = match language.map(ParseableLanguage::try_from) {
122        Some(Ok(lang)) => {
123            tracing::debug!("Using {} tree-sitter parser", lang);
124            chunk_language(text, lang)
125        }
126        Some(Err(_)) => {
127            tracing::debug!("Language not supported for parsing, using generic chunking strategy");
128            chunk_generic(text)
129        }
130        None => {
131            tracing::debug!("Using generic chunking strategy");
132            chunk_generic(text)
133        }
134    };
135
136    let mut chunks = result?;
137
138    // Apply striding if enabled and necessary
139    if config.enable_striding {
140        chunks = apply_striding(chunks, config)?;
141    }
142
143    tracing::debug!("Successfully created {} final chunks", chunks.len());
144    Ok(chunks)
145}
146
147fn chunk_generic(text: &str) -> Result<Vec<Chunk>> {
148    let mut chunks = Vec::new();
149    let lines: Vec<&str> = text.lines().collect();
150    let chunk_size = 20;
151    let overlap = 5;
152
153    // Pre-compute cumulative byte offsets for O(1) lookup
154    let mut line_byte_offsets = Vec::with_capacity(lines.len() + 1);
155    line_byte_offsets.push(0);
156    let mut cumulative_offset = 0;
157    for line in &lines {
158        cumulative_offset += line.len() + 1; // +1 for newline
159        line_byte_offsets.push(cumulative_offset);
160    }
161
162    let mut i = 0;
163    while i < lines.len() {
164        let end = (i + chunk_size).min(lines.len());
165        let chunk_lines = &lines[i..end];
166        let chunk_text = chunk_lines.join("\n");
167
168        let byte_start = line_byte_offsets[i];
169        let byte_end = byte_start + chunk_text.len();
170
171        chunks.push(Chunk {
172            span: Span {
173                byte_start,
174                byte_end,
175                line_start: i + 1,
176                line_end: end,
177            },
178            text: chunk_text,
179            chunk_type: ChunkType::Text,
180            stride_info: None,
181        });
182
183        i += chunk_size - overlap;
184        if i >= lines.len() {
185            break;
186        }
187    }
188
189    Ok(chunks)
190}
191
192fn chunk_language(text: &str, language: ParseableLanguage) -> Result<Vec<Chunk>> {
193    let mut parser = tree_sitter::Parser::new();
194
195    match language {
196        ParseableLanguage::Python => parser.set_language(&tree_sitter_python::language())?,
197        ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => {
198            parser.set_language(&tree_sitter_typescript::language_typescript())?
199        }
200        ParseableLanguage::Haskell => parser.set_language(&tree_sitter_haskell::language())?,
201        ParseableLanguage::Rust => parser.set_language(&tree_sitter_rust::language())?,
202        ParseableLanguage::Ruby => parser.set_language(&tree_sitter_ruby::language())?,
203        ParseableLanguage::Go => parser.set_language(&tree_sitter_go::language())?,
204    }
205
206    let tree = parser
207        .parse(text, None)
208        .ok_or_else(|| anyhow::anyhow!("Failed to parse {} code", language))?;
209
210    let mut chunks = Vec::new();
211    let mut cursor = tree.root_node().walk();
212
213    extract_code_chunks(&mut cursor, text, &mut chunks, language);
214
215    if chunks.is_empty() {
216        return chunk_generic(text);
217    }
218
219    Ok(chunks)
220}
221
222fn extract_code_chunks(
223    cursor: &mut tree_sitter::TreeCursor,
224    source: &str,
225    chunks: &mut Vec<Chunk>,
226    language: ParseableLanguage,
227) {
228    let node = cursor.node();
229    let node_kind = node.kind();
230
231    let is_chunk = match language {
232        ParseableLanguage::Python => {
233            matches!(node_kind, "function_definition" | "class_definition")
234        }
235        ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => matches!(
236            node_kind,
237            "function_declaration" | "class_declaration" | "method_definition" | "arrow_function"
238        ),
239        ParseableLanguage::Haskell => matches!(
240            node_kind,
241            "signature"
242                | "data_type"
243                | "newtype"
244                | "type_synomym"
245                | "type_family"
246                | "class"
247                | "instance"
248        ),
249        ParseableLanguage::Rust => matches!(
250            node_kind,
251            "function_item" | "impl_item" | "struct_item" | "enum_item" | "trait_item" | "mod_item"
252        ),
253        ParseableLanguage::Ruby => matches!(
254            node_kind,
255            "method" | "class" | "module" | "singleton_method"
256        ),
257        ParseableLanguage::Go => matches!(
258            node_kind,
259            "function_declaration"
260                | "method_declaration"
261                | "type_declaration"
262                | "var_declaration"
263                | "const_declaration"
264        ),
265    };
266
267    if is_chunk {
268        let start_byte = node.start_byte();
269        let end_byte = node.end_byte();
270        let start_pos = node.start_position();
271        let end_pos = node.end_position();
272
273        let text = &source[start_byte..end_byte];
274
275        let chunk_type = match node_kind {
276            "function_definition"
277            | "function_declaration"
278            | "arrow_function"
279            | "function"
280            | "signature"
281            | "function_item"
282            | "def"
283            | "defp"
284            | "method"
285            | "singleton_method"
286            | "defn"
287            | "defn-" => ChunkType::Function,
288            "class_definition"
289            | "class_declaration"
290            | "instance_declaration"
291            | "class"
292            | "instance"
293            | "struct_item"
294            | "enum_item"
295            | "defstruct"
296            | "defrecord"
297            | "deftype"
298            | "type_declaration" => ChunkType::Class,
299            "method_definition" | "method_declaration" | "defmacro" => ChunkType::Method,
300            "data_type" | "newtype" | "type_synomym" | "type_family" | "impl_item"
301            | "trait_item" | "mod_item" | "defmodule" | "module" | "defprotocol" | "ns"
302            | "var_declaration" | "const_declaration" => ChunkType::Module,
303            _ => ChunkType::Text,
304        };
305
306        chunks.push(Chunk {
307            span: Span {
308                byte_start: start_byte,
309                byte_end: end_byte,
310                line_start: start_pos.row + 1,
311                line_end: end_pos.row + 1,
312            },
313            text: text.to_string(),
314            chunk_type,
315            stride_info: None,
316        });
317    }
318
319    if cursor.goto_first_child() {
320        loop {
321            extract_code_chunks(cursor, source, chunks, language);
322            if !cursor.goto_next_sibling() {
323                break;
324            }
325        }
326        cursor.goto_parent();
327    }
328}
329
330/// Apply striding to chunks that exceed the token limit
331fn apply_striding(chunks: Vec<Chunk>, config: &ChunkConfig) -> Result<Vec<Chunk>> {
332    let mut result = Vec::new();
333
334    for chunk in chunks {
335        let estimated_tokens = estimate_tokens(&chunk.text);
336
337        if estimated_tokens <= config.max_tokens {
338            // Chunk fits within limit, no striding needed
339            result.push(chunk);
340        } else {
341            // Chunk exceeds limit, apply striding
342            tracing::debug!(
343                "Chunk with {} tokens exceeds limit of {}, applying striding",
344                estimated_tokens,
345                config.max_tokens
346            );
347
348            let strided_chunks = stride_large_chunk(chunk, config)?;
349            result.extend(strided_chunks);
350        }
351    }
352
353    Ok(result)
354}
355
356/// Create strided chunks from a large chunk that exceeds token limits
357fn stride_large_chunk(chunk: Chunk, config: &ChunkConfig) -> Result<Vec<Chunk>> {
358    let text = &chunk.text;
359    let text_len = text.len();
360
361    // Calculate stride parameters in characters (approximate)
362    // Use a conservative estimate to ensure we stay under token limits
363    let chars_per_token = text_len as f32 / estimate_tokens(text) as f32;
364    let window_chars = ((config.max_tokens as f32 * 0.9) * chars_per_token) as usize; // 10% buffer
365    let overlap_chars = (config.stride_overlap as f32 * chars_per_token) as usize;
366    let stride_chars = window_chars.saturating_sub(overlap_chars);
367
368    if stride_chars == 0 {
369        return Err(anyhow::anyhow!("Stride size is too small"));
370    }
371
372    let mut strided_chunks = Vec::new();
373    let original_chunk_id = format!("{}:{}", chunk.span.byte_start, chunk.span.byte_end);
374    let mut start_pos = 0;
375    let mut stride_index = 0;
376
377    // Calculate total number of strides
378    let total_strides = if text_len <= window_chars {
379        1
380    } else {
381        ((text_len - overlap_chars) as f32 / stride_chars as f32).ceil() as usize
382    };
383
384    while start_pos < text_len {
385        let end_pos = (start_pos + window_chars).min(text_len);
386        let stride_text = &text[start_pos..end_pos];
387
388        // Calculate overlap information
389        let overlap_start = if stride_index > 0 { overlap_chars } else { 0 };
390        let overlap_end = if end_pos < text_len { overlap_chars } else { 0 };
391
392        // Calculate span for this stride
393        let byte_offset_start = chunk.span.byte_start + start_pos;
394        let byte_offset_end = chunk.span.byte_start + end_pos;
395
396        // Estimate line numbers (approximate)
397        let text_before_start = &text[..start_pos];
398        let line_offset_start = text_before_start.lines().count().saturating_sub(1);
399        let stride_lines = stride_text.lines().count();
400
401        let stride_chunk = Chunk {
402            span: Span {
403                byte_start: byte_offset_start,
404                byte_end: byte_offset_end,
405                line_start: chunk.span.line_start + line_offset_start,
406                line_end: chunk.span.line_start + line_offset_start + stride_lines,
407            },
408            text: stride_text.to_string(),
409            chunk_type: chunk.chunk_type.clone(),
410            stride_info: Some(StrideInfo {
411                original_chunk_id: original_chunk_id.clone(),
412                stride_index,
413                total_strides,
414                overlap_start,
415                overlap_end,
416            }),
417        };
418
419        strided_chunks.push(stride_chunk);
420
421        // Move to next stride
422        if end_pos >= text_len {
423            break;
424        }
425
426        start_pos += stride_chars;
427        stride_index += 1;
428    }
429
430    tracing::debug!(
431        "Created {} strides from chunk of {} tokens",
432        strided_chunks.len(),
433        estimate_tokens(text)
434    );
435
436    Ok(strided_chunks)
437}
438
439/// Simple token estimation (matches the one in ck-embed)
440fn estimate_tokens(text: &str) -> usize {
441    if text.is_empty() {
442        return 0;
443    }
444
445    // Rough estimation: ~4.5 characters per token on average
446    let char_count = text.chars().count();
447    (char_count as f32 / 4.5).ceil() as usize
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[test]
455    fn test_chunk_generic_byte_offsets() {
456        // Test that byte offsets are calculated correctly using O(n) algorithm
457        let text = "line 1\nline 2\nline 3\nline 4\nline 5";
458        let chunks = chunk_generic(text).unwrap();
459
460        assert!(!chunks.is_empty());
461
462        // First chunk should start at byte 0
463        assert_eq!(chunks[0].span.byte_start, 0);
464
465        // Each chunk's byte_end should match the actual text length
466        for chunk in &chunks {
467            let expected_len = chunk.text.len();
468            let actual_len = chunk.span.byte_end - chunk.span.byte_start;
469            assert_eq!(actual_len, expected_len);
470        }
471    }
472
473    #[test]
474    fn test_chunk_generic_large_file_performance() {
475        // Create a large text to ensure O(n) performance
476        let lines: Vec<String> = (0..1000)
477            .map(|i| format!("Line {}: Some content here", i))
478            .collect();
479        let text = lines.join("\n");
480
481        let start = std::time::Instant::now();
482        let chunks = chunk_generic(&text).unwrap();
483        let duration = start.elapsed();
484
485        // Should complete quickly even for 1000 lines
486        assert!(
487            duration.as_millis() < 100,
488            "Chunking took too long: {:?}",
489            duration
490        );
491        assert!(!chunks.is_empty());
492
493        // Verify chunks have correct line numbers
494        for chunk in &chunks {
495            assert!(chunk.span.line_start > 0);
496            assert!(chunk.span.line_end >= chunk.span.line_start);
497        }
498    }
499
500    #[test]
501    fn test_chunk_rust() {
502        let rust_code = r#"
503pub struct Calculator {
504    memory: f64,
505}
506
507impl Calculator {
508    pub fn new() -> Self {
509        Calculator { memory: 0.0 }
510    }
511    
512    pub fn add(&mut self, a: f64, b: f64) -> f64 {
513        a + b
514    }
515}
516
517fn main() {
518    let calc = Calculator::new();
519}
520
521pub mod utils {
522    pub fn helper() {}
523}
524"#;
525
526        let chunks = chunk_language(rust_code, ParseableLanguage::Rust).unwrap();
527        assert!(!chunks.is_empty());
528
529        // Should find struct, impl, functions, and module
530        let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
531        assert!(chunk_types.contains(&&ChunkType::Class)); // struct
532        assert!(chunk_types.contains(&&ChunkType::Module)); // impl and mod
533        assert!(chunk_types.contains(&&ChunkType::Function)); // functions
534    }
535
536    #[test]
537    fn test_chunk_ruby() {
538        let ruby_code = r#"
539class Calculator
540  def initialize
541    @memory = 0.0
542  end
543
544  def add(a, b)
545    a + b
546  end
547
548  def self.class_method
549    "class method"
550  end
551
552  private
553
554  def private_method
555    "private"
556  end
557end
558
559module Utils
560  def self.helper
561    "helper"
562  end
563end
564
565def main
566  calc = Calculator.new
567end
568"#;
569
570        let chunks = chunk_language(ruby_code, ParseableLanguage::Ruby).unwrap();
571        assert!(!chunks.is_empty());
572
573        // Should find class, module, and methods
574        let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
575        assert!(chunk_types.contains(&&ChunkType::Class)); // class
576        assert!(chunk_types.contains(&&ChunkType::Module)); // module
577        assert!(chunk_types.contains(&&ChunkType::Function)); // methods
578    }
579
580    #[test]
581    fn test_language_detection_fallback() {
582        // Test that unknown languages fall back to generic chunking
583        let generic_text = "Some text\nwith multiple lines\nto chunk generically";
584
585        let chunks_unknown = chunk_text(generic_text, None).unwrap();
586        let chunks_generic = chunk_generic(generic_text).unwrap();
587
588        // Should produce the same result
589        assert_eq!(chunks_unknown.len(), chunks_generic.len());
590        assert_eq!(chunks_unknown[0].text, chunks_generic[0].text);
591    }
592
593    #[test]
594    fn test_chunk_go() {
595        let go_code = r#"
596package main
597
598import "fmt"
599
600const Pi = 3.14159
601
602var memory float64
603
604type Calculator struct {
605    memory float64
606}
607
608type Operation interface {
609    Calculate(a, b float64) float64
610}
611
612func NewCalculator() *Calculator {
613    return &Calculator{memory: 0.0}
614}
615
616func (c *Calculator) Add(a, b float64) float64 {
617    return a + b
618}
619
620func main() {
621    calc := NewCalculator()
622}
623"#;
624
625        let chunks = chunk_language(go_code, ParseableLanguage::Go).unwrap();
626        assert!(!chunks.is_empty());
627
628        // Should find const, var, type declarations, functions, and methods
629        let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
630        assert!(chunk_types.contains(&&ChunkType::Module)); // const and var
631        assert!(chunk_types.contains(&&ChunkType::Class)); // struct and interface
632        assert!(chunk_types.contains(&&ChunkType::Function)); // functions
633        assert!(chunk_types.contains(&&ChunkType::Method)); // methods
634    }
635}