use std::cell::RefCell;
use std::sync::OnceLock;
use once_cell::sync::Lazy;
use rayon::prelude::*;
use regex::Regex;
use serde::{Deserialize, Serialize};
use tiktoken_rs::CoreBPE;
use super::types::{CHUNK_OVERLAP_TOKENS, MAX_CODE_PREVIEW_TOKENS};
use crate::embedding::{TeiClient, TeiError};
const MIN_CHUNK_TOKENS: usize = 10;
#[inline]
#[must_use]
pub fn index_to_line(index: usize) -> usize {
index + 1
}
#[inline]
#[must_use]
pub fn line_to_index(line: usize) -> usize {
line.saturating_sub(1)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TokenizerType {
#[default]
Cl100kBase,
P50kBase,
R50kBase,
Qwen3,
CharEstimate,
}
impl TokenizerType {
#[must_use]
pub fn name(&self) -> &'static str {
match self {
Self::Cl100kBase => "cl100k_base (GPT-4)",
Self::P50kBase => "p50k_base (GPT-3)",
Self::R50kBase => "r50k_base (Codex)",
Self::Qwen3 => "Qwen3 (TEI server)",
Self::CharEstimate => "CharEstimate (Python parity)",
}
}
#[must_use]
pub fn requires_tei(&self) -> bool {
matches!(self, Self::Qwen3)
}
#[must_use]
pub fn is_estimation(&self) -> bool {
matches!(self, Self::CharEstimate)
}
#[must_use]
pub fn variance_vs_qwen3(&self) -> f64 {
match self {
Self::Cl100kBase => 0.10, Self::P50kBase => 0.15, Self::R50kBase => 0.12, Self::Qwen3 => 0.0, Self::CharEstimate => 0.05, }
}
}
impl std::fmt::Display for TokenizerType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone)]
pub struct ChunkConfig {
pub max_tokens: usize,
pub overlap_tokens: usize,
pub tokenizer: TokenizerType,
pub variance_margin: f64,
}
impl Default for ChunkConfig {
fn default() -> Self {
Self {
max_tokens: MAX_CODE_PREVIEW_TOKENS,
overlap_tokens: CHUNK_OVERLAP_TOKENS,
tokenizer: TokenizerType::default(),
variance_margin: 0.0, }
}
}
impl ChunkConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = max_tokens;
self
}
#[must_use]
pub fn with_overlap_tokens(mut self, overlap_tokens: usize) -> Self {
self.overlap_tokens = overlap_tokens;
self
}
#[must_use]
pub fn with_tokenizer(mut self, tokenizer: TokenizerType) -> Self {
self.tokenizer = tokenizer;
self
}
#[must_use]
pub fn with_variance_margin(mut self, margin: f64) -> Self {
self.variance_margin = margin.clamp(0.0, 0.5); self
}
#[must_use]
pub fn effective_max_tokens(&self) -> usize {
if self.tokenizer.requires_tei() {
self.max_tokens
} else {
let margin = self.variance_margin + self.tokenizer.variance_vs_qwen3();
let effective = (self.max_tokens as f64 * (1.0 - margin)) as usize;
effective.max(MIN_CHUNK_TOKENS) }
}
}
static GLOBAL_TOKENIZER_TYPE: OnceLock<TokenizerType> = OnceLock::new();
pub fn set_tokenizer_type(tokenizer_type: TokenizerType) -> bool {
GLOBAL_TOKENIZER_TYPE.set(tokenizer_type).is_ok()
}
#[must_use]
pub fn get_tokenizer_type() -> TokenizerType {
GLOBAL_TOKENIZER_TYPE
.get()
.copied()
.unwrap_or(TokenizerType::Cl100kBase)
}
thread_local! {
static THREAD_TOKENIZER: RefCell<Option<CoreBPE>> = RefCell::new(
match tiktoken_rs::cl100k_base() {
Ok(bpe) => Some(bpe),
Err(e) => {
tracing::warn!(
"Failed to initialize thread-local tokenizer: {}, falling back to estimation",
e
);
None
}
}
);
}
#[must_use]
pub fn count_tokens(text: &str) -> usize {
if text.is_empty() {
return 0;
}
match get_tokenizer_type() {
TokenizerType::CharEstimate => estimate_tokens_unicode_aware(text),
TokenizerType::Qwen3 => estimate_tokens_unicode_aware(text),
TokenizerType::Cl100kBase | TokenizerType::P50kBase | TokenizerType::R50kBase => {
THREAD_TOKENIZER.with(|tokenizer| match tokenizer.borrow().as_ref() {
Some(bpe) => bpe.encode_ordinary(text).len(),
None => estimate_tokens_unicode_aware(text),
})
}
}
}
#[inline]
#[must_use]
#[allow(dead_code)]
fn estimate_tokens(text: &str) -> usize {
(text.len() + 3) / 4
}
#[must_use]
pub fn estimate_tokens_unicode_aware(text: &str) -> usize {
if text.is_empty() {
return 0;
}
let bytes = text.as_bytes();
let ascii_chars = count_ascii_bytes_simd(bytes);
let mut cjk_chars: usize = 0;
let mut emoji_chars: usize = 0;
let mut other_chars: usize = 0;
for c in text.chars() {
let code = c as u32;
if code >= 128 {
if is_cjk_char(code) {
cjk_chars += 1;
} else if is_emoji_char(code) {
emoji_chars += 1;
} else {
other_chars += 1;
}
}
}
let estimated = ascii_chars as f64 / 4.0 + cjk_chars as f64 * 1.5 + emoji_chars as f64 * 3.0 + other_chars as f64 / 2.0;
(estimated as usize).max(1)
}
#[inline]
fn count_ascii_bytes_simd(bytes: &[u8]) -> usize {
let len = bytes.len();
if len < 64 {
return bytes.iter().filter(|&&b| b < 128).count();
}
let mut i = 0;
#[cfg(target_arch = "x86_64")]
let mut ascii_count = count_ascii_x86_64(bytes, &mut i);
#[cfg(target_arch = "aarch64")]
let mut ascii_count = count_ascii_aarch64(bytes, &mut i);
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
let mut ascii_count: usize = 0;
for &b in &bytes[i..] {
if b < 128 {
ascii_count += 1;
}
}
ascii_count
}
#[cfg(target_arch = "x86_64")]
#[inline]
fn count_ascii_x86_64(bytes: &[u8], offset: &mut usize) -> usize {
use std::arch::x86_64::*;
let len = bytes.len();
let mut count: usize = 0;
let ptr = bytes.as_ptr();
if is_x86_feature_detected!("avx2") {
unsafe {
let high_bit_mask = _mm256_set1_epi8(i8::MIN);
while *offset + 32 <= len {
let chunk = _mm256_loadu_si256(ptr.add(*offset).cast::<__m256i>());
let high_bits = _mm256_and_si256(chunk, high_bit_mask);
let is_ascii = _mm256_cmpeq_epi8(high_bits, _mm256_setzero_si256());
let mask = _mm256_movemask_epi8(is_ascii) as u32;
count += mask.count_ones() as usize;
*offset += 32;
}
}
} else if is_x86_feature_detected!("sse2") {
unsafe {
let high_bit_mask = _mm_set1_epi8(i8::MIN);
while *offset + 16 <= len {
let chunk = _mm_loadu_si128(ptr.add(*offset).cast::<__m128i>());
let high_bits = _mm_and_si128(chunk, high_bit_mask);
let is_ascii = _mm_cmpeq_epi8(high_bits, _mm_setzero_si128());
let mask = _mm_movemask_epi8(is_ascii) as u16;
count += mask.count_ones() as usize;
*offset += 16;
}
}
}
count
}
#[cfg(target_arch = "aarch64")]
#[inline]
fn count_ascii_aarch64(bytes: &[u8], offset: &mut usize) -> usize {
use std::arch::aarch64::*;
let len = bytes.len();
let mut count: usize = 0;
let ptr = bytes.as_ptr();
unsafe {
let high_bit_mask = vdupq_n_u8(0x80);
let zero = vdupq_n_u8(0);
while *offset + 16 <= len {
let chunk = vld1q_u8(ptr.add(*offset));
let high_bits = vandq_u8(chunk, high_bit_mask);
let is_ascii = vceqq_u8(high_bits, zero);
let ascii_bytes = vandq_u8(is_ascii, vdupq_n_u8(1));
count += vaddvq_u8(ascii_bytes) as usize;
*offset += 16;
}
}
count
}
#[inline]
fn is_cjk_char(code: u32) -> bool {
(0x4E00..=0x9FFF).contains(&code)
|| (0x3400..=0x4DBF).contains(&code)
|| (0x20000..=0x2A6DF).contains(&code)
|| (0x2A700..=0x2B73F).contains(&code)
|| (0x2B740..=0x2B81F).contains(&code)
|| (0xF900..=0xFAFF).contains(&code)
|| (0x3000..=0x303F).contains(&code)
|| (0x3040..=0x309F).contains(&code)
|| (0x30A0..=0x30FF).contains(&code)
|| (0xAC00..=0xD7AF).contains(&code)
}
#[inline]
fn is_emoji_char(code: u32) -> bool {
(0x1F300..=0x1F9FF).contains(&code)
|| (0x1FA00..=0x1FA6F).contains(&code)
|| (0x1F600..=0x1F64F).contains(&code)
|| (0x2600..=0x26FF).contains(&code)
|| (0x2700..=0x27BF).contains(&code)
|| (0xFE00..=0xFE0F).contains(&code)
|| (0x1F000..=0x1F02F).contains(&code)
}
pub async fn count_tokens_tei(text: &str, client: &TeiClient) -> Result<usize, TeiError> {
if text.is_empty() {
return Ok(0);
}
client.count_tokens(text).await
}
pub async fn count_tokens_batch_tei(
texts: &[&str],
client: &TeiClient,
) -> Result<Vec<usize>, TeiError> {
if texts.is_empty() {
return Ok(Vec::new());
}
client.count_tokens_batch(texts).await
}
pub async fn compare_tokenizer_counts(
text: &str,
client: &TeiClient,
) -> Result<(usize, usize, f64), TeiError> {
let local_count = count_tokens(text);
let tei_count = count_tokens_tei(text, client).await?;
let variance = if tei_count == 0 {
0.0
} else {
(local_count as f64 - tei_count as f64).abs() / tei_count as f64 * 100.0
};
Ok((local_count, tei_count, variance))
}
#[inline]
#[must_use]
pub fn count_lines(content: &str) -> usize {
use std::simd::{cmp::SimdPartialEq, u8x32};
if content.is_empty() {
return 0;
}
let bytes = content.as_bytes();
let len = bytes.len();
let newline_vec = u8x32::splat(b'\n');
let mut count = 0_usize;
let mut offset = 0_usize;
while offset + 32 <= len {
let chunk = u8x32::from_slice(&bytes[offset..offset + 32]);
let mask = chunk.simd_eq(newline_vec);
count += mask.to_bitmask().count_ones() as usize;
offset += 32;
}
for &byte in &bytes[offset..] {
if byte == b'\n' {
count += 1;
}
}
count + 1
}
#[must_use]
pub fn truncate_to_tokens(text: &str, max_tokens: usize) -> String {
if text.is_empty() || max_tokens == 0 {
return String::new();
}
let current_tokens = count_tokens(text);
if current_tokens <= max_tokens {
return text.to_string();
}
THREAD_TOKENIZER.with(|tokenizer| {
match tokenizer.borrow().as_ref() {
Some(bpe) => {
let tokens = bpe.encode_ordinary(text);
if tokens.len() <= max_tokens {
return text.to_string();
}
let truncated_tokens = &tokens[..max_tokens];
bpe.decode(truncated_tokens.to_vec())
.unwrap_or_else(|_| {
let max_chars = max_tokens * 4;
if text.len() <= max_chars {
text.to_string()
} else {
let truncated = &text[..max_chars.min(text.len())];
match truncated.rfind(char::is_whitespace) {
Some(pos) => truncated[..pos].to_string(),
None => truncated.to_string(),
}
}
})
}
None => {
let max_chars = max_tokens * 4;
if text.len() <= max_chars {
text.to_string()
} else {
let truncated = &text[..max_chars];
match truncated.rfind(char::is_whitespace) {
Some(pos) => truncated[..pos].to_string(),
None => truncated.to_string(),
}
}
}
}
})
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Chunk {
pub chunk_id: String,
pub content: String,
pub token_count: usize,
pub start_line: usize,
pub end_line: usize,
pub start_char: usize,
pub end_char: usize,
pub chunk_index: usize,
pub chunk_total: usize,
pub parent_ref: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub signature: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub docstring: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub calls: Vec<String>,
}
impl Chunk {
#[must_use]
pub fn new(
content: String,
start_line: usize,
end_line: usize,
start_char: usize,
end_char: usize,
) -> Self {
debug_assert!(
start_line >= 1,
"start_line must be 1-indexed (>= 1), got {}. \
Use index_to_line() to convert from 0-indexed array indices.",
start_line
);
debug_assert!(
end_line >= start_line,
"end_line ({}) must be >= start_line ({}). \
For single-line chunks, use start_line == end_line.",
end_line,
start_line
);
let token_count = count_tokens(&content);
Self {
chunk_id: String::new(),
content,
token_count,
start_line,
end_line,
start_char,
end_char,
chunk_index: 0,
chunk_total: 1,
parent_ref: None,
signature: None,
docstring: None,
calls: Vec::new(),
}
}
#[must_use]
pub fn with_id(mut self, id: impl Into<String>) -> Self {
self.chunk_id = id.into();
self
}
#[must_use]
pub fn with_parent(mut self, parent: impl Into<String>) -> Self {
self.parent_ref = Some(parent.into());
self
}
#[must_use]
pub fn with_chunk_info(mut self, index: usize, total: usize) -> Self {
self.chunk_index = index;
self.chunk_total = total;
self
}
#[must_use]
pub fn with_metadata(
mut self,
signature: Option<String>,
docstring: Option<String>,
calls: Vec<String>,
) -> Self {
self.signature = signature;
self.docstring = docstring;
self.calls = calls;
self
}
#[must_use]
pub fn with_signature(mut self, signature: impl Into<String>) -> Self {
self.signature = Some(signature.into());
self
}
#[must_use]
pub fn with_docstring(mut self, docstring: impl Into<String>) -> Self {
self.docstring = Some(docstring.into());
self
}
#[must_use]
pub fn with_calls(mut self, calls: Vec<String>) -> Self {
self.calls = calls;
self
}
#[must_use]
pub fn is_standalone(&self) -> bool {
self.chunk_total == 1
}
#[must_use]
pub fn has_metadata(&self) -> bool {
self.signature.is_some() || self.docstring.is_some() || !self.calls.is_empty()
}
#[must_use]
pub fn line_count(&self) -> usize {
self.end_line.saturating_sub(self.start_line) + 1
}
#[must_use]
pub fn contains_line(&self, line: usize) -> bool {
line >= self.start_line && line <= self.end_line
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum BoundaryKind {
BlankLine,
Comment,
BlockEnd,
FunctionStart,
ClassStart,
}
#[derive(Debug, Clone)]
struct Boundary {
line_idx: usize,
#[allow(dead_code)]
char_offset: usize,
kind: BoundaryKind,
}
static BOUNDARY_PATTERNS: Lazy<BoundaryPatterns> = Lazy::new(BoundaryPatterns::new);
struct BoundaryPatterns {
function_start: Regex,
class_start: Regex,
comment: Regex,
blank: Regex,
}
impl BoundaryPatterns {
fn new() -> Self {
Self {
function_start: Regex::new(r"^\s*(pub\s+)?(async\s+)?(def|fn|func|function|fun)\s+\w+")
.expect("valid regex"),
class_start: Regex::new(
r"^\s*(pub\s+)?(?:(?:class|struct|interface|trait|type)\s+\w+|impl\s*(?:<(?:[^<>]|<[^>]*>)*>)?\s*\w+)"
).expect("valid regex"),
comment: Regex::new(
r#"^\s*(?:#(?:$|[^!\[#])|//|/\*|"""|'''|--(?:$|[^>])|\(\*|%(?:$|[^%=])|;|<!--|!(?:\s|$)|'(?:$|[^']))"#
).expect("valid regex"),
blank: Regex::new(r"^\s*$").expect("valid regex"),
}
}
}
fn detect_boundaries(lines: &[&str], line_offsets: &[usize]) -> Vec<Boundary> {
let mut boundaries = Vec::new();
let patterns = &*BOUNDARY_PATTERNS;
for (idx, line) in lines.iter().enumerate() {
let char_offset = line_offsets.get(idx).copied().unwrap_or(0);
if patterns.class_start.is_match(line) {
boundaries.push(Boundary {
line_idx: idx,
char_offset,
kind: BoundaryKind::ClassStart,
});
} else if patterns.function_start.is_match(line) {
boundaries.push(Boundary {
line_idx: idx,
char_offset,
kind: BoundaryKind::FunctionStart,
});
} else if patterns.blank.is_match(line) {
boundaries.push(Boundary {
line_idx: idx,
char_offset,
kind: BoundaryKind::BlankLine,
});
} else if patterns.comment.is_match(line) {
boundaries.push(Boundary {
line_idx: idx,
char_offset,
kind: BoundaryKind::Comment,
});
}
if idx > 0 {
let prev_indent = get_indent_depth(lines[idx - 1]);
let curr_indent = get_indent_depth(line);
if prev_indent > curr_indent && !line.trim().is_empty() {
boundaries.push(Boundary {
line_idx: idx,
char_offset,
kind: BoundaryKind::BlockEnd,
});
}
}
}
boundaries
}
#[inline]
fn get_indent_depth(line: &str) -> usize {
let stripped = line.trim_start();
if stripped.is_empty() {
return 0;
}
let leading_len = line.len() - stripped.len();
let leading = &line[..leading_len];
let expanded_len: usize = leading.chars().map(|c| if c == '\t' { 4 } else { 1 }).sum();
expanded_len / 4
}
#[must_use]
pub fn chunk_code(code: &str, max_tokens: usize) -> Vec<Chunk> {
chunk_code_with_overlap(code, max_tokens, CHUNK_OVERLAP_TOKENS)
}
#[must_use]
pub fn chunk_code_with_overlap(code: &str, max_tokens: usize, overlap_tokens: usize) -> Vec<Chunk> {
if code.is_empty() {
return Vec::new();
}
let total_tokens = count_tokens(code);
if total_tokens <= max_tokens {
let line_count = count_lines(code);
return vec![Chunk::new(code.to_string(), 1, line_count, 0, code.len())];
}
let lines: Vec<&str> = code.lines().collect();
let line_offsets = build_line_offsets(&lines, code);
let boundaries = detect_boundaries(&lines, &line_offsets);
chunk_with_boundaries(
code,
&lines,
&line_offsets,
&boundaries,
max_tokens,
overlap_tokens,
)
}
fn build_line_offsets(lines: &[&str], code: &str) -> Vec<usize> {
let mut offsets = Vec::with_capacity(lines.len());
let mut current_offset = 0;
for (i, line) in lines.iter().enumerate() {
offsets.push(current_offset);
current_offset += line.len();
if i < lines.len() - 1 {
current_offset += 1;
}
}
if code.ends_with('\n') && !lines.is_empty() {
}
offsets
}
fn chunk_with_boundaries(
code: &str,
lines: &[&str],
line_offsets: &[usize],
boundaries: &[Boundary],
max_tokens: usize,
overlap_tokens: usize,
) -> Vec<Chunk> {
let mut chunks = Vec::new();
let line_token_counts: Vec<usize> = lines
.par_iter()
.enumerate()
.map(|(idx, line)| {
let with_newline = if idx < lines.len() - 1 {
format!("{}\n", line)
} else {
(*line).to_string()
};
count_tokens(&with_newline)
})
.collect();
let mut chunk_start_line = 0;
let mut current_tokens = 0;
let mut last_good_boundary: Option<usize> = None;
let mut overlap_lines: Vec<usize>;
let mut overlap_token_count: usize;
for (line_idx, _line) in lines.iter().enumerate() {
let line_tokens = line_token_counts[line_idx];
if current_tokens + line_tokens > max_tokens && line_idx > chunk_start_line {
let split_line =
find_best_split(chunk_start_line, line_idx, last_good_boundary, boundaries);
let chunk =
create_chunk_from_range(code, lines, line_offsets, chunk_start_line, split_line);
chunks.push(chunk);
(overlap_lines, overlap_token_count) =
calculate_overlap(lines, chunk_start_line, split_line, overlap_tokens);
chunk_start_line = split_line.saturating_sub(overlap_lines.len());
current_tokens = overlap_token_count;
last_good_boundary = None;
}
if boundaries.iter().any(|b| b.line_idx == line_idx) {
last_good_boundary = Some(line_idx);
}
current_tokens += line_tokens;
}
if chunk_start_line < lines.len() {
let chunk =
create_chunk_from_range(code, lines, line_offsets, chunk_start_line, lines.len());
chunks.push(chunk);
}
let total_chunks = chunks.len();
for (i, chunk) in chunks.iter_mut().enumerate() {
chunk.chunk_index = i;
chunk.chunk_total = total_chunks;
chunk.chunk_id = format!("chunk_{}", i + 1);
}
if chunks.len() == 1 && chunks[0].token_count > max_tokens {
return handle_oversized_chunk(&chunks[0], max_tokens, overlap_tokens);
}
chunks
}
fn find_best_split(
chunk_start: usize,
current_line: usize,
last_boundary: Option<usize>,
boundaries: &[Boundary],
) -> usize {
if let Some(boundary_line) = last_boundary {
if boundary_line > chunk_start && boundary_line < current_line {
return boundary_line;
}
}
let candidates: Vec<_> = boundaries
.iter()
.filter(|b| b.line_idx > chunk_start && b.line_idx < current_line)
.collect();
if let Some(best) = candidates.iter().max_by_key(|b| b.kind) {
return best.line_idx;
}
current_line
}
fn create_chunk_from_range(
code: &str,
lines: &[&str],
line_offsets: &[usize],
start_idx: usize,
end_idx: usize,
) -> Chunk {
let start_char = line_offsets.get(start_idx).copied().unwrap_or(0);
let end_char = if end_idx >= lines.len() {
code.len()
} else {
line_offsets.get(end_idx).copied().unwrap_or(code.len())
};
let end_char = end_char.min(code.len());
let content = code[start_char..end_char].to_string();
Chunk::new(
content,
index_to_line(start_idx), end_idx, start_char,
end_char,
)
}
fn calculate_overlap(
lines: &[&str],
chunk_start_line: usize,
split_line: usize,
overlap_tokens: usize,
) -> (Vec<usize>, usize) {
let mut overlap_lines = Vec::new();
let mut token_count = 0;
for line_idx in (chunk_start_line..split_line).rev() {
let line = lines.get(line_idx).copied().unwrap_or("");
let line_tokens = count_tokens(line) + 1;
if token_count + line_tokens > overlap_tokens {
break;
}
overlap_lines.insert(0, line_idx);
token_count += line_tokens;
}
(overlap_lines, token_count)
}
fn handle_oversized_chunk(chunk: &Chunk, max_tokens: usize, overlap_tokens: usize) -> Vec<Chunk> {
let lines: Vec<&str> = chunk.content.lines().collect();
if lines.len() <= 1 {
let truncated = truncate_to_tokens(&chunk.content, max_tokens);
return vec![Chunk::new(
truncated,
chunk.start_line,
chunk.start_line,
chunk.start_char,
chunk.start_char + chunk.content.len().min(max_tokens * 4),
)];
}
let line_offsets = build_line_offsets(&lines, &chunk.content);
let boundaries = detect_boundaries(&lines, &line_offsets);
if !boundaries.is_empty() {
let reduced_max = (max_tokens * 3 / 4).max(MIN_CHUNK_TOKENS);
if reduced_max < max_tokens {
let sub_chunks = chunk_with_boundaries(
&chunk.content,
&lines,
&line_offsets,
&boundaries,
reduced_max,
overlap_tokens,
);
if sub_chunks.len() > 1 {
return sub_chunks
.into_iter()
.map(|mut c| {
c.start_line += chunk.start_line - 1;
c.end_line += chunk.start_line - 1;
c.start_char += chunk.start_char;
c.end_char += chunk.start_char;
c
})
.collect();
}
}
}
force_split_by_lines(
&chunk.content,
chunk.start_line,
chunk.start_char,
max_tokens,
)
}
fn force_split_by_lines(
content: &str,
base_line: usize,
base_char: usize,
max_tokens: usize,
) -> Vec<Chunk> {
let mut chunks = Vec::new();
let mut current_content = String::new();
let mut current_tokens = 0;
let mut current_start_line = base_line;
let mut current_start_char = base_char;
let mut char_offset = 0;
let has_trailing_newline = content.ends_with('\n');
let lines: Vec<&str> = content.lines().collect();
let total_lines = lines.len();
for (i, line) in lines.iter().enumerate() {
let is_last_line = i == total_lines - 1;
let should_add_newline = !is_last_line || has_trailing_newline;
let line_with_optional_newline = if should_add_newline {
format!("{}\n", line)
} else {
(*line).to_string()
};
let line_tokens = count_tokens(&line_with_optional_newline);
let line_byte_len = if should_add_newline {
line.len() + 1
} else {
line.len()
};
if line_tokens > max_tokens && current_content.is_empty() {
let truncated = truncate_to_tokens(line, max_tokens);
chunks.push(Chunk::new(
truncated,
base_line + i,
base_line + i,
base_char + char_offset,
base_char + char_offset + line.len(),
));
char_offset += line_byte_len;
current_start_line = base_line + i + 1;
current_start_char = base_char + char_offset;
continue;
}
if current_tokens + line_tokens > max_tokens && !current_content.is_empty() {
let end_line = base_line + i - 1;
let end_char = base_char + char_offset;
chunks.push(Chunk::new(
std::mem::take(&mut current_content),
current_start_line,
end_line,
current_start_char,
end_char,
));
current_tokens = 0;
current_start_line = base_line + i;
current_start_char = base_char + char_offset;
}
current_content.push_str(&line_with_optional_newline);
current_tokens += line_tokens;
char_offset += line_byte_len;
}
if !current_content.is_empty() {
let end_line = base_line + total_lines.saturating_sub(1);
chunks.push(Chunk::new(
current_content,
current_start_line,
end_line,
current_start_char,
base_char + content.len(),
));
}
let total = chunks.len();
for (i, chunk) in chunks.iter_mut().enumerate() {
chunk.chunk_index = i;
chunk.chunk_total = total;
chunk.chunk_id = format!("chunk_{}", i + 1);
}
chunks
}
#[must_use]
pub fn chunk_code_default(code: &str) -> Vec<Chunk> {
chunk_code(code, MAX_CODE_PREVIEW_TOKENS)
}
#[must_use]
pub fn needs_chunking(code: &str, max_tokens: usize) -> bool {
count_tokens(code) > max_tokens
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_count_tokens_empty() {
assert_eq!(count_tokens(""), 0);
}
#[test]
fn test_count_tokens_simple() {
let count = count_tokens("hello world");
assert!(count > 0);
assert!(count < 10);
}
#[test]
fn test_count_tokens_code() {
let code = "def hello():\n print('world')\n";
let count = count_tokens(code);
assert!(count > 0);
assert!(count < 50);
}
#[test]
fn test_count_lines_empty() {
assert_eq!(count_lines(""), 0);
}
#[test]
fn test_count_lines_single_without_newline() {
assert_eq!(count_lines("a"), 1);
assert_eq!(count_lines("hello world"), 1);
}
#[test]
fn test_count_lines_single_with_newline() {
assert_eq!(count_lines("a\n"), 2);
assert_eq!(count_lines("hello\n"), 2);
}
#[test]
fn test_count_lines_multiple_without_trailing() {
assert_eq!(count_lines("a\nb"), 2);
assert_eq!(count_lines("a\nb\nc"), 3);
assert_eq!(count_lines("line1\nline2\nline3"), 3);
}
#[test]
fn test_count_lines_multiple_with_trailing() {
assert_eq!(count_lines("a\nb\n"), 3);
assert_eq!(count_lines("a\nb\nc\n"), 4);
assert_eq!(count_lines("line1\nline2\nline3\n"), 4);
}
#[test]
fn test_count_lines_only_newlines() {
assert_eq!(count_lines("\n"), 2); assert_eq!(count_lines("\n\n"), 3); assert_eq!(count_lines("\n\n\n"), 4); assert_eq!(count_lines("\n\n\n\n"), 5);
}
#[test]
fn test_count_lines_mixed_empty_lines() {
assert_eq!(count_lines("a\n\nb"), 3); assert_eq!(count_lines("a\n\n\nb"), 4); assert_eq!(count_lines("a\n\nb\n"), 4); }
#[test]
fn test_count_lines_vs_lines_iterator() {
let s1 = "a\nb\nc";
assert_eq!(count_lines(s1), 3);
assert_eq!(s1.lines().count(), 3);
let s2 = "a\nb\nc\n";
assert_eq!(count_lines(s2), 4);
assert_eq!(s2.lines().count(), 3);
let s3 = "\n\n\n";
assert_eq!(count_lines(s3), 4);
assert_eq!(s3.lines().count(), 3); }
#[test]
fn test_count_lines_chunk_end_line_accuracy() {
let code = "line1\nline2\nline3\n";
let chunks = chunk_code(code, 1000);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].start_line, 1);
assert_eq!(
chunks[0].end_line, 4,
"end_line should account for trailing newline"
);
}
#[test]
fn test_truncate_to_tokens_empty() {
assert_eq!(truncate_to_tokens("", 100), "");
}
#[test]
fn test_truncate_to_tokens_fits() {
let text = "hello world";
let result = truncate_to_tokens(text, 100);
assert_eq!(result, text);
}
#[test]
fn test_truncate_to_tokens_truncates() {
let text = "hello world this is a longer text that should be truncated";
let result = truncate_to_tokens(text, 3);
assert!(result.len() < text.len());
assert!(!result.is_empty());
}
#[test]
fn test_chunk_empty() {
let chunks = chunk_code("", 100);
assert!(chunks.is_empty());
}
#[test]
fn test_chunk_fits_single() {
let code = "def hello():\n pass\n";
let chunks = chunk_code(code, 1000);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].content, code);
assert_eq!(chunks[0].start_line, 1);
assert!(chunks[0].is_standalone());
}
#[test]
fn test_chunk_splits_at_functions() {
let code = r#"def func1():
"""First function."""
pass
def func2():
"""Second function."""
pass
def func3():
"""Third function."""
pass
"#;
let chunks = chunk_code(code, 20);
assert!(chunks.len() > 1);
for chunk in &chunks {
assert!(
chunk.token_count <= 50,
"Chunk {} has {} tokens (expected <= 50)",
chunk.chunk_index,
chunk.token_count
);
}
}
#[test]
fn test_chunk_preserves_content() {
let code = "line1\nline2\nline3\nline4\nline5\n";
let chunks = chunk_code(code, 5);
let all_lines: std::collections::HashSet<_> = code.lines().collect();
let chunked_lines: std::collections::HashSet<_> =
chunks.iter().flat_map(|c| c.content.lines()).collect();
for line in &all_lines {
assert!(
chunked_lines.contains(line),
"Line '{}' missing from chunks",
line
);
}
}
#[test]
fn test_chunk_metadata() {
let code = "a\nb\nc\nd\ne\nf\ng\nh\n";
let chunks = chunk_code(code, 3);
for (i, chunk) in chunks.iter().enumerate() {
assert_eq!(chunk.chunk_index, i);
assert_eq!(chunk.chunk_total, chunks.len());
}
}
#[test]
fn test_get_indent_depth() {
assert_eq!(get_indent_depth(""), 0);
assert_eq!(get_indent_depth(" code"), 1);
assert_eq!(get_indent_depth(" code"), 2);
assert_eq!(get_indent_depth("\tcode"), 1);
assert_eq!(get_indent_depth("\t\tcode"), 2);
assert_eq!(get_indent_depth(" \tcode"), 1); }
#[test]
fn test_boundary_detection() {
let code = "def func():\n pass\n\nclass MyClass:\n pass\n";
let lines: Vec<&str> = code.lines().collect();
let offsets = build_line_offsets(&lines, code);
let boundaries = detect_boundaries(&lines, &offsets);
let kinds: Vec<_> = boundaries.iter().map(|b| b.kind).collect();
assert!(kinds.contains(&BoundaryKind::FunctionStart));
assert!(kinds.contains(&BoundaryKind::BlankLine));
assert!(kinds.contains(&BoundaryKind::ClassStart));
}
fn is_detected_as_comment(line: &str) -> bool {
let lines = vec![line];
let offsets = vec![0usize];
let boundaries = detect_boundaries(&lines, &offsets);
boundaries.iter().any(|b| b.kind == BoundaryKind::Comment)
}
#[test]
fn test_comment_detection_python_ruby_shell() {
assert!(is_detected_as_comment(" # Python comment"));
assert!(is_detected_as_comment("# Ruby comment"));
assert!(is_detected_as_comment(" # Shell comment"));
assert!(is_detected_as_comment("#comment without space"));
assert!(!is_detected_as_comment("#!/usr/bin/env python"));
assert!(!is_detected_as_comment("#!bash"));
assert!(!is_detected_as_comment("#[derive(Debug)]"));
assert!(!is_detected_as_comment(" #[cfg(test)]"));
assert!(!is_detected_as_comment("## Section header decorator"));
}
#[test]
fn test_comment_detection_c_style() {
assert!(is_detected_as_comment(" // C++ comment"));
assert!(is_detected_as_comment("// single line"));
assert!(is_detected_as_comment(" /* block comment start */"));
assert!(is_detected_as_comment("/* multi"));
}
#[test]
fn test_comment_detection_python_docstrings() {
assert!(is_detected_as_comment(r#" """Docstring""""#));
assert!(is_detected_as_comment(r#""""Triple quote doc""""#));
assert!(is_detected_as_comment(" '''Single quote doc'''"));
assert!(is_detected_as_comment("'''"));
}
#[test]
fn test_comment_detection_sql_haskell_lua() {
assert!(is_detected_as_comment(" -- SQL comment"));
assert!(is_detected_as_comment("-- Haskell comment"));
assert!(is_detected_as_comment(" -- Lua comment"));
assert!(is_detected_as_comment("--compact"));
assert!(!is_detected_as_comment(" -->"));
assert!(!is_detected_as_comment("-->closing tag"));
}
#[test]
fn test_comment_detection_ocaml_pascal() {
assert!(is_detected_as_comment(" (* OCaml comment *)"));
assert!(is_detected_as_comment("(* Pascal comment *)"));
assert!(is_detected_as_comment(" (* nested (* comment *) *)"));
assert!(is_detected_as_comment("(*"));
}
#[test]
fn test_comment_detection_latex_erlang_matlab() {
assert!(is_detected_as_comment(" % LaTeX comment"));
assert!(is_detected_as_comment("% Erlang comment"));
assert!(is_detected_as_comment(" % MATLAB comment"));
assert!(!is_detected_as_comment(" %% format spec"));
assert!(!is_detected_as_comment("%=modulo"));
}
#[test]
fn test_comment_detection_lisp_scheme() {
assert!(is_detected_as_comment(" ; Lisp comment"));
assert!(is_detected_as_comment("; Scheme comment"));
assert!(is_detected_as_comment(";;; Section comment"));
assert!(is_detected_as_comment(" ; Clojure comment"));
}
#[test]
fn test_comment_detection_html_xml() {
assert!(is_detected_as_comment(" <!-- HTML comment -->"));
assert!(is_detected_as_comment("<!-- XML comment"));
assert!(is_detected_as_comment(" <!--"));
}
#[test]
fn test_comment_detection_fortran() {
assert!(is_detected_as_comment(" ! Fortran comment"));
assert!(is_detected_as_comment("! comment"));
assert!(!is_detected_as_comment(" !="));
assert!(!is_detected_as_comment("!important"));
}
#[test]
fn test_comment_detection_basic_vb() {
assert!(is_detected_as_comment(" ' VB comment"));
assert!(is_detected_as_comment("' BASIC comment"));
assert!(is_detected_as_comment("'''")); }
#[test]
fn test_comment_detection_not_trailing() {
assert!(!is_detected_as_comment(" code() // trailing comment"));
assert!(!is_detected_as_comment("x = 1 # trailing"));
assert!(!is_detected_as_comment("SELECT * FROM foo -- trailing"));
}
#[test]
fn test_comment_detection_empty_and_whitespace() {
assert!(!is_detected_as_comment(""));
assert!(!is_detected_as_comment(" "));
assert!(!is_detected_as_comment("\t\t"));
}
#[test]
fn test_single_level_dedent_block_end_detection() {
let code = r#"def foo():
if x:
pass
bar()
if y:
baz()
return result"#;
let lines: Vec<&str> = code.lines().collect();
let offsets = build_line_offsets(&lines, code);
let boundaries = detect_boundaries(&lines, &offsets);
let block_ends: Vec<_> = boundaries
.iter()
.filter(|b| b.kind == BoundaryKind::BlockEnd)
.collect();
assert!(
block_ends.len() >= 2,
"Should detect at least 2 single-level dedents as BlockEnd, found {}. \
Boundaries: {:?}",
block_ends.len(),
block_ends.iter().map(|b| b.line_idx).collect::<Vec<_>>()
);
let block_end_lines: Vec<_> = block_ends.iter().map(|b| b.line_idx).collect();
assert!(
block_end_lines.contains(&3),
"bar() at line 3 should be detected as BlockEnd. Found: {:?}",
block_end_lines
);
assert!(
block_end_lines.contains(&6),
"return result at line 6 should be detected as BlockEnd. Found: {:?}",
block_end_lines
);
}
#[test]
fn test_multi_level_dedent_still_detected() {
let code = r#"def foo():
if x:
if y:
pass
bar()"#;
let lines: Vec<&str> = code.lines().collect();
let offsets = build_line_offsets(&lines, code);
let boundaries = detect_boundaries(&lines, &offsets);
let block_ends: Vec<_> = boundaries
.iter()
.filter(|b| b.kind == BoundaryKind::BlockEnd)
.map(|b| b.line_idx)
.collect();
assert!(
block_ends.contains(&4),
"Multi-level dedent at line 4 should be detected. Found: {:?}",
block_ends
);
}
#[test]
fn test_needs_chunking() {
assert!(!needs_chunking("short", 100));
assert!(needs_chunking(&"x".repeat(1000), 10));
}
#[test]
fn test_chunk_with_overlap() {
let code = "line1\nline2\nline3\nline4\nline5\n";
let chunks = chunk_code_with_overlap(code, 5, 2);
if chunks.len() > 1 {
for i in 1..chunks.len() {
let prev_end = chunks[i - 1].end_line;
let curr_start = chunks[i].start_line;
assert!(
curr_start <= prev_end + 1,
"No overlap between chunks {} and {}",
i - 1,
i
);
}
}
}
#[test]
fn test_chunk_line_numbers() {
let code = "line1\nline2\nline3\n";
let chunks = chunk_code(code, 1000);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].start_line, 1);
assert_eq!(chunks[0].end_line, 4);
}
#[test]
fn test_end_line_is_inclusive() {
let content = "line1\nline2\nline3";
let chunks = chunk_code(content, 1000);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].start_line, 1, "start_line should be 1 (first line)");
assert_eq!(chunks[0].end_line, 3, "end_line should be 3 (last line, INCLUSIVE)");
assert_eq!(chunks[0].line_count(), 3, "line_count should be 3");
assert!(chunks[0].contains_line(1), "Line 1 should be in chunk");
assert!(chunks[0].contains_line(2), "Line 2 should be in chunk");
assert!(chunks[0].contains_line(3), "Line 3 should be in chunk (INCLUSIVE)");
assert!(!chunks[0].contains_line(0), "Line 0 should NOT be in chunk");
assert!(!chunks[0].contains_line(4), "Line 4 should NOT be in chunk");
let content_with_newline = "line1\nline2\nline3\n";
let chunks2 = chunk_code(content_with_newline, 1000);
assert_eq!(chunks2.len(), 1);
assert_eq!(chunks2[0].start_line, 1);
assert_eq!(
chunks2[0].end_line, 4,
"Trailing newline creates line 4 (empty), which should be INCLUDED"
);
assert_eq!(chunks2[0].line_count(), 4, "Should count all 4 lines");
let single_line = "only one line";
let single_chunks = chunk_code(single_line, 1000);
assert_eq!(single_chunks.len(), 1);
assert_eq!(single_chunks[0].start_line, 1);
assert_eq!(
single_chunks[0].end_line, 1,
"Single-line chunk: end_line == start_line"
);
assert_eq!(single_chunks[0].line_count(), 1);
assert!(single_chunks[0].contains_line(1));
assert!(!single_chunks[0].contains_line(2));
}
#[test]
fn test_chunk_handles_long_single_line() {
let long_line = "x".repeat(10000);
let chunks = chunk_code(&long_line, 100);
assert!(!chunks.is_empty());
assert!(chunks[0].token_count <= 100 || chunks[0].content.len() < long_line.len());
}
#[test]
fn test_estimate_tokens() {
assert_eq!(estimate_tokens("12345678"), 2); assert_eq!(estimate_tokens(""), 0);
assert_eq!(estimate_tokens("abc"), 1); }
#[test]
fn test_chunk_python_with_docstrings() {
let code = r#"def process_data(items):
"""Process a list of items.
This function takes items and processes them
in a very important way.
Args:
items: List of items to process
Returns:
Processed items
"""
result = []
for item in items:
result.append(transform(item))
return result
"#;
let chunks = chunk_code(code, 100);
assert!(!chunks.is_empty());
let first_chunk = &chunks[0];
assert!(
first_chunk.content.contains("Process a list"),
"Docstring should be in first chunk"
);
}
#[test]
fn test_chunk_rust_code() {
let code = r#"fn main() {
println!("Hello");
}
pub fn helper() {
// Helper function
}
impl MyStruct {
fn method(&self) {
// Method body
}
}
"#;
let chunks = chunk_code(code, 30);
assert!(!chunks.is_empty());
}
#[test]
fn test_rust_impl_generics_boundary_detection() {
let patterns = &*BOUNDARY_PATTERNS;
assert!(
patterns.class_start.is_match("impl Foo {"),
"Simple impl should match"
);
assert!(
patterns.class_start.is_match("impl<T> Foo<T> {"),
"impl<T> should match"
);
assert!(
patterns.class_start.is_match("impl<T: Clone> Foo<T> for Bar {"),
"impl<T: Clone> should match"
);
assert!(
patterns.class_start.is_match("impl<'a, T: 'a + Clone> Foo<'a, T> {"),
"impl with lifetime should match"
);
assert!(
patterns.class_start.is_match("impl<T: Iterator<Item=Foo>> Bar {"),
"impl with nested generics should match"
);
assert!(
patterns.class_start.is_match("pub impl<T> Foo<T> {"),
"pub impl<T> should match"
);
assert!(
patterns.class_start.is_match("impl<T, U> Pair<T, U> {"),
"impl with multiple type params should match"
);
assert!(
patterns.class_start.is_match(" impl<T> Foo<T> {"),
"Indented impl<T> should match"
);
}
#[test]
fn test_rust_impl_generics_in_chunking() {
let code = r#"fn setup() {
// Setup code
}
impl<T: Clone> Container<T> {
fn new(value: T) -> Self {
Self { value }
}
fn get(&self) -> &T {
&self.value
}
}
impl<T: Default> Default for Container<T> {
fn default() -> Self {
Self::new(T::default())
}
}
"#;
let lines: Vec<&str> = code.lines().collect();
let offsets = build_line_offsets(&lines, code);
let boundaries = detect_boundaries(&lines, &offsets);
let class_boundaries: Vec<_> = boundaries
.iter()
.filter(|b| b.kind == BoundaryKind::ClassStart)
.collect();
assert!(
class_boundaries.len() >= 2,
"Should detect at least 2 impl blocks with generics, found {}",
class_boundaries.len()
);
let impl_lines: Vec<_> = class_boundaries.iter().map(|b| b.line_idx).collect();
assert!(
impl_lines.contains(&4),
"impl<T: Clone> Container<T> should be detected at line 4, got {:?}",
impl_lines
);
assert!(
impl_lines.contains(&14),
"impl<T: Default> Default for Container<T> should be detected at line 14, got {:?}",
impl_lines
);
}
#[test]
fn test_chunk_typescript_code() {
let code = r#"function processData(data: string[]): void {
console.log(data);
}
async function fetchData(): Promise<string> {
return "data";
}
class DataProcessor {
private data: string[];
constructor() {
this.data = [];
}
}
"#;
let chunks = chunk_code(code, 40);
assert!(!chunks.is_empty());
}
#[test]
fn test_calculate_overlap_only_scans_current_chunk() {
let mut large_code = String::new();
for i in 0..2000 {
large_code.push_str(&format!("def func_{}():\n pass\n\n", i));
}
let start = std::time::Instant::now();
let chunks = chunk_code(&large_code, 100);
let elapsed = start.elapsed();
assert!(
elapsed.as_secs() < 5,
"Chunking took too long ({:?}), possible O(n^2) regression",
elapsed
);
assert!(chunks.len() > 10, "Should produce multiple chunks");
for i in 1..chunks.len() {
let prev_end = chunks[i - 1].end_line;
let curr_start = chunks[i].start_line;
assert!(
curr_start <= prev_end + 1,
"Chunk {} start ({}) should overlap with chunk {} end ({})",
i, curr_start, i - 1, prev_end
);
}
}
#[test]
fn test_chunk_new_has_empty_metadata() {
let chunk = Chunk::new("def foo(): pass".to_string(), 1, 1, 0, 15);
assert!(chunk.signature.is_none());
assert!(chunk.docstring.is_none());
assert!(chunk.calls.is_empty());
assert!(!chunk.has_metadata());
}
#[test]
fn test_chunk_with_metadata() {
let chunk = Chunk::new("def foo(): pass".to_string(), 1, 1, 0, 15).with_metadata(
Some("def foo() -> None".to_string()),
Some("Does foo things".to_string()),
vec!["bar".to_string(), "baz".to_string()],
);
assert_eq!(chunk.signature, Some("def foo() -> None".to_string()));
assert_eq!(chunk.docstring, Some("Does foo things".to_string()));
assert_eq!(chunk.calls, vec!["bar", "baz"]);
assert!(chunk.has_metadata());
}
#[test]
fn test_chunk_individual_metadata_setters() {
let chunk = Chunk::new("code".to_string(), 1, 1, 0, 4)
.with_signature("fn foo()")
.with_docstring("A function")
.with_calls(vec!["helper".to_string()]);
assert_eq!(chunk.signature, Some("fn foo()".to_string()));
assert_eq!(chunk.docstring, Some("A function".to_string()));
assert_eq!(chunk.calls, vec!["helper"]);
}
#[test]
fn test_chunk_serialization() {
let chunk = Chunk::new("def foo(): bar()".to_string(), 1, 1, 0, 16)
.with_id("test::foo#chunk1")
.with_parent("foo")
.with_chunk_info(0, 2)
.with_metadata(
Some("def foo()".to_string()),
Some("Test function".to_string()),
vec!["bar".to_string()],
);
let json = serde_json::to_string(&chunk).expect("serialization should succeed");
assert!(json.contains("signature"));
assert!(json.contains("docstring"));
assert!(json.contains("calls"));
assert!(json.contains("def foo()"));
let deserialized: Chunk =
serde_json::from_str(&json).expect("deserialization should succeed");
assert_eq!(deserialized.signature, chunk.signature);
assert_eq!(deserialized.docstring, chunk.docstring);
assert_eq!(deserialized.calls, chunk.calls);
}
#[test]
fn test_chunk_serialization_skips_empty_fields() {
let chunk = Chunk::new("def foo(): pass".to_string(), 1, 1, 0, 15);
let json = serde_json::to_string(&chunk).expect("serialization should succeed");
assert!(!json.contains("signature"));
assert!(!json.contains("docstring"));
assert!(!json.contains("\"calls\""));
}
#[test]
fn test_chunk_has_metadata_partial() {
let chunk1 = Chunk::new("code".to_string(), 1, 1, 0, 4).with_signature("fn foo()");
assert!(chunk1.has_metadata());
let chunk2 = Chunk::new("code".to_string(), 1, 1, 0, 4).with_docstring("A doc");
assert!(chunk2.has_metadata());
let chunk3 = Chunk::new("code".to_string(), 1, 1, 0, 4).with_calls(vec!["bar".to_string()]);
assert!(chunk3.has_metadata());
}
#[test]
fn test_force_split_preserves_trailing_newline() {
let content_no_newline = "line1\nline2\nline3";
let chunks = force_split_by_lines(content_no_newline, 1, 0, 5);
let reconstructed: String = chunks.iter().map(|c| c.content.as_str()).collect();
assert!(
!reconstructed.ends_with('\n'),
"Content without trailing newline should not gain one. Got: {:?}",
reconstructed
);
assert_eq!(
reconstructed.len(),
content_no_newline.len(),
"Reconstructed length should match original"
);
let content_with_newline = "line1\nline2\nline3\n";
let chunks_with = force_split_by_lines(content_with_newline, 1, 0, 5);
let reconstructed_with: String = chunks_with.iter().map(|c| c.content.as_str()).collect();
assert!(
reconstructed_with.ends_with('\n'),
"Content with trailing newline should preserve it. Got: {:?}",
reconstructed_with
);
assert_eq!(
reconstructed_with.len(),
content_with_newline.len(),
"Reconstructed length should match original"
);
}
#[test]
fn test_force_split_single_line_no_trailing_newline() {
let content = "single_line_content";
let chunks = force_split_by_lines(content, 1, 0, 100);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].content, content);
assert!(!chunks[0].content.ends_with('\n'));
}
#[test]
fn test_force_split_single_line_with_trailing_newline() {
let content = "single_line_content\n";
let chunks = force_split_by_lines(content, 1, 0, 100);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].content, content);
assert!(chunks[0].content.ends_with('\n'));
}
#[test]
fn test_recursive_reduction_does_not_reach_zero_max_tokens() {
let code = r#"def very_long_function():
"""A function that will exceed small token limits."""
x = 1
y = 2
z = x + y
return z
"#;
let chunks = chunk_code(code, 4);
assert!(
!chunks.is_empty(),
"Chunking should produce at least one chunk"
);
for chunk in &chunks {
assert!(
!chunk.content.is_empty(),
"Each chunk should have non-empty content"
);
}
}
#[test]
fn test_chunk_code_with_min_chunk_tokens_boundary() {
let code = "def foo():\n pass\n\ndef bar():\n pass\n";
let chunks_at_min = chunk_code(code, MIN_CHUNK_TOKENS);
assert!(!chunks_at_min.is_empty());
let chunks_above_min = chunk_code(code, MIN_CHUNK_TOKENS + 5);
assert!(!chunks_above_min.is_empty());
let chunks_below_min = chunk_code(code, 5);
assert!(!chunks_below_min.is_empty());
}
#[test]
fn test_handle_oversized_chunk_respects_min_tokens() {
let content = "line1\nline2\nline3\nline4\nline5\n";
let chunk = Chunk::new(content.to_string(), 1, 5, 0, content.len());
let result = handle_oversized_chunk(&chunk, 8, 2);
assert!(!result.is_empty());
}
#[test]
fn test_estimate_tokens_unicode_aware_ascii() {
let result = estimate_tokens_unicode_aware("hello world");
assert_eq!(result, 2, "ASCII text should estimate ~4 chars per token");
assert_eq!(estimate_tokens_unicode_aware(""), 0);
assert_eq!(estimate_tokens_unicode_aware("a"), 1);
}
#[test]
fn test_estimate_tokens_unicode_aware_cjk() {
let cjk_text = "\u{4E00}\u{4E01}\u{4E02}\u{4E03}"; let result = estimate_tokens_unicode_aware(cjk_text);
assert_eq!(result, 6, "CJK text should estimate ~1.5 tokens per char");
}
#[test]
fn test_estimate_tokens_unicode_aware_mixed() {
let mixed = "hello\u{4E00}\u{4E01}";
let result = estimate_tokens_unicode_aware(mixed);
assert_eq!(result, 4);
}
#[test]
fn test_estimate_tokens_unicode_aware_emoji() {
let emoji_text = "\u{1F600}\u{1F601}"; let result = estimate_tokens_unicode_aware(emoji_text);
assert_eq!(result, 6, "Emoji should estimate ~3 tokens each");
}
#[test]
fn test_estimate_tokens_unicode_aware_other_unicode() {
let cyrillic = "\u{0410}\u{0411}\u{0412}\u{0413}"; let result = estimate_tokens_unicode_aware(cyrillic);
assert_eq!(result, 2, "Other Unicode should estimate ~2 chars per token");
}
#[test]
fn test_tokenizer_type_char_estimate_variant() {
let char_estimate = TokenizerType::CharEstimate;
assert_eq!(char_estimate.name(), "CharEstimate (Python parity)");
assert!(!char_estimate.requires_tei());
assert!(char_estimate.is_estimation());
assert_eq!(char_estimate.variance_vs_qwen3(), 0.05);
}
#[test]
fn test_is_cjk_char() {
assert!(is_cjk_char(0x4E00)); assert!(is_cjk_char(0x9FFF));
assert!(is_cjk_char(0x3040));
assert!(is_cjk_char(0x309F));
assert!(is_cjk_char(0x30A0));
assert!(is_cjk_char(0x30FF));
assert!(is_cjk_char(0xAC00));
assert!(is_cjk_char(0xD7AF));
assert!(!is_cjk_char(0x0041)); assert!(!is_cjk_char(0x0020)); }
#[test]
fn test_is_emoji_char() {
assert!(is_emoji_char(0x1F600)); assert!(is_emoji_char(0x1F601)); assert!(is_emoji_char(0x2600));
assert!(!is_emoji_char(0x0041)); assert!(!is_emoji_char(0x4E00)); }
#[test]
fn test_count_ascii_bytes_simd_small() {
let small = b"hello world";
assert_eq!(count_ascii_bytes_simd(small), 11);
assert_eq!(count_ascii_bytes_simd(b""), 0);
assert_eq!(count_ascii_bytes_simd(b"x"), 1);
let just_under = vec![b'a'; 63];
assert_eq!(count_ascii_bytes_simd(&just_under), 63);
}
#[test]
fn test_count_ascii_bytes_simd_medium() {
let medium = vec![b'a'; 100];
assert_eq!(count_ascii_bytes_simd(&medium), 100);
let exact = vec![b'x'; 64];
assert_eq!(count_ascii_bytes_simd(&exact), 64);
let multiple = vec![b'z'; 128];
assert_eq!(count_ascii_bytes_simd(&multiple), 128);
}
#[test]
fn test_count_ascii_bytes_simd_large() {
let large = vec![b'a'; 1000];
assert_eq!(count_ascii_bytes_simd(&large), 1000);
let very_large = vec![b'b'; 10000];
assert_eq!(count_ascii_bytes_simd(&very_large), 10000);
}
#[test]
fn test_count_ascii_bytes_simd_non_ascii() {
let non_ascii: Vec<u8> = vec![0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0];
assert_eq!(count_ascii_bytes_simd(&non_ascii), 0);
let large_non_ascii: Vec<u8> = vec![0xFF; 100];
assert_eq!(count_ascii_bytes_simd(&large_non_ascii), 0);
}
#[test]
fn test_count_ascii_bytes_simd_mixed() {
let mixed: Vec<u8> = vec![
b'h', b'e', b'l', b'l', b'o', 0xC3, 0xA9, b'w', b'o', b'r', b'l', b'd', ];
assert_eq!(count_ascii_bytes_simd(&mixed), 10);
let mut large_mixed = vec![b'a'; 50];
large_mixed.extend(vec![0xFF; 20]);
large_mixed.extend(vec![b'b'; 50]);
assert_eq!(count_ascii_bytes_simd(&large_mixed), 100);
}
#[test]
fn test_count_ascii_bytes_simd_utf8_string() {
let text = "Hello World 123 \u{4E00}\u{4E01} \u{1F600}";
let bytes = text.as_bytes();
let ascii_count = bytes.iter().filter(|&&b| b < 128).count();
assert_eq!(count_ascii_bytes_simd(bytes), ascii_count);
}
#[test]
fn test_get_tokenizer_type_default() {
let current = get_tokenizer_type();
matches!(
current,
TokenizerType::Cl100kBase
| TokenizerType::CharEstimate
| TokenizerType::Qwen3
| TokenizerType::P50kBase
| TokenizerType::R50kBase
);
}
#[test]
fn test_index_to_line_conversion() {
assert_eq!(index_to_line(0), 1, "Index 0 should map to line 1");
assert_eq!(index_to_line(1), 2, "Index 1 should map to line 2");
assert_eq!(index_to_line(9), 10, "Index 9 should map to line 10");
assert_eq!(index_to_line(99), 100, "Index 99 should map to line 100");
}
#[test]
fn test_line_to_index_conversion() {
assert_eq!(line_to_index(1), 0, "Line 1 should map to index 0");
assert_eq!(line_to_index(2), 1, "Line 2 should map to index 1");
assert_eq!(line_to_index(10), 9, "Line 10 should map to index 9");
assert_eq!(line_to_index(100), 99, "Line 100 should map to index 99");
assert_eq!(
line_to_index(0),
0,
"Invalid line 0 should saturate to index 0"
);
}
#[test]
fn test_roundtrip_index_line_conversion() {
for idx in 0..100 {
let line = index_to_line(idx);
let back = line_to_index(line);
assert_eq!(
back, idx,
"Roundtrip failed: index {} -> line {} -> index {}",
idx, line, back
);
}
}
#[test]
fn test_line_numbers_are_1_indexed() {
let content = "line1\nline2\nline3";
let chunks = chunk_code(content, 1000);
assert_eq!(chunks.len(), 1, "Should produce single chunk");
assert_eq!(
chunks[0].start_line, 1,
"First line should be 1, not 0"
);
assert_eq!(
chunks[0].end_line, 3,
"Last line should be 3 (3 lines, 1-indexed inclusive)"
);
assert_eq!(
chunks[0].line_count(),
3,
"Line count should be 3 (end - start + 1)"
);
}
#[test]
fn test_single_line_chunk_line_numbers() {
let content = "single line";
let chunks = chunk_code(content, 1000);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].start_line, 1, "Single line: start should be 1");
assert_eq!(chunks[0].end_line, 1, "Single line: end should be 1 (inclusive)");
assert_eq!(chunks[0].line_count(), 1, "Single line: count should be 1");
}
#[test]
fn test_multi_chunk_line_numbers_continuous() {
let mut content = String::new();
for i in 1..=20 {
content.push_str(&format!("def function_{}():\n pass\n\n", i));
}
let chunks = chunk_code(&content, 50);
assert!(chunks.len() > 1, "Should produce multiple chunks");
assert_eq!(
chunks[0].start_line, 1,
"First chunk should start at line 1"
);
for i in 1..chunks.len() {
let prev_end = chunks[i - 1].end_line;
let curr_start = chunks[i].start_line;
assert!(
curr_start <= prev_end + 1,
"Chunk {} start ({}) should be <= prev chunk {} end ({}) + 1",
i,
curr_start,
i - 1,
prev_end
);
}
}
#[test]
fn test_contains_line_uses_1_indexed() {
let chunk = Chunk::new("a\nb\nc".to_string(), 5, 7, 0, 5);
assert!(chunk.contains_line(5), "Line 5 should be in chunk [5,7]");
assert!(chunk.contains_line(6), "Line 6 should be in chunk [5,7]");
assert!(chunk.contains_line(7), "Line 7 should be in chunk [5,7]");
assert!(!chunk.contains_line(4), "Line 4 should NOT be in chunk [5,7]");
assert!(!chunk.contains_line(8), "Line 8 should NOT be in chunk [5,7]");
assert!(
!chunk.contains_line(0),
"Line 0 (invalid) should NOT be in any chunk"
);
}
#[test]
fn test_create_chunk_from_range_converts_correctly() {
let code = "line0\nline1\nline2\nline3";
let lines: Vec<&str> = code.lines().collect();
let line_offsets = build_line_offsets(&lines, code);
let chunk = create_chunk_from_range(code, &lines, &line_offsets, 0, 3);
assert_eq!(
chunk.start_line, 1,
"Internal index 0 should become line 1"
);
assert_eq!(
chunk.end_line, 3,
"Internal exclusive end 3 should become inclusive line 3"
);
assert_eq!(chunk.line_count(), 3, "Should contain 3 lines");
assert!(chunk.content.contains("line0"));
assert!(chunk.content.contains("line1"));
assert!(chunk.content.contains("line2"));
assert!(!chunk.content.contains("line3"));
}
#[test]
fn test_force_split_preserves_1_indexed_base() {
let content = "a\nb\nc\nd\ne";
let base_line = 10;
let chunks = force_split_by_lines(content, base_line, 0, 2);
assert!(!chunks.is_empty(), "Should produce chunks");
assert_eq!(
chunks[0].start_line, 10,
"First sub-chunk should start at base_line {}",
base_line
);
for (i, chunk) in chunks.iter().enumerate() {
assert!(
chunk.start_line >= base_line,
"Chunk {} start_line {} should be >= base_line {}",
i,
chunk.start_line,
base_line
);
assert!(
chunk.end_line >= chunk.start_line,
"Chunk {} end_line {} should be >= start_line {}",
i,
chunk.end_line,
chunk.start_line
);
}
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "start_line must be 1-indexed")]
fn test_chunk_new_rejects_zero_start_line() {
let _ = Chunk::new("content".to_string(), 0, 1, 0, 7);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "end_line")]
fn test_chunk_new_rejects_end_before_start() {
let _ = Chunk::new("content".to_string(), 5, 3, 0, 7);
}
}