1pub fn find_merge_indices(token_counts: &[usize], chunk_size: usize) -> Vec<usize> {
32 if token_counts.is_empty() {
33 return vec![];
34 }
35
36 let n = token_counts.len();
37
38 let mut cumulative: Vec<usize> = Vec::with_capacity(n + 1);
40 cumulative.push(0);
41
42 let mut sum = 0usize;
43 for &count in token_counts {
44 sum += count;
45 cumulative.push(sum);
46 }
47
48 let mut indices = Vec::new();
50 let mut current_pos = 0;
51
52 while current_pos < n {
53 let mut left = current_pos + 1;
55 let mut right = n + 1;
56
57 while left < right {
58 let mid = (left + right) / 2;
59 let fits = cumulative[mid] - cumulative[current_pos] <= chunk_size;
60
61 if fits {
62 left = mid + 1;
63 } else {
64 right = mid;
65 }
66 }
67
68 let index = left.saturating_sub(1).max(current_pos + 1).min(n);
70
71 indices.push(index);
72 current_pos = index;
73 }
74
75 indices
76}
77
78fn compute_merged_token_counts(token_counts: &[usize], merge_indices: &[usize]) -> Vec<usize> {
80 if merge_indices.is_empty() {
81 return vec![];
82 }
83
84 let mut result = Vec::with_capacity(merge_indices.len());
85 let mut start = 0;
86
87 for &end in merge_indices {
88 let end = end.min(token_counts.len());
89 let sum: usize = token_counts[start..end].iter().sum();
90 result.push(sum);
91 start = end;
92 }
93
94 result
95}
96
97#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct MergeResult {
100 pub merged: Vec<String>,
102 pub token_counts: Vec<usize>,
104}
105
106pub fn merge_splits(
133 splits: &[&str],
134 token_counts: &[usize],
135 chunk_size: usize,
136) -> MergeResult {
137 if splits.is_empty() || token_counts.is_empty() {
139 return MergeResult {
140 merged: vec![],
141 token_counts: vec![],
142 };
143 }
144
145 if token_counts.iter().all(|&c| c > chunk_size) {
147 return MergeResult {
148 merged: splits.iter().map(|s| s.to_string()).collect(),
149 token_counts: token_counts.to_vec(),
150 };
151 }
152
153 let indices = find_merge_indices(token_counts, chunk_size);
154 let merged_counts = compute_merged_token_counts(token_counts, &indices);
155
156 let mut merged = Vec::with_capacity(indices.len());
158 let mut start = 0;
159
160 for &end in &indices {
161 let end = end.min(splits.len());
162 let total_len: usize = splits[start..end].iter().map(|s| s.len()).sum();
164 let mut s = String::with_capacity(total_len);
165 for segment in &splits[start..end] {
166 s.push_str(segment);
167 }
168 merged.push(s);
169 start = end;
170 }
171
172 MergeResult {
173 merged,
174 token_counts: merged_counts,
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_find_merge_indices_basic() {
184 let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
185 let indices = find_merge_indices(&token_counts, 3);
186 assert_eq!(indices, vec![3, 6, 7]);
187 }
188
189 #[test]
190 fn test_find_merge_indices_large_chunks() {
191 let token_counts = vec![10, 15, 20, 5, 8, 12];
192 let indices = find_merge_indices(&token_counts, 30);
193 assert_eq!(indices, vec![2, 4, 6]);
194 }
195
196 #[test]
197 fn test_find_merge_indices_empty() {
198 let token_counts: Vec<usize> = vec![];
199 let indices = find_merge_indices(&token_counts, 10);
200 assert!(indices.is_empty());
201 }
202
203 #[test]
204 fn test_find_merge_indices_single() {
205 let token_counts = vec![5];
206 let indices = find_merge_indices(&token_counts, 10);
207 assert_eq!(indices, vec![1]);
208 }
209
210 #[test]
211 fn test_find_merge_indices_all_large() {
212 let token_counts = vec![50, 60, 70];
213 let indices = find_merge_indices(&token_counts, 30);
214 assert_eq!(indices, vec![1, 2, 3]);
215 }
216
217 #[test]
218 fn test_merge_splits_basic() {
219 let splits = vec!["a", "b", "c", "d", "e", "f", "g"];
220 let token_counts = vec![1, 1, 1, 1, 1, 1, 1];
221 let result = merge_splits(&splits, &token_counts, 3);
222 assert_eq!(result.merged, vec!["abc", "def", "g"]);
223 assert_eq!(result.token_counts, vec![3, 3, 1]);
224 }
225
226 #[test]
227 fn test_merge_splits_empty() {
228 let splits: Vec<&str> = vec![];
229 let token_counts: Vec<usize> = vec![];
230 let result = merge_splits(&splits, &token_counts, 10);
231 assert!(result.merged.is_empty());
232 assert!(result.token_counts.is_empty());
233 }
234
235 #[test]
236 fn test_merge_splits_all_exceed_limit() {
237 let splits = vec!["aaa", "bbb", "ccc"];
238 let token_counts = vec![50, 60, 70];
239 let result = merge_splits(&splits, &token_counts, 30);
240 assert_eq!(result.merged, vec!["aaa", "bbb", "ccc"]);
241 assert_eq!(result.token_counts, vec![50, 60, 70]);
242 }
243}