batch_mode_batch_scribe/
construct_batches.rs

1// ---------------- [ File: batch-mode-batch-scribe/src/construct_batches.rs ]
2crate::ix!();
3
4/// Break requests into workable batches.
5pub fn construct_batches(
6    requests:                 &[LanguageModelBatchAPIRequest], 
7    requests_per_batch:       usize,
8    throw_on_too_small_batch: bool,
9
10) -> Result<Enumerate<Chunks<'_,LanguageModelBatchAPIRequest>>,LanguageModelBatchCreationError> {
11
12    let mut batches = requests.chunks(requests_per_batch).enumerate();
13
14    // If there's exactly 1 chunk, and it's under 32, panic:
15    if batches.len() == 1 && throw_on_too_small_batch {
16        let only_batch_len = batches.nth(0).unwrap().1.len();
17        if only_batch_len < 32 {
18            return Err(LanguageModelBatchCreationError::TrivialBatchSizeBlocked { len: only_batch_len });
19        }
20    }
21    info!(
22        "Constructing {} batch(es), each with max {} items",
23        batches.len(),
24        requests_per_batch
25    );
26
27    // Rebuild the enumerator, because we consumed it with nth(0).
28    Ok(requests.chunks(requests_per_batch).enumerate())
29}
30
31#[cfg(test)]
32mod construct_batches_exhaustive_tests {
33    use super::*;
34
35    // A simple helper to build test requests.
36    // (Replace with actual construction logic as needed.)
37    fn build_requests(count: usize) -> Vec<LanguageModelBatchAPIRequest> {
38        (0..count).map(|c| LanguageModelBatchAPIRequest::mock(&format!("{c}"))).collect()
39    }
40
41    #[traced_test]
42    async fn empty_requests_returns_no_batches() {
43        trace!("===== BEGIN TEST: empty_requests_returns_no_batches =====");
44        let requests = build_requests(0);
45        let requests_per_batch = 5;
46        let throw_on_too_small_batch = false;
47
48        trace!(
49            "Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
50            requests.len(),
51            requests_per_batch,
52            throw_on_too_small_batch
53        );
54
55        let result: Vec<_> = construct_batches(
56            &requests,
57            requests_per_batch,
58            throw_on_too_small_batch
59        ).unwrap().collect();
60
61        debug!("Number of batches returned: {}", result.len());
62        pretty_assert_eq!(result.len(), 0, "Expected no batches for empty requests");
63
64        trace!("===== END TEST: empty_requests_returns_no_batches =====");
65    }
66
67    #[traced_test]
68    async fn single_batch_at_least_32_no_panic_with_flag() {
69        trace!("===== BEGIN TEST: single_batch_at_least_32_no_panic_with_flag =====");
70        // 1 chunk, size exactly 32
71        let requests = build_requests(32);
72        let requests_per_batch = 40;
73        let throw_on_too_small_batch = true;
74
75        trace!(
76            "Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
77            requests.len(),
78            requests_per_batch,
79            throw_on_too_small_batch
80        );
81
82        let result: Vec<_> = construct_batches(
83            &requests,
84            requests_per_batch,
85            throw_on_too_small_batch
86        ).unwrap().collect();
87
88        debug!("Number of batches returned: {}", result.len());
89        pretty_assert_eq!(result.len(), 1, "Expected exactly one batch");
90        let first_batch = &result[0].1;
91        debug!("Size of the single batch: {}", first_batch.len());
92        pretty_assert_eq!(first_batch.len(), 32, "Batch should contain 32 requests");
93
94        trace!("===== END TEST: single_batch_at_least_32_no_panic_with_flag =====");
95    }
96
97    #[traced_test]
98    async fn single_batch_under_32_no_panic_without_flag() {
99        trace!("===== BEGIN TEST: single_batch_under_32_no_panic_without_flag =====");
100        // 1 chunk, size under 32, but throw_on_too_small_batch=false
101        let requests = build_requests(10);
102        let requests_per_batch = 10;
103        let throw_on_too_small_batch = false;
104
105        trace!(
106            "Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
107            requests.len(),
108            requests_per_batch,
109            throw_on_too_small_batch
110        );
111
112        let result: Vec<_> = construct_batches(
113            &requests,
114            requests_per_batch,
115            throw_on_too_small_batch
116        ).unwrap().collect();
117
118        debug!("Number of batches returned: {}", result.len());
119        pretty_assert_eq!(result.len(), 1, "Expected exactly one batch");
120        let first_batch = &result[0].1;
121        debug!("Size of the single batch: {}", first_batch.len());
122        pretty_assert_eq!(first_batch.len(), 10, "Batch should contain 10 requests");
123
124        trace!("===== END TEST: single_batch_under_32_no_panic_without_flag =====");
125    }
126
127    #[traced_test]
128    async fn multiple_batches_with_remainder() {
129        trace!("===== BEGIN TEST: multiple_batches_with_remainder =====");
130        // e.g. 50 requests, 20 per batch => 3 batches: 20, 20, 10
131        let requests = build_requests(50);
132        let requests_per_batch = 20;
133        let throw_on_too_small_batch = false;
134
135        trace!(
136            "Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
137            requests.len(),
138            requests_per_batch,
139            throw_on_too_small_batch
140        );
141
142        let result: Vec<_> = construct_batches(
143            &requests,
144            requests_per_batch,
145            throw_on_too_small_batch
146        ).unwrap().collect();
147
148        debug!("Number of batches returned: {}", result.len());
149        pretty_assert_eq!(result.len(), 3, "Expected 3 batches total");
150
151        let sizes: Vec<usize> = result.iter().map(|(_, chunk)| chunk.len()).collect();
152        debug!("Batch sizes: {:?}", sizes);
153        pretty_assert_eq!(sizes, vec![20, 20, 10], "Unexpected chunk sizes");
154
155        trace!("===== END TEST: multiple_batches_with_remainder =====");
156    }
157
158    #[traced_test]
159    async fn multiple_batches_exact_division() {
160        trace!("===== BEGIN TEST: multiple_batches_exact_division =====");
161        // e.g. 40 requests, 10 per batch => 4 batches
162        let requests = build_requests(40);
163        let requests_per_batch = 10;
164        let throw_on_too_small_batch = false;
165
166        trace!(
167            "Constructing batches with {} requests, {} per batch, throw_on_too_small_batch={}",
168            requests.len(),
169            requests_per_batch,
170            throw_on_too_small_batch
171        );
172
173        let result: Vec<_> = construct_batches(
174            &requests,
175            requests_per_batch,
176            throw_on_too_small_batch
177        ).unwrap().collect();
178
179        debug!("Number of batches returned: {}", result.len());
180        pretty_assert_eq!(result.len(), 4, "Expected 4 batches total");
181
182        for (index, (_, chunk)) in result.iter().enumerate() {
183            debug!("Batch index {} has size {}", index, chunk.len());
184            pretty_assert_eq!(
185                chunk.len(),
186                10,
187                "Expected each batch to have exactly 10 requests"
188            );
189        }
190
191        trace!("===== END TEST: multiple_batches_exact_division =====");
192    }
193}