Skip to main content

git_iris/agents/
provider.rs

1//! Dynamic provider abstraction for rig-core 0.27+
2//!
3//! This module provides runtime provider selection using enum dispatch,
4//! allowing git-iris to work with any supported provider based on config.
5
6use anyhow::Result;
7use rig::{
8    agent::{Agent, AgentBuilder, PromptResponse},
9    client::{CompletionClient, ProviderClient},
10    completion::{CompletionModel, Prompt, PromptError},
11    providers::{anthropic, gemini, openai},
12};
13use serde_json::{Map, Value, json};
14use std::collections::HashMap;
15
16use crate::providers::{Provider, ProviderConfig};
17
18/// Completion model types for each provider
19pub type OpenAIModel = openai::completion::CompletionModel;
20pub type AnthropicModel = anthropic::completion::CompletionModel;
21pub type GeminiModel = gemini::completion::CompletionModel;
22
23/// Agent builder types for each provider
24pub type OpenAIBuilder = AgentBuilder<OpenAIModel>;
25pub type AnthropicBuilder = AgentBuilder<AnthropicModel>;
26pub type GeminiBuilder = AgentBuilder<GeminiModel>;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum CompletionProfile {
30    MainAgent,
31    Subagent,
32    StatusMessage,
33}
34
35impl CompletionProfile {
36    const fn default_openai_reasoning_effort(self) -> &'static str {
37        match self {
38            Self::MainAgent => "medium",
39            Self::Subagent => "low",
40            Self::StatusMessage => "none",
41        }
42    }
43}
44
45/// Dynamic agent that can be any provider's agent type
46pub enum DynAgent {
47    OpenAI(Agent<OpenAIModel>),
48    Anthropic(Agent<AnthropicModel>),
49    Gemini(Agent<GeminiModel>),
50}
51
52impl DynAgent {
53    /// Simple prompt - returns response string
54    ///
55    /// # Errors
56    ///
57    /// Returns an error when the underlying provider request fails.
58    pub async fn prompt(&self, msg: &str) -> Result<String, PromptError> {
59        match self {
60            Self::OpenAI(a) => a.prompt(msg).await,
61            Self::Anthropic(a) => a.prompt(msg).await,
62            Self::Gemini(a) => a.prompt(msg).await,
63        }
64    }
65
66    /// Multi-turn prompt with specified depth for tool calling
67    ///
68    /// # Errors
69    ///
70    /// Returns an error when the underlying provider request fails.
71    pub async fn prompt_multi_turn(&self, msg: &str, depth: usize) -> Result<String, PromptError> {
72        match self {
73            Self::OpenAI(a) => a.prompt(msg).max_turns(depth).await,
74            Self::Anthropic(a) => a.prompt(msg).max_turns(depth).await,
75            Self::Gemini(a) => a.prompt(msg).max_turns(depth).await,
76        }
77    }
78
79    /// Multi-turn prompt with extended details (token usage, etc.)
80    ///
81    /// # Errors
82    ///
83    /// Returns an error when the underlying provider request fails.
84    pub async fn prompt_extended(
85        &self,
86        msg: &str,
87        depth: usize,
88    ) -> Result<PromptResponse, PromptError> {
89        match self {
90            Self::OpenAI(a) => a.prompt(msg).max_turns(depth).extended_details().await,
91            Self::Anthropic(a) => a.prompt(msg).max_turns(depth).extended_details().await,
92            Self::Gemini(a) => a.prompt(msg).max_turns(depth).extended_details().await,
93        }
94    }
95}
96
97/// Source of the resolved API key (for logging/debugging)
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum ApiKeySource {
100    Config,
101    Environment,
102    ClientDefault,
103}
104
105/// Validate API key format and log warnings for suspicious keys
106fn validate_and_warn(key: &str, provider: Provider, source: &str) {
107    if let Err(warning) = provider.validate_api_key_format(key) {
108        tracing::warn!(
109            provider = %provider,
110            source = source,
111            "API key format warning: {}",
112            warning
113        );
114    }
115}
116
117/// Resolve API key from config or environment variable.
118///
119/// Resolution order:
120/// 1. If `api_key` is `Some` and non-empty, use it (from config)
121/// 2. Otherwise, check the provider's environment variable
122/// 3. If neither has a key, returns `None` (caller will use `from_env()`)
123///
124/// Note: An empty string in config is treated as "not configured" and falls
125/// back to the environment variable. This allows users to override env vars
126/// in config while still supporting env-only setups.
127pub fn resolve_api_key(
128    api_key: Option<&str>,
129    provider: Provider,
130) -> (Option<String>, ApiKeySource) {
131    // If explicit key provided and non-empty, use it
132    if let Some(key) = api_key
133        && !key.is_empty()
134    {
135        tracing::trace!(
136            provider = %provider,
137            source = "config",
138            "Using API key from configuration"
139        );
140        validate_and_warn(key, provider, "config");
141        return (Some(key.to_string()), ApiKeySource::Config);
142    }
143
144    // Fall back to environment variable
145    if let Ok(key) = std::env::var(provider.api_key_env()) {
146        tracing::trace!(
147            provider = %provider,
148            env_var = %provider.api_key_env(),
149            source = "environment",
150            "Using API key from environment variable"
151        );
152        validate_and_warn(&key, provider, "environment");
153        return (Some(key), ApiKeySource::Environment);
154    }
155
156    tracing::trace!(
157        provider = %provider,
158        source = "client_default",
159        "No API key found, will use client's from_env()"
160    );
161    (None, ApiKeySource::ClientDefault)
162}
163
164/// Create an `OpenAI` agent builder
165///
166/// # Arguments
167/// * `model` - The model name to use
168/// * `api_key` - Optional API key from config. Resolution order:
169///   1. Non-empty `api_key` parameter (from config)
170///   2. `OPENAI_API_KEY` environment variable
171///   3. Client's `from_env()` (requires env var to be set)
172///
173/// # Errors
174/// Returns an error if client creation fails (invalid credentials or missing env var).
175///
176/// # Security
177/// Error messages are sanitized to prevent potential API key exposure.
178pub fn openai_builder(model: &str, api_key: Option<&str>) -> Result<OpenAIBuilder> {
179    let (resolved_key, _source) = resolve_api_key(api_key, Provider::OpenAI);
180    let client = match resolved_key {
181        Some(key) => openai::Client::new(&key)
182            // Sanitize error to prevent potential key exposure in error messages
183            .map_err(|_| {
184                anyhow::anyhow!(
185                    "Failed to create OpenAI client: authentication or configuration error"
186                )
187            })?,
188        None => openai::Client::from_env(),
189    };
190    Ok(client.completions_api().agent(model))
191}
192
193/// Create an Anthropic agent builder
194///
195/// # Arguments
196/// * `model` - The model name to use
197/// * `api_key` - Optional API key from config. Resolution order:
198///   1. Non-empty `api_key` parameter (from config)
199///   2. `ANTHROPIC_API_KEY` environment variable
200///   3. Client's `from_env()` (requires env var to be set)
201///
202/// # Errors
203/// Returns an error if client creation fails (invalid credentials or missing env var).
204///
205/// # Security
206/// Error messages are sanitized to prevent potential API key exposure.
207pub fn anthropic_builder(model: &str, api_key: Option<&str>) -> Result<AnthropicBuilder> {
208    let (resolved_key, _source) = resolve_api_key(api_key, Provider::Anthropic);
209    let client = match resolved_key {
210        Some(key) => anthropic::Client::new(&key)
211            // Sanitize error to prevent potential key exposure in error messages
212            .map_err(|_| {
213                anyhow::anyhow!(
214                    "Failed to create Anthropic client: authentication or configuration error"
215                )
216            })?,
217        None => anthropic::Client::from_env(),
218    };
219    Ok(client.agent(model))
220}
221
222/// Create a Gemini agent builder
223///
224/// # Arguments
225/// * `model` - The model name to use
226/// * `api_key` - Optional API key from config. Resolution order:
227///   1. Non-empty `api_key` parameter (from config)
228///   2. `GOOGLE_API_KEY` environment variable
229///   3. Client's `from_env()` (requires env var to be set)
230///
231/// # Errors
232/// Returns an error if client creation fails (invalid credentials or missing env var).
233///
234/// # Security
235/// Error messages are sanitized to prevent potential API key exposure.
236pub fn gemini_builder(model: &str, api_key: Option<&str>) -> Result<GeminiBuilder> {
237    let (resolved_key, _source) = resolve_api_key(api_key, Provider::Google);
238    let client = match resolved_key {
239        Some(key) => gemini::Client::new(&key)
240            // Sanitize error to prevent potential key exposure in error messages
241            .map_err(|_| {
242                anyhow::anyhow!(
243                    "Failed to create Gemini client: authentication or configuration error"
244                )
245            })?,
246        None => gemini::Client::from_env(),
247    };
248    Ok(client.agent(model))
249}
250
251fn parse_additional_param_value(raw: &str) -> Value {
252    serde_json::from_str(raw).unwrap_or_else(|_| Value::String(raw.to_string()))
253}
254
255fn additional_params_json<S>(
256    additional_params: Option<&HashMap<String, String, S>>,
257) -> Map<String, Value>
258where
259    S: std::hash::BuildHasher,
260{
261    let mut params = Map::new();
262    if let Some(additional_params) = additional_params {
263        for (key, value) in additional_params {
264            params.insert(key.clone(), parse_additional_param_value(value));
265        }
266    }
267    params
268}
269
270fn supports_openai_reasoning_defaults(model: &str) -> bool {
271    model.to_lowercase().starts_with("gpt-5")
272}
273
274fn completion_params_json<S>(
275    additional_params: Option<&HashMap<String, String, S>>,
276    provider: Provider,
277    model: &str,
278    max_tokens: u64,
279    profile: CompletionProfile,
280) -> Map<String, Value>
281where
282    S: std::hash::BuildHasher,
283{
284    let mut params = additional_params_json(additional_params);
285
286    if provider == Provider::OpenAI && needs_max_completion_tokens(model) {
287        params.insert("max_completion_tokens".to_string(), json!(max_tokens));
288    }
289
290    if provider == Provider::OpenAI
291        && supports_openai_reasoning_defaults(model)
292        && !params.contains_key("reasoning")
293    {
294        params.insert(
295            "reasoning".to_string(),
296            json!({ "effort": profile.default_openai_reasoning_effort() }),
297        );
298    }
299
300    params
301}
302
303fn needs_max_completion_tokens(model: &str) -> bool {
304    let model = model.to_lowercase();
305    model.starts_with("gpt-5")
306        || model.starts_with("gpt-4.1")
307        || model.starts_with("o1")
308        || model.starts_with("o3")
309        || model.starts_with("o4")
310}
311
312pub fn apply_completion_params<M, S>(
313    mut builder: AgentBuilder<M>,
314    provider: Provider,
315    model: &str,
316    max_tokens: u64,
317    additional_params: Option<&HashMap<String, String, S>>,
318    profile: CompletionProfile,
319) -> AgentBuilder<M>
320where
321    M: CompletionModel,
322    S: std::hash::BuildHasher,
323{
324    if !(provider == Provider::OpenAI && needs_max_completion_tokens(model)) {
325        builder = builder.max_tokens(max_tokens);
326    }
327
328    let params = completion_params_json(additional_params, provider, model, max_tokens, profile);
329
330    if params.is_empty() {
331        builder
332    } else {
333        builder.additional_params(Value::Object(params))
334    }
335}
336
337/// Parse a configured provider name into the canonical provider enum.
338///
339/// # Errors
340///
341/// Returns an error when the provider name is not supported.
342pub fn provider_from_name(provider: &str) -> Result<Provider> {
343    provider
344        .parse()
345        .map_err(|_| anyhow::anyhow!("Unsupported provider: {}", provider))
346}
347
348#[must_use]
349pub fn current_provider_config<'a>(
350    config: Option<&'a crate::config::Config>,
351    provider: &str,
352) -> Option<&'a ProviderConfig> {
353    config.and_then(|config| config.get_provider_config(provider))
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_resolve_api_key_uses_config_when_provided() {
362        // Config key takes precedence
363        let (key, source) = resolve_api_key(Some("sk-config-key-1234567890"), Provider::OpenAI);
364        assert_eq!(key, Some("sk-config-key-1234567890".to_string()));
365        assert_eq!(source, ApiKeySource::Config);
366    }
367
368    #[test]
369    fn test_resolve_api_key_empty_config_not_used() {
370        // Empty config should NOT be treated as a valid key
371        // It should fall through to env var or client default
372        let empty_config: Option<&str> = Some("");
373        let (_key, source) = resolve_api_key(empty_config, Provider::OpenAI);
374
375        // Empty config should NOT return Config source
376        // This test verifies the empty string is treated as "not configured"
377        assert_ne!(source, ApiKeySource::Config);
378    }
379
380    #[test]
381    fn test_resolve_api_key_none_config_checks_env() {
382        // When config is None, should check env var
383        let (key, source) = resolve_api_key(None, Provider::OpenAI);
384
385        // Result depends on whether OPENAI_API_KEY is set in the environment
386        // We just verify the function doesn't panic and returns appropriate source
387        match source {
388            ApiKeySource::Environment => {
389                assert!(key.is_some());
390            }
391            ApiKeySource::ClientDefault => {
392                assert!(key.is_none());
393            }
394            ApiKeySource::Config => {
395                unreachable!("Should not return Config source when config is None");
396            }
397        }
398    }
399
400    #[test]
401    fn test_api_key_source_enum_equality() {
402        assert_eq!(ApiKeySource::Config, ApiKeySource::Config);
403        assert_eq!(ApiKeySource::Environment, ApiKeySource::Environment);
404        assert_eq!(ApiKeySource::ClientDefault, ApiKeySource::ClientDefault);
405        assert_ne!(ApiKeySource::Config, ApiKeySource::Environment);
406    }
407
408    #[test]
409    fn test_resolve_api_key_all_providers() {
410        // Test that resolve_api_key works for all supported providers
411        for provider in Provider::ALL {
412            let (key, source) = resolve_api_key(Some("test-key-123456789012345"), *provider);
413            assert_eq!(key, Some("test-key-123456789012345".to_string()));
414            assert_eq!(source, ApiKeySource::Config);
415        }
416    }
417
418    #[test]
419    fn test_resolve_api_key_config_precedence() {
420        // Even if env var is set, config should take precedence
421        // We can't easily mock env vars in unit tests, but we can verify
422        // that a provided config key is always used regardless of env state
423        let config_key = "sk-from-config-abcdef1234567890";
424        let (key, source) = resolve_api_key(Some(config_key), Provider::OpenAI);
425
426        assert_eq!(key.as_deref(), Some(config_key));
427        assert_eq!(source, ApiKeySource::Config);
428    }
429
430    #[test]
431    fn test_api_key_source_debug_impl() {
432        // Verify Debug is implemented for logging purposes
433        let source = ApiKeySource::Config;
434        let debug_str = format!("{:?}", source);
435        assert!(debug_str.contains("Config"));
436    }
437
438    #[test]
439    fn test_apply_completion_params_parses_json_like_additional_params() {
440        let mut additional_params = HashMap::new();
441        additional_params.insert("temperature".to_string(), "0.7".to_string());
442        additional_params.insert("reasoning".to_string(), r#"{"effort":"low"}"#.to_string());
443
444        let params = additional_params_json(Some(&additional_params));
445        assert_eq!(params.get("temperature"), Some(&json!(0.7)));
446        assert_eq!(params.get("reasoning"), Some(&json!({"effort": "low"})));
447    }
448
449    #[test]
450    fn test_completion_params_use_profile_specific_openai_reasoning_defaults() {
451        let main_params = completion_params_json::<std::collections::hash_map::RandomState>(
452            None,
453            Provider::OpenAI,
454            "gpt-5.4",
455            16_384,
456            CompletionProfile::MainAgent,
457        );
458        assert_eq!(
459            main_params.get("reasoning"),
460            Some(&json!({"effort": "medium"}))
461        );
462        assert_eq!(
463            main_params.get("max_completion_tokens"),
464            Some(&json!(16_384))
465        );
466
467        let status_params = completion_params_json::<std::collections::hash_map::RandomState>(
468            None,
469            Provider::OpenAI,
470            "gpt-5.4-mini",
471            50,
472            CompletionProfile::StatusMessage,
473        );
474        assert_eq!(
475            status_params.get("reasoning"),
476            Some(&json!({"effort": "none"}))
477        );
478        assert_eq!(status_params.get("max_completion_tokens"), Some(&json!(50)));
479    }
480
481    #[test]
482    fn test_completion_params_preserve_explicit_reasoning_overrides() {
483        let mut additional_params = HashMap::new();
484        additional_params.insert("reasoning".to_string(), r#"{"effort":"high"}"#.to_string());
485
486        let params = completion_params_json(
487            Some(&additional_params),
488            Provider::OpenAI,
489            "gpt-5.4",
490            4096,
491            CompletionProfile::MainAgent,
492        );
493
494        assert_eq!(params.get("reasoning"), Some(&json!({"effort": "high"})));
495    }
496
497    #[test]
498    fn test_completion_params_skip_openai_reasoning_defaults_for_non_gpt5_models() {
499        let params = completion_params_json::<std::collections::hash_map::RandomState>(
500            None,
501            Provider::OpenAI,
502            "gpt-4.1",
503            4096,
504            CompletionProfile::MainAgent,
505        );
506
507        assert!(!params.contains_key("reasoning"));
508        assert_eq!(params.get("max_completion_tokens"), Some(&json!(4096)));
509    }
510
511    #[test]
512    fn test_provider_from_name_supports_aliases() {
513        assert_eq!(provider_from_name("openai").ok(), Some(Provider::OpenAI));
514        assert_eq!(provider_from_name("claude").ok(), Some(Provider::Anthropic));
515        assert_eq!(provider_from_name("gemini").ok(), Some(Provider::Google));
516    }
517
518    #[test]
519    fn test_needs_max_completion_tokens_for_gpt5_family() {
520        assert!(needs_max_completion_tokens("gpt-5.4"));
521        assert!(needs_max_completion_tokens("o3"));
522        assert!(!needs_max_completion_tokens("claude-opus-4-6"));
523    }
524}