Skip to main content

agents_test/
lib.rs

1//! Test helpers for [`agents`].
2//!
3//! `agents-test` contains opt-in support for provider-backed integration tests,
4//! including local Ollama container helpers and shared runner builders.
5
6mod ollama_container;
7
8use std::collections::HashSet;
9use std::sync::{Arc, Once};
10
11use agents::LlmRunner;
12use agents::error::{Error as LlmError, LlmResult};
13use agents::provider::anthropic::{Anthropic, AnthropicConfig};
14use agents::provider::cloudflare::workers_ai::{WorkersAI, WorkersAIConfig};
15use agents::provider::ollama::{Ollama, OllamaConfig};
16use agents::provider::openai::{OpenAI, OpenAIConfig};
17use agents::provider::openrouter::{OpenRouter, OpenRouterConfig};
18use ollama_container::LlmContainer;
19use tokio::sync::{Mutex, OnceCell};
20
21static DOTENV: Once = Once::new();
22static OLLAMA_CONTEXT: OnceCell<Arc<TestContext>> = OnceCell::const_new();
23
24pub fn init_tracing() {
25    // Intentionally a no-op. The caller owns logging/tracing configuration.
26}
27
28pub fn init_test_env() {
29    DOTENV.call_once(|| {
30        let _ = dotenvy::dotenv();
31    });
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum TestProvider {
36    Ollama,
37}
38
39pub struct TestContext {
40    provider: TestProvider,
41    base_url: String,
42    runtime: TestRuntime,
43}
44
45enum TestRuntime {
46    Ollama(SharedOllamaServer),
47}
48
49struct SharedOllamaServer {
50    container: LlmContainer,
51    ensured_models: Mutex<HashSet<String>>,
52}
53
54impl TestContext {
55    pub async fn shared(provider: TestProvider) -> LlmResult<Arc<Self>> {
56        init_tracing();
57        init_test_env();
58
59        match provider {
60            TestProvider::Ollama => OLLAMA_CONTEXT
61                .get_or_try_init(|| async {
62                    let container = LlmContainer::start_ollama().await?;
63                    Ok(Arc::new(Self {
64                        provider,
65                        base_url: container.base_url.clone(),
66                        runtime: TestRuntime::Ollama(SharedOllamaServer {
67                            container,
68                            ensured_models: Mutex::new(HashSet::new()),
69                        }),
70                    }))
71                })
72                .await
73                .map(Arc::clone),
74        }
75    }
76
77    pub fn provider(&self) -> TestProvider {
78        self.provider
79    }
80
81    pub fn base_url(&self) -> &str {
82        &self.base_url
83    }
84
85    pub async fn ensure_model(&self, model: &str) -> LlmResult<()> {
86        match &self.runtime {
87            TestRuntime::Ollama(server) => {
88                let mut ensured = server.ensured_models.lock().await;
89                if ensured.contains(model) {
90                    return Ok(());
91                }
92
93                server.container.ensure_model(model).await?;
94                ensured.insert(model.to_string());
95                Ok(())
96            }
97        }
98    }
99
100    pub async fn runner_for_model(&self, model: &str) -> LlmResult<LlmRunner> {
101        self.ensure_model(model).await?;
102        Ok(match self.provider {
103            TestProvider::Ollama => LlmRunner::builder()
104                .add_provider(self.ollama_provider_for_model(model).await?)
105                .build(),
106        })
107    }
108
109    pub async fn ollama_provider_for_model(&self, model: &str) -> LlmResult<Ollama> {
110        self.ensure_model(model).await?;
111        Ok(Ollama::new(
112            OllamaConfig::new(model.to_string()).with_base_url(self.base_url.clone()),
113        ))
114    }
115}
116
117pub fn required_test_env(name: &str) -> LlmResult<String> {
118    init_test_env();
119    std::env::var(name).map_err(|_| LlmError::Configuration(format!("missing test env var {name}")))
120}
121
122pub fn optional_test_env(name: &str) -> Option<String> {
123    init_test_env();
124    std::env::var(name).ok()
125}
126
127pub fn openai_provider_for_model(model: &str) -> LlmResult<OpenAI> {
128    let api_key = required_test_env("OPENAI_API_KEY")?;
129    let config = OpenAIConfig::new(api_key, model.to_string()).map_err(LlmError::OpenAIConfig)?;
130    Ok(OpenAI::new(config))
131}
132
133pub fn anthropic_provider_for_model(model: &str) -> LlmResult<Anthropic> {
134    let api_key = required_test_env("ANTHROPIC_API_KEY")?;
135    let config =
136        AnthropicConfig::new(api_key, model.to_string()).map_err(LlmError::AnthropicConfig)?;
137    Ok(Anthropic::new(config))
138}
139
140pub fn openrouter_provider_for_model(model: &str) -> LlmResult<OpenRouter> {
141    let api_key = required_test_env("OPENROUTER_API_KEY")?;
142    let config =
143        OpenRouterConfig::new(api_key, model.to_string()).map_err(LlmError::OpenRouterConfig)?;
144    Ok(OpenRouter::new(config))
145}
146
147pub fn runner_with_openai_model(model: &str) -> LlmResult<LlmRunner> {
148    Ok(LlmRunner::builder()
149        .add_provider(openai_provider_for_model(model)?)
150        .build())
151}
152
153pub fn runner_with_anthropic_model(model: &str) -> LlmResult<LlmRunner> {
154    Ok(LlmRunner::builder()
155        .add_provider(anthropic_provider_for_model(model)?)
156        .build())
157}
158
159pub fn runner_with_openrouter_model(model: &str) -> LlmResult<LlmRunner> {
160    Ok(LlmRunner::builder()
161        .add_provider(openrouter_provider_for_model(model)?)
162        .build())
163}
164
165pub fn workers_ai_provider_for_model(model: &str) -> LlmResult<WorkersAI> {
166    let api_token = optional_test_env("BORG_LLM_WORKERS_AI_API_TOKEN")
167        .or_else(|| optional_test_env("CLOUDFLARE_API_TOKEN"))
168        .ok_or_else(|| {
169            LlmError::Configuration(
170                "missing Workers AI API token (expected BORG_LLM_WORKERS_AI_API_TOKEN or CLOUDFLARE_API_TOKEN)".to_string(),
171            )
172        })?;
173    let account_id = optional_test_env("CLOUDFLARE_ACCOUNT_ID").ok_or_else(|| {
174        LlmError::Configuration(
175            "missing Workers AI account id (expected CLOUDFLARE_ACCOUNT_ID)".to_string(),
176        )
177    })?;
178
179    let mut config = WorkersAIConfig::new(api_token, account_id, model.to_string())
180        .map_err(LlmError::WorkersAIConfig)?;
181
182    if let Some(base_url) = optional_test_env("BORG_LLM_WORKERS_AI_BASE_URL") {
183        config = config.with_base_url(base_url);
184    }
185
186    Ok(WorkersAI::new(config))
187}
188
189pub fn runner_with_workers_ai_model(model: &str) -> LlmResult<LlmRunner> {
190    Ok(LlmRunner::builder()
191        .add_provider(workers_ai_provider_for_model(model)?)
192        .build())
193}