ck_chunk/
lib.rs

1use anyhow::Result;
2use ck_core::Span;
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct Chunk {
7    pub span: Span,
8    pub text: String,
9    pub chunk_type: ChunkType,
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
13pub enum ChunkType {
14    Text,
15    Function,
16    Class,
17    Method,
18    Module,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ParseableLanguage {
23    Python,
24    TypeScript,
25    JavaScript,
26    Haskell,
27    Rust,
28    Ruby,
29    Go,
30}
31
32impl std::fmt::Display for ParseableLanguage {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        let name = match self {
35            ParseableLanguage::Python => "python",
36            ParseableLanguage::TypeScript => "typescript",
37            ParseableLanguage::JavaScript => "javascript",
38            ParseableLanguage::Haskell => "haskell",
39            ParseableLanguage::Rust => "rust",
40            ParseableLanguage::Ruby => "ruby",
41            ParseableLanguage::Go => "go",
42        };
43        write!(f, "{}", name)
44    }
45}
46
47impl TryFrom<ck_core::Language> for ParseableLanguage {
48    type Error = anyhow::Error;
49
50    fn try_from(lang: ck_core::Language) -> Result<Self, Self::Error> {
51        match lang {
52            ck_core::Language::Python => Ok(ParseableLanguage::Python),
53            ck_core::Language::TypeScript => Ok(ParseableLanguage::TypeScript),
54            ck_core::Language::JavaScript => Ok(ParseableLanguage::JavaScript),
55            ck_core::Language::Haskell => Ok(ParseableLanguage::Haskell),
56            ck_core::Language::Rust => Ok(ParseableLanguage::Rust),
57            ck_core::Language::Ruby => Ok(ParseableLanguage::Ruby),
58            ck_core::Language::Go => Ok(ParseableLanguage::Go),
59            _ => Err(anyhow::anyhow!(
60                "Language {:?} is not supported for parsing",
61                lang
62            )),
63        }
64    }
65}
66
67pub fn chunk_text(text: &str, language: Option<ck_core::Language>) -> Result<Vec<Chunk>> {
68    tracing::debug!(
69        "Chunking text with language: {:?}, length: {} chars",
70        language,
71        text.len()
72    );
73
74    let result = match language.map(ParseableLanguage::try_from) {
75        Some(Ok(lang)) => {
76            tracing::debug!("Using {} tree-sitter parser", lang);
77            chunk_language(text, lang)
78        }
79        Some(Err(_)) => {
80            tracing::debug!("Language not supported for parsing, using generic chunking strategy");
81            chunk_generic(text)
82        }
83        None => {
84            tracing::debug!("Using generic chunking strategy");
85            chunk_generic(text)
86        }
87    };
88
89    match &result {
90        Ok(chunks) => tracing::debug!("Successfully created {} chunks", chunks.len()),
91        Err(e) => tracing::warn!("Chunking failed: {}", e),
92    }
93
94    result
95}
96
97fn chunk_generic(text: &str) -> Result<Vec<Chunk>> {
98    let mut chunks = Vec::new();
99    let lines: Vec<&str> = text.lines().collect();
100    let chunk_size = 20;
101    let overlap = 5;
102
103    // Pre-compute cumulative byte offsets for O(1) lookup
104    let mut line_byte_offsets = Vec::with_capacity(lines.len() + 1);
105    line_byte_offsets.push(0);
106    let mut cumulative_offset = 0;
107    for line in &lines {
108        cumulative_offset += line.len() + 1; // +1 for newline
109        line_byte_offsets.push(cumulative_offset);
110    }
111
112    let mut i = 0;
113    while i < lines.len() {
114        let end = (i + chunk_size).min(lines.len());
115        let chunk_lines = &lines[i..end];
116        let chunk_text = chunk_lines.join("\n");
117
118        let byte_start = line_byte_offsets[i];
119        let byte_end = byte_start + chunk_text.len();
120
121        chunks.push(Chunk {
122            span: Span {
123                byte_start,
124                byte_end,
125                line_start: i + 1,
126                line_end: end,
127            },
128            text: chunk_text,
129            chunk_type: ChunkType::Text,
130        });
131
132        i += chunk_size - overlap;
133        if i >= lines.len() {
134            break;
135        }
136    }
137
138    Ok(chunks)
139}
140
141fn chunk_language(text: &str, language: ParseableLanguage) -> Result<Vec<Chunk>> {
142    let mut parser = tree_sitter::Parser::new();
143
144    match language {
145        ParseableLanguage::Python => parser.set_language(&tree_sitter_python::language())?,
146        ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => {
147            parser.set_language(&tree_sitter_typescript::language_typescript())?
148        }
149        ParseableLanguage::Haskell => parser.set_language(&tree_sitter_haskell::language())?,
150        ParseableLanguage::Rust => parser.set_language(&tree_sitter_rust::language())?,
151        ParseableLanguage::Ruby => parser.set_language(&tree_sitter_ruby::language())?,
152        ParseableLanguage::Go => parser.set_language(&tree_sitter_go::language())?,
153    }
154
155    let tree = parser
156        .parse(text, None)
157        .ok_or_else(|| anyhow::anyhow!("Failed to parse {} code", language))?;
158
159    let mut chunks = Vec::new();
160    let mut cursor = tree.root_node().walk();
161
162    extract_code_chunks(&mut cursor, text, &mut chunks, language);
163
164    if chunks.is_empty() {
165        return chunk_generic(text);
166    }
167
168    Ok(chunks)
169}
170
171fn extract_code_chunks(
172    cursor: &mut tree_sitter::TreeCursor,
173    source: &str,
174    chunks: &mut Vec<Chunk>,
175    language: ParseableLanguage,
176) {
177    let node = cursor.node();
178    let node_kind = node.kind();
179
180    let is_chunk = match language {
181        ParseableLanguage::Python => {
182            matches!(node_kind, "function_definition" | "class_definition")
183        }
184        ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => matches!(
185            node_kind,
186            "function_declaration" | "class_declaration" | "method_definition" | "arrow_function"
187        ),
188        ParseableLanguage::Haskell => matches!(
189            node_kind,
190            "signature"
191                | "data_type"
192                | "newtype"
193                | "type_synomym"
194                | "type_family"
195                | "class"
196                | "instance"
197        ),
198        ParseableLanguage::Rust => matches!(
199            node_kind,
200            "function_item" | "impl_item" | "struct_item" | "enum_item" | "trait_item" | "mod_item"
201        ),
202        ParseableLanguage::Ruby => matches!(
203            node_kind,
204            "method" | "class" | "module" | "singleton_method"
205        ),
206        ParseableLanguage::Go => matches!(
207            node_kind,
208            "function_declaration"
209                | "method_declaration"
210                | "type_declaration"
211                | "var_declaration"
212                | "const_declaration"
213        ),
214    };
215
216    if is_chunk {
217        let start_byte = node.start_byte();
218        let end_byte = node.end_byte();
219        let start_pos = node.start_position();
220        let end_pos = node.end_position();
221
222        let text = &source[start_byte..end_byte];
223
224        let chunk_type = match node_kind {
225            "function_definition"
226            | "function_declaration"
227            | "arrow_function"
228            | "function"
229            | "signature"
230            | "function_item"
231            | "def"
232            | "defp"
233            | "method"
234            | "singleton_method"
235            | "defn"
236            | "defn-" => ChunkType::Function,
237            "class_definition"
238            | "class_declaration"
239            | "instance_declaration"
240            | "class"
241            | "instance"
242            | "struct_item"
243            | "enum_item"
244            | "defstruct"
245            | "defrecord"
246            | "deftype"
247            | "type_declaration" => ChunkType::Class,
248            "method_definition" | "method_declaration" | "defmacro" => ChunkType::Method,
249            "data_type" | "newtype" | "type_synomym" | "type_family" | "impl_item"
250            | "trait_item" | "mod_item" | "defmodule" | "module" | "defprotocol" | "ns"
251            | "var_declaration" | "const_declaration" => ChunkType::Module,
252            _ => ChunkType::Text,
253        };
254
255        chunks.push(Chunk {
256            span: Span {
257                byte_start: start_byte,
258                byte_end: end_byte,
259                line_start: start_pos.row + 1,
260                line_end: end_pos.row + 1,
261            },
262            text: text.to_string(),
263            chunk_type,
264        });
265    }
266
267    if cursor.goto_first_child() {
268        loop {
269            extract_code_chunks(cursor, source, chunks, language);
270            if !cursor.goto_next_sibling() {
271                break;
272            }
273        }
274        cursor.goto_parent();
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn test_chunk_generic_byte_offsets() {
284        // Test that byte offsets are calculated correctly using O(n) algorithm
285        let text = "line 1\nline 2\nline 3\nline 4\nline 5";
286        let chunks = chunk_generic(text).unwrap();
287
288        assert!(!chunks.is_empty());
289
290        // First chunk should start at byte 0
291        assert_eq!(chunks[0].span.byte_start, 0);
292
293        // Each chunk's byte_end should match the actual text length
294        for chunk in &chunks {
295            let expected_len = chunk.text.len();
296            let actual_len = chunk.span.byte_end - chunk.span.byte_start;
297            assert_eq!(actual_len, expected_len);
298        }
299    }
300
301    #[test]
302    fn test_chunk_generic_large_file_performance() {
303        // Create a large text to ensure O(n) performance
304        let lines: Vec<String> = (0..1000)
305            .map(|i| format!("Line {}: Some content here", i))
306            .collect();
307        let text = lines.join("\n");
308
309        let start = std::time::Instant::now();
310        let chunks = chunk_generic(&text).unwrap();
311        let duration = start.elapsed();
312
313        // Should complete quickly even for 1000 lines
314        assert!(
315            duration.as_millis() < 100,
316            "Chunking took too long: {:?}",
317            duration
318        );
319        assert!(!chunks.is_empty());
320
321        // Verify chunks have correct line numbers
322        for chunk in &chunks {
323            assert!(chunk.span.line_start > 0);
324            assert!(chunk.span.line_end >= chunk.span.line_start);
325        }
326    }
327
328    #[test]
329    fn test_chunk_rust() {
330        let rust_code = r#"
331pub struct Calculator {
332    memory: f64,
333}
334
335impl Calculator {
336    pub fn new() -> Self {
337        Calculator { memory: 0.0 }
338    }
339    
340    pub fn add(&mut self, a: f64, b: f64) -> f64 {
341        a + b
342    }
343}
344
345fn main() {
346    let calc = Calculator::new();
347}
348
349pub mod utils {
350    pub fn helper() {}
351}
352"#;
353
354        let chunks = chunk_language(rust_code, ParseableLanguage::Rust).unwrap();
355        assert!(!chunks.is_empty());
356
357        // Should find struct, impl, functions, and module
358        let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
359        assert!(chunk_types.contains(&&ChunkType::Class)); // struct
360        assert!(chunk_types.contains(&&ChunkType::Module)); // impl and mod
361        assert!(chunk_types.contains(&&ChunkType::Function)); // functions
362    }
363
364    #[test]
365    fn test_chunk_ruby() {
366        let ruby_code = r#"
367class Calculator
368  def initialize
369    @memory = 0.0
370  end
371
372  def add(a, b)
373    a + b
374  end
375
376  def self.class_method
377    "class method"
378  end
379
380  private
381
382  def private_method
383    "private"
384  end
385end
386
387module Utils
388  def self.helper
389    "helper"
390  end
391end
392
393def main
394  calc = Calculator.new
395end
396"#;
397
398        let chunks = chunk_language(ruby_code, ParseableLanguage::Ruby).unwrap();
399        assert!(!chunks.is_empty());
400
401        // Should find class, module, and methods
402        let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
403        assert!(chunk_types.contains(&&ChunkType::Class)); // class
404        assert!(chunk_types.contains(&&ChunkType::Module)); // module
405        assert!(chunk_types.contains(&&ChunkType::Function)); // methods
406    }
407
408    #[test]
409    fn test_language_detection_fallback() {
410        // Test that unknown languages fall back to generic chunking
411        let generic_text = "Some text\nwith multiple lines\nto chunk generically";
412
413        let chunks_unknown = chunk_text(generic_text, None).unwrap();
414        let chunks_generic = chunk_generic(generic_text).unwrap();
415
416        // Should produce the same result
417        assert_eq!(chunks_unknown.len(), chunks_generic.len());
418        assert_eq!(chunks_unknown[0].text, chunks_generic[0].text);
419    }
420
421    #[test]
422    fn test_chunk_go() {
423        let go_code = r#"
424package main
425
426import "fmt"
427
428const Pi = 3.14159
429
430var memory float64
431
432type Calculator struct {
433    memory float64
434}
435
436type Operation interface {
437    Calculate(a, b float64) float64
438}
439
440func NewCalculator() *Calculator {
441    return &Calculator{memory: 0.0}
442}
443
444func (c *Calculator) Add(a, b float64) float64 {
445    return a + b
446}
447
448func main() {
449    calc := NewCalculator()
450}
451"#;
452
453        let chunks = chunk_language(go_code, ParseableLanguage::Go).unwrap();
454        assert!(!chunks.is_empty());
455
456        // Should find const, var, type declarations, functions, and methods
457        let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
458        assert!(chunk_types.contains(&&ChunkType::Module)); // const and var
459        assert!(chunk_types.contains(&&ChunkType::Class)); // struct and interface
460        assert!(chunk_types.contains(&&ChunkType::Function)); // functions
461        assert!(chunk_types.contains(&&ChunkType::Method)); // methods
462    }
463}