1mod ollama_container;
7
8use std::collections::HashSet;
9use std::sync::{Arc, Once};
10
11use agents::LlmRunner;
12use agents::error::{Error as LlmError, LlmResult};
13use agents::provider::anthropic::{Anthropic, AnthropicConfig};
14use agents::provider::cloudflare::workers_ai::{WorkersAI, WorkersAIConfig};
15use agents::provider::ollama::{Ollama, OllamaConfig};
16use agents::provider::openai::{OpenAI, OpenAIConfig};
17use agents::provider::openrouter::{OpenRouter, OpenRouterConfig};
18use ollama_container::LlmContainer;
19use tokio::sync::{Mutex, OnceCell};
20
21static DOTENV: Once = Once::new();
22static OLLAMA_CONTEXT: OnceCell<Arc<TestContext>> = OnceCell::const_new();
23
24pub fn init_tracing() {
25 }
27
28pub fn init_test_env() {
29 DOTENV.call_once(|| {
30 let _ = dotenvy::dotenv();
31 });
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum TestProvider {
36 Ollama,
37}
38
39pub struct TestContext {
40 provider: TestProvider,
41 base_url: String,
42 runtime: TestRuntime,
43}
44
45enum TestRuntime {
46 Ollama(SharedOllamaServer),
47}
48
49struct SharedOllamaServer {
50 container: LlmContainer,
51 ensured_models: Mutex<HashSet<String>>,
52}
53
54impl TestContext {
55 pub async fn shared(provider: TestProvider) -> LlmResult<Arc<Self>> {
56 init_tracing();
57 init_test_env();
58
59 match provider {
60 TestProvider::Ollama => OLLAMA_CONTEXT
61 .get_or_try_init(|| async {
62 let container = LlmContainer::start_ollama().await?;
63 Ok(Arc::new(Self {
64 provider,
65 base_url: container.base_url.clone(),
66 runtime: TestRuntime::Ollama(SharedOllamaServer {
67 container,
68 ensured_models: Mutex::new(HashSet::new()),
69 }),
70 }))
71 })
72 .await
73 .map(Arc::clone),
74 }
75 }
76
77 pub fn provider(&self) -> TestProvider {
78 self.provider
79 }
80
81 pub fn base_url(&self) -> &str {
82 &self.base_url
83 }
84
85 pub async fn ensure_model(&self, model: &str) -> LlmResult<()> {
86 match &self.runtime {
87 TestRuntime::Ollama(server) => {
88 let mut ensured = server.ensured_models.lock().await;
89 if ensured.contains(model) {
90 return Ok(());
91 }
92
93 server.container.ensure_model(model).await?;
94 ensured.insert(model.to_string());
95 Ok(())
96 }
97 }
98 }
99
100 pub async fn runner_for_model(&self, model: &str) -> LlmResult<LlmRunner> {
101 self.ensure_model(model).await?;
102 Ok(match self.provider {
103 TestProvider::Ollama => LlmRunner::builder()
104 .add_provider(self.ollama_provider_for_model(model).await?)
105 .build(),
106 })
107 }
108
109 pub async fn ollama_provider_for_model(&self, model: &str) -> LlmResult<Ollama> {
110 self.ensure_model(model).await?;
111 Ok(Ollama::new(
112 OllamaConfig::new(model.to_string()).with_base_url(self.base_url.clone()),
113 ))
114 }
115}
116
117pub fn required_test_env(name: &str) -> LlmResult<String> {
118 init_test_env();
119 std::env::var(name).map_err(|_| LlmError::Configuration(format!("missing test env var {name}")))
120}
121
122pub fn optional_test_env(name: &str) -> Option<String> {
123 init_test_env();
124 std::env::var(name).ok()
125}
126
127pub fn openai_provider_for_model(model: &str) -> LlmResult<OpenAI> {
128 let api_key = required_test_env("OPENAI_API_KEY")?;
129 let config = OpenAIConfig::new(api_key, model.to_string()).map_err(LlmError::OpenAIConfig)?;
130 Ok(OpenAI::new(config))
131}
132
133pub fn anthropic_provider_for_model(model: &str) -> LlmResult<Anthropic> {
134 let api_key = required_test_env("ANTHROPIC_API_KEY")?;
135 let config =
136 AnthropicConfig::new(api_key, model.to_string()).map_err(LlmError::AnthropicConfig)?;
137 Ok(Anthropic::new(config))
138}
139
140pub fn openrouter_provider_for_model(model: &str) -> LlmResult<OpenRouter> {
141 let api_key = required_test_env("OPENROUTER_API_KEY")?;
142 let config =
143 OpenRouterConfig::new(api_key, model.to_string()).map_err(LlmError::OpenRouterConfig)?;
144 Ok(OpenRouter::new(config))
145}
146
147pub fn runner_with_openai_model(model: &str) -> LlmResult<LlmRunner> {
148 Ok(LlmRunner::builder()
149 .add_provider(openai_provider_for_model(model)?)
150 .build())
151}
152
153pub fn runner_with_anthropic_model(model: &str) -> LlmResult<LlmRunner> {
154 Ok(LlmRunner::builder()
155 .add_provider(anthropic_provider_for_model(model)?)
156 .build())
157}
158
159pub fn runner_with_openrouter_model(model: &str) -> LlmResult<LlmRunner> {
160 Ok(LlmRunner::builder()
161 .add_provider(openrouter_provider_for_model(model)?)
162 .build())
163}
164
165pub fn workers_ai_provider_for_model(model: &str) -> LlmResult<WorkersAI> {
166 let api_token = optional_test_env("BORG_LLM_WORKERS_AI_API_TOKEN")
167 .or_else(|| optional_test_env("CLOUDFLARE_API_TOKEN"))
168 .ok_or_else(|| {
169 LlmError::Configuration(
170 "missing Workers AI API token (expected BORG_LLM_WORKERS_AI_API_TOKEN or CLOUDFLARE_API_TOKEN)".to_string(),
171 )
172 })?;
173 let account_id = optional_test_env("CLOUDFLARE_ACCOUNT_ID").ok_or_else(|| {
174 LlmError::Configuration(
175 "missing Workers AI account id (expected CLOUDFLARE_ACCOUNT_ID)".to_string(),
176 )
177 })?;
178
179 let mut config = WorkersAIConfig::new(api_token, account_id, model.to_string())
180 .map_err(LlmError::WorkersAIConfig)?;
181
182 if let Some(base_url) = optional_test_env("BORG_LLM_WORKERS_AI_BASE_URL") {
183 config = config.with_base_url(base_url);
184 }
185
186 Ok(WorkersAI::new(config))
187}
188
189pub fn runner_with_workers_ai_model(model: &str) -> LlmResult<LlmRunner> {
190 Ok(LlmRunner::builder()
191 .add_provider(workers_ai_provider_for_model(model)?)
192 .build())
193}