chunk/merge.rs
1//! Token-aware merging for chunkers.
2//!
3//! This module provides functions to merge text segments based on token counts,
4//! equivalent to Chonkie's Cython `merge.pyx`. Used by RecursiveChunker and
5//! other chunkers that need to respect token limits.
6
7/// Find merge indices for combining segments within token limits.
8///
9/// This is the core algorithm used by RecursiveChunker to merge small segments
10/// into larger chunks that fit within a token budget.
11///
12/// # Arguments
13///
14/// * `token_counts` - Token count for each segment
15/// * `chunk_size` - Maximum tokens per merged chunk
16/// * `combine_whitespace` - If true, adds +1 token per join for whitespace
17///
18/// # Returns
19///
20/// Vector of end indices where merges should occur. Each index marks the
21/// exclusive end of a merged chunk.
22///
23/// # Example
24///
25/// ```
26/// use chunk::find_merge_indices;
27///
28/// let token_counts = vec![10, 15, 20, 5, 8, 12];
29/// let indices = find_merge_indices(&token_counts, 30, false);
30/// // Merge [0:2], [2:4], [4:6] -> indices = [2, 4, 6]
31/// ```
32pub fn find_merge_indices(
33 token_counts: &[usize],
34 chunk_size: usize,
35 combine_whitespace: bool,
36) -> Vec<usize> {
37 if token_counts.is_empty() {
38 return vec![];
39 }
40
41 let n = token_counts.len();
42
43 // Build cumulative token counts (raw, without whitespace adjustment)
44 let mut cumulative: Vec<usize> = Vec::with_capacity(n + 1);
45 cumulative.push(0);
46
47 let mut sum = 0usize;
48 for &count in token_counts {
49 sum += count;
50 cumulative.push(sum);
51 }
52
53 // Find merge indices using binary search
54 let mut indices = Vec::new();
55 let mut current_pos = 0;
56
57 while current_pos < n {
58 // For a chunk from current_pos to end:
59 // - Raw tokens: cumulative[end] - cumulative[current_pos]
60 // - Whitespace tokens (if combine_whitespace): (end - current_pos - 1) for joins
61 // - Total must be <= chunk_size
62 //
63 // With whitespace: cumulative[end] + end <= cumulative[current_pos] + current_pos + chunk_size + 1
64 // Without: cumulative[end] <= cumulative[current_pos] + chunk_size
65
66 // Binary search for rightmost valid position
67 let mut left = current_pos + 1;
68 let mut right = n + 1;
69
70 while left < right {
71 let mid = (left + right) / 2;
72 let fits = if combine_whitespace {
73 // Total tokens = raw_sum + (end - start - 1) whitespace joins
74 let raw_sum = cumulative[mid] - cumulative[current_pos];
75 let whitespace = if mid > current_pos + 1 {
76 mid - current_pos - 1
77 } else {
78 0
79 };
80 raw_sum + whitespace <= chunk_size
81 } else {
82 cumulative[mid] - cumulative[current_pos] <= chunk_size
83 };
84
85 if fits {
86 left = mid + 1;
87 } else {
88 right = mid;
89 }
90 }
91
92 // left is now one past the last valid position
93 let index = left.saturating_sub(1).max(current_pos + 1).min(n);
94
95 indices.push(index);
96 current_pos = index;
97 }
98
99 indices
100}
101
102/// Compute merged token counts from merge indices.
103///
104/// Given the original token counts and merge indices from `find_merge_indices`,
105/// compute the token count for each merged chunk.
106///
107/// # Arguments
108///
109/// * `token_counts` - Original token counts for each segment
110/// * `merge_indices` - End indices from `find_merge_indices`
111/// * `combine_whitespace` - If true, adds +1 token per join for whitespace
112///
113/// # Returns
114///
115/// Vector of token counts for each merged chunk.
116pub fn compute_merged_token_counts(
117 token_counts: &[usize],
118 merge_indices: &[usize],
119 combine_whitespace: bool,
120) -> Vec<usize> {
121 if merge_indices.is_empty() {
122 return vec![];
123 }
124
125 let mut result = Vec::with_capacity(merge_indices.len());
126 let mut start = 0;
127
128 for &end in merge_indices {
129 let end = end.min(token_counts.len());
130 let mut sum: usize = token_counts[start..end].iter().sum();
131
132 if combine_whitespace && end > start {
133 // Add whitespace tokens for joins (n-1 joins for n segments)
134 sum += end - start - 1;
135 }
136
137 result.push(sum);
138 start = end;
139 }
140
141 result
142}
143
144/// Result of merge_splits operation.
145#[derive(Debug, Clone, PartialEq, Eq)]
146pub struct MergeResult {
147 /// End indices for each merged chunk (exclusive).
148 /// Use with slicing: segments[prev_end..end]
149 pub indices: Vec<usize>,
150 /// Token count for each merged chunk
151 pub token_counts: Vec<usize>,
152}
153
154/// Merge segments based on token counts, respecting chunk size limits.
155///
156/// This is the Rust equivalent of Chonkie's Cython `_merge_splits` function.
157/// Returns indices for slicing the original segments, rather than copying strings.
158///
159/// # Arguments
160///
161/// * `token_counts` - Token count for each segment
162/// * `chunk_size` - Maximum tokens per merged chunk
163/// * `combine_whitespace` - If true, join with whitespace (+1 token per join)
164///
165/// # Returns
166///
167/// `MergeResult` containing:
168/// - `indices`: End indices for slicing segments
169/// - `token_counts`: Token count for each merged chunk
170///
171/// # Example
172///
173/// ```
174/// use chunk::merge_splits;
175///
176/// // segments = ["Hello", "world", "!", "How", "are", "you", "?"]
177/// let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
178/// let result = merge_splits(&token_counts, 5, true);
179///
180/// // Use indices to slice: segments[0..3], segments[3..6], segments[6..7]
181/// // chunk_size=5 allows 3 segments + 2 whitespace joins = 5 tokens per chunk
182/// assert_eq!(result.indices, vec![3, 6, 7]);
183/// assert_eq!(result.token_counts, vec![5, 5, 1]); // includes whitespace tokens
184/// ```
185pub fn merge_splits(
186 token_counts: &[usize],
187 chunk_size: usize,
188 combine_whitespace: bool,
189) -> MergeResult {
190 // Early exit for empty input
191 if token_counts.is_empty() {
192 return MergeResult {
193 indices: vec![],
194 token_counts: vec![],
195 };
196 }
197
198 // If all token counts exceed chunk_size, return one chunk per segment
199 if token_counts.iter().all(|&c| c > chunk_size) {
200 let indices: Vec<usize> = (1..=token_counts.len()).collect();
201 return MergeResult {
202 indices,
203 token_counts: token_counts.to_vec(),
204 };
205 }
206
207 let indices = find_merge_indices(token_counts, chunk_size, combine_whitespace);
208 let merged_counts = compute_merged_token_counts(token_counts, &indices, combine_whitespace);
209
210 MergeResult {
211 indices,
212 token_counts: merged_counts,
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn test_find_merge_indices_basic() {
222 let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
223 let indices = find_merge_indices(&token_counts, 3, false);
224 // Should merge into groups of 3 tokens
225 assert_eq!(indices, vec![3, 6, 7]);
226 }
227
228 #[test]
229 fn test_find_merge_indices_with_whitespace() {
230 let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
231 let indices = find_merge_indices(&token_counts, 3, true);
232 // With whitespace: (n-1) joins add (n-1) whitespace tokens
233 // Chunk 0..2: raw=2, whitespace=1, total=3 <= 3. Fits.
234 // Chunk 2..4: raw=2, whitespace=1, total=3 <= 3. Fits.
235 // Chunk 4..6: raw=2, whitespace=1, total=3 <= 3. Fits.
236 // Chunk 6..7: raw=1, whitespace=0, total=1 <= 3. Fits.
237 assert_eq!(indices, vec![2, 4, 6, 7]);
238 }
239
240 #[test]
241 fn test_find_merge_indices_large_chunks() {
242 let token_counts = vec![10, 15, 20, 5, 8, 12];
243 let indices = find_merge_indices(&token_counts, 30, false);
244 // 10+15=25 < 30, 10+15+20=45 > 30 -> merge at 2
245 // 20 < 30, 20+5=25 < 30, 20+5+8=33 > 30 -> merge at 4
246 // 8+12=20 < 30 -> merge at 6
247 assert_eq!(indices, vec![2, 4, 6]);
248 }
249
250 #[test]
251 fn test_find_merge_indices_empty() {
252 let token_counts: Vec<usize> = vec![];
253 let indices = find_merge_indices(&token_counts, 10, false);
254 assert!(indices.is_empty());
255 }
256
257 #[test]
258 fn test_find_merge_indices_single() {
259 let token_counts = vec![5];
260 let indices = find_merge_indices(&token_counts, 10, false);
261 assert_eq!(indices, vec![1]);
262 }
263
264 #[test]
265 fn test_find_merge_indices_all_large() {
266 // All segments exceed chunk_size
267 let token_counts = vec![50, 60, 70];
268 let indices = find_merge_indices(&token_counts, 30, false);
269 // Each segment becomes its own chunk
270 assert_eq!(indices, vec![1, 2, 3]);
271 }
272
273 #[test]
274 fn test_merge_splits_basic() {
275 let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
276 let result = merge_splits(&token_counts, 3, false);
277 assert_eq!(result.indices, vec![3, 6, 7]);
278 assert_eq!(result.token_counts, vec![3, 3, 1]);
279 }
280
281 #[test]
282 fn test_merge_splits_with_whitespace() {
283 let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
284 let result = merge_splits(&token_counts, 5, true);
285 // With whitespace tokens added for joins
286 // [1,1] + 1 whitespace = 3, [1,1] + 1 = 3, etc.
287 assert_eq!(result.indices, vec![3, 6, 7]);
288 assert_eq!(result.token_counts, vec![5, 5, 1]);
289 }
290
291 #[test]
292 fn test_merge_splits_empty() {
293 let token_counts: Vec<usize> = vec![];
294 let result = merge_splits(&token_counts, 10, false);
295 assert!(result.indices.is_empty());
296 assert!(result.token_counts.is_empty());
297 }
298
299 #[test]
300 fn test_merge_splits_all_exceed_limit() {
301 let token_counts = vec![50, 60, 70];
302 let result = merge_splits(&token_counts, 30, false);
303 assert_eq!(result.indices, vec![1, 2, 3]);
304 assert_eq!(result.token_counts, vec![50, 60, 70]);
305 }
306
307 #[test]
308 fn test_compute_merged_token_counts() {
309 let token_counts = vec![10, 15, 20, 5, 8, 12];
310 let indices = vec![2, 4, 6];
311 let merged = compute_merged_token_counts(&token_counts, &indices, false);
312 assert_eq!(merged, vec![25, 25, 20]); // 10+15, 20+5, 8+12
313 }
314
315 #[test]
316 fn test_compute_merged_token_counts_with_whitespace() {
317 let token_counts = vec![1, 1, 1];
318 let indices = vec![3];
319 let merged = compute_merged_token_counts(&token_counts, &indices, true);
320 // 1+1+1 + 2 whitespace tokens = 5
321 assert_eq!(merged, vec![5]);
322 }
323}