batch_mode_batch_client/
check_batch_status_online.rs

1// ---------------- [ File: batch-mode-batch-client/src/check_batch_status_online.rs ]
2crate::ix!();
3
4// Now implement it for `BatchFileTriple`:
5#[async_trait]
6impl<E> CheckBatchStatusOnline<E> for BatchFileTriple
7where
8    E: From<BatchDownloadError>
9        + From<OpenAIClientError>
10        + From<BatchMetadataError>
11        + From<std::io::Error>
12        + std::fmt::Debug,
13{
14    async fn check_batch_status_online(
15        &self,
16        client: &dyn LanguageModelClientInterface<E>,
17    ) -> Result<BatchOnlineStatus, E> {
18        info!("checking batch status online");
19
20        // Use the associated metadata file if set, else fallback
21        let metadata_filename: PathBuf = if let Some(path) = self.associated_metadata() {
22            path.clone()
23        } else {
24            self.effective_metadata_filename()
25        };
26        debug!("Using metadata file: {:?}", metadata_filename);
27
28        let mut metadata = BatchMetadata::load_from_file(&metadata_filename).await?;
29        let batch_id = metadata.batch_id().to_string();
30
31        let batch = client.retrieve_batch(&batch_id).await?;
32        match batch.status {
33            BatchStatus::Completed => {
34                // Only if completed do we store these IDs into the metadata:
35                metadata.set_output_file_id(batch.output_file_id.clone());
36                metadata.set_error_file_id(batch.error_file_id.clone());
37                metadata.save_to_file(&metadata_filename).await?;
38
39                Ok(BatchOnlineStatus::from(&batch))
40            }
41            BatchStatus::Failed => {
42                Err(BatchDownloadError::BatchFailed { batch_id }.into())
43            }
44            BatchStatus::Validating
45            | BatchStatus::InProgress
46            | BatchStatus::Finalizing => {
47                Err(BatchDownloadError::BatchStillProcessing { batch_id }.into())
48            }
49            _ => {
50                Err(BatchDownloadError::UnknownBatchStatus {
51                    batch_id,
52                    status: batch.status.clone(),
53                }
54                .into())
55            }
56        }
57    }
58}
59
60//-----------------------------------------------------
61// Test module
62//-----------------------------------------------------
63#[cfg(test)]
64mod check_batch_status_online_tests {
65    use super::*;
66    use futures::executor::block_on;
67    use tempfile::tempdir;
68    use tracing::{debug, error, info, trace, warn};
69    use std::fs;
70
71    // Example placeholder so references to MockBatchWorkspace compile:
72    #[derive(Default, Debug)]
73    pub struct MockBatchWorkspace;
74
75    #[traced_test]
76    async fn test_batch_completed_no_files() {
77        info!("Starting test_batch_completed_no_files");
78        trace!("Constructing a mock client...");
79        let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
80            .build()
81            .unwrap();
82        debug!("Mock client: {:?}", mock_client);
83
84        let batch_id = "test_batch_completed_no_files";
85        trace!("Inserting batch with ID={}", batch_id);
86
87        // Insert a completed batch with no output or error:
88        {
89            let mut guard = mock_client.batches().write().unwrap();
90            guard.insert(
91                batch_id.to_string(),
92                Batch {
93                    id: batch_id.to_string(),
94                    object: "batch".to_string(),
95                    endpoint: "/v1/chat/completions".to_string(),
96                    errors: None,
97                    input_file_id: "input_file_id".to_string(),
98                    completion_window: "24h".to_string(),
99                    status: BatchStatus::Completed,
100                    output_file_id: None,
101                    error_file_id: None,
102                    created_at: 0,
103                    in_progress_at: None,
104                    expires_at: None,
105                    finalizing_at: None,
106                    completed_at: None,
107                    failed_at: None,
108                    expired_at: None,
109                    cancelling_at: None,
110                    cancelled_at: None,
111                    request_counts: None,
112                    metadata: None,
113                },
114            );
115        }
116
117        let tmpdir = tempdir().unwrap();
118        let metadata_path = tmpdir.path().join("metadata.json");
119        let metadata = BatchMetadataBuilder::default()
120            .batch_id(batch_id.to_string())
121            .input_file_id("input_file_id".to_string())
122            .output_file_id(None)
123            .error_file_id(None)
124            .build()
125            .unwrap();
126        info!("Saving metadata at {:?}", metadata_path);
127        metadata.save_to_file(&metadata_path).await.unwrap();
128
129        trace!("Creating BatchFileTriple with known metadata path...");
130        let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
131        triple.set_metadata_path(Some(metadata_path.clone()));
132
133        trace!("Calling check_batch_status_online...");
134        let result = triple.check_batch_status_online(&mock_client).await;
135        debug!("Result from check_batch_status_online: {:?}", result);
136
137        assert!(
138            result.is_ok(),
139            "Should return Ok(...) for a completed batch with no output/error"
140        );
141        let online_status = result.unwrap();
142        pretty_assert_eq!(online_status.output_file_available(), false);
143        pretty_assert_eq!(online_status.error_file_available(), false);
144        info!("test_batch_completed_no_files passed successfully.");
145    }
146
147    #[traced_test]
148    async fn test_batch_completed_with_output_only() {
149        info!("Starting test_batch_completed_with_output_only");
150        let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
151            .build()
152            .unwrap();
153        debug!("Mock client: {:?}", mock_client);
154
155        let batch_id = "test_batch_completed_with_output_only";
156        {
157            let mut guard = mock_client.batches().write().unwrap();
158            guard.insert(
159                batch_id.to_string(),
160                Batch {
161                    id: batch_id.to_string(),
162                    object: "batch".to_string(),
163                    endpoint: "/v1/chat/completions".to_string(),
164                    errors: None,
165                    input_file_id: "input_file_id".to_string(),
166                    completion_window: "24h".to_string(),
167                    status: BatchStatus::Completed,
168                    output_file_id: Some("mock_output_file_id".to_string()),
169                    error_file_id: None,
170                    created_at: 0,
171                    in_progress_at: None,
172                    expires_at: None,
173                    finalizing_at: None,
174                    completed_at: None,
175                    failed_at: None,
176                    expired_at: None,
177                    cancelling_at: None,
178                    cancelled_at: None,
179                    request_counts: None,
180                    metadata: None,
181                },
182            );
183        }
184
185        let tmpdir = tempdir().unwrap();
186        let metadata_path = tmpdir.path().join("metadata.json");
187        let metadata = BatchMetadataBuilder::default()
188            .batch_id(batch_id.to_string())
189            .input_file_id("input_file_id".to_string())
190            .output_file_id(Some("mock_output_file_id".into()))
191            .error_file_id(None)
192            .build()
193            .unwrap();
194        info!("Saving metadata at {:?}", metadata_path);
195        metadata.save_to_file(&metadata_path).await.unwrap();
196
197        let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
198        triple.set_metadata_path(Some(metadata_path.clone()));
199
200        let result = triple.check_batch_status_online(&mock_client).await;
201        debug!("Result from check_batch_status_online: {:?}", result);
202
203        assert!(
204            result.is_ok(),
205            "Should return Ok(...) for a completed batch with output only"
206        );
207        let online_status = result.unwrap();
208        pretty_assert_eq!(online_status.output_file_available(), true);
209        pretty_assert_eq!(online_status.error_file_available(), false);
210        info!("test_batch_completed_with_output_only passed successfully.");
211    }
212
213    #[traced_test]
214    async fn test_batch_completed_with_error_only() {
215        info!("Starting test_batch_completed_with_error_only");
216        let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
217            .build()
218            .unwrap();
219        debug!("Mock client: {:?}", mock_client);
220
221        let batch_id = "test_batch_completed_with_error_only";
222        {
223            let mut guard = mock_client.batches().write().unwrap();
224            guard.insert(
225                batch_id.to_string(),
226                Batch {
227                    id: batch_id.to_string(),
228                    object: "batch".to_string(),
229                    endpoint: "/v1/chat/completions".to_string(),
230                    errors: None,
231                    input_file_id: "input_file_id".to_string(),
232                    completion_window: "24h".to_string(),
233                    status: BatchStatus::Completed,
234                    output_file_id: None,
235                    error_file_id: Some("mock_err_file_id".to_string()),
236                    created_at: 0,
237                    in_progress_at: None,
238                    expires_at: None,
239                    finalizing_at: None,
240                    completed_at: None,
241                    failed_at: None,
242                    expired_at: None,
243                    cancelling_at: None,
244                    cancelled_at: None,
245                    request_counts: None,
246                    metadata: None,
247                },
248            );
249        }
250
251        let tmpdir = tempdir().unwrap();
252        let metadata_path = tmpdir.path().join("metadata.json");
253        let metadata = BatchMetadataBuilder::default()
254            .batch_id(batch_id.to_string())
255            .input_file_id("input_file_id".to_string())
256            .output_file_id(None)
257            .error_file_id(Some("mock_err_file_id".into()))
258            .build()
259            .unwrap();
260        info!("Saving metadata at {:?}", metadata_path);
261        metadata.save_to_file(&metadata_path).await.unwrap();
262
263        let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
264        triple.set_metadata_path(Some(metadata_path.clone()));
265
266        let result = triple.check_batch_status_online(&mock_client).await;
267        debug!("Result from check_batch_status_online: {:?}", result);
268
269        assert!(
270            result.is_ok(),
271            "Should return Ok(...) for a completed batch with error only"
272        );
273        let online_status = result.unwrap();
274        pretty_assert_eq!(online_status.output_file_available(), false);
275        pretty_assert_eq!(online_status.error_file_available(), true);
276        info!("test_batch_completed_with_error_only passed successfully.");
277    }
278
279    #[traced_test]
280    async fn test_batch_completed_with_output_and_error() {
281        info!("Starting test_batch_completed_with_output_and_error");
282        let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
283            .build()
284            .unwrap();
285        debug!("Mock client: {:?}", mock_client);
286
287        let batch_id = "test_batch_completed_with_output_and_error";
288        {
289            let mut guard = mock_client.batches().write().unwrap();
290            guard.insert(
291                batch_id.to_string(),
292                Batch {
293                    id: batch_id.to_string(),
294                    object: "batch".to_string(),
295                    endpoint: "/v1/chat/completions".to_string(),
296                    errors: None,
297                    input_file_id: "input_file_id".to_string(),
298                    completion_window: "24h".to_string(),
299                    status: BatchStatus::Completed,
300                    output_file_id: Some("mock_output_file_id".to_string()),
301                    error_file_id: Some("mock_err_file_id".to_string()),
302                    created_at: 0,
303                    in_progress_at: None,
304                    expires_at: None,
305                    finalizing_at: None,
306                    completed_at: None,
307                    failed_at: None,
308                    expired_at: None,
309                    cancelling_at: None,
310                    cancelled_at: None,
311                    request_counts: None,
312                    metadata: None,
313                },
314            );
315        }
316
317        let tmpdir = tempdir().unwrap();
318        let metadata_path = tmpdir.path().join("metadata.json");
319        let metadata = BatchMetadataBuilder::default()
320            .batch_id(batch_id.to_string())
321            .input_file_id("input_file_id".to_string())
322            .output_file_id(Some("mock_output_file_id".into()))
323            .error_file_id(Some("mock_err_file_id".into()))
324            .build()
325            .unwrap();
326        info!("Saving metadata at {:?}", metadata_path);
327        metadata.save_to_file(&metadata_path).await.unwrap();
328
329        let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
330        triple.set_metadata_path(Some(metadata_path.clone()));
331
332        let result = triple.check_batch_status_online(&mock_client).await;
333        debug!("Result from check_batch_status_online: {:?}", result);
334
335        assert!(
336            result.is_ok(),
337            "Should return Ok(...) for a completed batch with both output and error"
338        );
339        let online_status = result.unwrap();
340        pretty_assert_eq!(online_status.output_file_available(), true);
341        pretty_assert_eq!(online_status.error_file_available(), true);
342        info!("test_batch_completed_with_output_and_error passed successfully.");
343    }
344
345    #[traced_test]
346    async fn test_batch_failed() {
347        info!("Starting test_batch_failed");
348        let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
349            .build()
350            .unwrap();
351        debug!("Mock client: {:?}", mock_client);
352
353        let batch_id = "test_batch_failed";
354        {
355            let mut guard = mock_client.batches().write().unwrap();
356            guard.insert(
357                batch_id.to_string(),
358                Batch {
359                    id: batch_id.to_string(),
360                    object: "batch".to_string(),
361                    endpoint: "/v1/chat/completions".to_string(),
362                    errors: None,
363                    input_file_id: "input_file_id".to_string(),
364                    completion_window: "24h".to_string(),
365                    status: BatchStatus::Failed,
366                    output_file_id: None,
367                    error_file_id: None,
368                    created_at: 0,
369                    in_progress_at: None,
370                    expires_at: None,
371                    finalizing_at: None,
372                    completed_at: None,
373                    failed_at: None,
374                    expired_at: None,
375                    cancelling_at: None,
376                    cancelled_at: None,
377                    request_counts: None,
378                    metadata: None,
379                },
380            );
381        }
382
383        let tmpdir = tempdir().unwrap();
384        let metadata_path = tmpdir.path().join("metadata.json");
385        let metadata = BatchMetadataBuilder::default()
386            .batch_id(batch_id.to_string())
387            .input_file_id("input_file_id".to_string())
388            .build()
389            .unwrap();
390        metadata.save_to_file(&metadata_path).await.unwrap();
391
392        let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
393        triple.set_metadata_path(Some(metadata_path.clone()));
394
395        let result = triple.check_batch_status_online(&mock_client).await;
396        debug!("Result from check_batch_status_online: {:?}", result);
397
398        assert!(result.is_err(), "Should return Err(...) for a failed batch");
399        info!("test_batch_failed passed successfully.");
400    }
401
402    #[traced_test]
403    async fn test_batch_inprogress() {
404        info!("Starting test_batch_inprogress");
405        let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
406            .build()
407            .unwrap();
408        debug!("Mock client: {:?}", mock_client);
409
410        let batch_id = "test_batch_inprogress";
411        {
412            let mut guard = mock_client.batches().write().unwrap();
413            guard.insert(
414                batch_id.to_string(),
415                Batch {
416                    id: batch_id.to_string(),
417                    object: "batch".to_string(),
418                    endpoint: "/v1/chat/completions".to_string(),
419                    errors: None,
420                    input_file_id: "input_file_id".to_string(),
421                    completion_window: "24h".to_string(),
422                    status: BatchStatus::InProgress,
423                    output_file_id: None,
424                    error_file_id: None,
425                    created_at: 0,
426                    in_progress_at: None,
427                    expires_at: None,
428                    finalizing_at: None,
429                    completed_at: None,
430                    failed_at: None,
431                    expired_at: None,
432                    cancelling_at: None,
433                    cancelled_at: None,
434                    request_counts: None,
435                    metadata: None,
436                },
437            );
438        }
439
440        let tmpdir = tempdir().unwrap();
441        let metadata_path = tmpdir.path().join("metadata.json");
442        let metadata = BatchMetadataBuilder::default()
443            .batch_id(batch_id.to_string())
444            .input_file_id("input_file_id".to_string())
445            .build()
446            .unwrap();
447        metadata.save_to_file(&metadata_path).await.unwrap();
448
449        let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
450        triple.set_metadata_path(Some(metadata_path.clone()));
451
452        let result = triple.check_batch_status_online(&mock_client).await;
453        debug!("Result from check_batch_status_online: {:?}", result);
454
455        // We expect "still processing" for an InProgress batch:
456        assert!(
457            result.is_err(),
458            "Should return Err(...) for an in-progress batch"
459        );
460        info!("test_batch_inprogress passed successfully.");
461    }
462
463    #[traced_test]
464    async fn test_batch_unknown_status() {
465        info!("Starting test_batch_unknown_status");
466        let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
467            .build()
468            .unwrap();
469        debug!("Mock client: {:?}", mock_client);
470
471        let batch_id = "test_batch_unknown_status";
472        {
473            let mut guard = mock_client.batches().write().unwrap();
474            let mut some_batch = Batch {
475                id: batch_id.to_string(),
476                object: "batch".to_string(),
477                endpoint: "/v1/chat/completions".to_string(),
478                errors: None,
479                input_file_id: "input_file_id".to_string(),
480                completion_window: "24h".to_string(),
481                status: BatchStatus::InProgress,
482                output_file_id: None,
483                error_file_id: None,
484                created_at: 0,
485                in_progress_at: None,
486                expires_at: None,
487                finalizing_at: None,
488                completed_at: None,
489                failed_at: None,
490                expired_at: None,
491                cancelling_at: None,
492                cancelled_at: None,
493                request_counts: None,
494                metadata: None,
495            };
496            // Force an unknown or "invalid" status:
497            some_batch.status = BatchStatus::Cancelled;
498            guard.insert(batch_id.to_string(), some_batch);
499        }
500
501        let tmpdir = tempdir().unwrap();
502        let metadata_path = tmpdir.path().join("metadata.json");
503        let metadata = BatchMetadataBuilder::default()
504            .batch_id(batch_id.to_string())
505            .input_file_id("input_file_id".to_string())
506            .build()
507            .unwrap();
508        metadata.save_to_file(&metadata_path).await.unwrap();
509
510        let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
511        triple.set_metadata_path(Some(metadata_path.clone()));
512
513        let result = triple.check_batch_status_online(&mock_client).await;
514        debug!("Result from check_batch_status_online: {:?}", result);
515
516        assert!(
517            result.is_err(),
518            "Should return Err(...) for an unknown batch status like Cancelled"
519        );
520        info!("test_batch_unknown_status passed successfully.");
521    }
522}