crate::ix!();
pub fn construct_batches(
requests: &[LanguageModelBatchAPIRequest],
requests_per_batch: usize,
throw_on_too_small_batch: bool,
) -> Result<Enumerate<Chunks<'_,LanguageModelBatchAPIRequest>>,LanguageModelBatchCreationError> {
let mut batches = requests.chunks(requests_per_batch).enumerate();
if batches.len() == 1 && throw_on_too_small_batch {
let only_batch_len = batches.nth(0).unwrap().1.len();
if only_batch_len < 32 {
return Err(LanguageModelBatchCreationError::TrivialBatchSizeBlocked { len: only_batch_len });
}
}
info!(
"Constructing {} batch(es), each with max {} items",
batches.len(),
requests_per_batch
);
Ok(requests.chunks(requests_per_batch).enumerate())
}
#[cfg(test)]
mod construct_batches_exhaustive_tests {
use super::*;
fn build_requests(count: usize) -> Vec<LanguageModelBatchAPIRequest> {
(0..count).map(|c| LanguageModelBatchAPIRequest::mock(&format!("{c}"))).collect()
}
#[traced_test]
async fn empty_requests_returns_no_batches() {
trace!("===== BEGIN TEST: empty_requests_returns_no_batches =====");
let requests = build_requests(0);
let requests_per_batch = 5;
let throw_on_too_small_batch = false;
trace!(
"Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
requests.len(),
requests_per_batch,
throw_on_too_small_batch
);
let result: Vec<_> = construct_batches(
&requests,
requests_per_batch,
throw_on_too_small_batch
).unwrap().collect();
debug!("Number of batches returned: {}", result.len());
pretty_assert_eq!(result.len(), 0, "Expected no batches for empty requests");
trace!("===== END TEST: empty_requests_returns_no_batches =====");
}
#[traced_test]
async fn single_batch_at_least_32_no_panic_with_flag() {
trace!("===== BEGIN TEST: single_batch_at_least_32_no_panic_with_flag =====");
let requests = build_requests(32);
let requests_per_batch = 40;
let throw_on_too_small_batch = true;
trace!(
"Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
requests.len(),
requests_per_batch,
throw_on_too_small_batch
);
let result: Vec<_> = construct_batches(
&requests,
requests_per_batch,
throw_on_too_small_batch
).unwrap().collect();
debug!("Number of batches returned: {}", result.len());
pretty_assert_eq!(result.len(), 1, "Expected exactly one batch");
let first_batch = &result[0].1;
debug!("Size of the single batch: {}", first_batch.len());
pretty_assert_eq!(first_batch.len(), 32, "Batch should contain 32 requests");
trace!("===== END TEST: single_batch_at_least_32_no_panic_with_flag =====");
}
#[traced_test]
async fn single_batch_under_32_no_panic_without_flag() {
trace!("===== BEGIN TEST: single_batch_under_32_no_panic_without_flag =====");
let requests = build_requests(10);
let requests_per_batch = 10;
let throw_on_too_small_batch = false;
trace!(
"Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
requests.len(),
requests_per_batch,
throw_on_too_small_batch
);
let result: Vec<_> = construct_batches(
&requests,
requests_per_batch,
throw_on_too_small_batch
).unwrap().collect();
debug!("Number of batches returned: {}", result.len());
pretty_assert_eq!(result.len(), 1, "Expected exactly one batch");
let first_batch = &result[0].1;
debug!("Size of the single batch: {}", first_batch.len());
pretty_assert_eq!(first_batch.len(), 10, "Batch should contain 10 requests");
trace!("===== END TEST: single_batch_under_32_no_panic_without_flag =====");
}
#[traced_test]
async fn multiple_batches_with_remainder() {
trace!("===== BEGIN TEST: multiple_batches_with_remainder =====");
let requests = build_requests(50);
let requests_per_batch = 20;
let throw_on_too_small_batch = false;
trace!(
"Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
requests.len(),
requests_per_batch,
throw_on_too_small_batch
);
let result: Vec<_> = construct_batches(
&requests,
requests_per_batch,
throw_on_too_small_batch
).unwrap().collect();
debug!("Number of batches returned: {}", result.len());
pretty_assert_eq!(result.len(), 3, "Expected 3 batches total");
let sizes: Vec<usize> = result.iter().map(|(_, chunk)| chunk.len()).collect();
debug!("Batch sizes: {:?}", sizes);
pretty_assert_eq!(sizes, vec![20, 20, 10], "Unexpected chunk sizes");
trace!("===== END TEST: multiple_batches_with_remainder =====");
}
#[traced_test]
async fn multiple_batches_exact_division() {
trace!("===== BEGIN TEST: multiple_batches_exact_division =====");
let requests = build_requests(40);
let requests_per_batch = 10;
let throw_on_too_small_batch = false;
trace!(
"Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
requests.len(),
requests_per_batch,
throw_on_too_small_batch
);
let result: Vec<_> = construct_batches(
&requests,
requests_per_batch,
throw_on_too_small_batch
).unwrap().collect();
debug!("Number of batches returned: {}", result.len());
pretty_assert_eq!(result.len(), 4, "Expected 4 batches total");
for (index, (_, chunk)) in result.iter().enumerate() {
debug!("Batch index {} has size {}", index, chunk.len());
pretty_assert_eq!(
chunk.len(),
10,
"Expected each batch to have exactly 10 requests"
);
}
trace!("===== END TEST: multiple_batches_exact_division =====");
}
}