git_iris/llm_providers/
test.rs1use super::{LLMProvider, LLMProviderConfig, ProviderMetadata};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use std::sync::{
5 atomic::{AtomicU64, AtomicUsize, Ordering},
6 Arc,
7};
8use std::time::Duration;
9use tokio::time::sleep;
10
11#[derive(Clone)]
12pub struct TestLLMProvider {
13 config: LLMProviderConfig,
14 fail_count: Arc<AtomicUsize>,
15 delay: Arc<AtomicU64>,
16 total_calls: Arc<AtomicUsize>,
17 response: Option<String>,
18 bad_response: Option<String>,
19 json_validation_failures: Arc<AtomicUsize>,
20}
21
22impl TestLLMProvider {
23 pub fn new(config: LLMProviderConfig) -> Self {
25 Self {
26 config,
27 fail_count: Arc::new(AtomicUsize::new(0)),
28 delay: Arc::new(AtomicU64::new(0)),
29 total_calls: Arc::new(AtomicUsize::new(0)),
30 response: None,
31 bad_response: None,
32 json_validation_failures: Arc::new(AtomicUsize::new(0)),
33 }
34 }
35
36 pub fn set_fail_count(&self, count: usize) {
37 self.fail_count.store(count, Ordering::SeqCst);
38 }
39
40 pub fn set_delay(&self, delay_ms: u64) {
41 self.delay.store(delay_ms, Ordering::SeqCst);
42 }
43
44 pub fn get_total_calls(&self) -> usize {
45 self.total_calls.load(Ordering::SeqCst)
46 }
47
48 pub fn set_response(&mut self, response: String) {
49 self.response = Some(response);
50 }
51
52 pub fn set_bad_response(&mut self, bad_response: String) {
53 self.bad_response = Some(bad_response);
54 }
55
56 pub fn set_json_validation_failures(&self, count: usize) {
57 self.json_validation_failures.store(count, Ordering::SeqCst);
58 }
59
60 pub fn reset(&self) {
61 self.fail_count.store(0, Ordering::SeqCst);
62 self.delay.store(0, Ordering::SeqCst);
63 self.total_calls.store(0, Ordering::SeqCst);
64 self.json_validation_failures.store(0, Ordering::SeqCst);
65 }
66}
67
68#[async_trait]
69impl LLMProvider for TestLLMProvider {
70 async fn generate_message(&self, system_prompt: &str, user_prompt: &str) -> Result<String> {
72 let total_calls = self.total_calls.fetch_add(1, Ordering::SeqCst);
73 println!(
74 "TestLLMProvider: generate_message called (total calls: {})",
75 total_calls + 1
76 );
77
78 let delay = self.delay.load(Ordering::SeqCst);
79 if delay > 0 {
80 println!("TestLLMProvider: Delaying for {delay} ms");
81 sleep(Duration::from_millis(delay)).await;
82 }
83
84 let fail_count = self.fail_count.load(Ordering::SeqCst);
85 if total_calls < fail_count {
86 println!("TestLLMProvider: Simulating failure");
87 Err(anyhow!("Simulated failure"))
88 } else {
89 println!("TestLLMProvider: Generating success response");
90 let json_validation_failures = self.json_validation_failures.load(Ordering::SeqCst);
91 if total_calls < json_validation_failures {
92 if let Some(bad_response) = &self.bad_response {
93 Ok(bad_response.clone())
94 } else {
95 Err(anyhow!("Simulated JSON validation failure"))
96 }
97 } else if let Some(response) = &self.response {
98 Ok(response.clone())
99 } else {
100 Ok(format!(
101 "Test response from model '{}'. System prompt: '{}', User prompt: '{}'",
102 self.config.model,
103 system_prompt.replace('\'', "\\'"),
104 user_prompt.replace('\'', "\\'")
105 ))
106 }
107 }
108 }
109}
110
111pub(super) fn get_metadata() -> ProviderMetadata {
112 ProviderMetadata {
113 name: "Test",
114 default_model: "test-model",
115 default_token_limit: 1000,
116 requires_api_key: false,
117 }
118}