1use anyhow::Result;
7use rig::{
8 agent::{Agent, AgentBuilder, PromptResponse},
9 client::{CompletionClient, ProviderClient},
10 completion::{CompletionModel, Prompt, PromptError},
11 providers::{anthropic, gemini, openai},
12};
13use serde_json::{Map, Value, json};
14use std::collections::HashMap;
15
16use crate::providers::{Provider, ProviderConfig};
17
18pub type OpenAIModel = openai::completion::CompletionModel;
20pub type AnthropicModel = anthropic::completion::CompletionModel;
21pub type GeminiModel = gemini::completion::CompletionModel;
22
23pub type OpenAIBuilder = AgentBuilder<OpenAIModel>;
25pub type AnthropicBuilder = AgentBuilder<AnthropicModel>;
26pub type GeminiBuilder = AgentBuilder<GeminiModel>;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum CompletionProfile {
30 MainAgent,
31 Subagent,
32 StatusMessage,
33}
34
35impl CompletionProfile {
36 const fn default_openai_reasoning_effort(self) -> &'static str {
37 match self {
38 Self::MainAgent => "medium",
39 Self::Subagent => "low",
40 Self::StatusMessage => "none",
41 }
42 }
43}
44
45pub enum DynAgent {
47 OpenAI(Agent<OpenAIModel>),
48 Anthropic(Agent<AnthropicModel>),
49 Gemini(Agent<GeminiModel>),
50}
51
52impl DynAgent {
53 pub async fn prompt(&self, msg: &str) -> Result<String, PromptError> {
59 match self {
60 Self::OpenAI(a) => a.prompt(msg).await,
61 Self::Anthropic(a) => a.prompt(msg).await,
62 Self::Gemini(a) => a.prompt(msg).await,
63 }
64 }
65
66 pub async fn prompt_multi_turn(&self, msg: &str, depth: usize) -> Result<String, PromptError> {
72 match self {
73 Self::OpenAI(a) => a.prompt(msg).max_turns(depth).await,
74 Self::Anthropic(a) => a.prompt(msg).max_turns(depth).await,
75 Self::Gemini(a) => a.prompt(msg).max_turns(depth).await,
76 }
77 }
78
79 pub async fn prompt_extended(
85 &self,
86 msg: &str,
87 depth: usize,
88 ) -> Result<PromptResponse, PromptError> {
89 match self {
90 Self::OpenAI(a) => a.prompt(msg).max_turns(depth).extended_details().await,
91 Self::Anthropic(a) => a.prompt(msg).max_turns(depth).extended_details().await,
92 Self::Gemini(a) => a.prompt(msg).max_turns(depth).extended_details().await,
93 }
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum ApiKeySource {
100 Config,
101 Environment,
102 ClientDefault,
103}
104
105fn validate_and_warn(key: &str, provider: Provider, source: &str) {
107 if let Err(warning) = provider.validate_api_key_format(key) {
108 tracing::warn!(
109 provider = %provider,
110 source = source,
111 "API key format warning: {}",
112 warning
113 );
114 }
115}
116
117pub fn resolve_api_key(
128 api_key: Option<&str>,
129 provider: Provider,
130) -> (Option<String>, ApiKeySource) {
131 if let Some(key) = api_key
133 && !key.is_empty()
134 {
135 tracing::trace!(
136 provider = %provider,
137 source = "config",
138 "Using API key from configuration"
139 );
140 validate_and_warn(key, provider, "config");
141 return (Some(key.to_string()), ApiKeySource::Config);
142 }
143
144 if let Ok(key) = std::env::var(provider.api_key_env()) {
146 tracing::trace!(
147 provider = %provider,
148 env_var = %provider.api_key_env(),
149 source = "environment",
150 "Using API key from environment variable"
151 );
152 validate_and_warn(&key, provider, "environment");
153 return (Some(key), ApiKeySource::Environment);
154 }
155
156 tracing::trace!(
157 provider = %provider,
158 source = "client_default",
159 "No API key found, will use client's from_env()"
160 );
161 (None, ApiKeySource::ClientDefault)
162}
163
164pub fn openai_builder(model: &str, api_key: Option<&str>) -> Result<OpenAIBuilder> {
179 let (resolved_key, _source) = resolve_api_key(api_key, Provider::OpenAI);
180 let client = match resolved_key {
181 Some(key) => openai::Client::new(&key)
182 .map_err(|_| {
184 anyhow::anyhow!(
185 "Failed to create OpenAI client: authentication or configuration error"
186 )
187 })?,
188 None => openai::Client::from_env()
189 .map_err(|_| anyhow::anyhow!("Failed to create OpenAI client from environment"))?,
190 };
191 Ok(client.completions_api().agent(model))
192}
193
194pub fn anthropic_agent_builder(client: &anthropic::Client, model: &str) -> AnthropicBuilder {
203 AgentBuilder::new(client.completion_model(model).with_automatic_caching())
204}
205
206pub fn anthropic_builder(model: &str, api_key: Option<&str>) -> Result<AnthropicBuilder> {
221 let (resolved_key, _source) = resolve_api_key(api_key, Provider::Anthropic);
222 let client = match resolved_key {
223 Some(key) => anthropic::Client::new(&key)
224 .map_err(|_| {
226 anyhow::anyhow!(
227 "Failed to create Anthropic client: authentication or configuration error"
228 )
229 })?,
230 None => anthropic::Client::from_env()
231 .map_err(|_| anyhow::anyhow!("Failed to create Anthropic client from environment"))?,
232 };
233 Ok(anthropic_agent_builder(&client, model))
234}
235
236pub fn gemini_builder(model: &str, api_key: Option<&str>) -> Result<GeminiBuilder> {
251 let (resolved_key, _source) = resolve_api_key(api_key, Provider::Google);
252 let client = match resolved_key {
253 Some(key) => gemini::Client::new(&key)
254 .map_err(|_| {
256 anyhow::anyhow!(
257 "Failed to create Gemini client: authentication or configuration error"
258 )
259 })?,
260 None => gemini::Client::from_env()
261 .map_err(|_| anyhow::anyhow!("Failed to create Gemini client from environment"))?,
262 };
263 Ok(client.agent(model))
264}
265
266fn parse_additional_param_value(raw: &str) -> Value {
267 serde_json::from_str(raw).unwrap_or_else(|_| Value::String(raw.to_string()))
268}
269
270fn additional_params_json<S>(
271 additional_params: Option<&HashMap<String, String, S>>,
272) -> Map<String, Value>
273where
274 S: std::hash::BuildHasher,
275{
276 let mut params = Map::new();
277 if let Some(additional_params) = additional_params {
278 for (key, value) in additional_params {
279 params.insert(key.clone(), parse_additional_param_value(value));
280 }
281 }
282 params
283}
284
285fn supports_openai_reasoning_defaults(model: &str) -> bool {
286 model.to_lowercase().starts_with("gpt-5")
287}
288
289fn completion_params_json<S>(
290 additional_params: Option<&HashMap<String, String, S>>,
291 provider: Provider,
292 model: &str,
293 max_tokens: u64,
294 profile: CompletionProfile,
295) -> Map<String, Value>
296where
297 S: std::hash::BuildHasher,
298{
299 let mut params = additional_params_json(additional_params);
300
301 if provider == Provider::OpenAI && needs_max_completion_tokens(model) {
302 params.insert("max_completion_tokens".to_string(), json!(max_tokens));
303 }
304
305 if provider == Provider::OpenAI
306 && supports_openai_reasoning_defaults(model)
307 && !params.contains_key("reasoning")
308 {
309 params.insert(
310 "reasoning".to_string(),
311 json!({ "effort": profile.default_openai_reasoning_effort() }),
312 );
313 }
314
315 params
316}
317
318fn needs_max_completion_tokens(model: &str) -> bool {
319 let model = model.to_lowercase();
320 model.starts_with("gpt-5")
321 || model.starts_with("gpt-4.1")
322 || model.starts_with("o1")
323 || model.starts_with("o3")
324 || model.starts_with("o4")
325}
326
327pub fn apply_completion_params<M, S>(
328 mut builder: AgentBuilder<M>,
329 provider: Provider,
330 model: &str,
331 max_tokens: u64,
332 additional_params: Option<&HashMap<String, String, S>>,
333 profile: CompletionProfile,
334) -> AgentBuilder<M>
335where
336 M: CompletionModel,
337 S: std::hash::BuildHasher,
338{
339 if !(provider == Provider::OpenAI && needs_max_completion_tokens(model)) {
340 builder = builder.max_tokens(max_tokens);
341 }
342
343 let params = completion_params_json(additional_params, provider, model, max_tokens, profile);
344
345 if params.is_empty() {
346 builder
347 } else {
348 builder.additional_params(Value::Object(params))
349 }
350}
351
352pub fn provider_from_name(provider: &str) -> Result<Provider> {
358 provider
359 .parse()
360 .map_err(|_| anyhow::anyhow!("Unsupported provider: {}", provider))
361}
362
363#[must_use]
364pub fn current_provider_config<'a>(
365 config: Option<&'a crate::config::Config>,
366 provider: &str,
367) -> Option<&'a ProviderConfig> {
368 config.and_then(|config| config.get_provider_config(provider))
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_resolve_api_key_uses_config_when_provided() {
377 let (key, source) = resolve_api_key(Some("sk-config-key-1234567890"), Provider::OpenAI);
379 assert_eq!(key, Some("sk-config-key-1234567890".to_string()));
380 assert_eq!(source, ApiKeySource::Config);
381 }
382
383 #[test]
384 fn test_resolve_api_key_empty_config_not_used() {
385 let empty_config: Option<&str> = Some("");
388 let (_key, source) = resolve_api_key(empty_config, Provider::OpenAI);
389
390 assert_ne!(source, ApiKeySource::Config);
393 }
394
395 #[test]
396 fn test_resolve_api_key_none_config_checks_env() {
397 let (key, source) = resolve_api_key(None, Provider::OpenAI);
399
400 match source {
403 ApiKeySource::Environment => {
404 assert!(key.is_some());
405 }
406 ApiKeySource::ClientDefault => {
407 assert!(key.is_none());
408 }
409 ApiKeySource::Config => {
410 unreachable!("Should not return Config source when config is None");
411 }
412 }
413 }
414
415 #[test]
416 fn test_api_key_source_enum_equality() {
417 assert_eq!(ApiKeySource::Config, ApiKeySource::Config);
418 assert_eq!(ApiKeySource::Environment, ApiKeySource::Environment);
419 assert_eq!(ApiKeySource::ClientDefault, ApiKeySource::ClientDefault);
420 assert_ne!(ApiKeySource::Config, ApiKeySource::Environment);
421 }
422
423 #[test]
424 fn test_resolve_api_key_all_providers() {
425 for provider in Provider::ALL {
427 let (key, source) = resolve_api_key(Some("test-key-123456789012345"), *provider);
428 assert_eq!(key, Some("test-key-123456789012345".to_string()));
429 assert_eq!(source, ApiKeySource::Config);
430 }
431 }
432
433 #[test]
434 fn test_resolve_api_key_config_precedence() {
435 let config_key = "sk-from-config-abcdef1234567890";
439 let (key, source) = resolve_api_key(Some(config_key), Provider::OpenAI);
440
441 assert_eq!(key.as_deref(), Some(config_key));
442 assert_eq!(source, ApiKeySource::Config);
443 }
444
445 #[test]
446 fn test_api_key_source_debug_impl() {
447 let source = ApiKeySource::Config;
449 let debug_str = format!("{:?}", source);
450 assert!(debug_str.contains("Config"));
451 }
452
453 #[test]
454 fn test_apply_completion_params_parses_json_like_additional_params() {
455 let mut additional_params = HashMap::new();
456 additional_params.insert("temperature".to_string(), "0.7".to_string());
457 additional_params.insert("reasoning".to_string(), r#"{"effort":"low"}"#.to_string());
458
459 let params = additional_params_json(Some(&additional_params));
460 assert_eq!(params.get("temperature"), Some(&json!(0.7)));
461 assert_eq!(params.get("reasoning"), Some(&json!({"effort": "low"})));
462 }
463
464 #[test]
465 fn test_completion_params_use_profile_specific_openai_reasoning_defaults() {
466 let main_params = completion_params_json::<std::collections::hash_map::RandomState>(
467 None,
468 Provider::OpenAI,
469 "gpt-5.4",
470 16_384,
471 CompletionProfile::MainAgent,
472 );
473 assert_eq!(
474 main_params.get("reasoning"),
475 Some(&json!({"effort": "medium"}))
476 );
477 assert_eq!(
478 main_params.get("max_completion_tokens"),
479 Some(&json!(16_384))
480 );
481
482 let status_params = completion_params_json::<std::collections::hash_map::RandomState>(
483 None,
484 Provider::OpenAI,
485 "gpt-5.4-mini",
486 50,
487 CompletionProfile::StatusMessage,
488 );
489 assert_eq!(
490 status_params.get("reasoning"),
491 Some(&json!({"effort": "none"}))
492 );
493 assert_eq!(status_params.get("max_completion_tokens"), Some(&json!(50)));
494 }
495
496 #[test]
497 fn test_completion_params_preserve_explicit_reasoning_overrides() {
498 let mut additional_params = HashMap::new();
499 additional_params.insert("reasoning".to_string(), r#"{"effort":"high"}"#.to_string());
500
501 let params = completion_params_json(
502 Some(&additional_params),
503 Provider::OpenAI,
504 "gpt-5.4",
505 4096,
506 CompletionProfile::MainAgent,
507 );
508
509 assert_eq!(params.get("reasoning"), Some(&json!({"effort": "high"})));
510 }
511
512 #[test]
513 fn test_completion_params_skip_openai_reasoning_defaults_for_non_gpt5_models() {
514 let params = completion_params_json::<std::collections::hash_map::RandomState>(
515 None,
516 Provider::OpenAI,
517 "gpt-4.1",
518 4096,
519 CompletionProfile::MainAgent,
520 );
521
522 assert!(!params.contains_key("reasoning"));
523 assert_eq!(params.get("max_completion_tokens"), Some(&json!(4096)));
524 }
525
526 #[test]
527 fn test_provider_from_name_supports_aliases() {
528 assert_eq!(provider_from_name("openai").ok(), Some(Provider::OpenAI));
529 assert_eq!(provider_from_name("claude").ok(), Some(Provider::Anthropic));
530 assert_eq!(provider_from_name("gemini").ok(), Some(Provider::Google));
531 }
532
533 #[test]
534 fn test_needs_max_completion_tokens_for_gpt5_family() {
535 assert!(needs_max_completion_tokens("gpt-5.4"));
536 assert!(needs_max_completion_tokens("o3"));
537 assert!(!needs_max_completion_tokens("claude-opus-4-6"));
538 }
539}