1use anyhow::Result;
2use ck_core::Span;
3use serde::{Deserialize, Serialize};
4
5mod query_chunker;
6
7pub use ck_embed::TokenEstimator;
9
10fn estimate_tokens(text: &str) -> usize {
12 TokenEstimator::estimate_tokens(text)
13}
14
15pub fn get_model_chunk_config(model_name: Option<&str>) -> (usize, usize) {
18 let model = model_name.unwrap_or("nomic-embed-text-v1.5");
19
20 match model {
21 "BAAI/bge-small-en-v1.5" | "sentence-transformers/all-MiniLM-L6-v2" => {
23 (400, 80) }
25
26 "nomic-embed-text-v1" | "nomic-embed-text-v1.5" | "jina-embeddings-v2-base-code" => {
29 (1024, 200) }
31
32 "BAAI/bge-base-en-v1.5" | "BAAI/bge-large-en-v1.5" => {
34 (400, 80) }
36
37 _ => (1024, 200), }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct StrideInfo {
45 pub original_chunk_id: String,
47 pub stride_index: usize,
49 pub total_strides: usize,
51 pub overlap_start: usize,
53 pub overlap_end: usize,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, Default)]
58pub struct ChunkMetadata {
59 pub ancestry: Vec<String>,
60 pub breadcrumb: Option<String>,
61 pub leading_trivia: Vec<String>,
62 pub trailing_trivia: Vec<String>,
63 pub byte_length: usize,
64 pub estimated_tokens: usize,
65}
66
67impl ChunkMetadata {
68 fn from_context(
69 text: &str,
70 ancestry: Vec<String>,
71 leading_trivia: Vec<String>,
72 trailing_trivia: Vec<String>,
73 ) -> Self {
74 let breadcrumb = if ancestry.is_empty() {
75 None
76 } else {
77 Some(ancestry.join("::"))
78 };
79
80 Self {
81 ancestry,
82 breadcrumb,
83 leading_trivia,
84 trailing_trivia,
85 byte_length: text.len(),
86 estimated_tokens: estimate_tokens(text),
87 }
88 }
89
90 fn from_text(text: &str) -> Self {
91 Self {
92 ancestry: Vec::new(),
93 breadcrumb: None,
94 leading_trivia: Vec::new(),
95 trailing_trivia: Vec::new(),
96 byte_length: text.len(),
97 estimated_tokens: estimate_tokens(text),
98 }
99 }
100
101 fn with_updated_text(&self, text: &str) -> Self {
102 let mut cloned = self.clone();
103 cloned.byte_length = text.len();
104 cloned.estimated_tokens = estimate_tokens(text);
105 cloned
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct Chunk {
111 pub span: Span,
112 pub text: String,
113 pub chunk_type: ChunkType,
114 pub stride_info: Option<StrideInfo>,
116 pub metadata: ChunkMetadata,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
120pub enum ChunkType {
121 Text,
122 Function,
123 Class,
124 Method,
125 Module,
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
129pub enum ParseableLanguage {
130 Python,
131 TypeScript,
132 JavaScript,
133 Haskell,
134 Rust,
135 Ruby,
136 Go,
137 CSharp,
138 Zig,
139}
140
141impl std::fmt::Display for ParseableLanguage {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 let name = match self {
144 ParseableLanguage::Python => "python",
145 ParseableLanguage::TypeScript => "typescript",
146 ParseableLanguage::JavaScript => "javascript",
147 ParseableLanguage::Haskell => "haskell",
148 ParseableLanguage::Rust => "rust",
149 ParseableLanguage::Ruby => "ruby",
150 ParseableLanguage::Go => "go",
151 ParseableLanguage::CSharp => "csharp",
152 ParseableLanguage::Zig => "zig",
153 };
154 write!(f, "{}", name)
155 }
156}
157
158impl TryFrom<ck_core::Language> for ParseableLanguage {
159 type Error = anyhow::Error;
160
161 fn try_from(lang: ck_core::Language) -> Result<Self, Self::Error> {
162 match lang {
163 ck_core::Language::Python => Ok(ParseableLanguage::Python),
164 ck_core::Language::TypeScript => Ok(ParseableLanguage::TypeScript),
165 ck_core::Language::JavaScript => Ok(ParseableLanguage::JavaScript),
166 ck_core::Language::Haskell => Ok(ParseableLanguage::Haskell),
167 ck_core::Language::Rust => Ok(ParseableLanguage::Rust),
168 ck_core::Language::Ruby => Ok(ParseableLanguage::Ruby),
169 ck_core::Language::Go => Ok(ParseableLanguage::Go),
170 ck_core::Language::CSharp => Ok(ParseableLanguage::CSharp),
171 ck_core::Language::Zig => Ok(ParseableLanguage::Zig),
172 _ => Err(anyhow::anyhow!(
173 "Language {:?} is not supported for parsing",
174 lang
175 )),
176 }
177 }
178}
179
180pub fn chunk_text(text: &str, language: Option<ck_core::Language>) -> Result<Vec<Chunk>> {
181 chunk_text_with_config(text, language, &ChunkConfig::default())
182}
183
184#[derive(Debug, Clone)]
186pub struct ChunkConfig {
187 pub max_tokens: usize,
189 pub stride_overlap: usize,
191 pub enable_striding: bool,
193}
194
195impl Default for ChunkConfig {
196 fn default() -> Self {
197 Self {
198 max_tokens: 8192, stride_overlap: 1024, enable_striding: true,
201 }
202 }
203}
204
205pub fn chunk_text_with_model(
207 text: &str,
208 language: Option<ck_core::Language>,
209 model_name: Option<&str>,
210) -> Result<Vec<Chunk>> {
211 let (target_tokens, overlap_tokens) = get_model_chunk_config(model_name);
212
213 let config = ChunkConfig {
215 max_tokens: target_tokens,
216 stride_overlap: overlap_tokens,
217 enable_striding: true,
218 };
219
220 chunk_text_with_config_and_model(text, language, &config, model_name)
221}
222
223pub fn chunk_text_with_config(
224 text: &str,
225 language: Option<ck_core::Language>,
226 config: &ChunkConfig,
227) -> Result<Vec<Chunk>> {
228 chunk_text_with_config_and_model(text, language, config, None)
229}
230
231fn chunk_text_with_config_and_model(
232 text: &str,
233 language: Option<ck_core::Language>,
234 config: &ChunkConfig,
235 model_name: Option<&str>,
236) -> Result<Vec<Chunk>> {
237 tracing::debug!(
238 "Chunking text with language: {:?}, length: {} chars, config: {:?}",
239 language,
240 text.len(),
241 config
242 );
243
244 let result = match language.map(ParseableLanguage::try_from) {
245 Some(Ok(lang)) => {
246 tracing::debug!("Using {} tree-sitter parser", lang);
247 chunk_language_with_model(text, lang, model_name)
248 }
249 Some(Err(_)) => {
250 tracing::debug!("Language not supported for parsing, using generic chunking strategy");
251 chunk_generic_with_token_config(text, model_name)
252 }
253 None => {
254 tracing::debug!("Using generic chunking strategy");
255 chunk_generic_with_token_config(text, model_name)
256 }
257 };
258
259 let mut chunks = result?;
260
261 if config.enable_striding {
263 chunks = apply_striding(chunks, config)?;
264 }
265
266 tracing::debug!("Successfully created {} final chunks", chunks.len());
267 Ok(chunks)
268}
269
270fn chunk_generic(text: &str) -> Result<Vec<Chunk>> {
271 chunk_generic_with_token_config(text, None)
272}
273
274fn chunk_generic_with_token_config(text: &str, model_name: Option<&str>) -> Result<Vec<Chunk>> {
275 let mut chunks = Vec::new();
276 let lines: Vec<&str> = text.lines().collect();
277
278 let (target_tokens, overlap_tokens) = get_model_chunk_config(model_name);
280
281 let avg_tokens_per_line = 10.0; let target_lines = ((target_tokens as f32) / avg_tokens_per_line) as usize;
285 let overlap_lines = ((overlap_tokens as f32) / avg_tokens_per_line) as usize;
286
287 let chunk_size = target_lines.max(5); let overlap = overlap_lines.max(1); let mut line_byte_offsets = Vec::with_capacity(lines.len() + 1);
292 line_byte_offsets.push(0);
293 let mut cumulative_offset = 0;
294 let mut byte_pos = 0;
295
296 for line in lines.iter() {
297 cumulative_offset += line.len();
298
299 let line_end_pos = byte_pos + line.len();
301 let newline_len = if line_end_pos < text.len() && text.as_bytes()[line_end_pos] == b'\r' {
302 if line_end_pos + 1 < text.len() && text.as_bytes()[line_end_pos + 1] == b'\n' {
303 2 } else {
305 1 }
307 } else if line_end_pos < text.len() && text.as_bytes()[line_end_pos] == b'\n' {
308 1 } else {
310 0 };
312
313 cumulative_offset += newline_len;
314 byte_pos = cumulative_offset;
315 line_byte_offsets.push(cumulative_offset);
316 }
317
318 let mut i = 0;
319 while i < lines.len() {
320 let end = (i + chunk_size).min(lines.len());
321 let chunk_lines = &lines[i..end];
322 let chunk_text = chunk_lines.join("\n");
323 let byte_start = line_byte_offsets[i];
324 let byte_end = line_byte_offsets[end];
325 let metadata = ChunkMetadata::from_text(&chunk_text);
326
327 chunks.push(Chunk {
328 span: Span {
329 byte_start,
330 byte_end,
331 line_start: i + 1,
332 line_end: end,
333 },
334 text: chunk_text,
335 chunk_type: ChunkType::Text,
336 stride_info: None,
337 metadata,
338 });
339
340 i += chunk_size - overlap;
341 if i >= lines.len() {
342 break;
343 }
344 }
345
346 Ok(chunks)
347}
348
349pub(crate) fn tree_sitter_language(language: ParseableLanguage) -> Result<tree_sitter::Language> {
350 let ts_language = match language {
351 ParseableLanguage::Python => tree_sitter_python::LANGUAGE,
352 ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => {
353 tree_sitter_typescript::LANGUAGE_TYPESCRIPT
354 }
355 ParseableLanguage::Haskell => tree_sitter_haskell::LANGUAGE,
356 ParseableLanguage::Rust => tree_sitter_rust::LANGUAGE,
357 ParseableLanguage::Ruby => tree_sitter_ruby::LANGUAGE,
358 ParseableLanguage::Go => tree_sitter_go::LANGUAGE,
359 ParseableLanguage::CSharp => tree_sitter_c_sharp::LANGUAGE,
360 ParseableLanguage::Zig => tree_sitter_zig::LANGUAGE,
361 };
362
363 Ok(ts_language.into())
364}
365
366fn chunk_language(text: &str, language: ParseableLanguage) -> Result<Vec<Chunk>> {
367 let mut parser = tree_sitter::Parser::new();
368 let ts_language = tree_sitter_language(language)?;
369 parser.set_language(&ts_language)?;
370
371 let tree = parser
372 .parse(text, None)
373 .ok_or_else(|| anyhow::anyhow!("Failed to parse {} code", language))?;
374
375 let mut chunks = match query_chunker::chunk_with_queries(language, ts_language, &tree, text)? {
376 Some(query_chunks) if !query_chunks.is_empty() => query_chunks,
377 _ => {
378 let mut legacy_chunks = Vec::new();
379 let mut cursor = tree.walk();
380 extract_code_chunks(&mut cursor, text, &mut legacy_chunks, language);
381 legacy_chunks
382 }
383 };
384
385 if chunks.is_empty() {
386 return chunk_generic(text);
387 }
388
389 if language == ParseableLanguage::Haskell {
391 chunks = merge_haskell_functions(chunks, text);
392 }
393
394 chunks = fill_gaps(chunks, text);
396
397 Ok(chunks)
398}
399
400fn fill_gaps(mut chunks: Vec<Chunk>, text: &str) -> Vec<Chunk> {
404 if chunks.is_empty() {
405 return chunks;
406 }
407
408 chunks.sort_by_key(|c| c.span.byte_start);
410
411 let mut result = Vec::new();
412 let mut last_end = 0;
413
414 let mut gaps = Vec::new();
416
417 for chunk in &chunks {
418 if last_end < chunk.span.byte_start {
419 let gap_start = last_end;
421 let gap_text = &text[gap_start..chunk.span.byte_start];
422
423 let mut current_byte = gap_start;
425 let mut segment_start = gap_start;
426
427 for line in gap_text.split('\n') {
428 let line_start_in_gap = current_byte - gap_start;
429 let _line_end_in_gap = line_start_in_gap + line.len();
430
431 if line.trim().is_empty() {
432 if segment_start < current_byte {
434 let segment_text = &text[segment_start..current_byte];
435 if !segment_text.trim().is_empty() {
436 gaps.push((segment_start, current_byte));
437 }
438 }
439 segment_start = current_byte + line.len() + 1;
441 }
442
443 current_byte += line.len() + 1; }
445
446 if segment_start < chunk.span.byte_start {
448 let remaining = &text[segment_start..chunk.span.byte_start];
449 if !remaining.trim().is_empty() {
450 gaps.push((segment_start, chunk.span.byte_start));
451 }
452 }
453 }
454 last_end = chunk.span.byte_end;
455 }
456
457 if last_end < text.len() {
459 let gap_text = &text[last_end..];
460 if !gap_text.trim().is_empty() {
461 gaps.push((last_end, text.len()));
462 }
463 }
464
465 let combined_gaps = gaps;
466
467 let mut gap_idx = 0;
469
470 for chunk in chunks {
471 while gap_idx < combined_gaps.len() && combined_gaps[gap_idx].1 <= chunk.span.byte_start {
473 let (gap_start, gap_end) = combined_gaps[gap_idx];
474 let gap_text = &text[gap_start..gap_end];
475
476 let line_start = text[..gap_start].matches('\n').count() + 1;
478 let newlines_up_to_end = text[..gap_end].matches('\n').count();
481 let line_end = if newlines_up_to_end >= line_start - 1 {
482 newlines_up_to_end.max(line_start)
483 } else {
484 line_start
485 };
486
487 let gap_chunk = Chunk {
488 text: gap_text.to_string(),
489 span: Span {
490 byte_start: gap_start,
491 byte_end: gap_end,
492 line_start,
493 line_end,
494 },
495 chunk_type: ChunkType::Text,
496 metadata: ChunkMetadata::from_text(gap_text),
497 stride_info: None,
498 };
499 result.push(gap_chunk);
500 gap_idx += 1;
501 }
502
503 result.push(chunk.clone());
504 }
505
506 while gap_idx < combined_gaps.len() {
508 let (gap_start, gap_end) = combined_gaps[gap_idx];
509 let gap_text = &text[gap_start..gap_end];
510
511 let line_start = text[..gap_start].matches('\n').count() + 1;
513 let newlines_up_to_end = text[..gap_end].matches('\n').count();
515 let line_end = if newlines_up_to_end >= line_start - 1 {
516 newlines_up_to_end.max(line_start)
517 } else {
518 line_start
519 };
520
521 let gap_chunk = Chunk {
522 text: gap_text.to_string(),
523 span: Span {
524 byte_start: gap_start,
525 byte_end: gap_end,
526 line_start,
527 line_end,
528 },
529 chunk_type: ChunkType::Text,
530 metadata: ChunkMetadata::from_text(gap_text),
531 stride_info: None,
532 };
533 result.push(gap_chunk);
534 gap_idx += 1;
535 }
536
537 result
538}
539
540fn merge_haskell_functions(chunks: Vec<Chunk>, source: &str) -> Vec<Chunk> {
542 if chunks.is_empty() {
543 return chunks;
544 }
545
546 let mut merged = Vec::new();
547 let mut i = 0;
548
549 while i < chunks.len() {
550 let chunk = &chunks[i];
551
552 let trimmed = chunk.text.trim();
554 if trimmed.is_empty()
555 || trimmed.starts_with("--")
556 || trimmed.starts_with("{-")
557 || !chunk.text.contains(|c: char| c.is_alphanumeric())
558 {
559 i += 1;
560 continue;
561 }
562
563 let is_signature = chunk.text.contains("::");
566 let function_name = if is_signature {
567 chunk
569 .text
570 .split("::")
571 .next()
572 .and_then(|s| s.split_whitespace().next())
573 .map(|s| s.to_string())
574 } else {
575 extract_haskell_function_name(&chunk.text)
576 };
577
578 if function_name.is_none() {
579 merged.push(chunk.clone());
581 i += 1;
582 continue;
583 }
584
585 let name = function_name.unwrap();
586 let group_start = chunk.span.byte_start;
587 let mut group_end = chunk.span.byte_end;
588 let line_start = chunk.span.line_start;
589 let mut line_end = chunk.span.line_end;
590 let mut trailing_trivia = chunk.metadata.trailing_trivia.clone();
591
592 let mut j = i + 1;
594 while j < chunks.len() {
595 let next_chunk = &chunks[j];
596
597 let next_trimmed = next_chunk.text.trim();
599 if next_trimmed.starts_with("--") || next_trimmed.starts_with("{-") {
600 j += 1;
601 continue;
602 }
603
604 let next_is_signature = next_chunk.text.contains("::");
605 let next_name = if next_is_signature {
606 next_chunk
607 .text
608 .split("::")
609 .next()
610 .and_then(|s| s.split_whitespace().next())
611 .map(|s| s.to_string())
612 } else {
613 extract_haskell_function_name(&next_chunk.text)
614 };
615
616 if next_name == Some(name.clone()) {
617 group_end = next_chunk.span.byte_end;
619 line_end = next_chunk.span.line_end;
620 trailing_trivia = next_chunk.metadata.trailing_trivia.clone();
621 j += 1;
622 } else {
623 break;
624 }
625 }
626
627 let merged_text = source.get(group_start..group_end).unwrap_or("").to_string();
629 let mut metadata = chunk.metadata.with_updated_text(&merged_text);
630 metadata.trailing_trivia = trailing_trivia;
631
632 merged.push(Chunk {
633 span: Span {
634 byte_start: group_start,
635 byte_end: group_end,
636 line_start,
637 line_end,
638 },
639 text: merged_text,
640 chunk_type: ChunkType::Function,
641 stride_info: None,
642 metadata,
643 });
644
645 i = j; }
647
648 merged
649}
650
651fn extract_haskell_function_name(text: &str) -> Option<String> {
653 let trimmed = text.trim();
656
657 let first_word = trimmed
659 .split_whitespace()
660 .next()?
661 .trim_end_matches(|c: char| !c.is_alphanumeric() && c != '_' && c != '\'');
662
663 if first_word.is_empty() {
665 return None;
666 }
667
668 let first_char = first_word.chars().next()?;
669 if first_char.is_lowercase() || first_char == '_' {
670 Some(first_word.to_string())
671 } else {
672 None
673 }
674}
675
676fn chunk_language_with_model(
677 text: &str,
678 language: ParseableLanguage,
679 _model_name: Option<&str>,
680) -> Result<Vec<Chunk>> {
681 chunk_language(text, language)
685}
686
687fn extract_code_chunks(
688 cursor: &mut tree_sitter::TreeCursor,
689 source: &str,
690 chunks: &mut Vec<Chunk>,
691 language: ParseableLanguage,
692) {
693 let node = cursor.node();
694
695 let should_skip = if language == ParseableLanguage::Haskell && node.kind() == "function" {
698 let mut current = node.parent();
700 while let Some(parent) = current {
701 if parent.kind() == "signature" {
702 return; }
704 current = parent.parent();
705 }
706 false
707 } else {
708 false
709 };
710
711 if !should_skip
712 && let Some(initial_chunk_type) = chunk_type_for_node(language, &node)
713 && let Some(chunk) = build_chunk(node, source, initial_chunk_type, language)
714 {
715 let is_duplicate = chunks.iter().any(|existing| {
716 existing.span.byte_start == chunk.span.byte_start
717 && existing.span.byte_end == chunk.span.byte_end
718 });
719
720 if !is_duplicate {
721 chunks.push(chunk);
722 }
723 }
724
725 let should_recurse = !(language == ParseableLanguage::Haskell && node.kind() == "signature");
727
728 if should_recurse && cursor.goto_first_child() {
729 loop {
730 extract_code_chunks(cursor, source, chunks, language);
731 if !cursor.goto_next_sibling() {
732 break;
733 }
734 }
735 cursor.goto_parent();
736 }
737}
738
739fn chunk_type_for_node(
740 language: ParseableLanguage,
741 node: &tree_sitter::Node<'_>,
742) -> Option<ChunkType> {
743 let kind = node.kind();
744
745 let supported = match language {
746 ParseableLanguage::Python => matches!(kind, "function_definition" | "class_definition"),
747 ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => matches!(
748 kind,
749 "function_declaration" | "class_declaration" | "method_definition" | "arrow_function"
750 ),
751 ParseableLanguage::Haskell => matches!(
752 kind,
753 "function" | "signature" | "data_type"
756 | "newtype"
757 | "type_synonym"
758 | "type_family"
759 | "class"
760 | "instance"
761 ),
762 ParseableLanguage::Rust => matches!(
763 kind,
764 "function_item" | "impl_item" | "struct_item" | "enum_item" | "trait_item" | "mod_item"
765 ),
766 ParseableLanguage::Ruby => {
767 matches!(kind, "method" | "class" | "module" | "singleton_method")
768 }
769 ParseableLanguage::Go => matches!(
770 kind,
771 "function_declaration"
772 | "method_declaration"
773 | "type_declaration"
774 | "var_declaration"
775 | "const_declaration"
776 ),
777 ParseableLanguage::CSharp => matches!(
778 kind,
779 "method_declaration"
780 | "class_declaration"
781 | "interface_declaration"
782 | "variable_declaration"
783 ),
784 ParseableLanguage::Zig => matches!(
785 kind,
786 "function_declaration"
787 | "test_declaration"
788 | "variable_declaration"
789 | "struct_declaration"
790 | "enum_declaration"
791 | "union_declaration"
792 | "opaque_declaration"
793 | "error_set_declaration"
794 | "comptime_declaration"
795 ),
796 };
797
798 if !supported {
799 return None;
800 }
801
802 match language {
803 ParseableLanguage::Go
804 if matches!(node.kind(), "var_declaration" | "const_declaration")
805 && node.parent().is_some_and(|p| p.kind() == "block") =>
806 {
807 return None;
808 }
809 ParseableLanguage::CSharp if node.kind() == "variable_declaration" => {
810 if !is_csharp_field_like(*node) {
811 return None;
812 }
813 }
814 _ => {}
815 }
816
817 Some(classify_chunk_kind(kind))
818}
819
820fn classify_chunk_kind(kind: &str) -> ChunkType {
821 match kind {
822 "function_definition"
823 | "function_declaration"
824 | "arrow_function"
825 | "function"
826 | "function_item"
827 | "def"
828 | "defp"
829 | "defn"
830 | "defn-"
831 | "method"
832 | "singleton_method" => ChunkType::Function,
833 "signature" => ChunkType::Function, "class_definition"
835 | "class_declaration"
836 | "instance_declaration"
837 | "class"
838 | "instance"
839 | "struct_item"
840 | "enum_item"
841 | "defstruct"
842 | "defrecord"
843 | "deftype"
844 | "type_declaration"
845 | "struct_declaration"
846 | "enum_declaration"
847 | "union_declaration"
848 | "opaque_declaration"
849 | "error_set_declaration" => ChunkType::Class,
850 "method_definition" | "method_declaration" | "defmacro" => ChunkType::Method,
851 "data_type"
852 | "newtype"
853 | "type_synonym"
854 | "type_family"
855 | "impl_item"
856 | "trait_item"
857 | "mod_item"
858 | "defmodule"
859 | "module"
860 | "defprotocol"
861 | "interface_declaration"
862 | "ns"
863 | "var_declaration"
864 | "const_declaration"
865 | "variable_declaration"
866 | "test_declaration"
867 | "comptime_declaration" => ChunkType::Module,
868 _ => ChunkType::Text,
869 }
870}
871
872pub(crate) fn build_chunk(
873 node: tree_sitter::Node<'_>,
874 source: &str,
875 initial_type: ChunkType,
876 language: ParseableLanguage,
877) -> Option<Chunk> {
878 let target_node = adjust_node_for_language(node, language);
879 let (byte_start, start_row, leading_segments) =
880 extend_with_leading_trivia(target_node, language, source);
881 let trailing_segments = collect_trailing_trivia(target_node, language, source);
882
883 let byte_end = target_node.end_byte();
884 let end_pos = target_node.end_position();
885
886 if byte_start >= byte_end || byte_end > source.len() {
887 return None;
888 }
889
890 let text = source.get(byte_start..byte_end)?.to_string();
891
892 if text.trim().is_empty() {
893 return None;
894 }
895
896 let chunk_type = adjust_chunk_type_for_context(target_node, initial_type, language);
897 let ancestry = collect_ancestry(target_node, language, source);
898 let leading_trivia = segments_to_strings(&leading_segments, source);
899 let trailing_trivia = segments_to_strings(&trailing_segments, source);
900 let metadata = ChunkMetadata::from_context(&text, ancestry, leading_trivia, trailing_trivia);
901
902 Some(Chunk {
903 span: Span {
904 byte_start,
905 byte_end,
906 line_start: start_row + 1,
907 line_end: end_pos.row + 1,
908 },
909 text,
910 chunk_type,
911 stride_info: None,
912 metadata,
913 })
914}
915
916fn adjust_node_for_language(
917 node: tree_sitter::Node<'_>,
918 language: ParseableLanguage,
919) -> tree_sitter::Node<'_> {
920 match language {
921 ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => {
922 if node.kind() == "arrow_function" {
923 return expand_arrow_function_context(node);
924 }
925 node
926 }
927 _ => node,
928 }
929}
930
931fn expand_arrow_function_context(mut node: tree_sitter::Node<'_>) -> tree_sitter::Node<'_> {
932 const PARENTS: &[&str] = &[
933 "parenthesized_expression",
934 "variable_declarator",
935 "variable_declaration",
936 "lexical_declaration",
937 "assignment_expression",
938 "expression_statement",
939 "public_field_definition",
940 "export_statement",
941 ];
942
943 while let Some(parent) = node.parent() {
944 let kind = parent.kind();
945 if PARENTS.contains(&kind) {
946 node = parent;
947 continue;
948 }
949 break;
950 }
951
952 node
953}
954
955#[derive(Clone, Copy)]
956struct TriviaSegment {
957 start_byte: usize,
958 end_byte: usize,
959}
960
961fn extend_with_leading_trivia(
962 node: tree_sitter::Node<'_>,
963 language: ParseableLanguage,
964 source: &str,
965) -> (usize, usize, Vec<TriviaSegment>) {
966 let mut start_byte = node.start_byte();
967 let mut start_row = node.start_position().row;
968 let mut current = node;
969 let mut segments = Vec::new();
970
971 while let Some(prev) = current.prev_sibling() {
972 if should_attach_leading_trivia(language, &prev)
973 && only_whitespace_between(source, prev.end_byte(), start_byte)
974 {
975 start_byte = prev.start_byte();
976 start_row = prev.start_position().row;
977 segments.push(TriviaSegment {
978 start_byte: prev.start_byte(),
979 end_byte: prev.end_byte(),
980 });
981 current = prev;
982 continue;
983 }
984 break;
985 }
986
987 segments.reverse();
988 (start_byte, start_row, segments)
989}
990
991fn should_attach_leading_trivia(language: ParseableLanguage, node: &tree_sitter::Node<'_>) -> bool {
992 let kind = node.kind();
993 if kind == "comment" {
994 return true;
995 }
996
997 match language {
998 ParseableLanguage::Rust => kind == "attribute_item",
999 ParseableLanguage::Python => kind == "decorator",
1000 ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => kind == "decorator",
1001 ParseableLanguage::CSharp => matches!(kind, "attribute_list" | "attribute"),
1002 _ => false,
1003 }
1004}
1005
1006fn collect_trailing_trivia(
1007 node: tree_sitter::Node<'_>,
1008 language: ParseableLanguage,
1009 source: &str,
1010) -> Vec<TriviaSegment> {
1011 let mut segments = Vec::new();
1012 let mut current = node;
1013 let mut previous_end = node.end_byte();
1014
1015 while let Some(next) = current.next_sibling() {
1016 if should_attach_trailing_trivia(language, &next)
1017 && only_whitespace_between(source, previous_end, next.start_byte())
1018 {
1019 segments.push(TriviaSegment {
1020 start_byte: next.start_byte(),
1021 end_byte: next.end_byte(),
1022 });
1023 previous_end = next.end_byte();
1024 current = next;
1025 continue;
1026 }
1027 break;
1028 }
1029
1030 segments
1031}
1032
1033fn should_attach_trailing_trivia(
1034 _language: ParseableLanguage,
1035 node: &tree_sitter::Node<'_>,
1036) -> bool {
1037 node.kind() == "comment"
1038}
1039
1040fn segments_to_strings(segments: &[TriviaSegment], source: &str) -> Vec<String> {
1041 let mut result = Vec::new();
1042
1043 for segment in segments {
1044 if let Some(text) = source
1045 .get(segment.start_byte..segment.end_byte)
1046 .map(|s| s.to_string())
1047 {
1048 result.push(text);
1049 }
1050 }
1051
1052 result
1053}
1054
1055fn collect_ancestry(
1056 mut node: tree_sitter::Node<'_>,
1057 language: ParseableLanguage,
1058 source: &str,
1059) -> Vec<String> {
1060 let mut parts = Vec::new();
1061
1062 while let Some(parent) = node.parent() {
1063 if let Some(parent_chunk_type) = chunk_type_for_node(language, &parent)
1064 && let Some(name) = display_name_for_node(parent, language, source, parent_chunk_type)
1065 {
1066 parts.push(name);
1067 }
1068 node = parent;
1069 }
1070
1071 parts.reverse();
1072 parts
1073}
1074
1075fn display_name_for_node(
1076 node: tree_sitter::Node<'_>,
1077 language: ParseableLanguage,
1078 source: &str,
1079 chunk_type: ChunkType,
1080) -> Option<String> {
1081 if let Some(name_node) = node.child_by_field_name("name") {
1082 return text_for_node(name_node, source);
1083 }
1084
1085 match language {
1086 ParseableLanguage::Rust => rust_display_name(node, source, chunk_type),
1087 ParseableLanguage::Python => find_identifier(node, source, &["identifier"]),
1088 ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => find_identifier(
1089 node,
1090 source,
1091 &["identifier", "type_identifier", "property_identifier"],
1092 ),
1093 ParseableLanguage::Haskell => {
1094 find_identifier(node, source, &["identifier", "type_identifier", "variable"])
1095 .or_else(|| first_word_of_node(node, source))
1096 }
1097 ParseableLanguage::Ruby => find_identifier(node, source, &["identifier"]),
1098 ParseableLanguage::Go => find_identifier(node, source, &["identifier", "type_identifier"]),
1099 ParseableLanguage::CSharp => find_identifier(node, source, &["identifier"]),
1100 ParseableLanguage::Zig => find_identifier(node, source, &["identifier"]),
1101 }
1102}
1103
1104fn rust_display_name(
1105 node: tree_sitter::Node<'_>,
1106 source: &str,
1107 chunk_type: ChunkType,
1108) -> Option<String> {
1109 match node.kind() {
1110 "impl_item" => {
1111 let mut parts = Vec::new();
1112 if let Some(ty) = node.child_by_field_name("type")
1113 && let Some(text) = text_for_node(ty, source)
1114 {
1115 parts.push(text);
1116 }
1117 if let Some(trait_node) = node.child_by_field_name("trait")
1118 && let Some(text) = text_for_node(trait_node, source)
1119 {
1120 if let Some(last) = parts.first() {
1121 parts[0] = format!("{} (impl {})", last, text.trim());
1122 } else {
1123 parts.push(format!("impl {}", text.trim()));
1124 }
1125 }
1126 if parts.is_empty() {
1127 find_identifier(node, source, &["identifier"])
1128 } else {
1129 Some(parts.remove(0))
1130 }
1131 }
1132 "mod_item" if chunk_type == ChunkType::Module => {
1133 find_identifier(node, source, &["identifier"])
1134 }
1135 _ => find_identifier(node, source, &["identifier", "type_identifier"]),
1136 }
1137}
1138
1139fn find_identifier(
1140 node: tree_sitter::Node<'_>,
1141 source: &str,
1142 candidate_kinds: &[&str],
1143) -> Option<String> {
1144 let mut cursor = node.walk();
1145 for child in node.children(&mut cursor) {
1146 if candidate_kinds.contains(&child.kind())
1147 && let Some(text) = text_for_node(child, source)
1148 {
1149 return Some(text.trim().to_string());
1150 }
1151 }
1152 None
1153}
1154
1155fn first_word_of_node(node: tree_sitter::Node<'_>, source: &str) -> Option<String> {
1156 let text = text_for_node(node, source)?;
1157 text.split_whitespace().next().map(|s| {
1158 s.trim_end_matches(|c: char| !c.is_alphanumeric() && c != '_')
1159 .to_string()
1160 })
1161}
1162
1163fn text_for_node(node: tree_sitter::Node<'_>, source: &str) -> Option<String> {
1164 node.utf8_text(source.as_bytes())
1165 .ok()
1166 .map(|s| s.to_string())
1167}
1168
1169fn only_whitespace_between(source: &str, start: usize, end: usize) -> bool {
1170 if start >= end || end > source.len() {
1171 return true;
1172 }
1173
1174 source[start..end].chars().all(|c| c.is_whitespace())
1175}
1176
1177fn adjust_chunk_type_for_context(
1178 node: tree_sitter::Node<'_>,
1179 chunk_type: ChunkType,
1180 language: ParseableLanguage,
1181) -> ChunkType {
1182 if chunk_type != ChunkType::Function {
1183 return chunk_type;
1184 }
1185
1186 if is_method_context(node, language) {
1187 ChunkType::Method
1188 } else {
1189 chunk_type
1190 }
1191}
1192
1193fn is_method_context(node: tree_sitter::Node<'_>, language: ParseableLanguage) -> bool {
1194 const PYTHON_CONTAINERS: &[&str] = &["class_definition"];
1195 const TYPESCRIPT_CONTAINERS: &[&str] = &["class_body", "class_declaration"];
1196 const RUBY_CONTAINERS: &[&str] = &["class", "module"];
1197 const RUST_CONTAINERS: &[&str] = &["impl_item", "trait_item"];
1198
1199 match language {
1200 ParseableLanguage::Python => ancestor_has_kind(node, PYTHON_CONTAINERS),
1201 ParseableLanguage::TypeScript | ParseableLanguage::JavaScript => {
1202 ancestor_has_kind(node, TYPESCRIPT_CONTAINERS)
1203 }
1204 ParseableLanguage::Ruby => ancestor_has_kind(node, RUBY_CONTAINERS),
1205 ParseableLanguage::Rust => ancestor_has_kind(node, RUST_CONTAINERS),
1206 ParseableLanguage::Go => false,
1207 ParseableLanguage::CSharp => false,
1208 ParseableLanguage::Haskell => false,
1209 ParseableLanguage::Zig => false,
1210 }
1211}
1212
1213fn ancestor_has_kind(node: tree_sitter::Node<'_>, kinds: &[&str]) -> bool {
1214 let mut current = node;
1215 while let Some(parent) = current.parent() {
1216 if kinds.contains(&parent.kind()) {
1217 return true;
1218 }
1219 current = parent;
1220 }
1221 false
1222}
1223
1224fn is_csharp_field_like(node: tree_sitter::Node<'_>) -> bool {
1225 if let Some(parent) = node.parent() {
1226 return matches!(
1227 parent.kind(),
1228 "field_declaration" | "event_field_declaration"
1229 );
1230 }
1231 false
1232}
1233
1234fn apply_striding(chunks: Vec<Chunk>, config: &ChunkConfig) -> Result<Vec<Chunk>> {
1236 let mut result = Vec::new();
1237
1238 for chunk in chunks {
1239 let estimated_tokens = estimate_tokens(&chunk.text);
1240
1241 if estimated_tokens <= config.max_tokens {
1242 result.push(chunk);
1244 } else {
1245 tracing::debug!(
1247 "Chunk with {} tokens exceeds limit of {}, applying striding",
1248 estimated_tokens,
1249 config.max_tokens
1250 );
1251
1252 let strided_chunks = stride_large_chunk(chunk, config)?;
1253 result.extend(strided_chunks);
1254 }
1255 }
1256
1257 Ok(result)
1258}
1259
1260fn stride_large_chunk(chunk: Chunk, config: &ChunkConfig) -> Result<Vec<Chunk>> {
1262 let text = &chunk.text;
1263
1264 if text.is_empty() {
1266 return Ok(vec![chunk]);
1267 }
1268
1269 let char_count = text.chars().count();
1272 let estimated_tokens = estimate_tokens(text);
1273 let chars_per_token = if estimated_tokens == 0 {
1275 4.5 } else {
1277 char_count as f32 / estimated_tokens as f32
1278 };
1279 let window_chars = ((config.max_tokens as f32 * 0.9) * chars_per_token) as usize; let overlap_chars = (config.stride_overlap as f32 * chars_per_token) as usize;
1281 let stride_chars = window_chars.saturating_sub(overlap_chars);
1282
1283 if stride_chars == 0 {
1284 return Err(anyhow::anyhow!("Stride size is too small"));
1285 }
1286
1287 let char_byte_indices: Vec<(usize, char)> = text.char_indices().collect();
1289 let mut strided_chunks = Vec::new();
1292 let original_chunk_id = format!("{}:{}", chunk.span.byte_start, chunk.span.byte_end);
1293 let mut start_char_idx = 0;
1294 let mut stride_index = 0;
1295
1296 let total_strides = if char_count <= window_chars {
1298 1
1299 } else {
1300 ((char_count - overlap_chars) as f32 / stride_chars as f32).ceil() as usize
1301 };
1302
1303 while start_char_idx < char_count {
1304 let end_char_idx = (start_char_idx + window_chars).min(char_count);
1305
1306 let start_byte_pos = char_byte_indices[start_char_idx].0;
1308 let end_byte_pos = if end_char_idx < char_count {
1309 char_byte_indices[end_char_idx].0
1310 } else {
1311 text.len()
1312 };
1313
1314 let stride_text = &text[start_byte_pos..end_byte_pos];
1315
1316 let overlap_start = if stride_index > 0 { overlap_chars } else { 0 };
1318 let overlap_end = if end_char_idx < char_count {
1319 overlap_chars
1320 } else {
1321 0
1322 };
1323
1324 let byte_offset_start = chunk.span.byte_start + start_byte_pos;
1326 let byte_offset_end = chunk.span.byte_start + end_byte_pos;
1327
1328 let text_before_start = &text[..start_byte_pos];
1330 let line_offset_start = text_before_start.lines().count().saturating_sub(1);
1331 let stride_lines = stride_text.lines().count();
1332 let metadata = chunk.metadata.with_updated_text(stride_text);
1333
1334 let stride_chunk = Chunk {
1335 span: Span {
1336 byte_start: byte_offset_start,
1337 byte_end: byte_offset_end,
1338 line_start: chunk.span.line_start + line_offset_start,
1339 line_end: chunk.span.line_start
1341 + line_offset_start
1342 + stride_lines.saturating_sub(1),
1343 },
1344 text: stride_text.to_string(),
1345 chunk_type: chunk.chunk_type.clone(),
1346 stride_info: Some(StrideInfo {
1347 original_chunk_id: original_chunk_id.clone(),
1348 stride_index,
1349 total_strides,
1350 overlap_start,
1351 overlap_end,
1352 }),
1353 metadata,
1354 };
1355
1356 strided_chunks.push(stride_chunk);
1357
1358 if end_char_idx >= char_count {
1360 break;
1361 }
1362
1363 start_char_idx += stride_chars;
1364 stride_index += 1;
1365 }
1366
1367 tracing::debug!(
1368 "Created {} strides from chunk of {} tokens",
1369 strided_chunks.len(),
1370 estimate_tokens(text)
1371 );
1372
1373 Ok(strided_chunks)
1374}
1375
1376#[cfg(test)]
1379mod tests {
1380 use super::*;
1381
1382 fn canonicalize_spans(
1383 mut spans: Vec<(usize, usize, ChunkType)>,
1384 ) -> Vec<(usize, usize, ChunkType)> {
1385 fn chunk_type_order(chunk_type: &ChunkType) -> u8 {
1386 match chunk_type {
1387 ChunkType::Text => 0,
1388 ChunkType::Function => 1,
1389 ChunkType::Class => 2,
1390 ChunkType::Method => 3,
1391 ChunkType::Module => 4,
1392 }
1393 }
1394
1395 spans.sort_by(|a, b| {
1396 let order_a = chunk_type_order(&a.2);
1397 let order_b = chunk_type_order(&b.2);
1398 order_a
1399 .cmp(&order_b)
1400 .then_with(|| a.0.cmp(&b.0))
1401 .then_with(|| a.1.cmp(&b.1))
1402 });
1403
1404 let mut result: Vec<(usize, usize, ChunkType)> = Vec::new();
1405 for (start, end, ty) in spans {
1406 if let Some(last) = result.last_mut()
1407 && last.0 == start
1408 && last.2 == ty
1409 {
1410 if end > last.1 {
1411 last.1 = end;
1412 }
1413 continue;
1414 }
1415 result.push((start, end, ty));
1416 }
1417
1418 result
1419 }
1420
1421 fn assert_query_parity(language: ParseableLanguage, source: &str) {
1422 let mut parser = tree_sitter::Parser::new();
1423 let ts_language = tree_sitter_language(language).expect("language");
1424 parser.set_language(&ts_language).expect("set language");
1425 let tree = parser.parse(source, None).expect("parse source");
1426
1427 let query_chunks = query_chunker::chunk_with_queries(language, ts_language, &tree, source)
1428 .expect("query execution")
1429 .expect("queries available");
1430
1431 let mut legacy_chunks = Vec::new();
1432 let mut cursor = tree.walk();
1433 extract_code_chunks(&mut cursor, source, &mut legacy_chunks, language);
1434
1435 let query_spans = canonicalize_spans(
1436 query_chunks
1437 .iter()
1438 .map(|chunk| {
1439 (
1440 chunk.span.byte_start,
1441 chunk.span.byte_end,
1442 chunk.chunk_type.clone(),
1443 )
1444 })
1445 .collect(),
1446 );
1447 let legacy_spans = canonicalize_spans(
1448 legacy_chunks
1449 .iter()
1450 .map(|chunk| {
1451 (
1452 chunk.span.byte_start,
1453 chunk.span.byte_end,
1454 chunk.chunk_type.clone(),
1455 )
1456 })
1457 .collect(),
1458 );
1459
1460 assert_eq!(query_spans, legacy_spans);
1461 }
1462
1463 #[test]
1464 fn test_chunk_generic_byte_offsets() {
1465 let text = "line 1\nline 2\nline 3\nline 4\nline 5";
1467 let chunks = chunk_generic(text).unwrap();
1468
1469 assert!(!chunks.is_empty());
1470
1471 assert_eq!(chunks[0].span.byte_start, 0);
1473
1474 for chunk in &chunks {
1476 let expected_len = chunk.text.len();
1477 let actual_len = chunk.span.byte_end - chunk.span.byte_start;
1478 assert_eq!(actual_len, expected_len);
1479 }
1480 }
1481
1482 #[test]
1483 fn test_chunk_generic_large_file_performance() {
1484 let lines: Vec<String> = (0..1000)
1486 .map(|i| format!("Line {}: Some content here", i))
1487 .collect();
1488 let text = lines.join("\n");
1489
1490 let start = std::time::Instant::now();
1491 let chunks = chunk_generic(&text).unwrap();
1492 let duration = start.elapsed();
1493
1494 assert!(
1496 duration.as_millis() < 100,
1497 "Chunking took too long: {:?}",
1498 duration
1499 );
1500 assert!(!chunks.is_empty());
1501
1502 for chunk in &chunks {
1504 assert!(chunk.span.line_start > 0);
1505 assert!(chunk.span.line_end >= chunk.span.line_start);
1506 }
1507 }
1508
1509 #[test]
1510 fn test_chunk_rust() {
1511 let rust_code = r#"
1512pub struct Calculator {
1513 memory: f64,
1514}
1515
1516impl Calculator {
1517 pub fn new() -> Self {
1518 Calculator { memory: 0.0 }
1519 }
1520
1521 pub fn add(&mut self, a: f64, b: f64) -> f64 {
1522 a + b
1523 }
1524}
1525
1526fn main() {
1527 let calc = Calculator::new();
1528}
1529
1530pub mod utils {
1531 pub fn helper() {}
1532}
1533"#;
1534
1535 let chunks = chunk_language(rust_code, ParseableLanguage::Rust).unwrap();
1536 assert!(!chunks.is_empty());
1537
1538 let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
1540 assert!(chunk_types.contains(&&ChunkType::Class)); assert!(chunk_types.contains(&&ChunkType::Module)); assert!(chunk_types.contains(&&ChunkType::Function)); }
1544
1545 #[test]
1546 fn test_rust_query_matches_legacy() {
1547 let source = r#"
1548 mod sample {
1549 struct Thing;
1550
1551 impl Thing {
1552 fn new() -> Self { Self }
1553 fn helper(&self) {}
1554 }
1555 }
1556
1557 fn util() {}
1558 "#;
1559
1560 assert_query_parity(ParseableLanguage::Rust, source);
1561 }
1562
1563 #[test]
1564 fn test_python_query_matches_legacy() {
1565 let source = r#"
1566class Example:
1567 @classmethod
1568 def build(cls):
1569 return cls()
1570
1571
1572def helper():
1573 return 1
1574
1575
1576async def async_helper():
1577 return 2
1578"#;
1579
1580 assert_query_parity(ParseableLanguage::Python, source);
1581 }
1582
1583 #[test]
1584 fn test_chunk_ruby() {
1585 let ruby_code = r#"
1586class Calculator
1587 def initialize
1588 @memory = 0.0
1589 end
1590
1591 def add(a, b)
1592 a + b
1593 end
1594
1595 def self.class_method
1596 "class method"
1597 end
1598
1599 private
1600
1601 def private_method
1602 "private"
1603 end
1604end
1605
1606module Utils
1607 def self.helper
1608 "helper"
1609 end
1610end
1611
1612def main
1613 calc = Calculator.new
1614end
1615"#;
1616
1617 let chunks = chunk_language(ruby_code, ParseableLanguage::Ruby).unwrap();
1618 assert!(!chunks.is_empty());
1619
1620 let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
1622 assert!(chunk_types.contains(&&ChunkType::Class)); assert!(chunk_types.contains(&&ChunkType::Module)); assert!(chunk_types.contains(&&ChunkType::Function)); }
1626
1627 #[test]
1628 fn test_language_detection_fallback() {
1629 let generic_text = "Some text\nwith multiple lines\nto chunk generically";
1631
1632 let chunks_unknown = chunk_text(generic_text, None).unwrap();
1633 let chunks_generic = chunk_generic(generic_text).unwrap();
1634
1635 assert_eq!(chunks_unknown.len(), chunks_generic.len());
1637 assert_eq!(chunks_unknown[0].text, chunks_generic[0].text);
1638 }
1639
1640 #[test]
1641 fn test_chunk_go() {
1642 let go_code = r#"
1643package main
1644
1645import "fmt"
1646
1647const Pi = 3.14159
1648
1649var memory float64
1650
1651type Calculator struct {
1652 memory float64
1653}
1654
1655type Operation interface {
1656 Calculate(a, b float64) float64
1657}
1658
1659func NewCalculator() *Calculator {
1660 return &Calculator{memory: 0.0}
1661}
1662
1663func (c *Calculator) Add(a, b float64) float64 {
1664 return a + b
1665}
1666
1667func main() {
1668 calc := NewCalculator()
1669}
1670"#;
1671
1672 let chunks = chunk_language(go_code, ParseableLanguage::Go).unwrap();
1673 assert!(!chunks.is_empty());
1674
1675 let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
1677 assert!(chunk_types.contains(&&ChunkType::Module)); assert!(chunk_types.contains(&&ChunkType::Class)); assert!(chunk_types.contains(&&ChunkType::Function)); assert!(chunk_types.contains(&&ChunkType::Method)); }
1682
1683 #[test]
1684 #[ignore] fn test_chunk_typescript_arrow_context() {
1686 let ts_code = r#"
1687// Utility function
1688export const util = () => {
1689 // comment about util
1690 return 42;
1691};
1692
1693export class Example {
1694 // leading comment for method
1695 constructor() {}
1696
1697 // Another comment
1698 run = () => {
1699 return util();
1700 };
1701}
1702
1703const compute = (x: number) => x * 2;
1704"#;
1705
1706 let chunks = chunk_language(ts_code, ParseableLanguage::TypeScript).unwrap();
1707
1708 let util_chunk = chunks
1709 .iter()
1710 .find(|chunk| chunk.text.contains("export const util"))
1711 .expect("Expected chunk for util arrow function");
1712 assert_eq!(util_chunk.chunk_type, ChunkType::Function);
1713 assert!(
1714 util_chunk.text.contains("// Utility function"),
1715 "expected leading comment to be included"
1716 );
1717 assert!(util_chunk.text.contains("export const util ="));
1718
1719 let method_chunk = chunks
1721 .iter()
1722 .find(|chunk| {
1723 chunk.chunk_type == ChunkType::Method && chunk.text.contains("run = () =>")
1724 })
1725 .expect("Expected chunk for class field arrow function");
1726
1727 assert_eq!(method_chunk.chunk_type, ChunkType::Method);
1728 assert!(
1729 method_chunk.text.contains("// Another comment"),
1730 "expected inline comment to be included"
1731 );
1732
1733 let compute_chunk = chunks
1734 .iter()
1735 .find(|chunk| chunk.text.contains("const compute"))
1736 .expect("Expected chunk for compute arrow function");
1737 assert_eq!(compute_chunk.chunk_type, ChunkType::Function);
1738 assert!(
1739 compute_chunk
1740 .text
1741 .contains("const compute = (x: number) => x * 2;")
1742 );
1743
1744 assert!(
1746 chunks
1747 .iter()
1748 .all(|chunk| !chunk.text.trim_start().starts_with("() =>"))
1749 );
1750 assert!(
1751 chunks
1752 .iter()
1753 .all(|chunk| !chunk.text.trim_start().starts_with("(x: number) =>"))
1754 );
1755 }
1756
1757 #[test]
1761 #[ignore]
1762 fn test_typescript_query_matches_legacy() {
1763 let source = r#"
1764export const util = () => {
1765 return 42;
1766};
1767
1768export class Example {
1769 run = () => {
1770 return util();
1771 };
1772}
1773
1774const compute = (x: number) => x * 2;
1775"#;
1776
1777 assert_query_parity(ParseableLanguage::TypeScript, source);
1778 }
1779
1780 #[test]
1781 fn test_ruby_query_matches_legacy() {
1782 let source = r#"
1783class Calculator
1784 def initialize
1785 @memory = 0.0
1786 end
1787
1788 def add(a, b)
1789 a + b
1790 end
1791
1792 def self.class_method
1793 "class method"
1794 end
1795end
1796"#;
1797
1798 assert_query_parity(ParseableLanguage::Ruby, source);
1799 }
1800
1801 #[test]
1802 fn test_go_query_matches_legacy() {
1803 let source = r#"
1804package main
1805
1806import "fmt"
1807
1808const Pi = 3.14159
1809
1810var memory float64
1811
1812type Calculator struct {
1813 memory float64
1814}
1815
1816func (c *Calculator) Add(a, b float64) float64 {
1817 return a + b
1818}
1819
1820func Helper() {}
1821"#;
1822
1823 assert_query_parity(ParseableLanguage::Go, source);
1824 }
1825
1826 #[test]
1827 fn test_haskell_query_matches_legacy() {
1828 let source = r#"
1829module Example where
1830
1831data Shape
1832 = Circle Float
1833 | Square Float
1834
1835type family Area a
1836
1837class Printable a where
1838 printValue :: a -> String
1839
1840instance Printable Shape where
1841 printValue (Circle _) = "circle"
1842 printValue (Square _) = "square"
1843
1844shapeDescription :: Shape -> String
1845shapeDescription (Circle r) = "circle of radius " ++ show r
1846shapeDescription (Square s) = "square of side " ++ show s
1847"#;
1848
1849 assert_query_parity(ParseableLanguage::Haskell, source);
1850 }
1851
1852 #[test]
1853 fn test_csharp_query_matches_legacy() {
1854 let source = r#"
1855namespace Calculator;
1856
1857public interface ICalculator
1858{
1859 double Add(double x, double y);
1860}
1861
1862public class Calculator
1863{
1864 public static double PI = 3.14159;
1865 private double _memory;
1866
1867 public Calculator()
1868 {
1869 _memory = 0.0;
1870 }
1871
1872 public double Add(double x, double y)
1873 {
1874 return x + y;
1875 }
1876}
1877"#;
1878
1879 assert_query_parity(ParseableLanguage::CSharp, source);
1880 }
1881
1882 #[test]
1883 fn test_zig_query_matches_legacy() {
1884 let source = r#"
1885const std = @import("std");
1886
1887const Calculator = struct {
1888 memory: f64,
1889
1890 pub fn init() Calculator {
1891 return Calculator{ .memory = 0.0 };
1892 }
1893
1894 pub fn add(self: *Calculator, a: f64, b: f64) f64 {
1895 return a + b;
1896 }
1897};
1898
1899test "calculator addition" {
1900 var calc = Calculator.init();
1901 const result = calc.add(2.0, 3.0);
1902 try std.testing.expect(result == 5.0);
1903}
1904"#;
1905
1906 assert_query_parity(ParseableLanguage::Zig, source);
1907 }
1908
1909 #[test]
1910 fn test_chunk_zig() {
1911 let zig_code = r#"
1912const std = @import("std");
1913
1914const Calculator = struct {
1915 memory: f64,
1916
1917 pub fn init() Calculator {
1918 return Calculator{ .memory = 0.0 };
1919 }
1920
1921 pub fn add(self: *Calculator, a: f64, b: f64) f64 {
1922 const result = a + b;
1923 self.memory = result;
1924 return result;
1925 }
1926};
1927
1928const Color = enum {
1929 Red,
1930 Green,
1931 Blue,
1932};
1933
1934const Value = union(enum) {
1935 int: i32,
1936 float: f64,
1937};
1938
1939const Handle = opaque {};
1940
1941const MathError = error{
1942 DivisionByZero,
1943 Overflow,
1944};
1945
1946pub fn multiply(a: i32, b: i32) i32 {
1947 return a * b;
1948}
1949
1950pub fn divide(a: i32, b: i32) MathError!i32 {
1951 if (b == 0) return error.DivisionByZero;
1952 return @divTrunc(a, b);
1953}
1954
1955comptime {
1956 @compileLog("Compile-time validation");
1957}
1958
1959pub fn main() !void {
1960 var calc = Calculator.init();
1961 const result = calc.add(2.0, 3.0);
1962 std.debug.print("Result: {}\n", .{result});
1963}
1964
1965test "calculator addition" {
1966 var calc = Calculator.init();
1967 const result = calc.add(2.0, 3.0);
1968 try std.testing.expect(result == 5.0);
1969}
1970
1971test "multiply function" {
1972 const result = multiply(3, 4);
1973 try std.testing.expect(result == 12);
1974}
1975"#;
1976
1977 let chunks = chunk_language(zig_code, ParseableLanguage::Zig).unwrap();
1978 assert!(!chunks.is_empty());
1979
1980 let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
1981
1982 let class_count = chunk_types
1983 .iter()
1984 .filter(|&&t| t == &ChunkType::Class)
1985 .count();
1986 let function_count = chunk_types
1987 .iter()
1988 .filter(|&&t| t == &ChunkType::Function)
1989 .count();
1990 let module_count = chunk_types
1991 .iter()
1992 .filter(|&&t| t == &ChunkType::Module)
1993 .count();
1994
1995 assert!(
1996 class_count >= 5,
1997 "Expected at least 5 Class chunks (struct, enum, union, opaque, error set), found {}",
1998 class_count
1999 );
2000
2001 assert!(
2002 function_count >= 3,
2003 "Expected at least 3 functions (multiply, divide, main), found {}",
2004 function_count
2005 );
2006
2007 assert!(
2008 module_count >= 4,
2009 "Expected at least 4 module-type chunks (const std, comptime, 2 tests), found {}",
2010 module_count
2011 );
2012
2013 assert!(
2014 chunk_types.contains(&&ChunkType::Class),
2015 "Expected to find Class chunks"
2016 );
2017 assert!(
2018 chunk_types.contains(&&ChunkType::Function),
2019 "Expected to find Function chunks"
2020 );
2021 assert!(
2022 chunk_types.contains(&&ChunkType::Module),
2023 "Expected to find Module chunks"
2024 );
2025 }
2026
2027 #[test]
2028 fn test_chunk_csharp() {
2029 let csharp_code = r#"
2030namespace Calculator;
2031
2032public interface ICalculator
2033{
2034 double Add(double x, double y);
2035}
2036
2037public class Calculator
2038{
2039 public static const double PI = 3.14159;
2040 private double _memory;
2041
2042 public Calculator()
2043 {
2044 _memory = 0.0;
2045 }
2046
2047 public double Add(double x, double y)
2048 {
2049 return x + y;
2050 }
2051
2052 public static void Main(string[] args)
2053 {
2054 var calc = new Calculator();
2055 }
2056}
2057"#;
2058
2059 let chunks = chunk_language(csharp_code, ParseableLanguage::CSharp).unwrap();
2060 assert!(!chunks.is_empty());
2061
2062 let chunk_types: Vec<&ChunkType> = chunks.iter().map(|c| &c.chunk_type).collect();
2064 assert!(chunk_types.contains(&&ChunkType::Module)); assert!(chunk_types.contains(&&ChunkType::Class)); assert!(chunk_types.contains(&&ChunkType::Method)); }
2068
2069 #[test]
2070 fn test_stride_large_chunk_empty_text() {
2071 let empty_chunk = Chunk {
2073 span: Span {
2074 byte_start: 0,
2075 byte_end: 0,
2076 line_start: 1,
2077 line_end: 1,
2078 },
2079 text: String::new(), chunk_type: ChunkType::Text,
2081 stride_info: None,
2082 metadata: ChunkMetadata::from_text(""),
2083 };
2084
2085 let config = ChunkConfig::default();
2086 let result = stride_large_chunk(empty_chunk.clone(), &config);
2087
2088 assert!(result.is_ok());
2090 let chunks = result.unwrap();
2091 assert_eq!(chunks.len(), 1);
2092 assert_eq!(chunks[0].text, "");
2093 }
2094
2095 #[test]
2096 fn test_stride_large_chunk_zero_token_estimate() {
2097 let chunk = Chunk {
2099 span: Span {
2100 byte_start: 0,
2101 byte_end: 5,
2102 line_start: 1,
2103 line_end: 1,
2104 },
2105 text: " ".to_string(), chunk_type: ChunkType::Text,
2107 stride_info: None,
2108 metadata: ChunkMetadata::from_text(" "),
2109 };
2110
2111 let config = ChunkConfig::default();
2112 let result = stride_large_chunk(chunk, &config);
2113
2114 assert!(result.is_ok());
2116 }
2117
2118 #[test]
2119 fn test_strided_chunk_line_calculation() {
2120 let long_text = (1..=50).map(|i| format!("This is a longer line {} with more content to ensure token count is high enough", i)).collect::<Vec<_>>().join("\n");
2123
2124 let metadata = ChunkMetadata::from_text(&long_text);
2125 let chunk = Chunk {
2126 span: Span {
2127 byte_start: 0,
2128 byte_end: long_text.len(),
2129 line_start: 1,
2130 line_end: 50,
2131 },
2132 text: long_text,
2133 chunk_type: ChunkType::Text,
2134 stride_info: None,
2135 metadata,
2136 };
2137
2138 let config = ChunkConfig {
2139 max_tokens: 100, stride_overlap: 10, ..Default::default()
2142 };
2143
2144 let result = stride_large_chunk(chunk, &config);
2145 if let Err(e) = &result {
2146 eprintln!("Stride error: {}", e);
2147 }
2148 assert!(result.is_ok());
2149
2150 let chunks = result.unwrap();
2151 assert!(
2152 chunks.len() > 1,
2153 "Should create multiple chunks when striding"
2154 );
2155
2156 for chunk in chunks {
2157 assert!(chunk.span.line_end >= chunk.span.line_start);
2160
2161 let line_count = chunk.text.lines().count();
2163 if line_count > 0 {
2164 let calculated_line_span = chunk.span.line_end - chunk.span.line_start + 1;
2165
2166 assert!(
2168 calculated_line_span <= line_count + 1,
2169 "Line span {} should not exceed content lines {} by more than 1",
2170 calculated_line_span,
2171 line_count
2172 );
2173 }
2174 }
2175 }
2176
2177 #[test]
2178 fn test_gap_filling_coverage() {
2179 let test_cases = vec![
2181 (
2182 ParseableLanguage::Rust,
2183 r#"// This is a test file with imports at the top
2184use std::collections::HashMap;
2185use std::sync::Arc;
2186
2187// A comment between imports and code
2188const VERSION: &str = "1.0.0";
2189
2190// Main function
2191fn main() {
2192 println!("Hello, world!");
2193}
2194
2195// Some trailing content
2196// that should be indexed
2197"#,
2198 ),
2199 (
2200 ParseableLanguage::Python,
2201 r#"# Imports at the top
2202import os
2203import sys
2204
2205# Some constant
2206VERSION = "1.0.0"
2207
2208# Main function
2209def main():
2210 print("Hello, world!")
2211
2212# Trailing comment
2213# should be indexed
2214"#,
2215 ),
2216 (
2217 ParseableLanguage::TypeScript,
2218 r#"// Imports at the top
2219import { foo } from 'bar';
2220
2221// Some constant
2222const VERSION = "1.0.0";
2223
2224// Main function
2225function main() {
2226 console.log("Hello, world!");
2227}
2228
2229// Trailing comment
2230// should be indexed
2231"#,
2232 ),
2233 ];
2234
2235 for (language, code) in test_cases {
2236 eprintln!("\n=== Testing {} ===", language);
2237 let chunks = chunk_language(code, language).unwrap();
2238
2239 let mut covered_bytes = vec![false; code.len()];
2241 for chunk in &chunks {
2242 for item in covered_bytes
2243 .iter_mut()
2244 .take(chunk.span.byte_end)
2245 .skip(chunk.span.byte_start)
2246 {
2247 *item = true;
2248 }
2249 }
2250
2251 let uncovered_non_ws: Vec<usize> = covered_bytes
2252 .iter()
2253 .enumerate()
2254 .filter(|(i, covered)| !**covered && !code.as_bytes()[*i].is_ascii_whitespace())
2255 .map(|(i, _)| i)
2256 .collect();
2257
2258 if !uncovered_non_ws.is_empty() {
2259 eprintln!("\n=== UNCOVERED NON-WHITESPACE for {} ===", language);
2260 eprintln!("Total bytes: {}", code.len());
2261 eprintln!("Uncovered non-whitespace: {}", uncovered_non_ws.len());
2262
2263 for &pos in uncovered_non_ws.iter().take(10) {
2265 let context_start = pos.saturating_sub(20);
2266 let context_end = (pos + 20).min(code.len());
2267 eprintln!(
2268 "Uncovered at byte {}: {:?}",
2269 pos,
2270 &code[context_start..context_end]
2271 );
2272 }
2273
2274 eprintln!("\n=== CHUNKS ===");
2275 for (i, chunk) in chunks.iter().enumerate() {
2276 eprintln!(
2277 "Chunk {}: {:?} bytes {}-{} (len {})",
2278 i,
2279 chunk.chunk_type,
2280 chunk.span.byte_start,
2281 chunk.span.byte_end,
2282 chunk.span.byte_end - chunk.span.byte_start
2283 );
2284 eprintln!(" Text: {:?}", &chunk.text[..chunk.text.len().min(60)]);
2285 }
2286 }
2287
2288 assert!(
2289 uncovered_non_ws.is_empty(),
2290 "{}: Expected all non-whitespace covered but found {} uncovered non-whitespace bytes",
2291 language,
2292 uncovered_non_ws.len()
2293 );
2294 }
2295 }
2296
2297 #[test]
2298 fn test_web_server_file_coverage() {
2299 let code = std::fs::read_to_string("../examples/code/web_server.rs")
2301 .expect("Failed to read web_server.rs");
2302
2303 let chunks = chunk_language(&code, ParseableLanguage::Rust).unwrap();
2304
2305 let mut covered = vec![false; code.len()];
2307 for chunk in &chunks {
2308 for item in covered
2309 .iter_mut()
2310 .take(chunk.span.byte_end)
2311 .skip(chunk.span.byte_start)
2312 {
2313 *item = true;
2314 }
2315 }
2316
2317 let uncovered_non_whitespace: Vec<(usize, char)> = covered
2319 .iter()
2320 .enumerate()
2321 .filter(|(i, covered)| !**covered && !code.as_bytes()[*i].is_ascii_whitespace())
2322 .map(|(i, _)| (i, code.chars().nth(i).unwrap_or('?')))
2323 .collect();
2324
2325 if !uncovered_non_whitespace.is_empty() {
2326 eprintln!("\n=== WEB_SERVER.RS UNCOVERED NON-WHITESPACE ===");
2327 eprintln!("File size: {} bytes", code.len());
2328 eprintln!("Total chunks: {}", chunks.len());
2329 eprintln!(
2330 "Uncovered non-whitespace: {}",
2331 uncovered_non_whitespace.len()
2332 );
2333
2334 for &(pos, ch) in uncovered_non_whitespace.iter().take(10) {
2335 let start = pos.saturating_sub(30);
2336 let end = (pos + 30).min(code.len());
2337 eprintln!(
2338 "\nUncovered '{}' at byte {}: {:?}",
2339 ch,
2340 pos,
2341 &code[start..end]
2342 );
2343 }
2344
2345 eprintln!("\n=== CHUNKS ===");
2346 for (i, chunk) in chunks.iter().enumerate().take(20) {
2347 eprintln!(
2348 "Chunk {}: {:?} bytes {}-{} lines {}-{}",
2349 i,
2350 chunk.chunk_type,
2351 chunk.span.byte_start,
2352 chunk.span.byte_end,
2353 chunk.span.line_start,
2354 chunk.span.line_end
2355 );
2356 }
2357 }
2358
2359 assert!(
2360 uncovered_non_whitespace.is_empty(),
2361 "Expected all non-whitespace content covered but found {} uncovered non-whitespace bytes",
2362 uncovered_non_whitespace.len()
2363 );
2364 }
2365
2366 #[test]
2367 fn test_haskell_function_chunking() {
2368 let haskell_code = r#"
2369factorial :: Integer -> Integer
2370factorial 0 = 1
2371factorial n = n * factorial (n - 1)
2372
2373fibonacci :: Integer -> Integer
2374fibonacci 0 = 0
2375fibonacci 1 = 1
2376fibonacci n = fibonacci (n - 1) + fibonacci (n - 2)
2377"#;
2378
2379 let mut parser = tree_sitter::Parser::new();
2380 parser
2381 .set_language(&tree_sitter_haskell::LANGUAGE.into())
2382 .unwrap();
2383 let tree = parser.parse(haskell_code, None).unwrap();
2384
2385 fn walk(node: tree_sitter::Node, _src: &str, depth: usize) {
2387 let kind = node.kind();
2388 let start = node.start_position();
2389 let end = node.end_position();
2390 eprintln!(
2391 "{}{:30} L{}-{}",
2392 " ".repeat(depth),
2393 kind,
2394 start.row + 1,
2395 end.row + 1
2396 );
2397
2398 let mut cursor = node.walk();
2399 if cursor.goto_first_child() {
2400 loop {
2401 walk(cursor.node(), _src, depth + 1);
2402 if !cursor.goto_next_sibling() {
2403 break;
2404 }
2405 }
2406 }
2407 }
2408
2409 eprintln!("\n=== TREE STRUCTURE ===");
2410 walk(tree.root_node(), haskell_code, 0);
2411 eprintln!("=== END TREE ===\n");
2412
2413 let chunks = chunk_language(haskell_code, ParseableLanguage::Haskell).unwrap();
2414
2415 eprintln!("\n=== CHUNKS ===");
2416 for (i, chunk) in chunks.iter().enumerate() {
2417 eprintln!(
2418 "Chunk {}: {:?} L{}-{}",
2419 i, chunk.chunk_type, chunk.span.line_start, chunk.span.line_end
2420 );
2421 eprintln!(" Text: {:?}", chunk.text);
2422 }
2423 eprintln!("=== END CHUNKS ===\n");
2424
2425 assert!(!chunks.is_empty(), "Should find chunks in Haskell code");
2426
2427 let factorial_chunk = chunks.iter().find(|c| c.text.contains("factorial 0 = 1"));
2429 assert!(
2430 factorial_chunk.is_some(),
2431 "Should find factorial function body"
2432 );
2433
2434 let fac = factorial_chunk.unwrap();
2435 assert!(
2436 fac.text.contains("factorial :: Integer -> Integer"),
2437 "Should include type signature"
2438 );
2439 assert!(
2440 fac.text.contains("factorial 0 = 1"),
2441 "Should include base case"
2442 );
2443 assert!(
2444 fac.text.contains("factorial n = n * factorial (n - 1)"),
2445 "Should include recursive case"
2446 );
2447 }
2448}