Skip to main content

codetether_agent/provider/
mod.rs

1//! AI Provider abstraction layer
2//!
3//! Unified interface for multiple AI providers (OpenAI, Anthropic, Google, StepFun, etc.)
4
5pub mod anthropic;
6pub mod bedrock;
7pub mod copilot;
8pub mod google;
9pub mod models;
10pub mod moonshot;
11pub mod openai;
12pub mod openrouter;
13pub mod stepfun;
14pub mod zai;
15
16use anyhow::Result;
17use async_trait::async_trait;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::Arc;
21
22/// A message in a conversation
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Message {
25    pub role: Role,
26    pub content: Vec<ContentPart>,
27}
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
30#[serde(rename_all = "lowercase")]
31pub enum Role {
32    System,
33    User,
34    Assistant,
35    Tool,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type", rename_all = "snake_case")]
40pub enum ContentPart {
41    Text {
42        text: String,
43    },
44    Image {
45        url: String,
46        mime_type: Option<String>,
47    },
48    File {
49        path: String,
50        mime_type: Option<String>,
51    },
52    ToolCall {
53        id: String,
54        name: String,
55        arguments: String,
56    },
57    ToolResult {
58        tool_call_id: String,
59        content: String,
60    },
61    Thinking {
62        text: String,
63    },
64}
65
66/// Tool definition for the model
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ToolDefinition {
69    pub name: String,
70    pub description: String,
71    pub parameters: serde_json::Value, // JSON Schema
72}
73
74/// Request to generate a completion
75#[derive(Debug, Clone)]
76pub struct CompletionRequest {
77    pub messages: Vec<Message>,
78    pub tools: Vec<ToolDefinition>,
79    pub model: String,
80    pub temperature: Option<f32>,
81    pub top_p: Option<f32>,
82    pub max_tokens: Option<usize>,
83    pub stop: Vec<String>,
84}
85
86/// A streaming chunk from the model
87#[derive(Debug, Clone)]
88pub enum StreamChunk {
89    Text(String),
90    ToolCallStart { id: String, name: String },
91    ToolCallDelta { id: String, arguments_delta: String },
92    ToolCallEnd { id: String },
93    Done { usage: Option<Usage> },
94    Error(String),
95}
96
97/// Token usage information
98#[derive(Debug, Clone, Default, Serialize, Deserialize)]
99pub struct Usage {
100    pub prompt_tokens: usize,
101    pub completion_tokens: usize,
102    pub total_tokens: usize,
103    pub cache_read_tokens: Option<usize>,
104    pub cache_write_tokens: Option<usize>,
105}
106
107/// Response from a completion request
108#[derive(Debug, Clone)]
109pub struct CompletionResponse {
110    pub message: Message,
111    pub usage: Usage,
112    pub finish_reason: FinishReason,
113}
114
115#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
116#[serde(rename_all = "snake_case")]
117pub enum FinishReason {
118    Stop,
119    Length,
120    ToolCalls,
121    ContentFilter,
122    Error,
123}
124
125/// Provider trait that all AI providers must implement
126#[async_trait]
127pub trait Provider: Send + Sync {
128    /// Get the provider name
129    fn name(&self) -> &str;
130
131    /// List available models
132    async fn list_models(&self) -> Result<Vec<ModelInfo>>;
133
134    /// Generate a completion
135    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
136
137    /// Generate a streaming completion
138    async fn complete_stream(
139        &self,
140        request: CompletionRequest,
141    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>>;
142}
143
144/// Information about a model
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct ModelInfo {
147    pub id: String,
148    pub name: String,
149    pub provider: String,
150    pub context_window: usize,
151    pub max_output_tokens: Option<usize>,
152    pub supports_vision: bool,
153    pub supports_tools: bool,
154    pub supports_streaming: bool,
155    pub input_cost_per_million: Option<f64>,
156    pub output_cost_per_million: Option<f64>,
157}
158
159/// Registry of available providers
160pub struct ProviderRegistry {
161    providers: HashMap<String, Arc<dyn Provider>>,
162}
163
164impl std::fmt::Debug for ProviderRegistry {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        f.debug_struct("ProviderRegistry")
167            .field("provider_count", &self.providers.len())
168            .field("providers", &self.providers.keys().collect::<Vec<_>>())
169            .finish()
170    }
171}
172
173impl ProviderRegistry {
174    pub fn new() -> Self {
175        Self {
176            providers: HashMap::new(),
177        }
178    }
179
180    /// Register a provider
181    pub fn register(&mut self, provider: Arc<dyn Provider>) {
182        self.providers.insert(provider.name().to_string(), provider);
183    }
184
185    /// Get a provider by name
186    pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
187        self.providers.get(name).cloned()
188    }
189
190    /// List all registered providers
191    pub fn list(&self) -> Vec<&str> {
192        self.providers.keys().map(|s| s.as_str()).collect()
193    }
194
195    /// Initialize with default providers from config
196    pub async fn from_config(config: &crate::config::Config) -> Result<Self> {
197        let mut registry = Self::new();
198
199        // Always try to initialize OpenAI if key is available
200        if let Some(provider_config) = config.providers.get("openai") {
201            if let Some(api_key) = &provider_config.api_key {
202                registry.register(Arc::new(openai::OpenAIProvider::new(api_key.clone())?));
203            }
204        } else if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
205            registry.register(Arc::new(openai::OpenAIProvider::new(api_key)?));
206        }
207
208        // Initialize Anthropic
209        if let Some(provider_config) = config.providers.get("anthropic") {
210            if let Some(api_key) = &provider_config.api_key {
211                registry.register(Arc::new(anthropic::AnthropicProvider::new(
212                    api_key.clone(),
213                )?));
214            }
215        } else if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
216            registry.register(Arc::new(anthropic::AnthropicProvider::new(api_key)?));
217        }
218
219        // Initialize Google
220        if let Some(provider_config) = config.providers.get("google") {
221            if let Some(api_key) = &provider_config.api_key {
222                registry.register(Arc::new(google::GoogleProvider::new(api_key.clone())?));
223            }
224        } else if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
225            registry.register(Arc::new(google::GoogleProvider::new(api_key)?));
226        }
227
228        // Initialize Novita (OpenAI-compatible)
229        if let Some(provider_config) = config.providers.get("novita") {
230            if let Some(api_key) = &provider_config.api_key {
231                let base_url = provider_config
232                    .base_url
233                    .clone()
234                    .unwrap_or_else(|| "https://api.novita.ai/openai/v1".to_string());
235                registry.register(Arc::new(openai::OpenAIProvider::with_base_url(
236                    api_key.clone(),
237                    base_url,
238                    "novita",
239                )?));
240            }
241        }
242
243        // Initialize Bedrock via AWS credentials (env vars or ~/.aws/credentials)
244        if let Some(creds) = bedrock::AwsCredentials::from_environment() {
245            let region = bedrock::AwsCredentials::detect_region()
246                .unwrap_or_else(|| bedrock::DEFAULT_REGION.to_string());
247            match bedrock::BedrockProvider::with_credentials(creds, region) {
248                Ok(p) => registry.register(Arc::new(p)),
249                Err(e) => tracing::warn!("Failed to init bedrock from AWS credentials: {}", e),
250            }
251        }
252
253        Ok(registry)
254    }
255
256    /// Initialize providers from HashiCorp Vault
257    ///
258    /// This loads API keys from Vault and creates providers dynamically.
259    /// Supports OpenAI-compatible providers via base_url.
260    pub async fn from_vault() -> Result<Self> {
261        let mut registry = Self::new();
262
263        if let Some(manager) = crate::secrets::secrets_manager() {
264            // List all configured providers from Vault
265            let providers = manager.list_configured_providers().await?;
266            tracing::info!("Found {} providers configured in Vault", providers.len());
267
268            for provider_id in providers {
269                let secrets = match manager.get_provider_secrets(&provider_id).await? {
270                    Some(s) => s,
271                    None => continue,
272                };
273
274                // Handle Bedrock before api_key extraction since it can use
275                // AWS IAM credentials instead of an API key.
276                if matches!(provider_id.as_str(), "bedrock" | "aws-bedrock") {
277                    let region = secrets
278                        .extra
279                        .get("region")
280                        .and_then(|v| v.as_str())
281                        .unwrap_or("us-east-1")
282                        .to_string();
283
284                    // Prefer SigV4 if AWS credentials are in Vault
285                    let aws_key_id = secrets
286                        .extra
287                        .get("aws_access_key_id")
288                        .and_then(|v| v.as_str());
289                    let aws_secret = secrets
290                        .extra
291                        .get("aws_secret_access_key")
292                        .and_then(|v| v.as_str());
293
294                    let result = if let (Some(key_id), Some(secret)) = (aws_key_id, aws_secret) {
295                        let creds = bedrock::AwsCredentials {
296                            access_key_id: key_id.to_string(),
297                            secret_access_key: secret.to_string(),
298                            session_token: secrets
299                                .extra
300                                .get("aws_session_token")
301                                .and_then(|v| v.as_str())
302                                .map(|s| s.to_string()),
303                        };
304                        bedrock::BedrockProvider::with_credentials(creds, region)
305                    } else if let Some(ref key) = secrets.api_key {
306                        bedrock::BedrockProvider::with_region(key.clone(), region)
307                    } else {
308                        // Try auto-detecting from environment as last resort
309                        if let Some(creds) = bedrock::AwsCredentials::from_environment() {
310                            bedrock::BedrockProvider::with_credentials(creds, region)
311                        } else {
312                            Err(anyhow::anyhow!(
313                                "No AWS credentials or API key found for Bedrock"
314                            ))
315                        }
316                    };
317
318                    match result {
319                        Ok(p) => registry.register(Arc::new(p)),
320                        Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
321                    }
322                    continue;
323                }
324
325                let api_key = match secrets.api_key {
326                    Some(key) => key,
327                    None => continue,
328                };
329
330                // Determine which provider implementation to use
331                match provider_id.as_str() {
332                    // Native providers
333                    "anthropic" | "anthropic-eu" | "anthropic-asia" => {
334                        match anthropic::AnthropicProvider::new(api_key) {
335                            Ok(p) => registry.register(Arc::new(p)),
336                            Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
337                        }
338                    }
339                    "google" | "google-vertex" => match google::GoogleProvider::new(api_key) {
340                        Ok(p) => registry.register(Arc::new(p)),
341                        Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
342                    },
343                    // StepFun - native provider (direct API, not via OpenRouter)
344                    "stepfun" => match stepfun::StepFunProvider::new(api_key) {
345                        Ok(p) => registry.register(Arc::new(p)),
346                        Err(e) => tracing::warn!("Failed to init stepfun: {}", e),
347                    },
348                    // OpenRouter - native provider with support for extended response formats
349                    "openrouter" => match openrouter::OpenRouterProvider::new(api_key) {
350                        Ok(p) => registry.register(Arc::new(p)),
351                        Err(e) => tracing::warn!("Failed to init openrouter: {}", e),
352                    },
353                    // Moonshot AI - native provider for Kimi models
354                    "moonshotai" | "moonshotai-cn" => {
355                        match moonshot::MoonshotProvider::new(api_key) {
356                            Ok(p) => registry.register(Arc::new(p)),
357                            Err(e) => tracing::warn!("Failed to init moonshotai: {}", e),
358                        }
359                    }
360                    // GitHub Copilot providers require custom headers/token semantics
361                    "github-copilot" => {
362                        let result = if let Some(base_url) = secrets.base_url.clone() {
363                            copilot::CopilotProvider::with_base_url(
364                                api_key,
365                                base_url,
366                                "github-copilot",
367                            )
368                        } else {
369                            copilot::CopilotProvider::new(api_key)
370                        };
371
372                        match result {
373                            Ok(p) => registry.register(Arc::new(p)),
374                            Err(e) => tracing::warn!("Failed to init github-copilot: {}", e),
375                        }
376                    }
377                    "github-copilot-enterprise" => {
378                        let enterprise_url = secrets
379                            .extra
380                            .get("enterpriseUrl")
381                            .and_then(|v| v.as_str())
382                            .or_else(|| {
383                                secrets.extra.get("enterprise_url").and_then(|v| v.as_str())
384                            });
385
386                        let result = if let Some(base_url) = secrets.base_url.clone() {
387                            copilot::CopilotProvider::with_base_url(
388                                api_key,
389                                base_url,
390                                "github-copilot-enterprise",
391                            )
392                        } else if let Some(url) = enterprise_url {
393                            copilot::CopilotProvider::enterprise(api_key, url.to_string())
394                        } else {
395                            copilot::CopilotProvider::with_base_url(
396                                api_key,
397                                "https://api.githubcopilot.com".to_string(),
398                                "github-copilot-enterprise",
399                            )
400                        };
401
402                        match result {
403                            Ok(p) => registry.register(Arc::new(p)),
404                            Err(e) => {
405                                tracing::warn!("Failed to init github-copilot-enterprise: {}", e)
406                            }
407                        }
408                    }
409                    // Z.AI (formerly ZhipuAI) — first-class provider for GLM models
410                    "zhipuai" | "zai" => {
411                        let base_url = secrets
412                            .base_url
413                            .clone()
414                            .unwrap_or_else(|| "https://api.z.ai/api/coding/paas/v4".to_string());
415                        match zai::ZaiProvider::with_base_url(api_key, base_url) {
416                            Ok(p) => registry.register(Arc::new(p)),
417                            Err(e) => tracing::warn!("Failed to init zai: {}", e),
418                        }
419                    }
420                    // Cerebras - OpenAI-compatible fast inference
421                    "cerebras" => {
422                        let base_url = secrets
423                            .base_url
424                            .clone()
425                            .unwrap_or_else(|| "https://api.cerebras.ai/v1".to_string());
426                        match openai::OpenAIProvider::with_base_url(api_key, base_url, "cerebras") {
427                            Ok(p) => registry.register(Arc::new(p)),
428                            Err(e) => tracing::warn!("Failed to init cerebras: {}", e),
429                        }
430                    }
431                    // MiniMax - OpenAI-compatible API
432                    "minimax" => {
433                        let base_url = secrets
434                            .base_url
435                            .clone()
436                            .unwrap_or_else(|| "https://api.minimax.chat/v1".to_string());
437                        match openai::OpenAIProvider::with_base_url(api_key, base_url, "minimax") {
438                            Ok(p) => registry.register(Arc::new(p)),
439                            Err(e) => tracing::warn!("Failed to init minimax: {}", e),
440                        }
441                    }
442                    // OpenAI-compatible providers (with custom base_url)
443                    "deepseek" | "groq" | "togetherai" | "fireworks-ai" | "mistral" | "nvidia"
444                    | "alibaba" | "openai" | "azure" | "novita" => {
445                        if let Some(base_url) = secrets.base_url.clone() {
446                            match openai::OpenAIProvider::with_base_url(
447                                api_key,
448                                base_url,
449                                &provider_id,
450                            ) {
451                                Ok(p) => registry.register(Arc::new(p)),
452                                Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
453                            }
454                        } else if provider_id == "openai" {
455                            // OpenAI doesn't need a custom base_url
456                            match openai::OpenAIProvider::new(api_key) {
457                                Ok(p) => registry.register(Arc::new(p)),
458                                Err(e) => tracing::warn!("Failed to init openai: {}", e),
459                            }
460                        } else if provider_id == "novita" {
461                            let base_url = "https://api.novita.ai/openai/v1".to_string();
462                            match openai::OpenAIProvider::with_base_url(
463                                api_key,
464                                base_url,
465                                &provider_id,
466                            ) {
467                                Ok(p) => registry.register(Arc::new(p)),
468                                Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
469                            }
470                        } else {
471                            // Try using the base_url from the models API
472                            if let Ok(catalog) = models::ModelCatalog::fetch().await {
473                                if let Some(provider_info) = catalog.get_provider(&provider_id) {
474                                    if let Some(api_url) = &provider_info.api {
475                                        match openai::OpenAIProvider::with_base_url(
476                                            api_key,
477                                            api_url.clone(),
478                                            &provider_id,
479                                        ) {
480                                            Ok(p) => registry.register(Arc::new(p)),
481                                            Err(e) => {
482                                                tracing::warn!(
483                                                    "Failed to init {}: {}",
484                                                    provider_id,
485                                                    e
486                                                )
487                                            }
488                                        }
489                                    }
490                                }
491                            }
492                        }
493                    }
494                    // Unknown providers - try as OpenAI-compatible with base_url from API
495                    other => {
496                        if let Some(base_url) = secrets.base_url {
497                            match openai::OpenAIProvider::with_base_url(api_key, base_url, other) {
498                                Ok(p) => registry.register(Arc::new(p)),
499                                Err(e) => tracing::warn!("Failed to init {}: {}", other, e),
500                            }
501                        } else {
502                            tracing::debug!(
503                                "Unknown provider {} without base_url, skipping",
504                                other
505                            );
506                        }
507                    }
508                }
509            }
510        } else {
511            tracing::warn!("Vault not configured, no providers will be available from Vault");
512        }
513
514        // If Bedrock wasn't registered via Vault, try auto-detecting AWS credentials
515        if !registry.providers.contains_key("bedrock") {
516            if let Some(creds) = bedrock::AwsCredentials::from_environment() {
517                let region = bedrock::AwsCredentials::detect_region()
518                    .unwrap_or_else(|| "us-east-1".to_string());
519                match bedrock::BedrockProvider::with_credentials(creds, region) {
520                    Ok(p) => {
521                        tracing::info!("Registered Bedrock provider from local AWS credentials");
522                        registry.register(Arc::new(p));
523                    }
524                    Err(e) => tracing::warn!("Failed to init bedrock from AWS credentials: {}", e),
525                }
526            }
527        }
528
529        tracing::info!(
530            "Registered {} providers from Vault",
531            registry.providers.len()
532        );
533        Ok(registry)
534    }
535}
536
537impl Default for ProviderRegistry {
538    fn default() -> Self {
539        Self::new()
540    }
541}
542
543/// Parse a model string in the format "provider/model"
544pub fn parse_model_string(s: &str) -> (Option<&str>, &str) {
545    if let Some((provider, model)) = s.split_once('/') {
546        (Some(provider), model)
547    } else {
548        (None, s)
549    }
550}