Skip to main content

codetether_agent/provider/
mod.rs

1//! AI Provider abstraction layer
2//!
3//! Unified interface for multiple AI providers (OpenAI, Anthropic, Google, StepFun, etc.)
4
5pub mod anthropic;
6pub mod bedrock;
7pub mod copilot;
8pub mod google;
9pub mod models;
10pub mod moonshot;
11pub mod openai;
12pub mod openrouter;
13pub mod stepfun;
14pub mod zai;
15
16use anyhow::Result;
17use async_trait::async_trait;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::Arc;
21
22/// A message in a conversation
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Message {
25    pub role: Role,
26    pub content: Vec<ContentPart>,
27}
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
30#[serde(rename_all = "lowercase")]
31pub enum Role {
32    System,
33    User,
34    Assistant,
35    Tool,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type", rename_all = "snake_case")]
40pub enum ContentPart {
41    Text {
42        text: String,
43    },
44    Image {
45        url: String,
46        mime_type: Option<String>,
47    },
48    File {
49        path: String,
50        mime_type: Option<String>,
51    },
52    ToolCall {
53        id: String,
54        name: String,
55        arguments: String,
56    },
57    ToolResult {
58        tool_call_id: String,
59        content: String,
60    },
61    Thinking {
62        text: String,
63    },
64}
65
66/// Tool definition for the model
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ToolDefinition {
69    pub name: String,
70    pub description: String,
71    pub parameters: serde_json::Value, // JSON Schema
72}
73
74/// Request to generate a completion
75#[derive(Debug, Clone)]
76pub struct CompletionRequest {
77    pub messages: Vec<Message>,
78    pub tools: Vec<ToolDefinition>,
79    pub model: String,
80    pub temperature: Option<f32>,
81    pub top_p: Option<f32>,
82    pub max_tokens: Option<usize>,
83    pub stop: Vec<String>,
84}
85
86/// A streaming chunk from the model
87#[derive(Debug, Clone)]
88pub enum StreamChunk {
89    Text(String),
90    ToolCallStart { id: String, name: String },
91    ToolCallDelta { id: String, arguments_delta: String },
92    ToolCallEnd { id: String },
93    Done { usage: Option<Usage> },
94    Error(String),
95}
96
97/// Token usage information
98#[derive(Debug, Clone, Default, Serialize, Deserialize)]
99pub struct Usage {
100    pub prompt_tokens: usize,
101    pub completion_tokens: usize,
102    pub total_tokens: usize,
103    pub cache_read_tokens: Option<usize>,
104    pub cache_write_tokens: Option<usize>,
105}
106
107/// Response from a completion request
108#[derive(Debug, Clone)]
109pub struct CompletionResponse {
110    pub message: Message,
111    pub usage: Usage,
112    pub finish_reason: FinishReason,
113}
114
115#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
116#[serde(rename_all = "snake_case")]
117pub enum FinishReason {
118    Stop,
119    Length,
120    ToolCalls,
121    ContentFilter,
122    Error,
123}
124
125/// Provider trait that all AI providers must implement
126#[async_trait]
127pub trait Provider: Send + Sync {
128    /// Get the provider name
129    fn name(&self) -> &str;
130
131    /// List available models
132    async fn list_models(&self) -> Result<Vec<ModelInfo>>;
133
134    /// Generate a completion
135    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
136
137    /// Generate a streaming completion
138    async fn complete_stream(
139        &self,
140        request: CompletionRequest,
141    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>>;
142}
143
144/// Information about a model
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct ModelInfo {
147    pub id: String,
148    pub name: String,
149    pub provider: String,
150    pub context_window: usize,
151    pub max_output_tokens: Option<usize>,
152    pub supports_vision: bool,
153    pub supports_tools: bool,
154    pub supports_streaming: bool,
155    pub input_cost_per_million: Option<f64>,
156    pub output_cost_per_million: Option<f64>,
157}
158
159/// Registry of available providers
160pub struct ProviderRegistry {
161    providers: HashMap<String, Arc<dyn Provider>>,
162}
163
164impl std::fmt::Debug for ProviderRegistry {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        f.debug_struct("ProviderRegistry")
167            .field("provider_count", &self.providers.len())
168            .field("providers", &self.providers.keys().collect::<Vec<_>>())
169            .finish()
170    }
171}
172
173impl ProviderRegistry {
174    pub fn new() -> Self {
175        Self {
176            providers: HashMap::new(),
177        }
178    }
179
180    /// Register a provider
181    pub fn register(&mut self, provider: Arc<dyn Provider>) {
182        self.providers.insert(provider.name().to_string(), provider);
183    }
184
185    /// Get a provider by name
186    pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
187        self.providers.get(name).cloned()
188    }
189
190    /// List all registered providers
191    pub fn list(&self) -> Vec<&str> {
192        self.providers.keys().map(|s| s.as_str()).collect()
193    }
194
195    /// Initialize with default providers from config
196    pub async fn from_config(config: &crate::config::Config) -> Result<Self> {
197        let mut registry = Self::new();
198
199        // Always try to initialize OpenAI if key is available
200        if let Some(provider_config) = config.providers.get("openai") {
201            if let Some(api_key) = &provider_config.api_key {
202                registry.register(Arc::new(openai::OpenAIProvider::new(api_key.clone())?));
203            }
204        } else if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
205            registry.register(Arc::new(openai::OpenAIProvider::new(api_key)?));
206        }
207
208        // Initialize Anthropic
209        if let Some(provider_config) = config.providers.get("anthropic") {
210            if let Some(api_key) = &provider_config.api_key {
211                let provider = if let Some(base_url) = provider_config.base_url.clone() {
212                    anthropic::AnthropicProvider::with_base_url(
213                        api_key.clone(),
214                        base_url,
215                        "anthropic",
216                    )?
217                } else {
218                    anthropic::AnthropicProvider::new(api_key.clone())?
219                };
220                registry.register(Arc::new(provider));
221            }
222        } else if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
223            registry.register(Arc::new(anthropic::AnthropicProvider::new(api_key)?));
224        }
225
226        // Initialize Google
227        if let Some(provider_config) = config.providers.get("google") {
228            if let Some(api_key) = &provider_config.api_key {
229                registry.register(Arc::new(google::GoogleProvider::new(api_key.clone())?));
230            }
231        } else if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
232            registry.register(Arc::new(google::GoogleProvider::new(api_key)?));
233        }
234
235        // Initialize Novita (OpenAI-compatible)
236        if let Some(provider_config) = config.providers.get("novita") {
237            if let Some(api_key) = &provider_config.api_key {
238                let base_url = provider_config
239                    .base_url
240                    .clone()
241                    .unwrap_or_else(|| "https://api.novita.ai/openai/v1".to_string());
242                registry.register(Arc::new(openai::OpenAIProvider::with_base_url(
243                    api_key.clone(),
244                    base_url,
245                    "novita",
246                )?));
247            }
248        }
249
250        // Initialize Bedrock via AWS credentials (env vars or ~/.aws/credentials)
251        if let Some(creds) = bedrock::AwsCredentials::from_environment() {
252            let region = bedrock::AwsCredentials::detect_region()
253                .unwrap_or_else(|| bedrock::DEFAULT_REGION.to_string());
254            match bedrock::BedrockProvider::with_credentials(creds, region) {
255                Ok(p) => registry.register(Arc::new(p)),
256                Err(e) => tracing::warn!("Failed to init bedrock from AWS credentials: {}", e),
257            }
258        }
259
260        Ok(registry)
261    }
262
263    /// Initialize providers from HashiCorp Vault
264    ///
265    /// This loads API keys from Vault and creates providers dynamically.
266    /// Supports OpenAI-compatible providers via base_url.
267    pub async fn from_vault() -> Result<Self> {
268        let mut registry = Self::new();
269
270        if let Some(manager) = crate::secrets::secrets_manager() {
271            // List all configured providers from Vault
272            let providers = manager.list_configured_providers().await?;
273            tracing::info!("Found {} providers configured in Vault", providers.len());
274
275            for provider_id in providers {
276                let secrets = match manager.get_provider_secrets(&provider_id).await? {
277                    Some(s) => s,
278                    None => continue,
279                };
280
281                // Handle Bedrock before api_key extraction since it can use
282                // AWS IAM credentials instead of an API key.
283                if matches!(provider_id.as_str(), "bedrock" | "aws-bedrock") {
284                    let region = secrets
285                        .extra
286                        .get("region")
287                        .and_then(|v| v.as_str())
288                        .unwrap_or("us-east-1")
289                        .to_string();
290
291                    // Prefer SigV4 if AWS credentials are in Vault
292                    let aws_key_id = secrets
293                        .extra
294                        .get("aws_access_key_id")
295                        .and_then(|v| v.as_str());
296                    let aws_secret = secrets
297                        .extra
298                        .get("aws_secret_access_key")
299                        .and_then(|v| v.as_str());
300
301                    let result = if let (Some(key_id), Some(secret)) = (aws_key_id, aws_secret) {
302                        let creds = bedrock::AwsCredentials {
303                            access_key_id: key_id.to_string(),
304                            secret_access_key: secret.to_string(),
305                            session_token: secrets
306                                .extra
307                                .get("aws_session_token")
308                                .and_then(|v| v.as_str())
309                                .map(|s| s.to_string()),
310                        };
311                        bedrock::BedrockProvider::with_credentials(creds, region)
312                    } else if let Some(ref key) = secrets.api_key {
313                        bedrock::BedrockProvider::with_region(key.clone(), region)
314                    } else {
315                        // Try auto-detecting from environment as last resort
316                        if let Some(creds) = bedrock::AwsCredentials::from_environment() {
317                            bedrock::BedrockProvider::with_credentials(creds, region)
318                        } else {
319                            Err(anyhow::anyhow!(
320                                "No AWS credentials or API key found for Bedrock"
321                            ))
322                        }
323                    };
324
325                    match result {
326                        Ok(p) => registry.register(Arc::new(p)),
327                        Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
328                    }
329                    continue;
330                }
331
332                let api_key = match secrets.api_key {
333                    Some(key) => key,
334                    None => continue,
335                };
336
337                // Determine which provider implementation to use
338                match provider_id.as_str() {
339                    // Native providers
340                    "anthropic" | "anthropic-eu" | "anthropic-asia" => {
341                        let base_url = secrets
342                            .base_url
343                            .clone()
344                            .unwrap_or_else(|| "https://api.anthropic.com".to_string());
345                        match anthropic::AnthropicProvider::with_base_url(
346                            api_key,
347                            base_url,
348                            &provider_id,
349                        ) {
350                            Ok(p) => registry.register(Arc::new(p)),
351                            Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
352                        }
353                    }
354                    "google" | "google-vertex" => match google::GoogleProvider::new(api_key) {
355                        Ok(p) => registry.register(Arc::new(p)),
356                        Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
357                    },
358                    // StepFun - native provider (direct API, not via OpenRouter)
359                    "stepfun" => match stepfun::StepFunProvider::new(api_key) {
360                        Ok(p) => registry.register(Arc::new(p)),
361                        Err(e) => tracing::warn!("Failed to init stepfun: {}", e),
362                    },
363                    // OpenRouter - native provider with support for extended response formats
364                    "openrouter" => match openrouter::OpenRouterProvider::new(api_key) {
365                        Ok(p) => registry.register(Arc::new(p)),
366                        Err(e) => tracing::warn!("Failed to init openrouter: {}", e),
367                    },
368                    // Moonshot AI - native provider for Kimi models
369                    "moonshotai" | "moonshotai-cn" => {
370                        match moonshot::MoonshotProvider::new(api_key) {
371                            Ok(p) => registry.register(Arc::new(p)),
372                            Err(e) => tracing::warn!("Failed to init moonshotai: {}", e),
373                        }
374                    }
375                    // GitHub Copilot providers require custom headers/token semantics
376                    "github-copilot" => {
377                        let result = if let Some(base_url) = secrets.base_url.clone() {
378                            copilot::CopilotProvider::with_base_url(
379                                api_key,
380                                base_url,
381                                "github-copilot",
382                            )
383                        } else {
384                            copilot::CopilotProvider::new(api_key)
385                        };
386
387                        match result {
388                            Ok(p) => registry.register(Arc::new(p)),
389                            Err(e) => tracing::warn!("Failed to init github-copilot: {}", e),
390                        }
391                    }
392                    "github-copilot-enterprise" => {
393                        let enterprise_url = secrets
394                            .extra
395                            .get("enterpriseUrl")
396                            .and_then(|v| v.as_str())
397                            .or_else(|| {
398                                secrets.extra.get("enterprise_url").and_then(|v| v.as_str())
399                            });
400
401                        let result = if let Some(base_url) = secrets.base_url.clone() {
402                            copilot::CopilotProvider::with_base_url(
403                                api_key,
404                                base_url,
405                                "github-copilot-enterprise",
406                            )
407                        } else if let Some(url) = enterprise_url {
408                            copilot::CopilotProvider::enterprise(api_key, url.to_string())
409                        } else {
410                            copilot::CopilotProvider::with_base_url(
411                                api_key,
412                                "https://api.githubcopilot.com".to_string(),
413                                "github-copilot-enterprise",
414                            )
415                        };
416
417                        match result {
418                            Ok(p) => registry.register(Arc::new(p)),
419                            Err(e) => {
420                                tracing::warn!("Failed to init github-copilot-enterprise: {}", e)
421                            }
422                        }
423                    }
424                    // Z.AI (formerly ZhipuAI) — first-class provider for GLM models
425                    "zhipuai" | "zai" => {
426                        let base_url = secrets
427                            .base_url
428                            .clone()
429                            .unwrap_or_else(|| "https://api.z.ai/api/paas/v4".to_string());
430                        match zai::ZaiProvider::with_base_url(api_key, base_url) {
431                            Ok(p) => registry.register(Arc::new(p)),
432                            Err(e) => tracing::warn!("Failed to init zai: {}", e),
433                        }
434                    }
435                    // Cerebras - OpenAI-compatible fast inference
436                    "cerebras" => {
437                        let base_url = secrets
438                            .base_url
439                            .clone()
440                            .unwrap_or_else(|| "https://api.cerebras.ai/v1".to_string());
441                        match openai::OpenAIProvider::with_base_url(api_key, base_url, "cerebras") {
442                            Ok(p) => registry.register(Arc::new(p)),
443                            Err(e) => tracing::warn!("Failed to init cerebras: {}", e),
444                        }
445                    }
446                    // MiniMax - Anthropic-compatible API harness (recommended by MiniMax docs)
447                    "minimax" => {
448                        let base_url = secrets
449                            .base_url
450                            .clone()
451                            .unwrap_or_else(|| "https://api.minimax.io/anthropic".to_string());
452                        let base_url = normalize_minimax_anthropic_base_url(&base_url);
453                        match anthropic::AnthropicProvider::with_base_url(
454                            api_key, base_url, "minimax",
455                        ) {
456                            Ok(p) => registry.register(Arc::new(p)),
457                            Err(e) => tracing::warn!("Failed to init minimax: {}", e),
458                        }
459                    }
460                    // OpenAI-compatible providers (with custom base_url)
461                    "deepseek" | "groq" | "togetherai" | "fireworks-ai" | "mistral" | "nvidia"
462                    | "alibaba" | "openai" | "azure" | "novita" => {
463                        if let Some(base_url) = secrets.base_url.clone() {
464                            match openai::OpenAIProvider::with_base_url(
465                                api_key,
466                                base_url,
467                                &provider_id,
468                            ) {
469                                Ok(p) => registry.register(Arc::new(p)),
470                                Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
471                            }
472                        } else if provider_id == "openai" {
473                            // OpenAI doesn't need a custom base_url
474                            match openai::OpenAIProvider::new(api_key) {
475                                Ok(p) => registry.register(Arc::new(p)),
476                                Err(e) => tracing::warn!("Failed to init openai: {}", e),
477                            }
478                        } else if provider_id == "novita" {
479                            let base_url = "https://api.novita.ai/openai/v1".to_string();
480                            match openai::OpenAIProvider::with_base_url(
481                                api_key,
482                                base_url,
483                                &provider_id,
484                            ) {
485                                Ok(p) => registry.register(Arc::new(p)),
486                                Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
487                            }
488                        } else {
489                            // Try using the base_url from the models API
490                            if let Ok(catalog) = models::ModelCatalog::fetch().await {
491                                if let Some(provider_info) = catalog.get_provider(&provider_id) {
492                                    if let Some(api_url) = &provider_info.api {
493                                        match openai::OpenAIProvider::with_base_url(
494                                            api_key,
495                                            api_url.clone(),
496                                            &provider_id,
497                                        ) {
498                                            Ok(p) => registry.register(Arc::new(p)),
499                                            Err(e) => {
500                                                tracing::warn!(
501                                                    "Failed to init {}: {}",
502                                                    provider_id,
503                                                    e
504                                                )
505                                            }
506                                        }
507                                    }
508                                }
509                            }
510                        }
511                    }
512                    // Unknown providers - try as OpenAI-compatible with base_url from API
513                    other => {
514                        if let Some(base_url) = secrets.base_url {
515                            match openai::OpenAIProvider::with_base_url(api_key, base_url, other) {
516                                Ok(p) => registry.register(Arc::new(p)),
517                                Err(e) => tracing::warn!("Failed to init {}: {}", other, e),
518                            }
519                        } else {
520                            tracing::debug!(
521                                "Unknown provider {} without base_url, skipping",
522                                other
523                            );
524                        }
525                    }
526                }
527            }
528        } else {
529            tracing::warn!("Vault not configured, no providers will be available from Vault");
530        }
531
532        // If Bedrock wasn't registered via Vault, try auto-detecting AWS credentials
533        if !registry.providers.contains_key("bedrock") {
534            if let Some(creds) = bedrock::AwsCredentials::from_environment() {
535                let region = bedrock::AwsCredentials::detect_region()
536                    .unwrap_or_else(|| "us-east-1".to_string());
537                match bedrock::BedrockProvider::with_credentials(creds, region) {
538                    Ok(p) => {
539                        tracing::info!("Registered Bedrock provider from local AWS credentials");
540                        registry.register(Arc::new(p));
541                    }
542                    Err(e) => tracing::warn!("Failed to init bedrock from AWS credentials: {}", e),
543                }
544            }
545        }
546
547        tracing::info!(
548            "Registered {} providers from Vault",
549            registry.providers.len()
550        );
551        Ok(registry)
552    }
553}
554
555fn normalize_minimax_anthropic_base_url(base_url: &str) -> String {
556    let trimmed = base_url.trim().trim_end_matches('/');
557    if trimmed.eq_ignore_ascii_case("https://api.minimax.io/v1") {
558        "https://api.minimax.io/anthropic".to_string()
559    } else {
560        trimmed.to_string()
561    }
562}
563
564impl Default for ProviderRegistry {
565    fn default() -> Self {
566        Self::new()
567    }
568}
569
570/// Parse a model string in the format "provider/model"
571pub fn parse_model_string(s: &str) -> (Option<&str>, &str) {
572    if let Some((provider, model)) = s.split_once('/') {
573        (Some(provider), model)
574    } else {
575        (None, s)
576    }
577}