chunk/
merge.rs

1//! Token-aware merging for chunkers.
2//!
3//! This module provides functions to merge text segments based on token counts,
4//! equivalent to Chonkie's Cython `merge.pyx`. Used by RecursiveChunker and
5//! other chunkers that need to respect token limits.
6
7/// Find merge indices for combining segments within token limits.
8///
9/// This is the core algorithm used by RecursiveChunker to merge small segments
10/// into larger chunks that fit within a token budget.
11///
12/// # Arguments
13///
14/// * `token_counts` - Token count for each segment
15/// * `chunk_size` - Maximum tokens per merged chunk
16/// * `combine_whitespace` - If true, adds +1 token per join for whitespace
17///
18/// # Returns
19///
20/// Vector of end indices where merges should occur. Each index marks the
21/// exclusive end of a merged chunk.
22///
23/// # Example
24///
25/// ```
26/// use chunk::find_merge_indices;
27///
28/// let token_counts = vec![10, 15, 20, 5, 8, 12];
29/// let indices = find_merge_indices(&token_counts, 30, false);
30/// // Merge [0:2], [2:4], [4:6] -> indices = [2, 4, 6]
31/// ```
32pub fn find_merge_indices(
33    token_counts: &[usize],
34    chunk_size: usize,
35    combine_whitespace: bool,
36) -> Vec<usize> {
37    if token_counts.is_empty() {
38        return vec![];
39    }
40
41    let n = token_counts.len();
42
43    // Build cumulative token counts (raw, without whitespace adjustment)
44    let mut cumulative: Vec<usize> = Vec::with_capacity(n + 1);
45    cumulative.push(0);
46
47    let mut sum = 0usize;
48    for &count in token_counts {
49        sum += count;
50        cumulative.push(sum);
51    }
52
53    // Find merge indices using binary search
54    let mut indices = Vec::new();
55    let mut current_pos = 0;
56
57    while current_pos < n {
58        // For a chunk from current_pos to end:
59        // - Raw tokens: cumulative[end] - cumulative[current_pos]
60        // - Whitespace tokens (if combine_whitespace): (end - current_pos - 1) for joins
61        // - Total must be <= chunk_size
62        //
63        // With whitespace: cumulative[end] + end <= cumulative[current_pos] + current_pos + chunk_size + 1
64        // Without:         cumulative[end] <= cumulative[current_pos] + chunk_size
65
66        // Binary search for rightmost valid position
67        let mut left = current_pos + 1;
68        let mut right = n + 1;
69
70        while left < right {
71            let mid = (left + right) / 2;
72            let fits = if combine_whitespace {
73                // Total tokens = raw_sum + (end - start - 1) whitespace joins
74                let raw_sum = cumulative[mid] - cumulative[current_pos];
75                let whitespace = if mid > current_pos + 1 {
76                    mid - current_pos - 1
77                } else {
78                    0
79                };
80                raw_sum + whitespace <= chunk_size
81            } else {
82                cumulative[mid] - cumulative[current_pos] <= chunk_size
83            };
84
85            if fits {
86                left = mid + 1;
87            } else {
88                right = mid;
89            }
90        }
91
92        // left is now one past the last valid position
93        let index = left.saturating_sub(1).max(current_pos + 1).min(n);
94
95        indices.push(index);
96        current_pos = index;
97    }
98
99    indices
100}
101
102/// Compute merged token counts from merge indices.
103///
104/// Given the original token counts and merge indices from `find_merge_indices`,
105/// compute the token count for each merged chunk.
106///
107/// # Arguments
108///
109/// * `token_counts` - Original token counts for each segment
110/// * `merge_indices` - End indices from `find_merge_indices`
111/// * `combine_whitespace` - If true, adds +1 token per join for whitespace
112///
113/// # Returns
114///
115/// Vector of token counts for each merged chunk.
116pub fn compute_merged_token_counts(
117    token_counts: &[usize],
118    merge_indices: &[usize],
119    combine_whitespace: bool,
120) -> Vec<usize> {
121    if merge_indices.is_empty() {
122        return vec![];
123    }
124
125    let mut result = Vec::with_capacity(merge_indices.len());
126    let mut start = 0;
127
128    for &end in merge_indices {
129        let end = end.min(token_counts.len());
130        let mut sum: usize = token_counts[start..end].iter().sum();
131
132        if combine_whitespace && end > start {
133            // Add whitespace tokens for joins (n-1 joins for n segments)
134            sum += end - start - 1;
135        }
136
137        result.push(sum);
138        start = end;
139    }
140
141    result
142}
143
144/// Result of merge_splits operation.
145#[derive(Debug, Clone, PartialEq, Eq)]
146pub struct MergeResult {
147    /// End indices for each merged chunk (exclusive).
148    /// Use with slicing: segments[prev_end..end]
149    pub indices: Vec<usize>,
150    /// Token count for each merged chunk
151    pub token_counts: Vec<usize>,
152}
153
154/// Merge segments based on token counts, respecting chunk size limits.
155///
156/// This is the Rust equivalent of Chonkie's Cython `_merge_splits` function.
157/// Returns indices for slicing the original segments, rather than copying strings.
158///
159/// # Arguments
160///
161/// * `token_counts` - Token count for each segment
162/// * `chunk_size` - Maximum tokens per merged chunk
163/// * `combine_whitespace` - If true, join with whitespace (+1 token per join)
164///
165/// # Returns
166///
167/// `MergeResult` containing:
168/// - `indices`: End indices for slicing segments
169/// - `token_counts`: Token count for each merged chunk
170///
171/// # Example
172///
173/// ```
174/// use chunk::merge_splits;
175///
176/// // segments = ["Hello", "world", "!", "How", "are", "you", "?"]
177/// let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
178/// let result = merge_splits(&token_counts, 5, true);
179///
180/// // Use indices to slice: segments[0..3], segments[3..6], segments[6..7]
181/// // chunk_size=5 allows 3 segments + 2 whitespace joins = 5 tokens per chunk
182/// assert_eq!(result.indices, vec![3, 6, 7]);
183/// assert_eq!(result.token_counts, vec![5, 5, 1]); // includes whitespace tokens
184/// ```
185pub fn merge_splits(
186    token_counts: &[usize],
187    chunk_size: usize,
188    combine_whitespace: bool,
189) -> MergeResult {
190    // Early exit for empty input
191    if token_counts.is_empty() {
192        return MergeResult {
193            indices: vec![],
194            token_counts: vec![],
195        };
196    }
197
198    // If all token counts exceed chunk_size, return one chunk per segment
199    if token_counts.iter().all(|&c| c > chunk_size) {
200        let indices: Vec<usize> = (1..=token_counts.len()).collect();
201        return MergeResult {
202            indices,
203            token_counts: token_counts.to_vec(),
204        };
205    }
206
207    let indices = find_merge_indices(token_counts, chunk_size, combine_whitespace);
208    let merged_counts = compute_merged_token_counts(token_counts, &indices, combine_whitespace);
209
210    MergeResult {
211        indices,
212        token_counts: merged_counts,
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn test_find_merge_indices_basic() {
222        let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
223        let indices = find_merge_indices(&token_counts, 3, false);
224        // Should merge into groups of 3 tokens
225        assert_eq!(indices, vec![3, 6, 7]);
226    }
227
228    #[test]
229    fn test_find_merge_indices_with_whitespace() {
230        let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
231        let indices = find_merge_indices(&token_counts, 3, true);
232        // With whitespace: (n-1) joins add (n-1) whitespace tokens
233        // Chunk 0..2: raw=2, whitespace=1, total=3 <= 3. Fits.
234        // Chunk 2..4: raw=2, whitespace=1, total=3 <= 3. Fits.
235        // Chunk 4..6: raw=2, whitespace=1, total=3 <= 3. Fits.
236        // Chunk 6..7: raw=1, whitespace=0, total=1 <= 3. Fits.
237        assert_eq!(indices, vec![2, 4, 6, 7]);
238    }
239
240    #[test]
241    fn test_find_merge_indices_large_chunks() {
242        let token_counts = vec![10, 15, 20, 5, 8, 12];
243        let indices = find_merge_indices(&token_counts, 30, false);
244        // 10+15=25 < 30, 10+15+20=45 > 30 -> merge at 2
245        // 20 < 30, 20+5=25 < 30, 20+5+8=33 > 30 -> merge at 4
246        // 8+12=20 < 30 -> merge at 6
247        assert_eq!(indices, vec![2, 4, 6]);
248    }
249
250    #[test]
251    fn test_find_merge_indices_empty() {
252        let token_counts: Vec<usize> = vec![];
253        let indices = find_merge_indices(&token_counts, 10, false);
254        assert!(indices.is_empty());
255    }
256
257    #[test]
258    fn test_find_merge_indices_single() {
259        let token_counts = vec![5];
260        let indices = find_merge_indices(&token_counts, 10, false);
261        assert_eq!(indices, vec![1]);
262    }
263
264    #[test]
265    fn test_find_merge_indices_all_large() {
266        // All segments exceed chunk_size
267        let token_counts = vec![50, 60, 70];
268        let indices = find_merge_indices(&token_counts, 30, false);
269        // Each segment becomes its own chunk
270        assert_eq!(indices, vec![1, 2, 3]);
271    }
272
273    #[test]
274    fn test_merge_splits_basic() {
275        let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
276        let result = merge_splits(&token_counts, 3, false);
277        assert_eq!(result.indices, vec![3, 6, 7]);
278        assert_eq!(result.token_counts, vec![3, 3, 1]);
279    }
280
281    #[test]
282    fn test_merge_splits_with_whitespace() {
283        let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
284        let result = merge_splits(&token_counts, 5, true);
285        // With whitespace tokens added for joins
286        // [1,1] + 1 whitespace = 3, [1,1] + 1 = 3, etc.
287        assert_eq!(result.indices, vec![3, 6, 7]);
288        assert_eq!(result.token_counts, vec![5, 5, 1]);
289    }
290
291    #[test]
292    fn test_merge_splits_empty() {
293        let token_counts: Vec<usize> = vec![];
294        let result = merge_splits(&token_counts, 10, false);
295        assert!(result.indices.is_empty());
296        assert!(result.token_counts.is_empty());
297    }
298
299    #[test]
300    fn test_merge_splits_all_exceed_limit() {
301        let token_counts = vec![50, 60, 70];
302        let result = merge_splits(&token_counts, 30, false);
303        assert_eq!(result.indices, vec![1, 2, 3]);
304        assert_eq!(result.token_counts, vec![50, 60, 70]);
305    }
306
307    #[test]
308    fn test_compute_merged_token_counts() {
309        let token_counts = vec![10, 15, 20, 5, 8, 12];
310        let indices = vec![2, 4, 6];
311        let merged = compute_merged_token_counts(&token_counts, &indices, false);
312        assert_eq!(merged, vec![25, 25, 20]); // 10+15, 20+5, 8+12
313    }
314
315    #[test]
316    fn test_compute_merged_token_counts_with_whitespace() {
317        let token_counts = vec![1, 1, 1];
318        let indices = vec![3];
319        let merged = compute_merged_token_counts(&token_counts, &indices, true);
320        // 1+1+1 + 2 whitespace tokens = 5
321        assert_eq!(merged, vec![5]);
322    }
323}