1crate::ix!();
3
4#[async_trait]
6impl<E> CheckBatchStatusOnline<E> for BatchFileTriple
7where
8 E: From<BatchDownloadError>
9 + From<OpenAIClientError>
10 + From<BatchMetadataError>
11 + From<std::io::Error>
12 + std::fmt::Debug,
13{
14 async fn check_batch_status_online(
15 &self,
16 client: &dyn LanguageModelClientInterface<E>,
17 ) -> Result<BatchOnlineStatus, E> {
18 info!("checking batch status online");
19
20 let metadata_filename: PathBuf = if let Some(path) = self.associated_metadata() {
22 path.clone()
23 } else {
24 self.effective_metadata_filename()
25 };
26 debug!("Using metadata file: {:?}", metadata_filename);
27
28 let mut metadata = BatchMetadata::load_from_file(&metadata_filename).await?;
29 let batch_id = metadata.batch_id().to_string();
30
31 let batch = client.retrieve_batch(&batch_id).await?;
32 match batch.status {
33 BatchStatus::Completed => {
34 metadata.set_output_file_id(batch.output_file_id.clone());
36 metadata.set_error_file_id(batch.error_file_id.clone());
37 metadata.save_to_file(&metadata_filename).await?;
38
39 Ok(BatchOnlineStatus::from(&batch))
40 }
41 BatchStatus::Failed => {
42 Err(BatchDownloadError::BatchFailed { batch_id }.into())
43 }
44 BatchStatus::Validating
45 | BatchStatus::InProgress
46 | BatchStatus::Finalizing => {
47 Err(BatchDownloadError::BatchStillProcessing { batch_id }.into())
48 }
49 _ => {
50 Err(BatchDownloadError::UnknownBatchStatus {
51 batch_id,
52 status: batch.status.clone(),
53 }
54 .into())
55 }
56 }
57 }
58}
59
60#[cfg(test)]
64mod check_batch_status_online_tests {
65 use super::*;
66 use futures::executor::block_on;
67 use tempfile::tempdir;
68 use tracing::{debug, error, info, trace, warn};
69 use std::fs;
70
71 #[derive(Default, Debug)]
73 pub struct MockBatchWorkspace;
74
75 #[traced_test]
76 async fn test_batch_completed_no_files() {
77 info!("Starting test_batch_completed_no_files");
78 trace!("Constructing a mock client...");
79 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
80 .build()
81 .unwrap();
82 debug!("Mock client: {:?}", mock_client);
83
84 let batch_id = "test_batch_completed_no_files";
85 trace!("Inserting batch with ID={}", batch_id);
86
87 {
89 let mut guard = mock_client.batches().write().unwrap();
90 guard.insert(
91 batch_id.to_string(),
92 Batch {
93 id: batch_id.to_string(),
94 object: "batch".to_string(),
95 endpoint: "/v1/chat/completions".to_string(),
96 errors: None,
97 input_file_id: "input_file_id".to_string(),
98 completion_window: "24h".to_string(),
99 status: BatchStatus::Completed,
100 output_file_id: None,
101 error_file_id: None,
102 created_at: 0,
103 in_progress_at: None,
104 expires_at: None,
105 finalizing_at: None,
106 completed_at: None,
107 failed_at: None,
108 expired_at: None,
109 cancelling_at: None,
110 cancelled_at: None,
111 request_counts: None,
112 metadata: None,
113 },
114 );
115 }
116
117 let tmpdir = tempdir().unwrap();
118 let metadata_path = tmpdir.path().join("metadata.json");
119 let metadata = BatchMetadataBuilder::default()
120 .batch_id(batch_id.to_string())
121 .input_file_id("input_file_id".to_string())
122 .output_file_id(None)
123 .error_file_id(None)
124 .build()
125 .unwrap();
126 info!("Saving metadata at {:?}", metadata_path);
127 metadata.save_to_file(&metadata_path).await.unwrap();
128
129 trace!("Creating BatchFileTriple with known metadata path...");
130 let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
131 triple.set_metadata_path(Some(metadata_path.clone()));
132
133 trace!("Calling check_batch_status_online...");
134 let result = triple.check_batch_status_online(&mock_client).await;
135 debug!("Result from check_batch_status_online: {:?}", result);
136
137 assert!(
138 result.is_ok(),
139 "Should return Ok(...) for a completed batch with no output/error"
140 );
141 let online_status = result.unwrap();
142 pretty_assert_eq!(online_status.output_file_available(), false);
143 pretty_assert_eq!(online_status.error_file_available(), false);
144 info!("test_batch_completed_no_files passed successfully.");
145 }
146
147 #[traced_test]
148 async fn test_batch_completed_with_output_only() {
149 info!("Starting test_batch_completed_with_output_only");
150 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
151 .build()
152 .unwrap();
153 debug!("Mock client: {:?}", mock_client);
154
155 let batch_id = "test_batch_completed_with_output_only";
156 {
157 let mut guard = mock_client.batches().write().unwrap();
158 guard.insert(
159 batch_id.to_string(),
160 Batch {
161 id: batch_id.to_string(),
162 object: "batch".to_string(),
163 endpoint: "/v1/chat/completions".to_string(),
164 errors: None,
165 input_file_id: "input_file_id".to_string(),
166 completion_window: "24h".to_string(),
167 status: BatchStatus::Completed,
168 output_file_id: Some("mock_output_file_id".to_string()),
169 error_file_id: None,
170 created_at: 0,
171 in_progress_at: None,
172 expires_at: None,
173 finalizing_at: None,
174 completed_at: None,
175 failed_at: None,
176 expired_at: None,
177 cancelling_at: None,
178 cancelled_at: None,
179 request_counts: None,
180 metadata: None,
181 },
182 );
183 }
184
185 let tmpdir = tempdir().unwrap();
186 let metadata_path = tmpdir.path().join("metadata.json");
187 let metadata = BatchMetadataBuilder::default()
188 .batch_id(batch_id.to_string())
189 .input_file_id("input_file_id".to_string())
190 .output_file_id(Some("mock_output_file_id".into()))
191 .error_file_id(None)
192 .build()
193 .unwrap();
194 info!("Saving metadata at {:?}", metadata_path);
195 metadata.save_to_file(&metadata_path).await.unwrap();
196
197 let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
198 triple.set_metadata_path(Some(metadata_path.clone()));
199
200 let result = triple.check_batch_status_online(&mock_client).await;
201 debug!("Result from check_batch_status_online: {:?}", result);
202
203 assert!(
204 result.is_ok(),
205 "Should return Ok(...) for a completed batch with output only"
206 );
207 let online_status = result.unwrap();
208 pretty_assert_eq!(online_status.output_file_available(), true);
209 pretty_assert_eq!(online_status.error_file_available(), false);
210 info!("test_batch_completed_with_output_only passed successfully.");
211 }
212
213 #[traced_test]
214 async fn test_batch_completed_with_error_only() {
215 info!("Starting test_batch_completed_with_error_only");
216 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
217 .build()
218 .unwrap();
219 debug!("Mock client: {:?}", mock_client);
220
221 let batch_id = "test_batch_completed_with_error_only";
222 {
223 let mut guard = mock_client.batches().write().unwrap();
224 guard.insert(
225 batch_id.to_string(),
226 Batch {
227 id: batch_id.to_string(),
228 object: "batch".to_string(),
229 endpoint: "/v1/chat/completions".to_string(),
230 errors: None,
231 input_file_id: "input_file_id".to_string(),
232 completion_window: "24h".to_string(),
233 status: BatchStatus::Completed,
234 output_file_id: None,
235 error_file_id: Some("mock_err_file_id".to_string()),
236 created_at: 0,
237 in_progress_at: None,
238 expires_at: None,
239 finalizing_at: None,
240 completed_at: None,
241 failed_at: None,
242 expired_at: None,
243 cancelling_at: None,
244 cancelled_at: None,
245 request_counts: None,
246 metadata: None,
247 },
248 );
249 }
250
251 let tmpdir = tempdir().unwrap();
252 let metadata_path = tmpdir.path().join("metadata.json");
253 let metadata = BatchMetadataBuilder::default()
254 .batch_id(batch_id.to_string())
255 .input_file_id("input_file_id".to_string())
256 .output_file_id(None)
257 .error_file_id(Some("mock_err_file_id".into()))
258 .build()
259 .unwrap();
260 info!("Saving metadata at {:?}", metadata_path);
261 metadata.save_to_file(&metadata_path).await.unwrap();
262
263 let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
264 triple.set_metadata_path(Some(metadata_path.clone()));
265
266 let result = triple.check_batch_status_online(&mock_client).await;
267 debug!("Result from check_batch_status_online: {:?}", result);
268
269 assert!(
270 result.is_ok(),
271 "Should return Ok(...) for a completed batch with error only"
272 );
273 let online_status = result.unwrap();
274 pretty_assert_eq!(online_status.output_file_available(), false);
275 pretty_assert_eq!(online_status.error_file_available(), true);
276 info!("test_batch_completed_with_error_only passed successfully.");
277 }
278
279 #[traced_test]
280 async fn test_batch_completed_with_output_and_error() {
281 info!("Starting test_batch_completed_with_output_and_error");
282 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
283 .build()
284 .unwrap();
285 debug!("Mock client: {:?}", mock_client);
286
287 let batch_id = "test_batch_completed_with_output_and_error";
288 {
289 let mut guard = mock_client.batches().write().unwrap();
290 guard.insert(
291 batch_id.to_string(),
292 Batch {
293 id: batch_id.to_string(),
294 object: "batch".to_string(),
295 endpoint: "/v1/chat/completions".to_string(),
296 errors: None,
297 input_file_id: "input_file_id".to_string(),
298 completion_window: "24h".to_string(),
299 status: BatchStatus::Completed,
300 output_file_id: Some("mock_output_file_id".to_string()),
301 error_file_id: Some("mock_err_file_id".to_string()),
302 created_at: 0,
303 in_progress_at: None,
304 expires_at: None,
305 finalizing_at: None,
306 completed_at: None,
307 failed_at: None,
308 expired_at: None,
309 cancelling_at: None,
310 cancelled_at: None,
311 request_counts: None,
312 metadata: None,
313 },
314 );
315 }
316
317 let tmpdir = tempdir().unwrap();
318 let metadata_path = tmpdir.path().join("metadata.json");
319 let metadata = BatchMetadataBuilder::default()
320 .batch_id(batch_id.to_string())
321 .input_file_id("input_file_id".to_string())
322 .output_file_id(Some("mock_output_file_id".into()))
323 .error_file_id(Some("mock_err_file_id".into()))
324 .build()
325 .unwrap();
326 info!("Saving metadata at {:?}", metadata_path);
327 metadata.save_to_file(&metadata_path).await.unwrap();
328
329 let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
330 triple.set_metadata_path(Some(metadata_path.clone()));
331
332 let result = triple.check_batch_status_online(&mock_client).await;
333 debug!("Result from check_batch_status_online: {:?}", result);
334
335 assert!(
336 result.is_ok(),
337 "Should return Ok(...) for a completed batch with both output and error"
338 );
339 let online_status = result.unwrap();
340 pretty_assert_eq!(online_status.output_file_available(), true);
341 pretty_assert_eq!(online_status.error_file_available(), true);
342 info!("test_batch_completed_with_output_and_error passed successfully.");
343 }
344
345 #[traced_test]
346 async fn test_batch_failed() {
347 info!("Starting test_batch_failed");
348 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
349 .build()
350 .unwrap();
351 debug!("Mock client: {:?}", mock_client);
352
353 let batch_id = "test_batch_failed";
354 {
355 let mut guard = mock_client.batches().write().unwrap();
356 guard.insert(
357 batch_id.to_string(),
358 Batch {
359 id: batch_id.to_string(),
360 object: "batch".to_string(),
361 endpoint: "/v1/chat/completions".to_string(),
362 errors: None,
363 input_file_id: "input_file_id".to_string(),
364 completion_window: "24h".to_string(),
365 status: BatchStatus::Failed,
366 output_file_id: None,
367 error_file_id: None,
368 created_at: 0,
369 in_progress_at: None,
370 expires_at: None,
371 finalizing_at: None,
372 completed_at: None,
373 failed_at: None,
374 expired_at: None,
375 cancelling_at: None,
376 cancelled_at: None,
377 request_counts: None,
378 metadata: None,
379 },
380 );
381 }
382
383 let tmpdir = tempdir().unwrap();
384 let metadata_path = tmpdir.path().join("metadata.json");
385 let metadata = BatchMetadataBuilder::default()
386 .batch_id(batch_id.to_string())
387 .input_file_id("input_file_id".to_string())
388 .build()
389 .unwrap();
390 metadata.save_to_file(&metadata_path).await.unwrap();
391
392 let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
393 triple.set_metadata_path(Some(metadata_path.clone()));
394
395 let result = triple.check_batch_status_online(&mock_client).await;
396 debug!("Result from check_batch_status_online: {:?}", result);
397
398 assert!(result.is_err(), "Should return Err(...) for a failed batch");
399 info!("test_batch_failed passed successfully.");
400 }
401
402 #[traced_test]
403 async fn test_batch_inprogress() {
404 info!("Starting test_batch_inprogress");
405 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
406 .build()
407 .unwrap();
408 debug!("Mock client: {:?}", mock_client);
409
410 let batch_id = "test_batch_inprogress";
411 {
412 let mut guard = mock_client.batches().write().unwrap();
413 guard.insert(
414 batch_id.to_string(),
415 Batch {
416 id: batch_id.to_string(),
417 object: "batch".to_string(),
418 endpoint: "/v1/chat/completions".to_string(),
419 errors: None,
420 input_file_id: "input_file_id".to_string(),
421 completion_window: "24h".to_string(),
422 status: BatchStatus::InProgress,
423 output_file_id: None,
424 error_file_id: None,
425 created_at: 0,
426 in_progress_at: None,
427 expires_at: None,
428 finalizing_at: None,
429 completed_at: None,
430 failed_at: None,
431 expired_at: None,
432 cancelling_at: None,
433 cancelled_at: None,
434 request_counts: None,
435 metadata: None,
436 },
437 );
438 }
439
440 let tmpdir = tempdir().unwrap();
441 let metadata_path = tmpdir.path().join("metadata.json");
442 let metadata = BatchMetadataBuilder::default()
443 .batch_id(batch_id.to_string())
444 .input_file_id("input_file_id".to_string())
445 .build()
446 .unwrap();
447 metadata.save_to_file(&metadata_path).await.unwrap();
448
449 let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
450 triple.set_metadata_path(Some(metadata_path.clone()));
451
452 let result = triple.check_batch_status_online(&mock_client).await;
453 debug!("Result from check_batch_status_online: {:?}", result);
454
455 assert!(
457 result.is_err(),
458 "Should return Err(...) for an in-progress batch"
459 );
460 info!("test_batch_inprogress passed successfully.");
461 }
462
463 #[traced_test]
464 async fn test_batch_unknown_status() {
465 info!("Starting test_batch_unknown_status");
466 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
467 .build()
468 .unwrap();
469 debug!("Mock client: {:?}", mock_client);
470
471 let batch_id = "test_batch_unknown_status";
472 {
473 let mut guard = mock_client.batches().write().unwrap();
474 let mut some_batch = Batch {
475 id: batch_id.to_string(),
476 object: "batch".to_string(),
477 endpoint: "/v1/chat/completions".to_string(),
478 errors: None,
479 input_file_id: "input_file_id".to_string(),
480 completion_window: "24h".to_string(),
481 status: BatchStatus::InProgress,
482 output_file_id: None,
483 error_file_id: None,
484 created_at: 0,
485 in_progress_at: None,
486 expires_at: None,
487 finalizing_at: None,
488 completed_at: None,
489 failed_at: None,
490 expired_at: None,
491 cancelling_at: None,
492 cancelled_at: None,
493 request_counts: None,
494 metadata: None,
495 };
496 some_batch.status = BatchStatus::Cancelled;
498 guard.insert(batch_id.to_string(), some_batch);
499 }
500
501 let tmpdir = tempdir().unwrap();
502 let metadata_path = tmpdir.path().join("metadata.json");
503 let metadata = BatchMetadataBuilder::default()
504 .batch_id(batch_id.to_string())
505 .input_file_id("input_file_id".to_string())
506 .build()
507 .unwrap();
508 metadata.save_to_file(&metadata_path).await.unwrap();
509
510 let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
511 triple.set_metadata_path(Some(metadata_path.clone()));
512
513 let result = triple.check_batch_status_online(&mock_client).await;
514 debug!("Result from check_batch_status_online: {:?}", result);
515
516 assert!(
517 result.is_err(),
518 "Should return Err(...) for an unknown batch status like Cancelled"
519 );
520 info!("test_batch_unknown_status passed successfully.");
521 }
522}