batch_mode_batch_scribe/
construct_batches.rs1crate::ix!();
3
4pub fn construct_batches(
6 requests: &[LanguageModelBatchAPIRequest],
7 requests_per_batch: usize,
8 throw_on_too_small_batch: bool,
9
10) -> Result<Enumerate<Chunks<'_,LanguageModelBatchAPIRequest>>,LanguageModelBatchCreationError> {
11
12 let mut batches = requests.chunks(requests_per_batch).enumerate();
13
14 if batches.len() == 1 && throw_on_too_small_batch {
16 let only_batch_len = batches.nth(0).unwrap().1.len();
17 if only_batch_len < 32 {
18 return Err(LanguageModelBatchCreationError::TrivialBatchSizeBlocked { len: only_batch_len });
19 }
20 }
21 info!(
22 "Constructing {} batch(es), each with max {} items",
23 batches.len(),
24 requests_per_batch
25 );
26
27 Ok(requests.chunks(requests_per_batch).enumerate())
29}
30
31#[cfg(test)]
32mod construct_batches_exhaustive_tests {
33 use super::*;
34
35 fn build_requests(count: usize) -> Vec<LanguageModelBatchAPIRequest> {
38 (0..count).map(|c| LanguageModelBatchAPIRequest::mock(&format!("{c}"))).collect()
39 }
40
41 #[traced_test]
42 async fn empty_requests_returns_no_batches() {
43 trace!("===== BEGIN TEST: empty_requests_returns_no_batches =====");
44 let requests = build_requests(0);
45 let requests_per_batch = 5;
46 let throw_on_too_small_batch = false;
47
48 trace!(
49 "Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
50 requests.len(),
51 requests_per_batch,
52 throw_on_too_small_batch
53 );
54
55 let result: Vec<_> = construct_batches(
56 &requests,
57 requests_per_batch,
58 throw_on_too_small_batch
59 ).unwrap().collect();
60
61 debug!("Number of batches returned: {}", result.len());
62 pretty_assert_eq!(result.len(), 0, "Expected no batches for empty requests");
63
64 trace!("===== END TEST: empty_requests_returns_no_batches =====");
65 }
66
67 #[traced_test]
68 async fn single_batch_at_least_32_no_panic_with_flag() {
69 trace!("===== BEGIN TEST: single_batch_at_least_32_no_panic_with_flag =====");
70 let requests = build_requests(32);
72 let requests_per_batch = 40;
73 let throw_on_too_small_batch = true;
74
75 trace!(
76 "Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
77 requests.len(),
78 requests_per_batch,
79 throw_on_too_small_batch
80 );
81
82 let result: Vec<_> = construct_batches(
83 &requests,
84 requests_per_batch,
85 throw_on_too_small_batch
86 ).unwrap().collect();
87
88 debug!("Number of batches returned: {}", result.len());
89 pretty_assert_eq!(result.len(), 1, "Expected exactly one batch");
90 let first_batch = &result[0].1;
91 debug!("Size of the single batch: {}", first_batch.len());
92 pretty_assert_eq!(first_batch.len(), 32, "Batch should contain 32 requests");
93
94 trace!("===== END TEST: single_batch_at_least_32_no_panic_with_flag =====");
95 }
96
97 #[traced_test]
98 async fn single_batch_under_32_no_panic_without_flag() {
99 trace!("===== BEGIN TEST: single_batch_under_32_no_panic_without_flag =====");
100 let requests = build_requests(10);
102 let requests_per_batch = 10;
103 let throw_on_too_small_batch = false;
104
105 trace!(
106 "Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
107 requests.len(),
108 requests_per_batch,
109 throw_on_too_small_batch
110 );
111
112 let result: Vec<_> = construct_batches(
113 &requests,
114 requests_per_batch,
115 throw_on_too_small_batch
116 ).unwrap().collect();
117
118 debug!("Number of batches returned: {}", result.len());
119 pretty_assert_eq!(result.len(), 1, "Expected exactly one batch");
120 let first_batch = &result[0].1;
121 debug!("Size of the single batch: {}", first_batch.len());
122 pretty_assert_eq!(first_batch.len(), 10, "Batch should contain 10 requests");
123
124 trace!("===== END TEST: single_batch_under_32_no_panic_without_flag =====");
125 }
126
127 #[traced_test]
128 async fn multiple_batches_with_remainder() {
129 trace!("===== BEGIN TEST: multiple_batches_with_remainder =====");
130 let requests = build_requests(50);
132 let requests_per_batch = 20;
133 let throw_on_too_small_batch = false;
134
135 trace!(
136 "Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
137 requests.len(),
138 requests_per_batch,
139 throw_on_too_small_batch
140 );
141
142 let result: Vec<_> = construct_batches(
143 &requests,
144 requests_per_batch,
145 throw_on_too_small_batch
146 ).unwrap().collect();
147
148 debug!("Number of batches returned: {}", result.len());
149 pretty_assert_eq!(result.len(), 3, "Expected 3 batches total");
150
151 let sizes: Vec<usize> = result.iter().map(|(_, chunk)| chunk.len()).collect();
152 debug!("Batch sizes: {:?}", sizes);
153 pretty_assert_eq!(sizes, vec![20, 20, 10], "Unexpected chunk sizes");
154
155 trace!("===== END TEST: multiple_batches_with_remainder =====");
156 }
157
158 #[traced_test]
159 async fn multiple_batches_exact_division() {
160 trace!("===== BEGIN TEST: multiple_batches_exact_division =====");
161 let requests = build_requests(40);
163 let requests_per_batch = 10;
164 let throw_on_too_small_batch = false;
165
166 trace!(
167 "Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
168 requests.len(),
169 requests_per_batch,
170 throw_on_too_small_batch
171 );
172
173 let result: Vec<_> = construct_batches(
174 &requests,
175 requests_per_batch,
176 throw_on_too_small_batch
177 ).unwrap().collect();
178
179 debug!("Number of batches returned: {}", result.len());
180 pretty_assert_eq!(result.len(), 4, "Expected 4 batches total");
181
182 for (index, (_, chunk)) in result.iter().enumerate() {
183 debug!("Batch index {} has size {}", index, chunk.len());
184 pretty_assert_eq!(
185 chunk.len(),
186 10,
187 "Expected each batch to have exactly 10 requests"
188 );
189 }
190
191 trace!("===== END TEST: multiple_batches_exact_division =====");
192 }
193}