Skip to main content

git_iris/agents/
provider.rs

1//! Dynamic provider abstraction for rig-core 0.27+
2//!
3//! This module provides runtime provider selection using enum dispatch,
4//! allowing git-iris to work with any supported provider based on config.
5
6use anyhow::Result;
7use rig::{
8    agent::{Agent, AgentBuilder, PromptResponse},
9    client::{CompletionClient, ProviderClient},
10    completion::{CompletionModel, Prompt, PromptError},
11    providers::{anthropic, gemini, openai},
12};
13use serde_json::{Map, Value, json};
14use std::collections::HashMap;
15
16use crate::providers::{Provider, ProviderConfig};
17
18/// Completion model types for each provider
19pub type OpenAIModel = openai::completion::CompletionModel;
20pub type AnthropicModel = anthropic::completion::CompletionModel;
21pub type GeminiModel = gemini::completion::CompletionModel;
22
23/// Agent builder types for each provider
24pub type OpenAIBuilder = AgentBuilder<OpenAIModel>;
25pub type AnthropicBuilder = AgentBuilder<AnthropicModel>;
26pub type GeminiBuilder = AgentBuilder<GeminiModel>;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum CompletionProfile {
30    MainAgent,
31    Subagent,
32    StatusMessage,
33}
34
35impl CompletionProfile {
36    const fn default_openai_reasoning_effort(self) -> &'static str {
37        match self {
38            Self::MainAgent => "medium",
39            Self::Subagent => "low",
40            Self::StatusMessage => "none",
41        }
42    }
43}
44
45/// Dynamic agent that can be any provider's agent type
46pub enum DynAgent {
47    OpenAI(Agent<OpenAIModel>),
48    Anthropic(Agent<AnthropicModel>),
49    Gemini(Agent<GeminiModel>),
50}
51
52impl DynAgent {
53    /// Simple prompt - returns response string
54    ///
55    /// # Errors
56    ///
57    /// Returns an error when the underlying provider request fails.
58    pub async fn prompt(&self, msg: &str) -> Result<String, PromptError> {
59        match self {
60            Self::OpenAI(a) => a.prompt(msg).await,
61            Self::Anthropic(a) => a.prompt(msg).await,
62            Self::Gemini(a) => a.prompt(msg).await,
63        }
64    }
65
66    /// Multi-turn prompt with specified depth for tool calling
67    ///
68    /// # Errors
69    ///
70    /// Returns an error when the underlying provider request fails.
71    pub async fn prompt_multi_turn(&self, msg: &str, depth: usize) -> Result<String, PromptError> {
72        match self {
73            Self::OpenAI(a) => a.prompt(msg).max_turns(depth).await,
74            Self::Anthropic(a) => a.prompt(msg).max_turns(depth).await,
75            Self::Gemini(a) => a.prompt(msg).max_turns(depth).await,
76        }
77    }
78
79    /// Multi-turn prompt with extended details (token usage, etc.)
80    ///
81    /// # Errors
82    ///
83    /// Returns an error when the underlying provider request fails.
84    pub async fn prompt_extended(
85        &self,
86        msg: &str,
87        depth: usize,
88    ) -> Result<PromptResponse, PromptError> {
89        match self {
90            Self::OpenAI(a) => a.prompt(msg).max_turns(depth).extended_details().await,
91            Self::Anthropic(a) => a.prompt(msg).max_turns(depth).extended_details().await,
92            Self::Gemini(a) => a.prompt(msg).max_turns(depth).extended_details().await,
93        }
94    }
95}
96
97/// Source of the resolved API key (for logging/debugging)
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum ApiKeySource {
100    Config,
101    Environment,
102    ClientDefault,
103}
104
105/// Validate API key format and log warnings for suspicious keys
106fn validate_and_warn(key: &str, provider: Provider, source: &str) {
107    if let Err(warning) = provider.validate_api_key_format(key) {
108        tracing::warn!(
109            provider = %provider,
110            source = source,
111            "API key format warning: {}",
112            warning
113        );
114    }
115}
116
117/// Resolve API key from config or environment variable.
118///
119/// Resolution order:
120/// 1. If `api_key` is `Some` and non-empty, use it (from config)
121/// 2. Otherwise, check the provider's environment variable
122/// 3. If neither has a key, returns `None` (caller will use `from_env()`)
123///
124/// Note: An empty string in config is treated as "not configured" and falls
125/// back to the environment variable. This allows users to override env vars
126/// in config while still supporting env-only setups.
127pub fn resolve_api_key(
128    api_key: Option<&str>,
129    provider: Provider,
130) -> (Option<String>, ApiKeySource) {
131    // If explicit key provided and non-empty, use it
132    if let Some(key) = api_key
133        && !key.is_empty()
134    {
135        tracing::trace!(
136            provider = %provider,
137            source = "config",
138            "Using API key from configuration"
139        );
140        validate_and_warn(key, provider, "config");
141        return (Some(key.to_string()), ApiKeySource::Config);
142    }
143
144    // Fall back to environment variable
145    if let Ok(key) = std::env::var(provider.api_key_env()) {
146        tracing::trace!(
147            provider = %provider,
148            env_var = %provider.api_key_env(),
149            source = "environment",
150            "Using API key from environment variable"
151        );
152        validate_and_warn(&key, provider, "environment");
153        return (Some(key), ApiKeySource::Environment);
154    }
155
156    tracing::trace!(
157        provider = %provider,
158        source = "client_default",
159        "No API key found, will use client's from_env()"
160    );
161    (None, ApiKeySource::ClientDefault)
162}
163
164/// Create an `OpenAI` agent builder
165///
166/// # Arguments
167/// * `model` - The model name to use
168/// * `api_key` - Optional API key from config. Resolution order:
169///   1. Non-empty `api_key` parameter (from config)
170///   2. `OPENAI_API_KEY` environment variable
171///   3. Client's `from_env()` (requires env var to be set)
172///
173/// # Errors
174/// Returns an error if client creation fails (invalid credentials or missing env var).
175///
176/// # Security
177/// Error messages are sanitized to prevent potential API key exposure.
178pub fn openai_builder(model: &str, api_key: Option<&str>) -> Result<OpenAIBuilder> {
179    let (resolved_key, _source) = resolve_api_key(api_key, Provider::OpenAI);
180    let client = match resolved_key {
181        Some(key) => openai::Client::new(&key)
182            // Sanitize error to prevent potential key exposure in error messages
183            .map_err(|_| {
184                anyhow::anyhow!(
185                    "Failed to create OpenAI client: authentication or configuration error"
186                )
187            })?,
188        None => openai::Client::from_env()
189            .map_err(|_| anyhow::anyhow!("Failed to create OpenAI client from environment"))?,
190    };
191    Ok(client.completions_api().agent(model))
192}
193
194/// Wrap an existing Anthropic client into an agent builder with prompt caching.
195///
196/// Enables Anthropic's automatic prompt caching: a top-level `cache_control`
197/// breakpoint that the API places on the last cacheable block and advances as
198/// the conversation grows. On Iris's multi-turn tool loops, where every turn
199/// otherwise re-sends the whole transcript at full input price, cached turns
200/// are billed at a fraction of that cost. Caching is unconditional: prompts
201/// below the model's cacheable minimum are simply not cached by the API.
202pub fn anthropic_agent_builder(client: &anthropic::Client, model: &str) -> AnthropicBuilder {
203    AgentBuilder::new(client.completion_model(model).with_automatic_caching())
204}
205
206/// Create an Anthropic agent builder with prompt caching enabled.
207///
208/// # Arguments
209/// * `model` - The model name to use
210/// * `api_key` - Optional API key from config. Resolution order:
211///   1. Non-empty `api_key` parameter (from config)
212///   2. `ANTHROPIC_API_KEY` environment variable
213///   3. Client's `from_env()` (requires env var to be set)
214///
215/// # Errors
216/// Returns an error if client creation fails (invalid credentials or missing env var).
217///
218/// # Security
219/// Error messages are sanitized to prevent potential API key exposure.
220pub fn anthropic_builder(model: &str, api_key: Option<&str>) -> Result<AnthropicBuilder> {
221    let (resolved_key, _source) = resolve_api_key(api_key, Provider::Anthropic);
222    let client = match resolved_key {
223        Some(key) => anthropic::Client::new(&key)
224            // Sanitize error to prevent potential key exposure in error messages
225            .map_err(|_| {
226                anyhow::anyhow!(
227                    "Failed to create Anthropic client: authentication or configuration error"
228                )
229            })?,
230        None => anthropic::Client::from_env()
231            .map_err(|_| anyhow::anyhow!("Failed to create Anthropic client from environment"))?,
232    };
233    Ok(anthropic_agent_builder(&client, model))
234}
235
236/// Create a Gemini agent builder
237///
238/// # Arguments
239/// * `model` - The model name to use
240/// * `api_key` - Optional API key from config. Resolution order:
241///   1. Non-empty `api_key` parameter (from config)
242///   2. `GOOGLE_API_KEY` environment variable
243///   3. Client's `from_env()` (requires env var to be set)
244///
245/// # Errors
246/// Returns an error if client creation fails (invalid credentials or missing env var).
247///
248/// # Security
249/// Error messages are sanitized to prevent potential API key exposure.
250pub fn gemini_builder(model: &str, api_key: Option<&str>) -> Result<GeminiBuilder> {
251    let (resolved_key, _source) = resolve_api_key(api_key, Provider::Google);
252    let client = match resolved_key {
253        Some(key) => gemini::Client::new(&key)
254            // Sanitize error to prevent potential key exposure in error messages
255            .map_err(|_| {
256                anyhow::anyhow!(
257                    "Failed to create Gemini client: authentication or configuration error"
258                )
259            })?,
260        None => gemini::Client::from_env()
261            .map_err(|_| anyhow::anyhow!("Failed to create Gemini client from environment"))?,
262    };
263    Ok(client.agent(model))
264}
265
266fn parse_additional_param_value(raw: &str) -> Value {
267    serde_json::from_str(raw).unwrap_or_else(|_| Value::String(raw.to_string()))
268}
269
270fn additional_params_json<S>(
271    additional_params: Option<&HashMap<String, String, S>>,
272) -> Map<String, Value>
273where
274    S: std::hash::BuildHasher,
275{
276    let mut params = Map::new();
277    if let Some(additional_params) = additional_params {
278        for (key, value) in additional_params {
279            params.insert(key.clone(), parse_additional_param_value(value));
280        }
281    }
282    params
283}
284
285fn supports_openai_reasoning_defaults(model: &str) -> bool {
286    model.to_lowercase().starts_with("gpt-5")
287}
288
289fn completion_params_json<S>(
290    additional_params: Option<&HashMap<String, String, S>>,
291    provider: Provider,
292    model: &str,
293    max_tokens: u64,
294    profile: CompletionProfile,
295) -> Map<String, Value>
296where
297    S: std::hash::BuildHasher,
298{
299    let mut params = additional_params_json(additional_params);
300
301    if provider == Provider::OpenAI && needs_max_completion_tokens(model) {
302        params.insert("max_completion_tokens".to_string(), json!(max_tokens));
303    }
304
305    if provider == Provider::OpenAI
306        && supports_openai_reasoning_defaults(model)
307        && !params.contains_key("reasoning")
308    {
309        params.insert(
310            "reasoning".to_string(),
311            json!({ "effort": profile.default_openai_reasoning_effort() }),
312        );
313    }
314
315    params
316}
317
318fn needs_max_completion_tokens(model: &str) -> bool {
319    let model = model.to_lowercase();
320    model.starts_with("gpt-5")
321        || model.starts_with("gpt-4.1")
322        || model.starts_with("o1")
323        || model.starts_with("o3")
324        || model.starts_with("o4")
325}
326
327pub fn apply_completion_params<M, S>(
328    mut builder: AgentBuilder<M>,
329    provider: Provider,
330    model: &str,
331    max_tokens: u64,
332    additional_params: Option<&HashMap<String, String, S>>,
333    profile: CompletionProfile,
334) -> AgentBuilder<M>
335where
336    M: CompletionModel,
337    S: std::hash::BuildHasher,
338{
339    if !(provider == Provider::OpenAI && needs_max_completion_tokens(model)) {
340        builder = builder.max_tokens(max_tokens);
341    }
342
343    let params = completion_params_json(additional_params, provider, model, max_tokens, profile);
344
345    if params.is_empty() {
346        builder
347    } else {
348        builder.additional_params(Value::Object(params))
349    }
350}
351
352/// Parse a configured provider name into the canonical provider enum.
353///
354/// # Errors
355///
356/// Returns an error when the provider name is not supported.
357pub fn provider_from_name(provider: &str) -> Result<Provider> {
358    provider
359        .parse()
360        .map_err(|_| anyhow::anyhow!("Unsupported provider: {}", provider))
361}
362
363#[must_use]
364pub fn current_provider_config<'a>(
365    config: Option<&'a crate::config::Config>,
366    provider: &str,
367) -> Option<&'a ProviderConfig> {
368    config.and_then(|config| config.get_provider_config(provider))
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_resolve_api_key_uses_config_when_provided() {
377        // Config key takes precedence
378        let (key, source) = resolve_api_key(Some("sk-config-key-1234567890"), Provider::OpenAI);
379        assert_eq!(key, Some("sk-config-key-1234567890".to_string()));
380        assert_eq!(source, ApiKeySource::Config);
381    }
382
383    #[test]
384    fn test_resolve_api_key_empty_config_not_used() {
385        // Empty config should NOT be treated as a valid key
386        // It should fall through to env var or client default
387        let empty_config: Option<&str> = Some("");
388        let (_key, source) = resolve_api_key(empty_config, Provider::OpenAI);
389
390        // Empty config should NOT return Config source
391        // This test verifies the empty string is treated as "not configured"
392        assert_ne!(source, ApiKeySource::Config);
393    }
394
395    #[test]
396    fn test_resolve_api_key_none_config_checks_env() {
397        // When config is None, should check env var
398        let (key, source) = resolve_api_key(None, Provider::OpenAI);
399
400        // Result depends on whether OPENAI_API_KEY is set in the environment
401        // We just verify the function doesn't panic and returns appropriate source
402        match source {
403            ApiKeySource::Environment => {
404                assert!(key.is_some());
405            }
406            ApiKeySource::ClientDefault => {
407                assert!(key.is_none());
408            }
409            ApiKeySource::Config => {
410                unreachable!("Should not return Config source when config is None");
411            }
412        }
413    }
414
415    #[test]
416    fn test_api_key_source_enum_equality() {
417        assert_eq!(ApiKeySource::Config, ApiKeySource::Config);
418        assert_eq!(ApiKeySource::Environment, ApiKeySource::Environment);
419        assert_eq!(ApiKeySource::ClientDefault, ApiKeySource::ClientDefault);
420        assert_ne!(ApiKeySource::Config, ApiKeySource::Environment);
421    }
422
423    #[test]
424    fn test_resolve_api_key_all_providers() {
425        // Test that resolve_api_key works for all supported providers
426        for provider in Provider::ALL {
427            let (key, source) = resolve_api_key(Some("test-key-123456789012345"), *provider);
428            assert_eq!(key, Some("test-key-123456789012345".to_string()));
429            assert_eq!(source, ApiKeySource::Config);
430        }
431    }
432
433    #[test]
434    fn test_resolve_api_key_config_precedence() {
435        // Even if env var is set, config should take precedence
436        // We can't easily mock env vars in unit tests, but we can verify
437        // that a provided config key is always used regardless of env state
438        let config_key = "sk-from-config-abcdef1234567890";
439        let (key, source) = resolve_api_key(Some(config_key), Provider::OpenAI);
440
441        assert_eq!(key.as_deref(), Some(config_key));
442        assert_eq!(source, ApiKeySource::Config);
443    }
444
445    #[test]
446    fn test_api_key_source_debug_impl() {
447        // Verify Debug is implemented for logging purposes
448        let source = ApiKeySource::Config;
449        let debug_str = format!("{:?}", source);
450        assert!(debug_str.contains("Config"));
451    }
452
453    #[test]
454    fn test_apply_completion_params_parses_json_like_additional_params() {
455        let mut additional_params = HashMap::new();
456        additional_params.insert("temperature".to_string(), "0.7".to_string());
457        additional_params.insert("reasoning".to_string(), r#"{"effort":"low"}"#.to_string());
458
459        let params = additional_params_json(Some(&additional_params));
460        assert_eq!(params.get("temperature"), Some(&json!(0.7)));
461        assert_eq!(params.get("reasoning"), Some(&json!({"effort": "low"})));
462    }
463
464    #[test]
465    fn test_completion_params_use_profile_specific_openai_reasoning_defaults() {
466        let main_params = completion_params_json::<std::collections::hash_map::RandomState>(
467            None,
468            Provider::OpenAI,
469            "gpt-5.4",
470            16_384,
471            CompletionProfile::MainAgent,
472        );
473        assert_eq!(
474            main_params.get("reasoning"),
475            Some(&json!({"effort": "medium"}))
476        );
477        assert_eq!(
478            main_params.get("max_completion_tokens"),
479            Some(&json!(16_384))
480        );
481
482        let status_params = completion_params_json::<std::collections::hash_map::RandomState>(
483            None,
484            Provider::OpenAI,
485            "gpt-5.4-mini",
486            50,
487            CompletionProfile::StatusMessage,
488        );
489        assert_eq!(
490            status_params.get("reasoning"),
491            Some(&json!({"effort": "none"}))
492        );
493        assert_eq!(status_params.get("max_completion_tokens"), Some(&json!(50)));
494    }
495
496    #[test]
497    fn test_completion_params_preserve_explicit_reasoning_overrides() {
498        let mut additional_params = HashMap::new();
499        additional_params.insert("reasoning".to_string(), r#"{"effort":"high"}"#.to_string());
500
501        let params = completion_params_json(
502            Some(&additional_params),
503            Provider::OpenAI,
504            "gpt-5.4",
505            4096,
506            CompletionProfile::MainAgent,
507        );
508
509        assert_eq!(params.get("reasoning"), Some(&json!({"effort": "high"})));
510    }
511
512    #[test]
513    fn test_completion_params_skip_openai_reasoning_defaults_for_non_gpt5_models() {
514        let params = completion_params_json::<std::collections::hash_map::RandomState>(
515            None,
516            Provider::OpenAI,
517            "gpt-4.1",
518            4096,
519            CompletionProfile::MainAgent,
520        );
521
522        assert!(!params.contains_key("reasoning"));
523        assert_eq!(params.get("max_completion_tokens"), Some(&json!(4096)));
524    }
525
526    #[test]
527    fn test_provider_from_name_supports_aliases() {
528        assert_eq!(provider_from_name("openai").ok(), Some(Provider::OpenAI));
529        assert_eq!(provider_from_name("claude").ok(), Some(Provider::Anthropic));
530        assert_eq!(provider_from_name("gemini").ok(), Some(Provider::Google));
531    }
532
533    #[test]
534    fn test_needs_max_completion_tokens_for_gpt5_family() {
535        assert!(needs_max_completion_tokens("gpt-5.4"));
536        assert!(needs_max_completion_tokens("o3"));
537        assert!(!needs_max_completion_tokens("claude-opus-4-6"));
538    }
539}