batch_mode_batch_client/
wait_for_batch_completion.rs

1// ---------------- [ File: batch-mode-batch-client/src/wait_for_batch_completion.rs ]
2crate::ix!();
3
4#[async_trait]
5impl<E> WaitForBatchCompletion for OpenAIClientHandle<E>
6where
7    E: Debug + Send + Sync + From<OpenAIClientError>
8{
9    type Error = E;
10
11    async fn wait_for_batch_completion(&self, batch_id: &str)
12        -> Result<Batch, Self::Error>
13    {
14        info!("waiting for batch completion: batch_id={}", batch_id);
15
16        loop {
17            let batch = self.retrieve_batch(batch_id).await?;
18
19            match batch.status {
20                BatchStatus::Completed => return Ok(batch),
21                BatchStatus::Failed => {
22                    // Return an error: 
23                    let openai_err = OpenAIClientError::ApiError(OpenAIApiError {
24                        message: "Batch failed".to_owned(),
25                        r#type: None,
26                        param:  None,
27                        code:   None,
28                    });
29                    return Err(E::from(openai_err));
30                }
31                _ => {
32                    println!("Batch status: {:?}", batch.status);
33                    tokio::time::sleep(std::time::Duration::from_secs(20)).await;
34                }
35            }
36        }
37    }
38}
39
40#[cfg(test)]
41mod wait_for_batch_completion_tests {
42    use super::*;
43    use futures::executor::block_on;
44    use std::sync::Arc;
45    use tracing::{debug, error, info, trace, warn};
46
47    #[traced_test]
48    async fn test_wait_for_batch_completion_immediate_success() {
49        info!("Beginning test_wait_for_batch_completion_immediate_success");
50        trace!("Constructing mock client that immediately returns a completed batch...");
51        let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
52            .build()
53            .unwrap();
54        let mock_client = {
55            let c = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
56                .build()
57                .unwrap();
58            // Make the batch "immediate_success_id" be completed from the start:
59            {
60                let mut guard = c.batches().write().unwrap();
61                guard.insert(
62                    "immediate_success_id".to_string(),
63                    Batch {
64                        id:                 "immediate_success_id".to_string(),
65                        status:             BatchStatus::Completed,
66                        input_file_id:      "some_file".to_string(),
67                        completion_window:  "24h".to_string(),
68                        object:             "batch".to_string(),
69                        endpoint:           "/v1/chat/completions".to_string(),
70                        errors:             None,
71                        output_file_id:     None,
72                        error_file_id:      None,
73                        created_at:         0,
74                        in_progress_at:     None,
75                        expires_at:         None,
76                        finalizing_at:      None,
77                        completed_at:       None,
78                        failed_at:          None,
79                        expired_at:         None,
80                        cancelling_at:      None,
81                        cancelled_at:       None,
82                        request_counts:     None,
83                        metadata:           None,
84                    },
85                );
86            }
87            c
88        };
89
90        debug!("Mock client built: {:?}", mock_client);
91
92        let batch_id = "immediate_success_id";
93
94        trace!("Calling wait_for_batch_completion on mock_client with batch_id={}", batch_id);
95        let result = mock_client.wait_for_batch_completion(batch_id).await;
96        debug!("Result from wait_for_batch_completion: {:?}", result);
97
98        // "immediate_success_id" is forcibly set to Completed on first retrieval
99        assert!(
100            result.is_ok(),
101            "Expected wait_for_batch_completion to succeed if the batch is already completed"
102        );
103        let batch = result.unwrap();
104        pretty_assert_eq!(
105            batch.status,
106            BatchStatus::Completed,
107            "Batch status should be Completed"
108        );
109        info!("test_wait_for_batch_completion_immediate_success passed.");
110    }
111
112    #[traced_test]
113    async fn test_wait_for_batch_completion_immediate_failure() {
114        info!("Beginning test_wait_for_batch_completion_immediate_failure");
115        trace!("Constructing mock client that immediately returns a failed batch...");
116        let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
117            .build()
118            .unwrap();
119        debug!("Mock client built: {:?}", mock_client);
120
121        let batch_id = "immediate_failure_id";
122
123        trace!("Calling wait_for_batch_completion on mock_client with batch_id={}", batch_id);
124        let result = mock_client.wait_for_batch_completion(batch_id).await;
125        debug!("Result from wait_for_batch_completion: {:?}", result);
126
127        // Because "immediate_failure_id" is forcibly set to Failed on first retrieve,
128        // we expect an error.
129        assert!(
130            result.is_err(),
131            "Expected wait_for_batch_completion to return error if the batch is failed at once"
132        );
133        info!("test_wait_for_batch_completion_immediate_failure passed.");
134    }
135
136    #[traced_test]
137    async fn test_wait_for_batch_completion_eventual_failure() {
138        info!("Beginning test_wait_for_batch_completion_eventual_failure");
139        trace!("Constructing mock client that simulates in-progress followed by failure...");
140        let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
141            .build()
142            .unwrap();
143        debug!("Mock client built: {:?}", mock_client);
144
145        let batch_id = "eventual_failure_id";
146
147        trace!("Calling wait_for_batch_completion expecting multiple in-progress checks before failure");
148        let result = mock_client.wait_for_batch_completion(batch_id).await;
149        debug!("Result from wait_for_batch_completion: {:?}", result);
150
151        // Because the retrieve logic toggles from InProgress -> Failed,
152        // we eventually get a failure. So we expect an Err.
153        assert!(
154            result.is_err(),
155            "Expected wait_for_batch_completion to error after an eventual failure status"
156        );
157        info!("test_wait_for_batch_completion_eventual_failure passed.");
158    }
159
160    #[traced_test]
161    async fn test_wait_for_batch_completion_openai_error() {
162        info!("Beginning test_wait_for_batch_completion_openai_error");
163        trace!("Constructing mock client that simulates an OpenAI error during retrieve_batch...");
164        let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
165            .build()
166            .unwrap();
167        debug!("Mock client built: {:?}", mock_client);
168
169        let batch_id = "trigger_api_error";
170
171        trace!("Calling wait_for_batch_completion expecting an OpenAI error from retrieve_batch");
172        let result = mock_client.wait_for_batch_completion(batch_id).await;
173        debug!("Result from wait_for_batch_completion: {:?}", result);
174
175        // Because "trigger_api_error" forcibly returns an OpenAI error on the first retrieve,
176        // we expect an Err from wait_for_batch_completion
177        assert!(
178            result.is_err(),
179            "Expected wait_for_batch_completion to fail due to an OpenAI error in retrieve_batch"
180        );
181        info!("test_wait_for_batch_completion_openai_error passed.");
182    }
183
184    #[traced_test]
185    async fn test_wait_for_batch_completion_eventual_success() {
186        info!("Beginning test_wait_for_batch_completion_eventual_success");
187
188        // Build the mock
189        let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
190            .build()
191            .unwrap();
192
193        // So that the batch "eventual_success_id" is InProgress initially, 
194        // but flips to Completed on the FIRST retrieval (or second— you decide):
195        mock_client.configure_inprogress_then_complete_with("eventual_success_id", /*want_output=*/false, /*want_error=*/false);
196
197        info!("Calling wait_for_batch_completion expecting multiple in-progress checks before completion");
198        let result = mock_client.wait_for_batch_completion("eventual_success_id").await;
199        debug!("Result from wait_for_batch_completion: {:?}", result);
200
201        assert!(
202            result.is_ok(),
203            "Expected wait_for_batch_completion to succeed after in-progress statuses"
204        );
205        let final_batch = result.unwrap();
206        pretty_assert_eq!(final_batch.status, BatchStatus::Completed);
207        info!("test_wait_for_batch_completion_eventual_success passed.");
208    }
209}