batch_mode_batch_schema/
batch_output_data.rs

1// ---------------- [ File: batch-mode-batch-schema/src/batch_output_data.rs ]
2crate::ix!();
3
4#[derive(Builder,Getters,Clone,Debug,Serialize,Deserialize)]
5#[builder(setter(into))]
6#[getset(get="pub")]
7pub struct BatchOutputData {
8    responses: Vec<BatchResponseRecord>,
9}
10
11unsafe impl Send for BatchOutputData {}
12unsafe impl Sync for BatchOutputData {}
13
14impl BatchOutputData {
15
16    pub fn len(&self) -> usize {
17        self.responses.len()
18    }
19
20    pub fn new(responses: Vec<BatchResponseRecord>) -> Self {
21        Self { responses }
22    }
23
24    pub fn request_ids(&self) -> Vec<CustomRequestId> {
25        self.responses.iter().map(|r| r.custom_id().clone()).collect()
26    }
27
28    /// Returns an iterator over the BatchResponseRecord elements.
29    pub fn iter(&self) -> std::slice::Iter<BatchResponseRecord> {
30        self.responses.iter()
31    }
32}
33
34#[async_trait]
35impl LoadFromFile for BatchOutputData {
36
37    type Error = JsonParseError;
38
39    async fn load_from_file(
40        file_path: impl AsRef<Path> + Send,
41    ) -> Result<Self, Self::Error> {
42
43        let file   = File::open(file_path).await?;
44        let reader = BufReader::new(file);
45
46        let mut lines     = reader.lines();
47        let mut responses = Vec::new();
48
49        while let Some(line) = lines.next_line().await? {
50            let response_record: BatchResponseRecord = serde_json::from_str(&line)?;
51            responses.push(response_record);
52        }
53
54        Ok(BatchOutputData::new(responses))
55    }
56}
57
58impl From<Vec<BatchOutputData>> for BatchOutputData {
59    fn from(batch_outputs: Vec<BatchOutputData>) -> Self {
60        // Flatten the responses from all BatchOutputData instances into a single vector.
61        let aggregated_responses = batch_outputs
62            .into_iter()
63            .flat_map(|output_data| output_data.responses)
64            .collect();
65        BatchOutputData::new(aggregated_responses)
66    }
67}
68
69impl<'a> IntoIterator for &'a BatchOutputData {
70    type Item = &'a BatchResponseRecord;
71    type IntoIter = std::slice::Iter<'a, BatchResponseRecord>;
72
73    fn into_iter(self) -> Self::IntoIter {
74        self.responses.iter()
75    }
76}
77
78impl IntoIterator for BatchOutputData {
79    type Item = BatchResponseRecord;
80    type IntoIter = std::vec::IntoIter<BatchResponseRecord>;
81
82    fn into_iter(self) -> Self::IntoIter {
83        self.responses.into_iter()
84    }
85}
86
87#[cfg(test)]
88mod batch_output_data_tests {
89    use super::*;
90    use tempfile::NamedTempFile;
91    use std::io::Write;
92    use tokio::runtime::Runtime;
93
94    #[traced_test]
95    fn should_create_new_batch_output_data() {
96        info!("Testing construction of BatchOutputData using new.");
97
98        let records = vec![
99            BatchResponseRecord::mock_with_code("output-1", 200),
100            BatchResponseRecord::mock_with_code("output-2", 400),
101        ];
102        let output_data = BatchOutputData::new(records.clone());
103
104        pretty_assert_eq!(output_data.len(), records.len(), "Length should match the number of records.");
105        debug!("BatchOutputData created with length: {}", output_data.len());
106
107        let retrieved = output_data.responses();
108        pretty_assert_eq!(retrieved.len(), records.len(), "responses() should return the same number of records.");
109        trace!("responses() returned: {} items", retrieved.len());
110    }
111
112    #[traced_test]
113    fn should_return_request_ids_correctly() {
114        info!("Testing request_ids() for BatchOutputData.");
115
116        let records = vec![
117            BatchResponseRecord::mock_with_code("req-1", 200),
118            BatchResponseRecord::mock_with_code("req-2", 200),
119        ];
120        let output_data = BatchOutputData::new(records);
121
122        let ids = output_data.request_ids();
123        trace!("Extracted request IDs: {:?}", ids);
124
125        pretty_assert_eq!(ids.len(), 2, "Should have two request IDs.");
126        assert!(ids.contains(&CustomRequestId::new("req-1")));
127        assert!(ids.contains(&CustomRequestId::new("req-2")));
128    }
129
130    #[traced_test]
131    fn should_iterate_responses() {
132        info!("Testing the iter() method of BatchOutputData.");
133
134        let records = vec![
135            BatchResponseRecord::mock_with_code("iter-1", 200),
136            BatchResponseRecord::mock_with_code("iter-2", 200),
137        ];
138        let output_data = BatchOutputData::new(records.clone());
139
140        let mut count = 0;
141        for record in output_data.iter() {
142            trace!("Iterating record custom_id: {}", record.custom_id());
143            count += 1;
144        }
145        pretty_assert_eq!(count, records.len(), "Should iterate over all response records.");
146    }
147
148    #[traced_test]
149    fn should_iterate_with_into_iter_borrowed() {
150        info!("Testing IntoIterator for borrowed BatchOutputData.");
151
152        let records = vec![
153            BatchResponseRecord::mock_with_code("borrowed-1", 200),
154            BatchResponseRecord::mock_with_code("borrowed-2", 200),
155        ];
156        let output_data = BatchOutputData::new(records.clone());
157
158        let mut count = 0;
159        for record in &output_data {
160            trace!("Borrowed iteration on custom_id: {}", record.custom_id());
161            count += 1;
162        }
163        pretty_assert_eq!(count, records.len(), "Should iterate all records in borrowed form.");
164    }
165
166    #[traced_test]
167    fn should_iterate_with_into_iter_owned() {
168        info!("Testing IntoIterator for owned BatchOutputData.");
169
170        let records = vec![
171            BatchResponseRecord::mock_with_code("owned-1", 200),
172            BatchResponseRecord::mock_with_code("owned-2", 200),
173        ];
174        let output_data = BatchOutputData::new(records.clone());
175
176        let mut count = 0;
177        for record in output_data {
178            trace!("Owned iteration on custom_id: {}", record.custom_id());
179            count += 1;
180        }
181        pretty_assert_eq!(count, records.len(), "Should yield all records when owned iteration is used.");
182    }
183
184    #[traced_test]
185    fn should_convert_from_multiple_batch_output_data() {
186        info!("Testing the 'From<Vec<BatchOutputData>>' implementation.");
187
188        let batch_1 = BatchOutputData::new(vec![
189            BatchResponseRecord::mock_with_code("multi-1", 200),
190        ]);
191        let batch_2 = BatchOutputData::new(vec![
192            BatchResponseRecord::mock_with_code("multi-2", 400),
193            BatchResponseRecord::mock_with_code("multi-3", 400),
194        ]);
195
196        let combined = BatchOutputData::from(vec![batch_1, batch_2]);
197        pretty_assert_eq!(combined.len(), 3, "Expected combined vector length of 3.");
198        debug!("Combined length is: {}", combined.len());
199
200        let ids = combined.request_ids();
201        trace!("Combined request IDs: {:?}", ids);
202        pretty_assert_eq!(ids.len(), 3, "Should have 3 distinct request IDs total.");
203    }
204
205    #[traced_test]
206    fn should_handle_empty_new_batch_output_data() {
207        info!("Testing empty BatchOutputData creation.");
208
209        let output_data = BatchOutputData::new(vec![]);
210        pretty_assert_eq!(output_data.len(), 0, "Expected no records in empty BatchOutputData.");
211
212        let iter_count = output_data.iter().count();
213        pretty_assert_eq!(iter_count, 0, "Iterator should yield none for empty data.");
214        let ids = output_data.request_ids();
215        assert!(ids.is_empty(), "No IDs should be returned for empty data.");
216    }
217
218    #[traced_test]
219    fn should_handle_from_empty_vec_of_batch_output_data() {
220        info!("Testing 'From<Vec<BatchOutputData>>' with an empty list.");
221
222        let empty_vec: Vec<BatchOutputData> = vec![];
223        let result = BatchOutputData::from(empty_vec);
224
225        pretty_assert_eq!(result.len(), 0, "Should produce empty BatchOutputData from empty vector.");
226        trace!("No data aggregated, as expected.");
227    }
228
229    #[traced_test]
230    fn should_fail_load_from_file_with_invalid_json() {
231        info!("Testing load_from_file failure scenario with malformed JSON.");
232
233        let mut temp_file = NamedTempFile::new().expect("Failed to create NamedTempFile.");
234        // Write invalid JSON
235        writeln!(temp_file, "{{ invalid json }}").unwrap();
236
237        let rt = Runtime::new().expect("Failed to create tokio runtime.");
238        let result = rt.block_on(async {
239            BatchOutputData::load_from_file(temp_file.path()).await
240        });
241
242        assert!(result.is_err(), "Should fail when invalid JSON is encountered.");
243        error!("Received expected error for malformed JSON: {:?}", result.err());
244    }
245
246    #[traced_test]
247    fn should_load_from_file_successfully() {
248        info!("Testing load_from_file with a mock file in NDJSON format (one JSON object per line).");
249
250        // Put each JSON record on exactly one line (no multi-line objects).
251        // This is critical because our code parses each line as one complete JSON object.
252
253        // Single-line JSON 1
254        let line_1 = r#"{"id":"batch_req_file-1","custom_id":"file-1","response":{"status_code":200,"request_id":"resp_req_file-1","body":{"id":"success-id","object":"chat.completion","created":0,"model":"test-model","choices":[],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}},"error":null}"#;
255
256        // Single-line JSON 2 (has status_code=400 and an "error" object).
257        let line_2 = r#"{"id":"batch_req_file-2","custom_id":"file-2","response":{"status_code":400,"request_id":"resp_req_file-2","body":{"error":{"message":"Error for file-2","type":"test_error","param":null,"code":null},"object":"error"}},"error":null}"#;
258
259        // Create temp file and write these two lines
260        let mut temp_file = NamedTempFile::new().expect("Failed to create NamedTempFile.");
261        writeln!(temp_file, "{}", line_1).expect("Failed to write line_1");
262        writeln!(temp_file, "{}", line_2).expect("Failed to write line_2");
263
264        // Now parse using our load_from_file method
265        let rt = Runtime::new().expect("Failed to create tokio runtime.");
266        let result = rt.block_on(async {
267            BatchOutputData::load_from_file(temp_file.path()).await
268        });
269
270        assert!(result.is_ok(), "Expected successful load from file.");
271        let loaded_data = result.unwrap();
272        pretty_assert_eq!(loaded_data.len(), 2, "Should load exactly 2 records.");
273        debug!("Loaded {} records from file.", loaded_data.len());
274
275        let ids = loaded_data.request_ids();
276        trace!("Loaded request IDs: {:?}", ids);
277        assert!(ids.contains(&CustomRequestId::new("file-1")));
278        assert!(ids.contains(&CustomRequestId::new("file-2")));
279    }
280}