1use anyhow::Result;
7use rig::{
8 agent::{Agent, AgentBuilder, PromptResponse},
9 client::{CompletionClient, ProviderClient},
10 completion::{Prompt, PromptError},
11 providers::{anthropic, gemini, openai},
12};
13
14use crate::providers::Provider;
15
16pub type OpenAIModel = openai::completion::CompletionModel;
18pub type AnthropicModel = anthropic::completion::CompletionModel;
19pub type GeminiModel = gemini::completion::CompletionModel;
20
21pub type OpenAIBuilder = AgentBuilder<OpenAIModel>;
23pub type AnthropicBuilder = AgentBuilder<AnthropicModel>;
24pub type GeminiBuilder = AgentBuilder<GeminiModel>;
25
26pub enum DynAgent {
28 OpenAI(Agent<OpenAIModel>),
29 Anthropic(Agent<AnthropicModel>),
30 Gemini(Agent<GeminiModel>),
31}
32
33impl DynAgent {
34 pub async fn prompt(&self, msg: &str) -> Result<String, PromptError> {
36 match self {
37 Self::OpenAI(a) => a.prompt(msg).await,
38 Self::Anthropic(a) => a.prompt(msg).await,
39 Self::Gemini(a) => a.prompt(msg).await,
40 }
41 }
42
43 pub async fn prompt_multi_turn(&self, msg: &str, depth: usize) -> Result<String, PromptError> {
45 match self {
46 Self::OpenAI(a) => a.prompt(msg).multi_turn(depth).await,
47 Self::Anthropic(a) => a.prompt(msg).multi_turn(depth).await,
48 Self::Gemini(a) => a.prompt(msg).multi_turn(depth).await,
49 }
50 }
51
52 pub async fn prompt_extended(
54 &self,
55 msg: &str,
56 depth: usize,
57 ) -> Result<PromptResponse, PromptError> {
58 match self {
59 Self::OpenAI(a) => a.prompt(msg).multi_turn(depth).extended_details().await,
60 Self::Anthropic(a) => a.prompt(msg).multi_turn(depth).extended_details().await,
61 Self::Gemini(a) => a.prompt(msg).multi_turn(depth).extended_details().await,
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum ApiKeySource {
69 Config,
70 Environment,
71 ClientDefault,
72}
73
74fn validate_and_warn(key: &str, provider: Provider, source: &str) {
76 if let Err(warning) = provider.validate_api_key_format(key) {
77 tracing::warn!(
78 provider = %provider,
79 source = source,
80 "API key format warning: {}",
81 warning
82 );
83 }
84}
85
86pub fn resolve_api_key(api_key: Option<&str>, provider: Provider) -> (Option<String>, ApiKeySource) {
97 if let Some(key) = api_key
99 && !key.is_empty()
100 {
101 tracing::trace!(
102 provider = %provider,
103 source = "config",
104 "Using API key from configuration"
105 );
106 validate_and_warn(key, provider, "config");
107 return (Some(key.to_string()), ApiKeySource::Config);
108 }
109
110 if let Ok(key) = std::env::var(provider.api_key_env()) {
112 tracing::trace!(
113 provider = %provider,
114 env_var = %provider.api_key_env(),
115 source = "environment",
116 "Using API key from environment variable"
117 );
118 validate_and_warn(&key, provider, "environment");
119 return (Some(key), ApiKeySource::Environment);
120 }
121
122 tracing::trace!(
123 provider = %provider,
124 source = "client_default",
125 "No API key found, will use client's from_env()"
126 );
127 (None, ApiKeySource::ClientDefault)
128}
129
130pub fn openai_builder(model: &str, api_key: Option<&str>) -> Result<OpenAIBuilder> {
145 let (resolved_key, _source) = resolve_api_key(api_key, Provider::OpenAI);
146 let client = match resolved_key {
147 Some(key) => openai::Client::new(&key)
148 .map_err(|_| anyhow::anyhow!(
150 "Failed to create OpenAI client: authentication or configuration error"
151 ))?,
152 None => openai::Client::from_env(),
153 };
154 Ok(client.completions_api().agent(model))
155}
156
157pub fn anthropic_builder(model: &str, api_key: Option<&str>) -> Result<AnthropicBuilder> {
172 let (resolved_key, _source) = resolve_api_key(api_key, Provider::Anthropic);
173 let client = match resolved_key {
174 Some(key) => anthropic::Client::new(&key)
175 .map_err(|_| anyhow::anyhow!(
177 "Failed to create Anthropic client: authentication or configuration error"
178 ))?,
179 None => anthropic::Client::from_env(),
180 };
181 Ok(client.agent(model))
182}
183
184pub fn gemini_builder(model: &str, api_key: Option<&str>) -> Result<GeminiBuilder> {
199 let (resolved_key, _source) = resolve_api_key(api_key, Provider::Google);
200 let client = match resolved_key {
201 Some(key) => gemini::Client::new(&key)
202 .map_err(|_| anyhow::anyhow!(
204 "Failed to create Gemini client: authentication or configuration error"
205 ))?,
206 None => gemini::Client::from_env(),
207 };
208 Ok(client.agent(model))
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_resolve_api_key_uses_config_when_provided() {
217 let (key, source) =
219 resolve_api_key(Some("sk-config-key-1234567890"), Provider::OpenAI);
220 assert_eq!(key, Some("sk-config-key-1234567890".to_string()));
221 assert_eq!(source, ApiKeySource::Config);
222 }
223
224 #[test]
225 fn test_resolve_api_key_empty_config_not_used() {
226 let empty_config: Option<&str> = Some("");
229 let (_key, source) = resolve_api_key(empty_config, Provider::OpenAI);
230
231 assert_ne!(source, ApiKeySource::Config);
234 }
235
236 #[test]
237 fn test_resolve_api_key_none_config_checks_env() {
238 let (key, source) = resolve_api_key(None, Provider::OpenAI);
240
241 match source {
244 ApiKeySource::Environment => {
245 assert!(key.is_some());
246 }
247 ApiKeySource::ClientDefault => {
248 assert!(key.is_none());
249 }
250 ApiKeySource::Config => {
251 panic!("Should not return Config source when config is None");
252 }
253 }
254 }
255
256 #[test]
257 fn test_api_key_source_enum_equality() {
258 assert_eq!(ApiKeySource::Config, ApiKeySource::Config);
259 assert_eq!(ApiKeySource::Environment, ApiKeySource::Environment);
260 assert_eq!(ApiKeySource::ClientDefault, ApiKeySource::ClientDefault);
261 assert_ne!(ApiKeySource::Config, ApiKeySource::Environment);
262 }
263
264 #[test]
265 fn test_resolve_api_key_all_providers() {
266 for provider in Provider::ALL {
268 let (key, source) =
269 resolve_api_key(Some("test-key-123456789012345"), *provider);
270 assert_eq!(key, Some("test-key-123456789012345".to_string()));
271 assert_eq!(source, ApiKeySource::Config);
272 }
273 }
274
275 #[test]
276 fn test_resolve_api_key_config_precedence() {
277 let config_key = "sk-from-config-abcdef1234567890";
281 let (key, source) = resolve_api_key(Some(config_key), Provider::OpenAI);
282
283 assert_eq!(key.as_deref(), Some(config_key));
284 assert_eq!(source, ApiKeySource::Config);
285 }
286
287 #[test]
288 fn test_api_key_source_debug_impl() {
289 let source = ApiKeySource::Config;
291 let debug_str = format!("{:?}", source);
292 assert!(debug_str.contains("Config"));
293 }
294}