git_iris/llm_providers/
test.rs

1use super::{LLMProvider, LLMProviderConfig, ProviderMetadata};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use std::sync::{
5    atomic::{AtomicU64, AtomicUsize, Ordering},
6    Arc,
7};
8use std::time::Duration;
9use tokio::time::sleep;
10
11#[derive(Clone)]
12pub struct TestLLMProvider {
13    config: LLMProviderConfig,
14    fail_count: Arc<AtomicUsize>,
15    delay: Arc<AtomicU64>,
16    total_calls: Arc<AtomicUsize>,
17    response: Option<String>,
18    bad_response: Option<String>,
19    json_validation_failures: Arc<AtomicUsize>,
20}
21
22impl TestLLMProvider {
23    /// Creates a new instance of `TestLLMProvider` with the given configuration
24    pub fn new(config: LLMProviderConfig) -> Self {
25        Self {
26            config,
27            fail_count: Arc::new(AtomicUsize::new(0)),
28            delay: Arc::new(AtomicU64::new(0)),
29            total_calls: Arc::new(AtomicUsize::new(0)),
30            response: None,
31            bad_response: None,
32            json_validation_failures: Arc::new(AtomicUsize::new(0)),
33        }
34    }
35
36    pub fn set_fail_count(&self, count: usize) {
37        self.fail_count.store(count, Ordering::SeqCst);
38    }
39
40    pub fn set_delay(&self, delay_ms: u64) {
41        self.delay.store(delay_ms, Ordering::SeqCst);
42    }
43
44    pub fn get_total_calls(&self) -> usize {
45        self.total_calls.load(Ordering::SeqCst)
46    }
47
48    pub fn set_response(&mut self, response: String) {
49        self.response = Some(response);
50    }
51
52    pub fn set_bad_response(&mut self, bad_response: String) {
53        self.bad_response = Some(bad_response);
54    }
55
56    pub fn set_json_validation_failures(&self, count: usize) {
57        self.json_validation_failures.store(count, Ordering::SeqCst);
58    }
59
60    pub fn reset(&self) {
61        self.fail_count.store(0, Ordering::SeqCst);
62        self.delay.store(0, Ordering::SeqCst);
63        self.total_calls.store(0, Ordering::SeqCst);
64        self.json_validation_failures.store(0, Ordering::SeqCst);
65    }
66}
67
68#[async_trait]
69impl LLMProvider for TestLLMProvider {
70    /// Generates a message using the Test provider (returns model name + it's own prompts as the message)
71    async fn generate_message(&self, system_prompt: &str, user_prompt: &str) -> Result<String> {
72        let total_calls = self.total_calls.fetch_add(1, Ordering::SeqCst);
73        println!(
74            "TestLLMProvider: generate_message called (total calls: {})",
75            total_calls + 1
76        );
77
78        let delay = self.delay.load(Ordering::SeqCst);
79        if delay > 0 {
80            println!("TestLLMProvider: Delaying for {delay} ms");
81            sleep(Duration::from_millis(delay)).await;
82        }
83
84        let fail_count = self.fail_count.load(Ordering::SeqCst);
85        if total_calls < fail_count {
86            println!("TestLLMProvider: Simulating failure");
87            Err(anyhow!("Simulated failure"))
88        } else {
89            println!("TestLLMProvider: Generating success response");
90            let json_validation_failures = self.json_validation_failures.load(Ordering::SeqCst);
91            if total_calls < json_validation_failures {
92                if let Some(bad_response) = &self.bad_response {
93                    Ok(bad_response.clone())
94                } else {
95                    Err(anyhow!("Simulated JSON validation failure"))
96                }
97            } else if let Some(response) = &self.response {
98                Ok(response.clone())
99            } else {
100                Ok(format!(
101                    "Test response from model '{}'. System prompt: '{}', User prompt: '{}'",
102                    self.config.model,
103                    system_prompt.replace('\'', "\\'"),
104                    user_prompt.replace('\'', "\\'")
105                ))
106            }
107        }
108    }
109}
110
111pub(super) fn get_metadata() -> ProviderMetadata {
112    ProviderMetadata {
113        name: "Test",
114        default_model: "test-model",
115        default_token_limit: 1000,
116        requires_api_key: false,
117    }
118}