chunk/
merge.rs

1//! Token-aware merging for chunkers.
2//!
3//! This module provides functions to merge text segments based on token counts,
4//! equivalent to Chonkie's Cython `merge.pyx`. Used by RecursiveChunker and
5//! other chunkers that need to respect token limits.
6
7/// Find merge indices for combining segments within token limits.
8///
9/// This is the core algorithm used by chunkers to find optimal merge points
10/// based on token counts and chunk size limits.
11///
12/// # Arguments
13///
14/// * `token_counts` - Token count for each segment
15/// * `chunk_size` - Maximum tokens per merged chunk
16///
17/// # Returns
18///
19/// Vector of end indices where merges should occur. Each index marks the
20/// exclusive end of a merged chunk.
21///
22/// # Example
23///
24/// ```
25/// use chunk::find_merge_indices;
26///
27/// let token_counts = vec![10, 15, 20, 5, 8, 12];
28/// let indices = find_merge_indices(&token_counts, 30);
29/// // Merge [0:2], [2:4], [4:6] -> indices = [2, 4, 6]
30/// ```
31pub fn find_merge_indices(token_counts: &[usize], chunk_size: usize) -> Vec<usize> {
32    if token_counts.is_empty() {
33        return vec![];
34    }
35
36    let n = token_counts.len();
37
38    // Build cumulative token counts
39    let mut cumulative: Vec<usize> = Vec::with_capacity(n + 1);
40    cumulative.push(0);
41
42    let mut sum = 0usize;
43    for &count in token_counts {
44        sum += count;
45        cumulative.push(sum);
46    }
47
48    // Find merge indices using binary search
49    let mut indices = Vec::new();
50    let mut current_pos = 0;
51
52    while current_pos < n {
53        // Binary search for rightmost valid position
54        let mut left = current_pos + 1;
55        let mut right = n + 1;
56
57        while left < right {
58            let mid = (left + right) / 2;
59            let fits = cumulative[mid] - cumulative[current_pos] <= chunk_size;
60
61            if fits {
62                left = mid + 1;
63            } else {
64                right = mid;
65            }
66        }
67
68        // left is now one past the last valid position
69        let index = left.saturating_sub(1).max(current_pos + 1).min(n);
70
71        indices.push(index);
72        current_pos = index;
73    }
74
75    indices
76}
77
78/// Compute merged token counts from merge indices.
79fn compute_merged_token_counts(token_counts: &[usize], merge_indices: &[usize]) -> Vec<usize> {
80    if merge_indices.is_empty() {
81        return vec![];
82    }
83
84    let mut result = Vec::with_capacity(merge_indices.len());
85    let mut start = 0;
86
87    for &end in merge_indices {
88        let end = end.min(token_counts.len());
89        let sum: usize = token_counts[start..end].iter().sum();
90        result.push(sum);
91        start = end;
92    }
93
94    result
95}
96
97/// Result of merge_splits operation.
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct MergeResult {
100    /// Merged text segments
101    pub merged: Vec<String>,
102    /// Token count for each merged chunk
103    pub token_counts: Vec<usize>,
104}
105
106/// Merge text segments based on token counts, respecting chunk size limits.
107///
108/// This is the Rust equivalent of Chonkie's Cython `_merge_splits` function.
109/// Performs string concatenation in Rust for optimal performance.
110///
111/// # Arguments
112///
113/// * `splits` - Text segments to merge
114/// * `token_counts` - Token count for each segment
115/// * `chunk_size` - Maximum tokens per merged chunk
116///
117/// # Returns
118///
119/// `MergeResult` containing merged text and token counts.
120///
121/// # Example
122///
123/// ```
124/// use chunk::merge_splits;
125///
126/// let splits = vec!["Hello", "world", "!", "How", "are", "you"];
127/// let token_counts = vec![1, 1, 1, 1, 1, 1];
128/// let result = merge_splits(&splits, &token_counts, 3);
129/// assert_eq!(result.merged, vec!["Helloworld!", "Howareyou"]);
130/// assert_eq!(result.token_counts, vec![3, 3]);
131/// ```
132pub fn merge_splits(
133    splits: &[&str],
134    token_counts: &[usize],
135    chunk_size: usize,
136) -> MergeResult {
137    // Early exit for empty input
138    if splits.is_empty() || token_counts.is_empty() {
139        return MergeResult {
140            merged: vec![],
141            token_counts: vec![],
142        };
143    }
144
145    // If all token counts exceed chunk_size, return segments as-is
146    if token_counts.iter().all(|&c| c > chunk_size) {
147        return MergeResult {
148            merged: splits.iter().map(|s| s.to_string()).collect(),
149            token_counts: token_counts.to_vec(),
150        };
151    }
152
153    let indices = find_merge_indices(token_counts, chunk_size);
154    let merged_counts = compute_merged_token_counts(token_counts, &indices);
155
156    // Build merged strings
157    let mut merged = Vec::with_capacity(indices.len());
158    let mut start = 0;
159
160    for &end in &indices {
161        let end = end.min(splits.len());
162        // Pre-calculate total length for efficient allocation
163        let total_len: usize = splits[start..end].iter().map(|s| s.len()).sum();
164        let mut s = String::with_capacity(total_len);
165        for segment in &splits[start..end] {
166            s.push_str(segment);
167        }
168        merged.push(s);
169        start = end;
170    }
171
172    MergeResult {
173        merged,
174        token_counts: merged_counts,
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_find_merge_indices_basic() {
184        let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
185        let indices = find_merge_indices(&token_counts, 3);
186        assert_eq!(indices, vec![3, 6, 7]);
187    }
188
189    #[test]
190    fn test_find_merge_indices_large_chunks() {
191        let token_counts = vec![10, 15, 20, 5, 8, 12];
192        let indices = find_merge_indices(&token_counts, 30);
193        assert_eq!(indices, vec![2, 4, 6]);
194    }
195
196    #[test]
197    fn test_find_merge_indices_empty() {
198        let token_counts: Vec<usize> = vec![];
199        let indices = find_merge_indices(&token_counts, 10);
200        assert!(indices.is_empty());
201    }
202
203    #[test]
204    fn test_find_merge_indices_single() {
205        let token_counts = vec![5];
206        let indices = find_merge_indices(&token_counts, 10);
207        assert_eq!(indices, vec![1]);
208    }
209
210    #[test]
211    fn test_find_merge_indices_all_large() {
212        let token_counts = vec![50, 60, 70];
213        let indices = find_merge_indices(&token_counts, 30);
214        assert_eq!(indices, vec![1, 2, 3]);
215    }
216
217    #[test]
218    fn test_merge_splits_basic() {
219        let splits = vec!["a", "b", "c", "d", "e", "f", "g"];
220        let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
221        let result = merge_splits(&splits, &token_counts, 3);
222        assert_eq!(result.merged, vec!["abc", "def", "g"]);
223        assert_eq!(result.token_counts, vec![3, 3, 1]);
224    }
225
226    #[test]
227    fn test_merge_splits_empty() {
228        let splits: Vec<&str> = vec![];
229        let token_counts: Vec<usize> = vec![];
230        let result = merge_splits(&splits, &token_counts, 10);
231        assert!(result.merged.is_empty());
232        assert!(result.token_counts.is_empty());
233    }
234
235    #[test]
236    fn test_merge_splits_all_exceed_limit() {
237        let splits = vec!["aaa", "bbb", "ccc"];
238        let token_counts = vec![50, 60, 70];
239        let result = merge_splits(&splits, &token_counts, 30);
240        assert_eq!(result.merged, vec!["aaa", "bbb", "ccc"]);
241        assert_eq!(result.token_counts, vec![50, 60, 70]);
242    }
243}