batch_mode_batch_client/
wait_for_batch_completion.rs1crate::ix!();
3
4#[async_trait]
5impl<E> WaitForBatchCompletion for OpenAIClientHandle<E>
6where
7 E: Debug + Send + Sync + From<OpenAIClientError>
8{
9 type Error = E;
10
11 async fn wait_for_batch_completion(&self, batch_id: &str)
12 -> Result<Batch, Self::Error>
13 {
14 info!("waiting for batch completion: batch_id={}", batch_id);
15
16 loop {
17 let batch = self.retrieve_batch(batch_id).await?;
18
19 match batch.status {
20 BatchStatus::Completed => return Ok(batch),
21 BatchStatus::Failed => {
22 let openai_err = OpenAIClientError::ApiError(OpenAIApiError {
24 message: "Batch failed".to_owned(),
25 r#type: None,
26 param: None,
27 code: None,
28 });
29 return Err(E::from(openai_err));
30 }
31 _ => {
32 println!("Batch status: {:?}", batch.status);
33 tokio::time::sleep(std::time::Duration::from_secs(20)).await;
34 }
35 }
36 }
37 }
38}
39
40#[cfg(test)]
41mod wait_for_batch_completion_tests {
42 use super::*;
43 use futures::executor::block_on;
44 use std::sync::Arc;
45 use tracing::{debug, error, info, trace, warn};
46
47 #[traced_test]
48 async fn test_wait_for_batch_completion_immediate_success() {
49 info!("Beginning test_wait_for_batch_completion_immediate_success");
50 trace!("Constructing mock client that immediately returns a completed batch...");
51 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
52 .build()
53 .unwrap();
54 let mock_client = {
55 let c = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
56 .build()
57 .unwrap();
58 {
60 let mut guard = c.batches().write().unwrap();
61 guard.insert(
62 "immediate_success_id".to_string(),
63 Batch {
64 id: "immediate_success_id".to_string(),
65 status: BatchStatus::Completed,
66 input_file_id: "some_file".to_string(),
67 completion_window: "24h".to_string(),
68 object: "batch".to_string(),
69 endpoint: "/v1/chat/completions".to_string(),
70 errors: None,
71 output_file_id: None,
72 error_file_id: None,
73 created_at: 0,
74 in_progress_at: None,
75 expires_at: None,
76 finalizing_at: None,
77 completed_at: None,
78 failed_at: None,
79 expired_at: None,
80 cancelling_at: None,
81 cancelled_at: None,
82 request_counts: None,
83 metadata: None,
84 },
85 );
86 }
87 c
88 };
89
90 debug!("Mock client built: {:?}", mock_client);
91
92 let batch_id = "immediate_success_id";
93
94 trace!("Calling wait_for_batch_completion on mock_client with batch_id={}", batch_id);
95 let result = mock_client.wait_for_batch_completion(batch_id).await;
96 debug!("Result from wait_for_batch_completion: {:?}", result);
97
98 assert!(
100 result.is_ok(),
101 "Expected wait_for_batch_completion to succeed if the batch is already completed"
102 );
103 let batch = result.unwrap();
104 pretty_assert_eq!(
105 batch.status,
106 BatchStatus::Completed,
107 "Batch status should be Completed"
108 );
109 info!("test_wait_for_batch_completion_immediate_success passed.");
110 }
111
112 #[traced_test]
113 async fn test_wait_for_batch_completion_immediate_failure() {
114 info!("Beginning test_wait_for_batch_completion_immediate_failure");
115 trace!("Constructing mock client that immediately returns a failed batch...");
116 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
117 .build()
118 .unwrap();
119 debug!("Mock client built: {:?}", mock_client);
120
121 let batch_id = "immediate_failure_id";
122
123 trace!("Calling wait_for_batch_completion on mock_client with batch_id={}", batch_id);
124 let result = mock_client.wait_for_batch_completion(batch_id).await;
125 debug!("Result from wait_for_batch_completion: {:?}", result);
126
127 assert!(
130 result.is_err(),
131 "Expected wait_for_batch_completion to return error if the batch is failed at once"
132 );
133 info!("test_wait_for_batch_completion_immediate_failure passed.");
134 }
135
136 #[traced_test]
137 async fn test_wait_for_batch_completion_eventual_failure() {
138 info!("Beginning test_wait_for_batch_completion_eventual_failure");
139 trace!("Constructing mock client that simulates in-progress followed by failure...");
140 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
141 .build()
142 .unwrap();
143 debug!("Mock client built: {:?}", mock_client);
144
145 let batch_id = "eventual_failure_id";
146
147 trace!("Calling wait_for_batch_completion expecting multiple in-progress checks before failure");
148 let result = mock_client.wait_for_batch_completion(batch_id).await;
149 debug!("Result from wait_for_batch_completion: {:?}", result);
150
151 assert!(
154 result.is_err(),
155 "Expected wait_for_batch_completion to error after an eventual failure status"
156 );
157 info!("test_wait_for_batch_completion_eventual_failure passed.");
158 }
159
160 #[traced_test]
161 async fn test_wait_for_batch_completion_openai_error() {
162 info!("Beginning test_wait_for_batch_completion_openai_error");
163 trace!("Constructing mock client that simulates an OpenAI error during retrieve_batch...");
164 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
165 .build()
166 .unwrap();
167 debug!("Mock client built: {:?}", mock_client);
168
169 let batch_id = "trigger_api_error";
170
171 trace!("Calling wait_for_batch_completion expecting an OpenAI error from retrieve_batch");
172 let result = mock_client.wait_for_batch_completion(batch_id).await;
173 debug!("Result from wait_for_batch_completion: {:?}", result);
174
175 assert!(
178 result.is_err(),
179 "Expected wait_for_batch_completion to fail due to an OpenAI error in retrieve_batch"
180 );
181 info!("test_wait_for_batch_completion_openai_error passed.");
182 }
183
184 #[traced_test]
185 async fn test_wait_for_batch_completion_eventual_success() {
186 info!("Beginning test_wait_for_batch_completion_eventual_success");
187
188 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
190 .build()
191 .unwrap();
192
193 mock_client.configure_inprogress_then_complete_with("eventual_success_id", false, false);
196
197 info!("Calling wait_for_batch_completion expecting multiple in-progress checks before completion");
198 let result = mock_client.wait_for_batch_completion("eventual_success_id").await;
199 debug!("Result from wait_for_batch_completion: {:?}", result);
200
201 assert!(
202 result.is_ok(),
203 "Expected wait_for_batch_completion to succeed after in-progress statuses"
204 );
205 let final_batch = result.unwrap();
206 pretty_assert_eq!(final_batch.status, BatchStatus::Completed);
207 info!("test_wait_for_batch_completion_eventual_success passed.");
208 }
209}