1use anyhow::Result;
7use rig::{
8 agent::{Agent, AgentBuilder, PromptResponse},
9 client::{CompletionClient, ProviderClient},
10 completion::{Prompt, PromptError},
11 providers::{anthropic, gemini, openai},
12};
13
14use crate::providers::Provider;
15
16pub type OpenAIModel = openai::completion::CompletionModel;
18pub type AnthropicModel = anthropic::completion::CompletionModel;
19pub type GeminiModel = gemini::completion::CompletionModel;
20
21pub type OpenAIBuilder = AgentBuilder<OpenAIModel>;
23pub type AnthropicBuilder = AgentBuilder<AnthropicModel>;
24pub type GeminiBuilder = AgentBuilder<GeminiModel>;
25
26pub enum DynAgent {
28 OpenAI(Agent<OpenAIModel>),
29 Anthropic(Agent<AnthropicModel>),
30 Gemini(Agent<GeminiModel>),
31}
32
33impl DynAgent {
34 pub async fn prompt(&self, msg: &str) -> Result<String, PromptError> {
36 match self {
37 Self::OpenAI(a) => a.prompt(msg).await,
38 Self::Anthropic(a) => a.prompt(msg).await,
39 Self::Gemini(a) => a.prompt(msg).await,
40 }
41 }
42
43 pub async fn prompt_multi_turn(&self, msg: &str, depth: usize) -> Result<String, PromptError> {
45 match self {
46 Self::OpenAI(a) => a.prompt(msg).multi_turn(depth).await,
47 Self::Anthropic(a) => a.prompt(msg).multi_turn(depth).await,
48 Self::Gemini(a) => a.prompt(msg).multi_turn(depth).await,
49 }
50 }
51
52 pub async fn prompt_extended(
54 &self,
55 msg: &str,
56 depth: usize,
57 ) -> Result<PromptResponse, PromptError> {
58 match self {
59 Self::OpenAI(a) => a.prompt(msg).multi_turn(depth).extended_details().await,
60 Self::Anthropic(a) => a.prompt(msg).multi_turn(depth).extended_details().await,
61 Self::Gemini(a) => a.prompt(msg).multi_turn(depth).extended_details().await,
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum ApiKeySource {
69 Config,
70 Environment,
71 ClientDefault,
72}
73
74fn validate_and_warn(key: &str, provider: Provider, source: &str) {
76 if let Err(warning) = provider.validate_api_key_format(key) {
77 tracing::warn!(
78 provider = %provider,
79 source = source,
80 "API key format warning: {}",
81 warning
82 );
83 }
84}
85
86pub fn resolve_api_key(
97 api_key: Option<&str>,
98 provider: Provider,
99) -> (Option<String>, ApiKeySource) {
100 if let Some(key) = api_key
102 && !key.is_empty()
103 {
104 tracing::trace!(
105 provider = %provider,
106 source = "config",
107 "Using API key from configuration"
108 );
109 validate_and_warn(key, provider, "config");
110 return (Some(key.to_string()), ApiKeySource::Config);
111 }
112
113 if let Ok(key) = std::env::var(provider.api_key_env()) {
115 tracing::trace!(
116 provider = %provider,
117 env_var = %provider.api_key_env(),
118 source = "environment",
119 "Using API key from environment variable"
120 );
121 validate_and_warn(&key, provider, "environment");
122 return (Some(key), ApiKeySource::Environment);
123 }
124
125 tracing::trace!(
126 provider = %provider,
127 source = "client_default",
128 "No API key found, will use client's from_env()"
129 );
130 (None, ApiKeySource::ClientDefault)
131}
132
133pub fn openai_builder(model: &str, api_key: Option<&str>) -> Result<OpenAIBuilder> {
148 let (resolved_key, _source) = resolve_api_key(api_key, Provider::OpenAI);
149 let client = match resolved_key {
150 Some(key) => openai::Client::new(&key)
151 .map_err(|_| {
153 anyhow::anyhow!(
154 "Failed to create OpenAI client: authentication or configuration error"
155 )
156 })?,
157 None => openai::Client::from_env(),
158 };
159 Ok(client.completions_api().agent(model))
160}
161
162pub fn anthropic_builder(model: &str, api_key: Option<&str>) -> Result<AnthropicBuilder> {
177 let (resolved_key, _source) = resolve_api_key(api_key, Provider::Anthropic);
178 let client = match resolved_key {
179 Some(key) => anthropic::Client::new(&key)
180 .map_err(|_| {
182 anyhow::anyhow!(
183 "Failed to create Anthropic client: authentication or configuration error"
184 )
185 })?,
186 None => anthropic::Client::from_env(),
187 };
188 Ok(client.agent(model))
189}
190
191pub fn gemini_builder(model: &str, api_key: Option<&str>) -> Result<GeminiBuilder> {
206 let (resolved_key, _source) = resolve_api_key(api_key, Provider::Google);
207 let client = match resolved_key {
208 Some(key) => gemini::Client::new(&key)
209 .map_err(|_| {
211 anyhow::anyhow!(
212 "Failed to create Gemini client: authentication or configuration error"
213 )
214 })?,
215 None => gemini::Client::from_env(),
216 };
217 Ok(client.agent(model))
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_resolve_api_key_uses_config_when_provided() {
226 let (key, source) = resolve_api_key(Some("sk-config-key-1234567890"), Provider::OpenAI);
228 assert_eq!(key, Some("sk-config-key-1234567890".to_string()));
229 assert_eq!(source, ApiKeySource::Config);
230 }
231
232 #[test]
233 fn test_resolve_api_key_empty_config_not_used() {
234 let empty_config: Option<&str> = Some("");
237 let (_key, source) = resolve_api_key(empty_config, Provider::OpenAI);
238
239 assert_ne!(source, ApiKeySource::Config);
242 }
243
244 #[test]
245 fn test_resolve_api_key_none_config_checks_env() {
246 let (key, source) = resolve_api_key(None, Provider::OpenAI);
248
249 match source {
252 ApiKeySource::Environment => {
253 assert!(key.is_some());
254 }
255 ApiKeySource::ClientDefault => {
256 assert!(key.is_none());
257 }
258 ApiKeySource::Config => {
259 panic!("Should not return Config source when config is None");
260 }
261 }
262 }
263
264 #[test]
265 fn test_api_key_source_enum_equality() {
266 assert_eq!(ApiKeySource::Config, ApiKeySource::Config);
267 assert_eq!(ApiKeySource::Environment, ApiKeySource::Environment);
268 assert_eq!(ApiKeySource::ClientDefault, ApiKeySource::ClientDefault);
269 assert_ne!(ApiKeySource::Config, ApiKeySource::Environment);
270 }
271
272 #[test]
273 fn test_resolve_api_key_all_providers() {
274 for provider in Provider::ALL {
276 let (key, source) = resolve_api_key(Some("test-key-123456789012345"), *provider);
277 assert_eq!(key, Some("test-key-123456789012345".to_string()));
278 assert_eq!(source, ApiKeySource::Config);
279 }
280 }
281
282 #[test]
283 fn test_resolve_api_key_config_precedence() {
284 let config_key = "sk-from-config-abcdef1234567890";
288 let (key, source) = resolve_api_key(Some(config_key), Provider::OpenAI);
289
290 assert_eq!(key.as_deref(), Some(config_key));
291 assert_eq!(source, ApiKeySource::Config);
292 }
293
294 #[test]
295 fn test_api_key_source_debug_impl() {
296 let source = ApiKeySource::Config;
298 let debug_str = format!("{:?}", source);
299 assert!(debug_str.contains("Config"));
300 }
301}