1#![allow(unused_variables)]
3
4crate::ix!();
5
6error_tree!{
8 pub enum MockBatchClientError {
9 OpenAIClientError(OpenAIClientError),
10 BatchDownloadError(BatchDownloadError),
11 BatchMetadataError(BatchMetadataError),
12 IoError(std::io::Error),
13
14 BatchProcessingError,
16
17 JsonParseError(JsonParseError),
19 BatchValidationError(BatchValidationError),
20
21 BatchReconciliationError {
22 index: BatchIndex,
23 },
24
25 BatchErrorProcessingError(BatchErrorProcessingError),
26 BatchOutputProcessingError,
27 FileMoveError(FileMoveError),
28 }
29}
30
31#[derive(Getters, Setters, Builder, Debug)]
32#[builder(pattern = "owned")]
33pub struct MockLanguageModelClient<E> {
34 #[getset(get = "pub", set = "pub")]
35 #[builder(default)]
36 batches: RwLock<HashMap<String, Batch>>,
37
38 #[getset(get = "pub", set = "pub")]
39 #[builder(default)]
40 files: RwLock<HashMap<String, Bytes>>,
41
42 #[builder(default="false")]
43 #[getset(get = "pub", set = "pub")]
44 fail_on_file_create_openai_error: bool,
45
46 #[builder(default="false")]
47 #[getset(get = "pub", set = "pub")]
48 fail_on_file_create_other_error: bool,
49
50 #[builder(default)]
51 _error_marker: PhantomData<E>,
52
53 #[getset(get = "pub", set = "pub")]
55 #[builder(default)]
56 mock_batch_config: RwLock<MockBatchConfig>,
57}
58
59#[derive(MutGetters,Getters,Setters,Builder,Debug,Default)]
60#[builder(setter(into), default, pattern = "owned")]
61#[getset(get="pub",set="pub",get_mut="pub")]
62pub struct MockBatchConfig {
63 fails_on_attempt_1: HashSet<String>,
65
66 attempt_counters: HashMap<String, u32>,
68
69 planned_completions: HashMap<String, (bool, bool)>,
73}
74
75impl<E> MockLanguageModelClient<E>
76where
77 E: From<OpenAIClientError> + From<std::io::Error> + Debug + Send + Sync,
78{
79 pub fn configure_inprogress_then_complete_with(
86 &self,
87 batch_id: &str,
88 want_output: bool,
89 want_error: bool,
90 ) {
91 let mut map_guard = self.batches().write().unwrap();
94 map_guard.insert(
95 batch_id.to_string(),
96 Batch {
97 id: batch_id.to_string(),
98 object: "batch".to_string(),
99 endpoint: "/v1/chat/completions".to_string(),
100 errors: None,
101 input_file_id: batch_id.to_string(),
102 completion_window: "24h".to_string(),
103 status: BatchStatus::InProgress,
104 output_file_id: None,
105 error_file_id: None,
106 created_at: 0,
107 in_progress_at: None,
108 expires_at: None,
109 finalizing_at: None,
110 completed_at: None,
111 failed_at: None,
112 expired_at: None,
113 cancelling_at: None,
114 cancelled_at: None,
115 request_counts: None,
116 metadata: None,
117 },
118 );
119 drop(map_guard);
120
121 let mut cfg_guard = self.mock_batch_config().write().unwrap();
122 cfg_guard
123 .planned_completions_mut()
124 .insert(batch_id.to_string(), (want_output, want_error));
125 }
126
127 pub fn configure_failure(&self, batch_id: &str, is_immediate: bool) {
128 let mut guard = self.batches().write().unwrap();
131 if is_immediate {
132 guard.insert(
133 batch_id.to_string(),
134 Batch {
135 id: batch_id.to_string(),
136 object: "batch".to_string(),
137 endpoint: "/v1/chat/completions".to_string(),
138 errors: None,
139 input_file_id: format!("immediate_fail_for_{batch_id}"),
140 completion_window: "24h".to_string(),
141 status: BatchStatus::Failed,
142 output_file_id: None,
143 error_file_id: None,
144 created_at: 0,
145 in_progress_at: None,
146 expires_at: None,
147 finalizing_at: None,
148 completed_at: None,
149 failed_at: None,
150 expired_at: None,
151 cancelling_at: None,
152 cancelled_at: None,
153 request_counts: None,
154 metadata: None,
155 },
156 );
157 } else {
158 let mut cfg = self.mock_batch_config().write().unwrap();
161 cfg.fails_on_attempt_1_mut().insert(batch_id.to_string());
162 drop(cfg);
163
164 guard.insert(
165 batch_id.to_string(),
166 Batch {
167 id: batch_id.to_string(),
168 object: "batch".to_string(),
169 endpoint: "/v1/chat/completions".to_string(),
170 errors: None,
171 input_file_id: format!("eventual_fail_for_{batch_id}"),
172 completion_window: "24h".to_string(),
173 status: BatchStatus::InProgress,
174 output_file_id: None,
175 error_file_id: None,
176 created_at: 0,
177 in_progress_at: None,
178 expires_at: None,
179 finalizing_at: None,
180 completed_at: None,
181 failed_at: None,
182 expired_at: None,
183 cancelling_at: None,
184 cancelled_at: None,
185 request_counts: None,
186 metadata: None,
187 },
188 );
189 }
190 }
191
192}
193
194impl<E> MockLanguageModelClient<E>
195where
196 E: From<OpenAIClientError>
197 + From<std::io::Error>
198 + Debug
199 + Send
200 + Sync,
201{
202 pub fn new() -> Self {
203 if std::env::var("OPENAI_API_KEY").is_err() {
205 panic!("OPENAI_API_KEY environment variable not set (Mock client requires it for test)");
206 }
207
208 MockLanguageModelClientBuilder::<E>::default()
209 .build()
210 .expect("Failed to build mock client")
211 }
212
213 pub fn set_batch_to_inprogress_then_complete_with(
221 &self,
222 batch_id: &str,
223 want_output: bool,
224 want_error: bool,
225 ) {
226 {
227 let mut guard = self.batches().write().unwrap();
228 guard.insert(
229 batch_id.to_string(),
230 Batch {
231 id: batch_id.to_string(),
232 object: "batch".to_string(),
233 endpoint: "/v1/chat/completions".to_string(),
234 errors: None,
235 input_file_id: format!("some_input_file_for_{}", batch_id),
236 completion_window: "24h".to_string(),
237 status: BatchStatus::InProgress,
238 output_file_id: None,
239 error_file_id: None,
240 created_at: 0,
241 in_progress_at: None,
242 expires_at: None,
243 finalizing_at: None,
244 completed_at: None,
245 failed_at: None,
246 expired_at: None,
247 cancelling_at: None,
248 cancelled_at: None,
249 request_counts: None,
250 metadata: None,
251 },
252 );
253 }
254
255 let mut config_guard = self.mock_batch_config().write().unwrap();
258 config_guard.attempt_counters.insert(batch_id.to_string(), 0);
259
260 config_guard
262 .planned_completions
263 .insert(batch_id.to_string(), (want_output, want_error));
264 }
265}
266
267#[async_trait]
268impl<E> RetrieveBatchById for MockLanguageModelClient<E>
269where
270 E: From<OpenAIClientError>
271 + From<std::io::Error>
272 + Debug
273 + Send
274 + Sync,
275{
276 type Error = E;
277
278 async fn retrieve_batch(&self, batch_id: &str) -> Result<Batch, Self::Error> {
279 info!("Mock: retrieve_batch called with batch_id={batch_id}");
280
281 if batch_id.is_empty() {
283 let openai_err = OpenAIClientError::ApiError(OpenAIApiError {
284 message: "Cannot retrieve batch with empty batch_id".to_owned(),
285 r#type: None,
286 param: None,
287 code: None,
288 });
289 return Err(E::from(openai_err));
290 }
291 if batch_id == "trigger_api_error" {
292 let openai_err = OpenAIClientError::ApiError(OpenAIApiError {
293 message: "Simulated retrieve_batch OpenAI error".to_owned(),
294 r#type: None,
295 param: None,
296 code: None,
297 });
298 return Err(E::from(openai_err));
299 }
300 if batch_id == "trigger_other_error" {
301 let io_err = std::io::Error::new(
302 std::io::ErrorKind::Other,
303 "Simulated retrieve_batch non-OpenAI error",
304 );
305 return Err(E::from(io_err));
306 }
307
308 let (attempt_so_far, is_fail_on_attempt1, maybe_plan) = {
311 let mut cfg_guard = self.mock_batch_config().write().unwrap();
312 let count_ref = cfg_guard
314 .attempt_counters_mut()
315 .entry(batch_id.to_string())
316 .and_modify(|c| *c += 1)
317 .or_insert(1);
318
319 let current_attempt = *count_ref;
320 let fail1 = cfg_guard.fails_on_attempt_1().contains(batch_id);
321 let plan = cfg_guard.planned_completions().get(batch_id).cloned();
322 (current_attempt, fail1, plan)
323 };
324
325 let mut map_guard = self.batches().write().unwrap();
327 let batch_entry = map_guard.entry(batch_id.to_string()).or_insert_with(|| {
328 info!("Mock: auto-creating an InProgress batch for id={batch_id}");
329 Batch {
330 id: batch_id.to_string(),
331 object: "batch".to_string(),
332 endpoint: "/v1/chat/completions".to_string(),
333 errors: None,
334 input_file_id: format!("auto_{batch_id}"),
335 completion_window: "24h".to_string(),
336 status: BatchStatus::InProgress,
337 output_file_id: None,
338 error_file_id: None,
339 created_at: 0,
340 in_progress_at: None,
341 expires_at: None,
342 finalizing_at: None,
343 completed_at: None,
344 failed_at: None,
345 expired_at: None,
346 cancelling_at: None,
347 cancelled_at: None,
348 request_counts: None,
349 metadata: None,
350 }
351 });
352
353 if batch_id == "immediate_failure_id" {
355 batch_entry.status = BatchStatus::Failed;
356 }
357
358 if is_fail_on_attempt1 && attempt_so_far == 1 {
360 info!("Mock: forcibly failing {batch_id} on attempt=1 (fails_on_attempt_1)");
361 batch_entry.status = BatchStatus::Failed;
362 }
363
364 if batch_entry.status == BatchStatus::InProgress {
366 if let Some((want_output, want_error)) = maybe_plan {
367 info!("Mock: flipping {batch_id} from InProgress -> Completed (because of planned_completions).");
368 batch_entry.status = BatchStatus::Completed;
369
370 if want_output {
372 let out_id = "mock_out_file_id".to_string();
373 batch_entry.output_file_id = Some(out_id.clone());
374 self.files().write().unwrap().insert(
375 out_id,
376 Bytes::from(
377r#"{"id":"batch_req_mock_output","custom_id":"mock_out","response":{"status_code":200,"request_id":"resp_req_mock_output","body":{"id":"success-id","object":"chat.completion","created":0,"model":"test-model","choices":[],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}},"error":null}"#
378 ),
379 );
380 }
381 if want_error {
383 let err_id = "mock_err_file_id".to_string();
384 batch_entry.error_file_id = Some(err_id.clone());
385 self.files().write().unwrap().insert(
386 err_id,
387 Bytes::from(
388r#"{"id":"batch_req_mock_error","custom_id":"mock_err","response":{"status_code":400,"request_id":"resp_req_mock_error","body":{"error":{"message":"Some error message","type":"test_error","param":null,"code":null}}},"error":null}"#
389 ),
390 );
391 }
392 } else {
393 debug!("Mock: no planned completion => leaving status=InProgress for {batch_id}");
394 }
395 }
396
397 let final_batch = batch_entry.clone();
399 drop(map_guard);
400
401 debug!(
402 "Mock: retrieve_batch => returning final batch with status={:?}",
403 final_batch.status
404 );
405 Ok(final_batch)
406 }
407}
408
409#[async_trait]
410impl<E> GetBatchFileContent for MockLanguageModelClient<E>
411where
412 E: From<OpenAIClientError>
413 + From<std::io::Error>
414 + Debug
415 + Send
416 + Sync,
417{
418 type Error = E;
419
420 async fn file_content(&self, file_id: &str) -> Result<Bytes, Self::Error> {
421 info!("Mock: file_content called with file_id={}", file_id);
422
423 {
425 let mut guard = self.files().write().unwrap();
426 if file_id == "valid_file_id" && !guard.contains_key(file_id) {
427 debug!("Mock: auto-inserting 'valid_file_id' => 'some mock content'");
428 guard.insert("valid_file_id".to_string(), Bytes::from("some mock content"));
429 }
430 }
431
432 let files_guard = self.files().read().unwrap();
433 if let Some(data) = files_guard.get(file_id) {
434 debug!("Mock: Found file content for id={}", file_id);
435 Ok(data.clone())
436 } else {
437 warn!("Mock: No file found for id={}", file_id);
438 let openai_err = OpenAIClientError::ApiError(OpenAIApiError {
439 message: format!("No file found for id={}", file_id),
440 r#type: None,
441 param: None,
442 code: None,
443 });
444 Err(E::from(openai_err))
445 }
446 }
447}
448
449#[async_trait]
450impl<E> CreateBatch for MockLanguageModelClient<E>
451where
452 E: From<OpenAIClientError>
453 + From<std::io::Error>
454 + Debug
455 + Send
456 + Sync,
457{
458 type Error = E;
459
460 async fn create_batch(&self, input_file_id: &str) -> Result<Batch, Self::Error> {
461 info!("Mock: create_batch called with input_file_id={}", input_file_id);
462
463 if input_file_id.is_empty() {
465 let openai_err = OpenAIClientError::ApiError(OpenAIApiError {
466 message: "Cannot create batch with empty input_file_id".to_string(),
467 r#type: None,
468 param: None,
469 code: None,
470 });
471 return Err(E::from(openai_err));
472 }
473 if input_file_id == "trigger_api_error" {
474 let openai_err = OpenAIClientError::ApiError(OpenAIApiError {
475 message: "Simulated OpenAI error (trigger_api_error)".to_string(),
476 r#type: None,
477 param: None,
478 code: None,
479 });
480 return Err(E::from(openai_err));
481 }
482 if input_file_id == "trigger_other_error" {
483 let io_err = std::io::Error::new(std::io::ErrorKind::Other, "Simulated other error");
484 return Err(E::from(io_err));
485 }
486
487 let mock_id = format!("mock_batch_id_for_{}", input_file_id);
489
490 let mut map_guard = self.batches().write().unwrap();
491 if let Some(existing) = map_guard.get(&mock_id) {
492 return Ok(existing.clone());
495 }
496
497 let new_batch = Batch {
499 id: mock_id.clone(),
500 object: "batch".to_string(),
501 endpoint: "/v1/chat/completions".to_string(),
502 errors: None,
503 input_file_id: input_file_id.to_string(),
504 completion_window: "24h".to_string(),
505 status: BatchStatus::InProgress,
506 output_file_id: None,
507 error_file_id: None,
508 created_at: 0,
509 in_progress_at: None,
510 expires_at: None,
511 finalizing_at: None,
512 completed_at: None,
513 failed_at: None,
514 expired_at: None,
515 cancelling_at: None,
516 cancelled_at: None,
517 request_counts: None,
518 metadata: None,
519 };
520 map_guard.insert(mock_id.clone(), new_batch.clone());
521 Ok(new_batch)
522 }
523}
524
525#[async_trait]
526impl<E> WaitForBatchCompletion for MockLanguageModelClient<E>
527where
528 E: From<OpenAIClientError>
529 + From<std::io::Error>
530 + Debug
531 + Send
532 + Sync,
533{
534 type Error = E;
535
536 async fn wait_for_batch_completion(&self, batch_id: &str) -> Result<Batch, Self::Error> {
537 info!("Mock: wait_for_batch_completion called with batch_id={}", batch_id);
538
539 for attempt in 0..3 {
540 debug!("Mock: attempt #{} checking batch_id={}", attempt, batch_id);
541
542 let batch = self.retrieve_batch(batch_id).await?;
543 match batch.status {
544 BatchStatus::Completed => {
545 debug!("Mock: batch is Completed => returning Ok(batch)");
546 return Ok(batch);
547 }
548 BatchStatus::Failed => {
549 warn!("Mock: batch is Failed => returning error");
550 let openai_err = OpenAIClientError::ApiError(OpenAIApiError {
551 message: "Batch failed".to_owned(),
552 r#type: None,
553 param: None,
554 code: None,
555 });
556 return Err(E::from(openai_err));
557 }
558 _ => {
560 info!("Mock: batch has status={:?}, continuing loop", batch.status);
561 }
562 }
563
564 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
565 }
566
567 let openai_err = OpenAIClientError::ApiError(OpenAIApiError {
569 message: format!("Timed out waiting for batch {batch_id} to complete"),
570 r#type: None,
571 param: None,
572 code: None,
573 });
574 Err(E::from(openai_err))
575 }
576}
577
578#[async_trait]
579impl<E> UploadBatchFileCore for MockLanguageModelClient<E>
580where
581 E: From<OpenAIClientError>
582 + From<std::io::Error>
583 + Debug
584 + Send
585 + Sync,
586{
587 type Error = E;
588
589 async fn upload_batch_file_path(
590 &self,
591 file_path: &Path
592 ) -> Result<OpenAIFile, Self::Error> {
593 info!("Mock: upload_batch_file_path called with path={:?}", file_path);
594
595 let path_str = file_path.display().to_string();
596
597 if path_str.contains("trigger_api_error") {
599 warn!("Mock: forcibly returning an OpenAIClientError for file upload (trigger_api_error detected)");
600 let openai_err = OpenAIClientError::ApiError(OpenAIApiError {
601 message: "Simulated upload error (mocked as openai error)".to_string(),
602 r#type: None,
603 param: None,
604 code: None,
605 });
606 return Err(E::from(openai_err));
607 }
608
609 if path_str.contains("trigger_other_error") {
611 warn!("Mock: forcibly returning an IoError for file upload (trigger_other_error detected)");
612 let io_err = std::io::Error::new(
613 std::io::ErrorKind::Other,
614 "Simulated other error triggered in upload_batch_file_path"
615 );
616 return Err(E::from(io_err));
617 }
618
619 if *self.fail_on_file_create_openai_error() {
621 warn!("Mock: forcibly returning an OpenAIClientError for file upload due to fail_on_file_create_openai_error=true");
622 let openai_err = OpenAIClientError::ApiError(OpenAIApiError {
623 message: "Simulated upload error (mocked as openai error)".to_string(),
624 r#type: None,
625 param: None,
626 code: None,
627 });
628 return Err(E::from(openai_err));
629 }
630
631 if *self.fail_on_file_create_other_error() {
633 warn!("Mock: forcibly returning an IoError for file upload due to fail_on_file_create_other_error=true");
634 let io_err = std::io::Error::new(
635 std::io::ErrorKind::Other,
636 "Simulated other error triggered in upload_batch_file_path"
637 );
638 return Err(E::from(io_err));
639 }
640
641 if !file_path.exists() {
643 let io_err = std::io::Error::new(
644 std::io::ErrorKind::NotFound,
645 format!("File not found at {:?}", file_path),
646 );
647 error!("Mock: returning IoError for missing file {:?}", file_path);
648 return Err(E::from(io_err));
649 }
650
651 let file_id = format!("mock_file_id_{}", path_str);
653 debug!("Mock: Storing synthetic file content for file_id={}", file_id);
654
655 {
656 let mut files_guard = self.files().write().unwrap();
657 files_guard.insert(file_id.clone(), Bytes::from("mock file content"));
658 }
659
660 #[allow(deprecated)]
661 let openai_file = OpenAIFile {
662 id: file_id.clone(),
663 bytes: 123,
664 created_at: 0,
665 filename: file_path
666 .file_name()
667 .map(|os| os.to_string_lossy().into_owned())
668 .unwrap_or_else(|| "unknown".to_string()),
669 purpose: OpenAIFilePurpose::Batch,
670 object: "file".to_string(),
671 status: Some("uploaded".to_string()),
672 status_details: None,
673 };
674
675 Ok(openai_file)
676 }
677}
678
679
680#[async_trait]
682impl<E> LanguageModelClientInterface<E> for MockLanguageModelClient<E>
683where
684 E: From<OpenAIClientError>
685 + From<BatchDownloadError>
686 + From<std::io::Error>
687 + From<BatchMetadataError>
688 + Debug
689 + Send
690 + Sync,
691{
692 }
694
695#[cfg(test)]
696mod mock_client_handle_tests {
697 use super::*;
698 use std::sync::Arc;
699 use tempfile::tempdir;
700 use tracing::{debug, error, info, trace, warn};
701
702 #[traced_test]
714 fn test_new_openai_client_handle_env_var_missing() {
715 info!("Beginning test_new_openai_client_handle_env_var_missing");
716
717 let original_api_key = std::env::var("OPENAI_API_KEY").ok();
718 if original_api_key.is_some() {
719 trace!("OPENAI_API_KEY is currently set; removing it for this test...");
720 unsafe {
721 std::env::remove_var("OPENAI_API_KEY");
722 }
723 }
724
725 if std::env::var("OPENAI_API_KEY").is_ok() {
727 warn!("Skipping test_new_openai_client_handle_env_var_missing because we couldn't unset OPENAI_API_KEY in this environment.");
728 return;
729 }
730
731 let result = std::panic::catch_unwind(|| {
733 MockLanguageModelClient::<MockBatchClientError>::new();
734 });
735 debug!("Result from calling new() without the env var: {:?}", result);
736
737 assert!(
738 result.is_err(),
739 "Expected new() to panic when OPENAI_API_KEY is unset"
740 );
741
742 if let Some(val) = original_api_key {
744 trace!("Restoring OPENAI_API_KEY...");
745 unsafe {
746 std::env::set_var("OPENAI_API_KEY", val);
747 }
748 }
749
750 info!("test_new_openai_client_handle_env_var_missing passed (or skipped).");
751 }
752
753
754 #[traced_test]
755 fn test_new_openai_client_handle_env_var_present() {
756 info!("Beginning test_new_openai_client_handle_env_var_present");
757
758 let original_api_key = std::env::var("OPENAI_API_KEY").ok();
760 let test_value = "test_openai_api_key_12345";
761
762 trace!("Temporarily setting OPENAI_API_KEY to {}", test_value);
763 unsafe {
764 std::env::set_var("OPENAI_API_KEY", test_value);
765 }
766
767 let result = std::panic::catch_unwind(|| {
768 MockLanguageModelClient::<MockBatchClientError>::new()
769 });
770 debug!("Result from calling new() with env var set: {:?}", result);
771
772 assert!(
774 result.is_ok(),
775 "Expected new() to succeed when OPENAI_API_KEY is set"
776 );
777 let handle = result.unwrap();
778 debug!("Created handle: {:?}", handle);
779
780 match original_api_key {
782 Some(val) => {
783 trace!("Restoring original OPENAI_API_KEY value...");
784 unsafe {
785 std::env::set_var("OPENAI_API_KEY", val);
786 }
787 }
788 None => {
789 trace!("Removing OPENAI_API_KEY to restore no-value state...");
790 unsafe {
791 std::env::remove_var("OPENAI_API_KEY");
792 }
793 }
794 }
795
796 info!("test_new_openai_client_handle_env_var_present passed.");
797 }
798
799 #[traced_test]
800 fn test_delegate_methods() {
801 info!("Beginning test_delegate_methods");
802
803 let original_api_key = std::env::var("OPENAI_API_KEY").ok();
805 unsafe {
806 std::env::set_var("OPENAI_API_KEY", "mock_test_key");
807 }
808
809 let handle: MockLanguageModelClient<MockBatchClientError> = std::panic::catch_unwind(|| {
810 MockLanguageModelClient::<MockBatchClientError>::new()
811 })
812 .expect("Should not panic for mock_test_key");
813
814 debug!("Successfully created handle: {:?}", handle);
815
816 let _batches = handle.batches();
819 let _files = handle.files();
820
821 info!("Handle's delegated methods (batches, files) are callable without error.");
822
823 match original_api_key {
825 Some(val) => unsafe { std::env::set_var("OPENAI_API_KEY", val) },
826 None => unsafe { std::env::remove_var("OPENAI_API_KEY") },
827 }
828
829 info!("test_delegate_methods passed.");
830 }
831
832 #[traced_test]
833 fn test_aggregator_trait_compatibility() {
834 info!("Beginning test_aggregator_trait_compatibility");
835 trace!("Ensuring that `MockLanguageModelClient` can be used as `LanguageModelClientInterface` object.");
836
837 let original_api_key = std::env::var("OPENAI_API_KEY").ok();
838 unsafe {
839 std::env::set_var("OPENAI_API_KEY", "some_mock_key");
840 }
841
842 let handle_arc = Arc::new(std::panic::catch_unwind(|| {
843 MockLanguageModelClient::<MockBatchClientError>::new()
844 })
845 .expect("Should not panic with some_mock_key"));
846
847 let client_interface_arc: Arc<dyn LanguageModelClientInterface<MockBatchClientError>> =
848 handle_arc as Arc<dyn LanguageModelClientInterface<MockBatchClientError>>;
849 debug!(
850 "We can coerce the handle into the aggregator trait object: {:?}",
851 client_interface_arc
852 );
853
854 match original_api_key {
858 Some(val) => unsafe { std::env::set_var("OPENAI_API_KEY", val) },
859 None => unsafe { std::env::remove_var("OPENAI_API_KEY") },
860 }
861
862 info!("test_aggregator_trait_compatibility passed.");
863 }
864
865 #[traced_test]
866 async fn test_mock_language_model_client_basic_usage() {
867 info!("Starting test_mock_language_model_client_basic_usage");
868
869 let mock = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
871 .build()
872 .expect("Failed to build mock client");
873
874 mock.configure_inprogress_then_complete_with("mock_batch_id_for_example_file_id", false, false);
877
878 info!("Creating a batch via the mock client...");
879 let created = mock.create_batch("example_file_id").await;
880 assert!(created.is_ok(), "Should create batch successfully");
881 let created_batch = created.unwrap();
882 pretty_assert_eq!(created_batch.status, BatchStatus::InProgress);
883
884 info!("Retrieving the newly created batch...");
885 let retrieved = mock.retrieve_batch(&created_batch.id).await;
886 assert!(retrieved.is_ok(), "Should retrieve batch successfully");
887
888 info!("Waiting for batch completion...");
889 let wait_result = mock.wait_for_batch_completion(&created_batch.id).await;
890 debug!("Result from wait_for_batch_completion: {:?}", wait_result);
891
892 assert!(wait_result.is_ok(), "Should complete batch successfully");
895 let completed_batch = wait_result.unwrap();
896 pretty_assert_eq!(completed_batch.status, BatchStatus::Completed);
897
898 info!("Trying to retrieve a non-existent file...");
899 let file_content_result = mock.file_content("non_existent_file").await;
900 assert!(file_content_result.is_err(), "Should fail for unknown file ID");
901 }
902}