1use anyhow::Result;
7use rig::{
8 agent::{Agent, AgentBuilder, PromptResponse},
9 client::{CompletionClient, ProviderClient},
10 completion::{CompletionModel, Prompt, PromptError},
11 providers::{anthropic, gemini, openai},
12};
13use serde_json::{Map, Value, json};
14use std::collections::HashMap;
15
16use crate::providers::{Provider, ProviderConfig};
17
18pub type OpenAIModel = openai::completion::CompletionModel;
20pub type AnthropicModel = anthropic::completion::CompletionModel;
21pub type GeminiModel = gemini::completion::CompletionModel;
22
23pub type OpenAIBuilder = AgentBuilder<OpenAIModel>;
25pub type AnthropicBuilder = AgentBuilder<AnthropicModel>;
26pub type GeminiBuilder = AgentBuilder<GeminiModel>;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum CompletionProfile {
30 MainAgent,
31 Subagent,
32 StatusMessage,
33}
34
35impl CompletionProfile {
36 const fn default_openai_reasoning_effort(self) -> &'static str {
37 match self {
38 Self::MainAgent => "medium",
39 Self::Subagent => "low",
40 Self::StatusMessage => "none",
41 }
42 }
43}
44
45pub enum DynAgent {
47 OpenAI(Agent<OpenAIModel>),
48 Anthropic(Agent<AnthropicModel>),
49 Gemini(Agent<GeminiModel>),
50}
51
52impl DynAgent {
53 pub async fn prompt(&self, msg: &str) -> Result<String, PromptError> {
59 match self {
60 Self::OpenAI(a) => a.prompt(msg).await,
61 Self::Anthropic(a) => a.prompt(msg).await,
62 Self::Gemini(a) => a.prompt(msg).await,
63 }
64 }
65
66 pub async fn prompt_multi_turn(&self, msg: &str, depth: usize) -> Result<String, PromptError> {
72 match self {
73 Self::OpenAI(a) => a.prompt(msg).max_turns(depth).await,
74 Self::Anthropic(a) => a.prompt(msg).max_turns(depth).await,
75 Self::Gemini(a) => a.prompt(msg).max_turns(depth).await,
76 }
77 }
78
79 pub async fn prompt_extended(
85 &self,
86 msg: &str,
87 depth: usize,
88 ) -> Result<PromptResponse, PromptError> {
89 match self {
90 Self::OpenAI(a) => a.prompt(msg).max_turns(depth).extended_details().await,
91 Self::Anthropic(a) => a.prompt(msg).max_turns(depth).extended_details().await,
92 Self::Gemini(a) => a.prompt(msg).max_turns(depth).extended_details().await,
93 }
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum ApiKeySource {
100 Config,
101 Environment,
102 ClientDefault,
103}
104
105fn validate_and_warn(key: &str, provider: Provider, source: &str) {
107 if let Err(warning) = provider.validate_api_key_format(key) {
108 tracing::warn!(
109 provider = %provider,
110 source = source,
111 "API key format warning: {}",
112 warning
113 );
114 }
115}
116
117pub fn resolve_api_key(
128 api_key: Option<&str>,
129 provider: Provider,
130) -> (Option<String>, ApiKeySource) {
131 if let Some(key) = api_key
133 && !key.is_empty()
134 {
135 tracing::trace!(
136 provider = %provider,
137 source = "config",
138 "Using API key from configuration"
139 );
140 validate_and_warn(key, provider, "config");
141 return (Some(key.to_string()), ApiKeySource::Config);
142 }
143
144 if let Ok(key) = std::env::var(provider.api_key_env()) {
146 tracing::trace!(
147 provider = %provider,
148 env_var = %provider.api_key_env(),
149 source = "environment",
150 "Using API key from environment variable"
151 );
152 validate_and_warn(&key, provider, "environment");
153 return (Some(key), ApiKeySource::Environment);
154 }
155
156 tracing::trace!(
157 provider = %provider,
158 source = "client_default",
159 "No API key found, will use client's from_env()"
160 );
161 (None, ApiKeySource::ClientDefault)
162}
163
164pub fn openai_builder(model: &str, api_key: Option<&str>) -> Result<OpenAIBuilder> {
179 let (resolved_key, _source) = resolve_api_key(api_key, Provider::OpenAI);
180 let client = match resolved_key {
181 Some(key) => openai::Client::new(&key)
182 .map_err(|_| {
184 anyhow::anyhow!(
185 "Failed to create OpenAI client: authentication or configuration error"
186 )
187 })?,
188 None => openai::Client::from_env()
189 .map_err(|_| anyhow::anyhow!("Failed to create OpenAI client from environment"))?,
190 };
191 Ok(client.completions_api().agent(model))
192}
193
194pub fn anthropic_builder(model: &str, api_key: Option<&str>) -> Result<AnthropicBuilder> {
209 let (resolved_key, _source) = resolve_api_key(api_key, Provider::Anthropic);
210 let client = match resolved_key {
211 Some(key) => anthropic::Client::new(&key)
212 .map_err(|_| {
214 anyhow::anyhow!(
215 "Failed to create Anthropic client: authentication or configuration error"
216 )
217 })?,
218 None => anthropic::Client::from_env()
219 .map_err(|_| anyhow::anyhow!("Failed to create Anthropic client from environment"))?,
220 };
221 Ok(client.agent(model))
222}
223
224pub fn gemini_builder(model: &str, api_key: Option<&str>) -> Result<GeminiBuilder> {
239 let (resolved_key, _source) = resolve_api_key(api_key, Provider::Google);
240 let client = match resolved_key {
241 Some(key) => gemini::Client::new(&key)
242 .map_err(|_| {
244 anyhow::anyhow!(
245 "Failed to create Gemini client: authentication or configuration error"
246 )
247 })?,
248 None => gemini::Client::from_env()
249 .map_err(|_| anyhow::anyhow!("Failed to create Gemini client from environment"))?,
250 };
251 Ok(client.agent(model))
252}
253
254fn parse_additional_param_value(raw: &str) -> Value {
255 serde_json::from_str(raw).unwrap_or_else(|_| Value::String(raw.to_string()))
256}
257
258fn additional_params_json<S>(
259 additional_params: Option<&HashMap<String, String, S>>,
260) -> Map<String, Value>
261where
262 S: std::hash::BuildHasher,
263{
264 let mut params = Map::new();
265 if let Some(additional_params) = additional_params {
266 for (key, value) in additional_params {
267 params.insert(key.clone(), parse_additional_param_value(value));
268 }
269 }
270 params
271}
272
273fn supports_openai_reasoning_defaults(model: &str) -> bool {
274 model.to_lowercase().starts_with("gpt-5")
275}
276
277fn completion_params_json<S>(
278 additional_params: Option<&HashMap<String, String, S>>,
279 provider: Provider,
280 model: &str,
281 max_tokens: u64,
282 profile: CompletionProfile,
283) -> Map<String, Value>
284where
285 S: std::hash::BuildHasher,
286{
287 let mut params = additional_params_json(additional_params);
288
289 if provider == Provider::OpenAI && needs_max_completion_tokens(model) {
290 params.insert("max_completion_tokens".to_string(), json!(max_tokens));
291 }
292
293 if provider == Provider::OpenAI
294 && supports_openai_reasoning_defaults(model)
295 && !params.contains_key("reasoning")
296 {
297 params.insert(
298 "reasoning".to_string(),
299 json!({ "effort": profile.default_openai_reasoning_effort() }),
300 );
301 }
302
303 params
304}
305
306fn needs_max_completion_tokens(model: &str) -> bool {
307 let model = model.to_lowercase();
308 model.starts_with("gpt-5")
309 || model.starts_with("gpt-4.1")
310 || model.starts_with("o1")
311 || model.starts_with("o3")
312 || model.starts_with("o4")
313}
314
315pub fn apply_completion_params<M, S>(
316 mut builder: AgentBuilder<M>,
317 provider: Provider,
318 model: &str,
319 max_tokens: u64,
320 additional_params: Option<&HashMap<String, String, S>>,
321 profile: CompletionProfile,
322) -> AgentBuilder<M>
323where
324 M: CompletionModel,
325 S: std::hash::BuildHasher,
326{
327 if !(provider == Provider::OpenAI && needs_max_completion_tokens(model)) {
328 builder = builder.max_tokens(max_tokens);
329 }
330
331 let params = completion_params_json(additional_params, provider, model, max_tokens, profile);
332
333 if params.is_empty() {
334 builder
335 } else {
336 builder.additional_params(Value::Object(params))
337 }
338}
339
340pub fn provider_from_name(provider: &str) -> Result<Provider> {
346 provider
347 .parse()
348 .map_err(|_| anyhow::anyhow!("Unsupported provider: {}", provider))
349}
350
351#[must_use]
352pub fn current_provider_config<'a>(
353 config: Option<&'a crate::config::Config>,
354 provider: &str,
355) -> Option<&'a ProviderConfig> {
356 config.and_then(|config| config.get_provider_config(provider))
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_resolve_api_key_uses_config_when_provided() {
365 let (key, source) = resolve_api_key(Some("sk-config-key-1234567890"), Provider::OpenAI);
367 assert_eq!(key, Some("sk-config-key-1234567890".to_string()));
368 assert_eq!(source, ApiKeySource::Config);
369 }
370
371 #[test]
372 fn test_resolve_api_key_empty_config_not_used() {
373 let empty_config: Option<&str> = Some("");
376 let (_key, source) = resolve_api_key(empty_config, Provider::OpenAI);
377
378 assert_ne!(source, ApiKeySource::Config);
381 }
382
383 #[test]
384 fn test_resolve_api_key_none_config_checks_env() {
385 let (key, source) = resolve_api_key(None, Provider::OpenAI);
387
388 match source {
391 ApiKeySource::Environment => {
392 assert!(key.is_some());
393 }
394 ApiKeySource::ClientDefault => {
395 assert!(key.is_none());
396 }
397 ApiKeySource::Config => {
398 unreachable!("Should not return Config source when config is None");
399 }
400 }
401 }
402
403 #[test]
404 fn test_api_key_source_enum_equality() {
405 assert_eq!(ApiKeySource::Config, ApiKeySource::Config);
406 assert_eq!(ApiKeySource::Environment, ApiKeySource::Environment);
407 assert_eq!(ApiKeySource::ClientDefault, ApiKeySource::ClientDefault);
408 assert_ne!(ApiKeySource::Config, ApiKeySource::Environment);
409 }
410
411 #[test]
412 fn test_resolve_api_key_all_providers() {
413 for provider in Provider::ALL {
415 let (key, source) = resolve_api_key(Some("test-key-123456789012345"), *provider);
416 assert_eq!(key, Some("test-key-123456789012345".to_string()));
417 assert_eq!(source, ApiKeySource::Config);
418 }
419 }
420
421 #[test]
422 fn test_resolve_api_key_config_precedence() {
423 let config_key = "sk-from-config-abcdef1234567890";
427 let (key, source) = resolve_api_key(Some(config_key), Provider::OpenAI);
428
429 assert_eq!(key.as_deref(), Some(config_key));
430 assert_eq!(source, ApiKeySource::Config);
431 }
432
433 #[test]
434 fn test_api_key_source_debug_impl() {
435 let source = ApiKeySource::Config;
437 let debug_str = format!("{:?}", source);
438 assert!(debug_str.contains("Config"));
439 }
440
441 #[test]
442 fn test_apply_completion_params_parses_json_like_additional_params() {
443 let mut additional_params = HashMap::new();
444 additional_params.insert("temperature".to_string(), "0.7".to_string());
445 additional_params.insert("reasoning".to_string(), r#"{"effort":"low"}"#.to_string());
446
447 let params = additional_params_json(Some(&additional_params));
448 assert_eq!(params.get("temperature"), Some(&json!(0.7)));
449 assert_eq!(params.get("reasoning"), Some(&json!({"effort": "low"})));
450 }
451
452 #[test]
453 fn test_completion_params_use_profile_specific_openai_reasoning_defaults() {
454 let main_params = completion_params_json::<std::collections::hash_map::RandomState>(
455 None,
456 Provider::OpenAI,
457 "gpt-5.4",
458 16_384,
459 CompletionProfile::MainAgent,
460 );
461 assert_eq!(
462 main_params.get("reasoning"),
463 Some(&json!({"effort": "medium"}))
464 );
465 assert_eq!(
466 main_params.get("max_completion_tokens"),
467 Some(&json!(16_384))
468 );
469
470 let status_params = completion_params_json::<std::collections::hash_map::RandomState>(
471 None,
472 Provider::OpenAI,
473 "gpt-5.4-mini",
474 50,
475 CompletionProfile::StatusMessage,
476 );
477 assert_eq!(
478 status_params.get("reasoning"),
479 Some(&json!({"effort": "none"}))
480 );
481 assert_eq!(status_params.get("max_completion_tokens"), Some(&json!(50)));
482 }
483
484 #[test]
485 fn test_completion_params_preserve_explicit_reasoning_overrides() {
486 let mut additional_params = HashMap::new();
487 additional_params.insert("reasoning".to_string(), r#"{"effort":"high"}"#.to_string());
488
489 let params = completion_params_json(
490 Some(&additional_params),
491 Provider::OpenAI,
492 "gpt-5.4",
493 4096,
494 CompletionProfile::MainAgent,
495 );
496
497 assert_eq!(params.get("reasoning"), Some(&json!({"effort": "high"})));
498 }
499
500 #[test]
501 fn test_completion_params_skip_openai_reasoning_defaults_for_non_gpt5_models() {
502 let params = completion_params_json::<std::collections::hash_map::RandomState>(
503 None,
504 Provider::OpenAI,
505 "gpt-4.1",
506 4096,
507 CompletionProfile::MainAgent,
508 );
509
510 assert!(!params.contains_key("reasoning"));
511 assert_eq!(params.get("max_completion_tokens"), Some(&json!(4096)));
512 }
513
514 #[test]
515 fn test_provider_from_name_supports_aliases() {
516 assert_eq!(provider_from_name("openai").ok(), Some(Provider::OpenAI));
517 assert_eq!(provider_from_name("claude").ok(), Some(Provider::Anthropic));
518 assert_eq!(provider_from_name("gemini").ok(), Some(Provider::Google));
519 }
520
521 #[test]
522 fn test_needs_max_completion_tokens_for_gpt5_family() {
523 assert!(needs_max_completion_tokens("gpt-5.4"));
524 assert!(needs_max_completion_tokens("o3"));
525 assert!(!needs_max_completion_tokens("claude-opus-4-6"));
526 }
527}