1crate::ix!();
3
4#[derive(Builder,Getters,Clone,Debug,Serialize,Deserialize)]
5#[builder(setter(into))]
6#[getset(get="pub")]
7pub struct BatchOutputData {
8 responses: Vec<BatchResponseRecord>,
9}
10
11unsafe impl Send for BatchOutputData {}
12unsafe impl Sync for BatchOutputData {}
13
14impl BatchOutputData {
15
16 pub fn len(&self) -> usize {
17 self.responses.len()
18 }
19
20 pub fn new(responses: Vec<BatchResponseRecord>) -> Self {
21 Self { responses }
22 }
23
24 pub fn request_ids(&self) -> Vec<CustomRequestId> {
25 self.responses.iter().map(|r| r.custom_id().clone()).collect()
26 }
27
28 pub fn iter(&self) -> std::slice::Iter<BatchResponseRecord> {
30 self.responses.iter()
31 }
32}
33
34#[async_trait]
35impl LoadFromFile for BatchOutputData {
36
37 type Error = JsonParseError;
38
39 async fn load_from_file(
40 file_path: impl AsRef<Path> + Send,
41 ) -> Result<Self, Self::Error> {
42
43 let file = File::open(file_path).await?;
44 let reader = BufReader::new(file);
45
46 let mut lines = reader.lines();
47 let mut responses = Vec::new();
48
49 while let Some(line) = lines.next_line().await? {
50 let response_record: BatchResponseRecord = serde_json::from_str(&line)?;
51 responses.push(response_record);
52 }
53
54 Ok(BatchOutputData::new(responses))
55 }
56}
57
58impl From<Vec<BatchOutputData>> for BatchOutputData {
59 fn from(batch_outputs: Vec<BatchOutputData>) -> Self {
60 let aggregated_responses = batch_outputs
62 .into_iter()
63 .flat_map(|output_data| output_data.responses)
64 .collect();
65 BatchOutputData::new(aggregated_responses)
66 }
67}
68
69impl<'a> IntoIterator for &'a BatchOutputData {
70 type Item = &'a BatchResponseRecord;
71 type IntoIter = std::slice::Iter<'a, BatchResponseRecord>;
72
73 fn into_iter(self) -> Self::IntoIter {
74 self.responses.iter()
75 }
76}
77
78impl IntoIterator for BatchOutputData {
79 type Item = BatchResponseRecord;
80 type IntoIter = std::vec::IntoIter<BatchResponseRecord>;
81
82 fn into_iter(self) -> Self::IntoIter {
83 self.responses.into_iter()
84 }
85}
86
87#[cfg(test)]
88mod batch_output_data_tests {
89 use super::*;
90 use tempfile::NamedTempFile;
91 use std::io::Write;
92 use tokio::runtime::Runtime;
93
94 #[traced_test]
95 fn should_create_new_batch_output_data() {
96 info!("Testing construction of BatchOutputData using new.");
97
98 let records = vec![
99 BatchResponseRecord::mock_with_code("output-1", 200),
100 BatchResponseRecord::mock_with_code("output-2", 400),
101 ];
102 let output_data = BatchOutputData::new(records.clone());
103
104 pretty_assert_eq!(output_data.len(), records.len(), "Length should match the number of records.");
105 debug!("BatchOutputData created with length: {}", output_data.len());
106
107 let retrieved = output_data.responses();
108 pretty_assert_eq!(retrieved.len(), records.len(), "responses() should return the same number of records.");
109 trace!("responses() returned: {} items", retrieved.len());
110 }
111
112 #[traced_test]
113 fn should_return_request_ids_correctly() {
114 info!("Testing request_ids() for BatchOutputData.");
115
116 let records = vec![
117 BatchResponseRecord::mock_with_code("req-1", 200),
118 BatchResponseRecord::mock_with_code("req-2", 200),
119 ];
120 let output_data = BatchOutputData::new(records);
121
122 let ids = output_data.request_ids();
123 trace!("Extracted request IDs: {:?}", ids);
124
125 pretty_assert_eq!(ids.len(), 2, "Should have two request IDs.");
126 assert!(ids.contains(&CustomRequestId::new("req-1")));
127 assert!(ids.contains(&CustomRequestId::new("req-2")));
128 }
129
130 #[traced_test]
131 fn should_iterate_responses() {
132 info!("Testing the iter() method of BatchOutputData.");
133
134 let records = vec![
135 BatchResponseRecord::mock_with_code("iter-1", 200),
136 BatchResponseRecord::mock_with_code("iter-2", 200),
137 ];
138 let output_data = BatchOutputData::new(records.clone());
139
140 let mut count = 0;
141 for record in output_data.iter() {
142 trace!("Iterating record custom_id: {}", record.custom_id());
143 count += 1;
144 }
145 pretty_assert_eq!(count, records.len(), "Should iterate over all response records.");
146 }
147
148 #[traced_test]
149 fn should_iterate_with_into_iter_borrowed() {
150 info!("Testing IntoIterator for borrowed BatchOutputData.");
151
152 let records = vec![
153 BatchResponseRecord::mock_with_code("borrowed-1", 200),
154 BatchResponseRecord::mock_with_code("borrowed-2", 200),
155 ];
156 let output_data = BatchOutputData::new(records.clone());
157
158 let mut count = 0;
159 for record in &output_data {
160 trace!("Borrowed iteration on custom_id: {}", record.custom_id());
161 count += 1;
162 }
163 pretty_assert_eq!(count, records.len(), "Should iterate all records in borrowed form.");
164 }
165
166 #[traced_test]
167 fn should_iterate_with_into_iter_owned() {
168 info!("Testing IntoIterator for owned BatchOutputData.");
169
170 let records = vec![
171 BatchResponseRecord::mock_with_code("owned-1", 200),
172 BatchResponseRecord::mock_with_code("owned-2", 200),
173 ];
174 let output_data = BatchOutputData::new(records.clone());
175
176 let mut count = 0;
177 for record in output_data {
178 trace!("Owned iteration on custom_id: {}", record.custom_id());
179 count += 1;
180 }
181 pretty_assert_eq!(count, records.len(), "Should yield all records when owned iteration is used.");
182 }
183
184 #[traced_test]
185 fn should_convert_from_multiple_batch_output_data() {
186 info!("Testing the 'From<Vec<BatchOutputData>>' implementation.");
187
188 let batch_1 = BatchOutputData::new(vec![
189 BatchResponseRecord::mock_with_code("multi-1", 200),
190 ]);
191 let batch_2 = BatchOutputData::new(vec![
192 BatchResponseRecord::mock_with_code("multi-2", 400),
193 BatchResponseRecord::mock_with_code("multi-3", 400),
194 ]);
195
196 let combined = BatchOutputData::from(vec![batch_1, batch_2]);
197 pretty_assert_eq!(combined.len(), 3, "Expected combined vector length of 3.");
198 debug!("Combined length is: {}", combined.len());
199
200 let ids = combined.request_ids();
201 trace!("Combined request IDs: {:?}", ids);
202 pretty_assert_eq!(ids.len(), 3, "Should have 3 distinct request IDs total.");
203 }
204
205 #[traced_test]
206 fn should_handle_empty_new_batch_output_data() {
207 info!("Testing empty BatchOutputData creation.");
208
209 let output_data = BatchOutputData::new(vec![]);
210 pretty_assert_eq!(output_data.len(), 0, "Expected no records in empty BatchOutputData.");
211
212 let iter_count = output_data.iter().count();
213 pretty_assert_eq!(iter_count, 0, "Iterator should yield none for empty data.");
214 let ids = output_data.request_ids();
215 assert!(ids.is_empty(), "No IDs should be returned for empty data.");
216 }
217
218 #[traced_test]
219 fn should_handle_from_empty_vec_of_batch_output_data() {
220 info!("Testing 'From<Vec<BatchOutputData>>' with an empty list.");
221
222 let empty_vec: Vec<BatchOutputData> = vec![];
223 let result = BatchOutputData::from(empty_vec);
224
225 pretty_assert_eq!(result.len(), 0, "Should produce empty BatchOutputData from empty vector.");
226 trace!("No data aggregated, as expected.");
227 }
228
229 #[traced_test]
230 fn should_fail_load_from_file_with_invalid_json() {
231 info!("Testing load_from_file failure scenario with malformed JSON.");
232
233 let mut temp_file = NamedTempFile::new().expect("Failed to create NamedTempFile.");
234 writeln!(temp_file, "{{ invalid json }}").unwrap();
236
237 let rt = Runtime::new().expect("Failed to create tokio runtime.");
238 let result = rt.block_on(async {
239 BatchOutputData::load_from_file(temp_file.path()).await
240 });
241
242 assert!(result.is_err(), "Should fail when invalid JSON is encountered.");
243 error!("Received expected error for malformed JSON: {:?}", result.err());
244 }
245
246 #[traced_test]
247 fn should_load_from_file_successfully() {
248 info!("Testing load_from_file with a mock file in NDJSON format (one JSON object per line).");
249
250 let line_1 = r#"{"id":"batch_req_file-1","custom_id":"file-1","response":{"status_code":200,"request_id":"resp_req_file-1","body":{"id":"success-id","object":"chat.completion","created":0,"model":"test-model","choices":[],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}},"error":null}"#;
255
256 let line_2 = r#"{"id":"batch_req_file-2","custom_id":"file-2","response":{"status_code":400,"request_id":"resp_req_file-2","body":{"error":{"message":"Error for file-2","type":"test_error","param":null,"code":null},"object":"error"}},"error":null}"#;
258
259 let mut temp_file = NamedTempFile::new().expect("Failed to create NamedTempFile.");
261 writeln!(temp_file, "{}", line_1).expect("Failed to write line_1");
262 writeln!(temp_file, "{}", line_2).expect("Failed to write line_2");
263
264 let rt = Runtime::new().expect("Failed to create tokio runtime.");
266 let result = rt.block_on(async {
267 BatchOutputData::load_from_file(temp_file.path()).await
268 });
269
270 assert!(result.is_ok(), "Expected successful load from file.");
271 let loaded_data = result.unwrap();
272 pretty_assert_eq!(loaded_data.len(), 2, "Should load exactly 2 records.");
273 debug!("Loaded {} records from file.", loaded_data.len());
274
275 let ids = loaded_data.request_ids();
276 trace!("Loaded request IDs: {:?}", ids);
277 assert!(ids.contains(&CustomRequestId::new("file-1")));
278 assert!(ids.contains(&CustomRequestId::new("file-2")));
279 }
280}