1use anyhow::Result;
7use rig::{
8 agent::{Agent, AgentBuilder, PromptResponse},
9 client::{CompletionClient, ProviderClient},
10 completion::{CompletionModel, Prompt, PromptError},
11 providers::{anthropic, gemini, openai},
12};
13use serde_json::{Map, Value, json};
14use std::collections::HashMap;
15
16use crate::providers::{Provider, ProviderConfig};
17
18pub type OpenAIModel = openai::completion::CompletionModel;
20pub type AnthropicModel = anthropic::completion::CompletionModel;
21pub type GeminiModel = gemini::completion::CompletionModel;
22
23pub type OpenAIBuilder = AgentBuilder<OpenAIModel>;
25pub type AnthropicBuilder = AgentBuilder<AnthropicModel>;
26pub type GeminiBuilder = AgentBuilder<GeminiModel>;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum CompletionProfile {
30 MainAgent,
31 Subagent,
32 StatusMessage,
33}
34
35impl CompletionProfile {
36 const fn default_openai_reasoning_effort(self) -> &'static str {
37 match self {
38 Self::MainAgent => "medium",
39 Self::Subagent => "low",
40 Self::StatusMessage => "none",
41 }
42 }
43}
44
45pub enum DynAgent {
47 OpenAI(Agent<OpenAIModel>),
48 Anthropic(Agent<AnthropicModel>),
49 Gemini(Agent<GeminiModel>),
50}
51
52impl DynAgent {
53 pub async fn prompt(&self, msg: &str) -> Result<String, PromptError> {
59 match self {
60 Self::OpenAI(a) => a.prompt(msg).await,
61 Self::Anthropic(a) => a.prompt(msg).await,
62 Self::Gemini(a) => a.prompt(msg).await,
63 }
64 }
65
66 pub async fn prompt_multi_turn(&self, msg: &str, depth: usize) -> Result<String, PromptError> {
72 match self {
73 Self::OpenAI(a) => a.prompt(msg).max_turns(depth).await,
74 Self::Anthropic(a) => a.prompt(msg).max_turns(depth).await,
75 Self::Gemini(a) => a.prompt(msg).max_turns(depth).await,
76 }
77 }
78
79 pub async fn prompt_extended(
85 &self,
86 msg: &str,
87 depth: usize,
88 ) -> Result<PromptResponse, PromptError> {
89 match self {
90 Self::OpenAI(a) => a.prompt(msg).max_turns(depth).extended_details().await,
91 Self::Anthropic(a) => a.prompt(msg).max_turns(depth).extended_details().await,
92 Self::Gemini(a) => a.prompt(msg).max_turns(depth).extended_details().await,
93 }
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum ApiKeySource {
100 Config,
101 Environment,
102 ClientDefault,
103}
104
105fn validate_and_warn(key: &str, provider: Provider, source: &str) {
107 if let Err(warning) = provider.validate_api_key_format(key) {
108 tracing::warn!(
109 provider = %provider,
110 source = source,
111 "API key format warning: {}",
112 warning
113 );
114 }
115}
116
117pub fn resolve_api_key(
128 api_key: Option<&str>,
129 provider: Provider,
130) -> (Option<String>, ApiKeySource) {
131 if let Some(key) = api_key
133 && !key.is_empty()
134 {
135 tracing::trace!(
136 provider = %provider,
137 source = "config",
138 "Using API key from configuration"
139 );
140 validate_and_warn(key, provider, "config");
141 return (Some(key.to_string()), ApiKeySource::Config);
142 }
143
144 if let Ok(key) = std::env::var(provider.api_key_env()) {
146 tracing::trace!(
147 provider = %provider,
148 env_var = %provider.api_key_env(),
149 source = "environment",
150 "Using API key from environment variable"
151 );
152 validate_and_warn(&key, provider, "environment");
153 return (Some(key), ApiKeySource::Environment);
154 }
155
156 tracing::trace!(
157 provider = %provider,
158 source = "client_default",
159 "No API key found, will use client's from_env()"
160 );
161 (None, ApiKeySource::ClientDefault)
162}
163
164pub fn openai_builder(model: &str, api_key: Option<&str>) -> Result<OpenAIBuilder> {
179 let (resolved_key, _source) = resolve_api_key(api_key, Provider::OpenAI);
180 let client = match resolved_key {
181 Some(key) => openai::Client::new(&key)
182 .map_err(|_| {
184 anyhow::anyhow!(
185 "Failed to create OpenAI client: authentication or configuration error"
186 )
187 })?,
188 None => openai::Client::from_env(),
189 };
190 Ok(client.completions_api().agent(model))
191}
192
193pub fn anthropic_builder(model: &str, api_key: Option<&str>) -> Result<AnthropicBuilder> {
208 let (resolved_key, _source) = resolve_api_key(api_key, Provider::Anthropic);
209 let client = match resolved_key {
210 Some(key) => anthropic::Client::new(&key)
211 .map_err(|_| {
213 anyhow::anyhow!(
214 "Failed to create Anthropic client: authentication or configuration error"
215 )
216 })?,
217 None => anthropic::Client::from_env(),
218 };
219 Ok(client.agent(model))
220}
221
222pub fn gemini_builder(model: &str, api_key: Option<&str>) -> Result<GeminiBuilder> {
237 let (resolved_key, _source) = resolve_api_key(api_key, Provider::Google);
238 let client = match resolved_key {
239 Some(key) => gemini::Client::new(&key)
240 .map_err(|_| {
242 anyhow::anyhow!(
243 "Failed to create Gemini client: authentication or configuration error"
244 )
245 })?,
246 None => gemini::Client::from_env(),
247 };
248 Ok(client.agent(model))
249}
250
251fn parse_additional_param_value(raw: &str) -> Value {
252 serde_json::from_str(raw).unwrap_or_else(|_| Value::String(raw.to_string()))
253}
254
255fn additional_params_json<S>(
256 additional_params: Option<&HashMap<String, String, S>>,
257) -> Map<String, Value>
258where
259 S: std::hash::BuildHasher,
260{
261 let mut params = Map::new();
262 if let Some(additional_params) = additional_params {
263 for (key, value) in additional_params {
264 params.insert(key.clone(), parse_additional_param_value(value));
265 }
266 }
267 params
268}
269
270fn supports_openai_reasoning_defaults(model: &str) -> bool {
271 model.to_lowercase().starts_with("gpt-5")
272}
273
274fn completion_params_json<S>(
275 additional_params: Option<&HashMap<String, String, S>>,
276 provider: Provider,
277 model: &str,
278 max_tokens: u64,
279 profile: CompletionProfile,
280) -> Map<String, Value>
281where
282 S: std::hash::BuildHasher,
283{
284 let mut params = additional_params_json(additional_params);
285
286 if provider == Provider::OpenAI && needs_max_completion_tokens(model) {
287 params.insert("max_completion_tokens".to_string(), json!(max_tokens));
288 }
289
290 if provider == Provider::OpenAI
291 && supports_openai_reasoning_defaults(model)
292 && !params.contains_key("reasoning")
293 {
294 params.insert(
295 "reasoning".to_string(),
296 json!({ "effort": profile.default_openai_reasoning_effort() }),
297 );
298 }
299
300 params
301}
302
303fn needs_max_completion_tokens(model: &str) -> bool {
304 let model = model.to_lowercase();
305 model.starts_with("gpt-5")
306 || model.starts_with("gpt-4.1")
307 || model.starts_with("o1")
308 || model.starts_with("o3")
309 || model.starts_with("o4")
310}
311
312pub fn apply_completion_params<M, S>(
313 mut builder: AgentBuilder<M>,
314 provider: Provider,
315 model: &str,
316 max_tokens: u64,
317 additional_params: Option<&HashMap<String, String, S>>,
318 profile: CompletionProfile,
319) -> AgentBuilder<M>
320where
321 M: CompletionModel,
322 S: std::hash::BuildHasher,
323{
324 if !(provider == Provider::OpenAI && needs_max_completion_tokens(model)) {
325 builder = builder.max_tokens(max_tokens);
326 }
327
328 let params = completion_params_json(additional_params, provider, model, max_tokens, profile);
329
330 if params.is_empty() {
331 builder
332 } else {
333 builder.additional_params(Value::Object(params))
334 }
335}
336
337pub fn provider_from_name(provider: &str) -> Result<Provider> {
343 provider
344 .parse()
345 .map_err(|_| anyhow::anyhow!("Unsupported provider: {}", provider))
346}
347
348#[must_use]
349pub fn current_provider_config<'a>(
350 config: Option<&'a crate::config::Config>,
351 provider: &str,
352) -> Option<&'a ProviderConfig> {
353 config.and_then(|config| config.get_provider_config(provider))
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_resolve_api_key_uses_config_when_provided() {
362 let (key, source) = resolve_api_key(Some("sk-config-key-1234567890"), Provider::OpenAI);
364 assert_eq!(key, Some("sk-config-key-1234567890".to_string()));
365 assert_eq!(source, ApiKeySource::Config);
366 }
367
368 #[test]
369 fn test_resolve_api_key_empty_config_not_used() {
370 let empty_config: Option<&str> = Some("");
373 let (_key, source) = resolve_api_key(empty_config, Provider::OpenAI);
374
375 assert_ne!(source, ApiKeySource::Config);
378 }
379
380 #[test]
381 fn test_resolve_api_key_none_config_checks_env() {
382 let (key, source) = resolve_api_key(None, Provider::OpenAI);
384
385 match source {
388 ApiKeySource::Environment => {
389 assert!(key.is_some());
390 }
391 ApiKeySource::ClientDefault => {
392 assert!(key.is_none());
393 }
394 ApiKeySource::Config => {
395 unreachable!("Should not return Config source when config is None");
396 }
397 }
398 }
399
400 #[test]
401 fn test_api_key_source_enum_equality() {
402 assert_eq!(ApiKeySource::Config, ApiKeySource::Config);
403 assert_eq!(ApiKeySource::Environment, ApiKeySource::Environment);
404 assert_eq!(ApiKeySource::ClientDefault, ApiKeySource::ClientDefault);
405 assert_ne!(ApiKeySource::Config, ApiKeySource::Environment);
406 }
407
408 #[test]
409 fn test_resolve_api_key_all_providers() {
410 for provider in Provider::ALL {
412 let (key, source) = resolve_api_key(Some("test-key-123456789012345"), *provider);
413 assert_eq!(key, Some("test-key-123456789012345".to_string()));
414 assert_eq!(source, ApiKeySource::Config);
415 }
416 }
417
418 #[test]
419 fn test_resolve_api_key_config_precedence() {
420 let config_key = "sk-from-config-abcdef1234567890";
424 let (key, source) = resolve_api_key(Some(config_key), Provider::OpenAI);
425
426 assert_eq!(key.as_deref(), Some(config_key));
427 assert_eq!(source, ApiKeySource::Config);
428 }
429
430 #[test]
431 fn test_api_key_source_debug_impl() {
432 let source = ApiKeySource::Config;
434 let debug_str = format!("{:?}", source);
435 assert!(debug_str.contains("Config"));
436 }
437
438 #[test]
439 fn test_apply_completion_params_parses_json_like_additional_params() {
440 let mut additional_params = HashMap::new();
441 additional_params.insert("temperature".to_string(), "0.7".to_string());
442 additional_params.insert("reasoning".to_string(), r#"{"effort":"low"}"#.to_string());
443
444 let params = additional_params_json(Some(&additional_params));
445 assert_eq!(params.get("temperature"), Some(&json!(0.7)));
446 assert_eq!(params.get("reasoning"), Some(&json!({"effort": "low"})));
447 }
448
449 #[test]
450 fn test_completion_params_use_profile_specific_openai_reasoning_defaults() {
451 let main_params = completion_params_json::<std::collections::hash_map::RandomState>(
452 None,
453 Provider::OpenAI,
454 "gpt-5.4",
455 16_384,
456 CompletionProfile::MainAgent,
457 );
458 assert_eq!(
459 main_params.get("reasoning"),
460 Some(&json!({"effort": "medium"}))
461 );
462 assert_eq!(
463 main_params.get("max_completion_tokens"),
464 Some(&json!(16_384))
465 );
466
467 let status_params = completion_params_json::<std::collections::hash_map::RandomState>(
468 None,
469 Provider::OpenAI,
470 "gpt-5.4-mini",
471 50,
472 CompletionProfile::StatusMessage,
473 );
474 assert_eq!(
475 status_params.get("reasoning"),
476 Some(&json!({"effort": "none"}))
477 );
478 assert_eq!(status_params.get("max_completion_tokens"), Some(&json!(50)));
479 }
480
481 #[test]
482 fn test_completion_params_preserve_explicit_reasoning_overrides() {
483 let mut additional_params = HashMap::new();
484 additional_params.insert("reasoning".to_string(), r#"{"effort":"high"}"#.to_string());
485
486 let params = completion_params_json(
487 Some(&additional_params),
488 Provider::OpenAI,
489 "gpt-5.4",
490 4096,
491 CompletionProfile::MainAgent,
492 );
493
494 assert_eq!(params.get("reasoning"), Some(&json!({"effort": "high"})));
495 }
496
497 #[test]
498 fn test_completion_params_skip_openai_reasoning_defaults_for_non_gpt5_models() {
499 let params = completion_params_json::<std::collections::hash_map::RandomState>(
500 None,
501 Provider::OpenAI,
502 "gpt-4.1",
503 4096,
504 CompletionProfile::MainAgent,
505 );
506
507 assert!(!params.contains_key("reasoning"));
508 assert_eq!(params.get("max_completion_tokens"), Some(&json!(4096)));
509 }
510
511 #[test]
512 fn test_provider_from_name_supports_aliases() {
513 assert_eq!(provider_from_name("openai").ok(), Some(Provider::OpenAI));
514 assert_eq!(provider_from_name("claude").ok(), Some(Provider::Anthropic));
515 assert_eq!(provider_from_name("gemini").ok(), Some(Provider::Google));
516 }
517
518 #[test]
519 fn test_needs_max_completion_tokens_for_gpt5_family() {
520 assert!(needs_max_completion_tokens("gpt-5.4"));
521 assert!(needs_max_completion_tokens("o3"));
522 assert!(!needs_max_completion_tokens("claude-opus-4-6"));
523 }
524}