batch_mode_batch_triple/
ensure_input_matches_output_and_error.rs

1// ---------------- [ File: batch-mode-batch-triple/src/ensure_input_matches_output_and_error.rs ]
2crate::ix!();
3
4impl BatchFileTriple {
5
6    pub async fn ensure_input_matches_output_and_error(
7        &self,
8    ) -> Result<(), BatchValidationError> {
9        let input_data  = load_input_file(self.input().as_ref().unwrap()).await?;
10        let output_data = load_output_file(self.output().as_ref().unwrap()).await?;
11        let error_data  = load_error_file(self.error().as_ref().unwrap()).await?;
12
13        let input_ids:  HashSet<_> = input_data.request_ids().into_iter().collect();
14        let output_ids: HashSet<_> = output_data.request_ids().into_iter().collect();
15        let error_ids:  HashSet<_> = error_data.request_ids().into_iter().collect();
16
17        let combined_ids: HashSet<_> = output_ids.union(&error_ids).cloned().collect();
18
19        if input_ids != combined_ids {
20            return Err(BatchValidationError::RequestIdsMismatch {
21                index:      self.index().clone(),
22                input_ids:  Some(input_ids),
23                output_ids: Some(output_ids),
24                error_ids:  Some(error_ids),
25            });
26        }
27
28        info!("for our batch triple {:#?}, we have now ensured the input request ids match the combined request ids from the output and error files",self);
29
30        Ok(())
31    }
32}
33
34#[cfg(test)]
35mod batch_file_triple_ensure_input_matches_output_and_error_exhaustive_tests {
36    use super::*;
37    use tempfile::NamedTempFile;
38    use std::io::Write;
39    use tokio::runtime::Runtime;
40    use tracing::*;
41
42    #[traced_test]
43    fn ensure_input_matches_output_and_error_succeeds_when_ids_match() {
44        info!("Starting test: ensure_input_matches_output_and_error_succeeds_when_ids_match");
45
46        // Input with 2 requests
47        let mut input_file = NamedTempFile::new().expect("Failed to create temp file for input");
48        {
49            let req1 = LanguageModelBatchAPIRequest::mock("id-1");
50            let req2 = LanguageModelBatchAPIRequest::mock("id-2");
51
52            writeln!(input_file, "{}", serde_json::to_string(&req1).unwrap())
53                .expect("Failed to write req1");
54            writeln!(input_file, "{}", serde_json::to_string(&req2).unwrap())
55                .expect("Failed to write req2");
56        }
57
58        // Output file, code=200 => success lines, each as a single line
59        let mut output_file = NamedTempFile::new().expect("Failed to create temp file for output");
60        {
61            let line_1 = r#"{"id":"batch_req_id-1","custom_id":"id-1","response":{"status_code":200,"request_id":"resp_req_id-1","body":{"id":"success-id","object":"chat.completion","created":0,"model":"test-model","choices":[],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0,"prompt_tokens_details":null,"completion_tokens_details":null},"system_fingerprint":null}},"error":null}"#;
62            let line_2 = r#"{"id":"batch_req_id-2","custom_id":"id-2","response":{"status_code":200,"request_id":"resp_req_id-2","body":{"id":"success-id","object":"chat.completion","created":0,"model":"test-model","choices":[],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0,"prompt_tokens_details":null,"completion_tokens_details":null},"system_fingerprint":null}},"error":null}"#;
63
64            writeln!(output_file, "{}", line_1)
65                .expect("Failed to write line_1 to output file");
66            writeln!(output_file, "{}", line_2)
67                .expect("Failed to write line_2 to output file");
68        }
69
70        // Error file, code=400 => error lines, each as a single line
71        let mut error_file = NamedTempFile::new().expect("Failed to create temp file for error");
72        {
73            let err_line_1 = r#"{"id":"batch_req_id-1","custom_id":"id-1","response":{"status_code":400,"request_id":"resp_req_id-1","body":{"error":{"message":"Error for id-1","type":"test_error","param":null,"code":null}}},"error":null}"#;
74            let err_line_2 = r#"{"id":"batch_req_id-2","custom_id":"id-2","response":{"status_code":400,"request_id":"resp_req_id-2","body":{"error":{"message":"Error for id-2","type":"test_error","param":null,"code":null}}},"error":null}"#;
75
76            writeln!(error_file, "{}", err_line_1)
77                .expect("Failed to write err_line_1 to error file");
78            writeln!(error_file, "{}", err_line_2)
79                .expect("Failed to write err_line_2 to error file");
80        }
81
82        let triple = BatchFileTriple::new_direct(
83            &BatchIndex::Usize(1),
84            Some(input_file.path().to_path_buf()),
85            Some(output_file.path().to_path_buf()),
86            Some(error_file.path().to_path_buf()),
87            None,
88            Arc::new(MockBatchWorkspace::default()),
89        );
90
91        let rt = Runtime::new().expect("Failed to create tokio runtime");
92        let res = rt
93            .block_on(async { triple.ensure_input_matches_output_and_error().await });
94
95        debug!(
96            "Result of ensure_input_matches_output_and_error: {:?}",
97            res
98        );
99        assert!(
100            res.is_ok(),
101            "Expected matching request IDs to succeed for input vs (output + error)"
102        );
103
104        info!("Finished test: ensure_input_matches_output_and_error_succeeds_when_ids_match");
105    }
106}