Skip to main content

codetether_agent/provider/
mod.rs

1//! AI Provider abstraction layer
2//!
3//! Unified interface for multiple AI providers (OpenAI, Anthropic, Google, StepFun, etc.)
4
5pub mod anthropic;
6pub mod bedrock;
7pub mod copilot;
8pub mod google;
9pub mod models;
10pub mod moonshot;
11pub mod openai;
12pub mod openrouter;
13pub mod stepfun;
14
15use anyhow::Result;
16use async_trait::async_trait;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::Arc;
20
21/// A message in a conversation
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct Message {
24    pub role: Role,
25    pub content: Vec<ContentPart>,
26}
27
28#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
29#[serde(rename_all = "lowercase")]
30pub enum Role {
31    System,
32    User,
33    Assistant,
34    Tool,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(tag = "type", rename_all = "snake_case")]
39pub enum ContentPart {
40    Text {
41        text: String,
42    },
43    Image {
44        url: String,
45        mime_type: Option<String>,
46    },
47    File {
48        path: String,
49        mime_type: Option<String>,
50    },
51    ToolCall {
52        id: String,
53        name: String,
54        arguments: String,
55    },
56    ToolResult {
57        tool_call_id: String,
58        content: String,
59    },
60}
61
62/// Tool definition for the model
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ToolDefinition {
65    pub name: String,
66    pub description: String,
67    pub parameters: serde_json::Value, // JSON Schema
68}
69
70/// Request to generate a completion
71#[derive(Debug, Clone)]
72pub struct CompletionRequest {
73    pub messages: Vec<Message>,
74    pub tools: Vec<ToolDefinition>,
75    pub model: String,
76    pub temperature: Option<f32>,
77    pub top_p: Option<f32>,
78    pub max_tokens: Option<usize>,
79    pub stop: Vec<String>,
80}
81
82/// A streaming chunk from the model
83#[derive(Debug, Clone)]
84pub enum StreamChunk {
85    Text(String),
86    ToolCallStart { id: String, name: String },
87    ToolCallDelta { id: String, arguments_delta: String },
88    ToolCallEnd { id: String },
89    Done { usage: Option<Usage> },
90    Error(String),
91}
92
93/// Token usage information
94#[derive(Debug, Clone, Default, Serialize, Deserialize)]
95pub struct Usage {
96    pub prompt_tokens: usize,
97    pub completion_tokens: usize,
98    pub total_tokens: usize,
99    pub cache_read_tokens: Option<usize>,
100    pub cache_write_tokens: Option<usize>,
101}
102
103/// Response from a completion request
104#[derive(Debug, Clone)]
105pub struct CompletionResponse {
106    pub message: Message,
107    pub usage: Usage,
108    pub finish_reason: FinishReason,
109}
110
111#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
112#[serde(rename_all = "snake_case")]
113pub enum FinishReason {
114    Stop,
115    Length,
116    ToolCalls,
117    ContentFilter,
118    Error,
119}
120
121/// Provider trait that all AI providers must implement
122#[async_trait]
123pub trait Provider: Send + Sync {
124    /// Get the provider name
125    fn name(&self) -> &str;
126
127    /// List available models
128    async fn list_models(&self) -> Result<Vec<ModelInfo>>;
129
130    /// Generate a completion
131    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
132
133    /// Generate a streaming completion
134    async fn complete_stream(
135        &self,
136        request: CompletionRequest,
137    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>>;
138}
139
140/// Information about a model
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct ModelInfo {
143    pub id: String,
144    pub name: String,
145    pub provider: String,
146    pub context_window: usize,
147    pub max_output_tokens: Option<usize>,
148    pub supports_vision: bool,
149    pub supports_tools: bool,
150    pub supports_streaming: bool,
151    pub input_cost_per_million: Option<f64>,
152    pub output_cost_per_million: Option<f64>,
153}
154
155/// Registry of available providers
156pub struct ProviderRegistry {
157    providers: HashMap<String, Arc<dyn Provider>>,
158}
159
160impl std::fmt::Debug for ProviderRegistry {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        f.debug_struct("ProviderRegistry")
163            .field("provider_count", &self.providers.len())
164            .field("providers", &self.providers.keys().collect::<Vec<_>>())
165            .finish()
166    }
167}
168
169impl ProviderRegistry {
170    pub fn new() -> Self {
171        Self {
172            providers: HashMap::new(),
173        }
174    }
175
176    /// Register a provider
177    pub fn register(&mut self, provider: Arc<dyn Provider>) {
178        self.providers.insert(provider.name().to_string(), provider);
179    }
180
181    /// Get a provider by name
182    pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
183        self.providers.get(name).cloned()
184    }
185
186    /// List all registered providers
187    pub fn list(&self) -> Vec<&str> {
188        self.providers.keys().map(|s| s.as_str()).collect()
189    }
190
191    /// Initialize with default providers from config
192    pub async fn from_config(config: &crate::config::Config) -> Result<Self> {
193        let mut registry = Self::new();
194
195        // Always try to initialize OpenAI if key is available
196        if let Some(provider_config) = config.providers.get("openai") {
197            if let Some(api_key) = &provider_config.api_key {
198                registry.register(Arc::new(openai::OpenAIProvider::new(api_key.clone())?));
199            }
200        } else if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
201            registry.register(Arc::new(openai::OpenAIProvider::new(api_key)?));
202        }
203
204        // Initialize Anthropic
205        if let Some(provider_config) = config.providers.get("anthropic") {
206            if let Some(api_key) = &provider_config.api_key {
207                registry.register(Arc::new(anthropic::AnthropicProvider::new(
208                    api_key.clone(),
209                )?));
210            }
211        } else if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
212            registry.register(Arc::new(anthropic::AnthropicProvider::new(api_key)?));
213        }
214
215        // Initialize Google
216        if let Some(provider_config) = config.providers.get("google") {
217            if let Some(api_key) = &provider_config.api_key {
218                registry.register(Arc::new(google::GoogleProvider::new(api_key.clone())?));
219            }
220        } else if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
221            registry.register(Arc::new(google::GoogleProvider::new(api_key)?));
222        }
223
224        // Initialize Novita (OpenAI-compatible)
225        if let Some(provider_config) = config.providers.get("novita") {
226            if let Some(api_key) = &provider_config.api_key {
227                let base_url = provider_config
228                    .base_url
229                    .clone()
230                    .unwrap_or_else(|| "https://api.novita.ai/openai/v1".to_string());
231                registry.register(Arc::new(openai::OpenAIProvider::with_base_url(
232                    api_key.clone(),
233                    base_url,
234                    "novita",
235                )?));
236            }
237        }
238
239        Ok(registry)
240    }
241
242    /// Initialize providers from HashiCorp Vault
243    ///
244    /// This loads API keys from Vault and creates providers dynamically.
245    /// Supports OpenAI-compatible providers via base_url.
246    pub async fn from_vault() -> Result<Self> {
247        let mut registry = Self::new();
248
249        if let Some(manager) = crate::secrets::secrets_manager() {
250            // List all configured providers from Vault
251            let providers = manager.list_configured_providers().await?;
252            tracing::info!("Found {} providers configured in Vault", providers.len());
253
254            for provider_id in providers {
255                let secrets = match manager.get_provider_secrets(&provider_id).await? {
256                    Some(s) => s,
257                    None => continue,
258                };
259
260                let api_key = match secrets.api_key {
261                    Some(key) => key,
262                    None => continue,
263                };
264
265                // Determine which provider implementation to use
266                match provider_id.as_str() {
267                    // Amazon Bedrock - native Converse API with bearer token
268                    "bedrock" | "aws-bedrock" => {
269                        let region = secrets
270                            .extra
271                            .get("region")
272                            .and_then(|v| v.as_str())
273                            .unwrap_or("us-east-1")
274                            .to_string();
275                        match bedrock::BedrockProvider::with_region(api_key, region) {
276                            Ok(p) => registry.register(Arc::new(p)),
277                            Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
278                        }
279                    }
280                    // Native providers
281                    "anthropic" | "anthropic-eu" | "anthropic-asia" => {
282                        match anthropic::AnthropicProvider::new(api_key) {
283                            Ok(p) => registry.register(Arc::new(p)),
284                            Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
285                        }
286                    }
287                    "google" | "google-vertex" => match google::GoogleProvider::new(api_key) {
288                        Ok(p) => registry.register(Arc::new(p)),
289                        Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
290                    },
291                    // StepFun - native provider (direct API, not via OpenRouter)
292                    "stepfun" => match stepfun::StepFunProvider::new(api_key) {
293                        Ok(p) => registry.register(Arc::new(p)),
294                        Err(e) => tracing::warn!("Failed to init stepfun: {}", e),
295                    },
296                    // OpenRouter - native provider with support for extended response formats
297                    "openrouter" => match openrouter::OpenRouterProvider::new(api_key) {
298                        Ok(p) => registry.register(Arc::new(p)),
299                        Err(e) => tracing::warn!("Failed to init openrouter: {}", e),
300                    },
301                    // Moonshot AI - native provider for Kimi models
302                    "moonshotai" | "moonshotai-cn" => {
303                        match moonshot::MoonshotProvider::new(api_key) {
304                            Ok(p) => registry.register(Arc::new(p)),
305                            Err(e) => tracing::warn!("Failed to init moonshotai: {}", e),
306                        }
307                    }
308                    // GitHub Copilot providers require custom headers/token semantics
309                    "github-copilot" => {
310                        let result = if let Some(base_url) = secrets.base_url.clone() {
311                            copilot::CopilotProvider::with_base_url(
312                                api_key,
313                                base_url,
314                                "github-copilot",
315                            )
316                        } else {
317                            copilot::CopilotProvider::new(api_key)
318                        };
319
320                        match result {
321                            Ok(p) => registry.register(Arc::new(p)),
322                            Err(e) => tracing::warn!("Failed to init github-copilot: {}", e),
323                        }
324                    }
325                    "github-copilot-enterprise" => {
326                        let enterprise_url = secrets
327                            .extra
328                            .get("enterpriseUrl")
329                            .and_then(|v| v.as_str())
330                            .or_else(|| {
331                                secrets.extra.get("enterprise_url").and_then(|v| v.as_str())
332                            });
333
334                        let result = if let Some(base_url) = secrets.base_url.clone() {
335                            copilot::CopilotProvider::with_base_url(
336                                api_key,
337                                base_url,
338                                "github-copilot-enterprise",
339                            )
340                        } else if let Some(url) = enterprise_url {
341                            copilot::CopilotProvider::enterprise(api_key, url.to_string())
342                        } else {
343                            copilot::CopilotProvider::with_base_url(
344                                api_key,
345                                "https://api.githubcopilot.com".to_string(),
346                                "github-copilot-enterprise",
347                            )
348                        };
349
350                        match result {
351                            Ok(p) => registry.register(Arc::new(p)),
352                            Err(e) => {
353                                tracing::warn!("Failed to init github-copilot-enterprise: {}", e)
354                            }
355                        }
356                    }
357                    // ZhipuAI - OpenAI-compatible coding API
358                    "zhipuai" => {
359                        let base_url = secrets
360                            .base_url
361                            .clone()
362                            .unwrap_or_else(|| "https://api.z.ai/api/coding/paas/v4".to_string());
363                        match openai::OpenAIProvider::with_base_url(api_key, base_url, "zhipuai") {
364                            Ok(p) => registry.register(Arc::new(p)),
365                            Err(e) => tracing::warn!("Failed to init zhipuai: {}", e),
366                        }
367                    }
368                    // Cerebras - OpenAI-compatible fast inference
369                    "cerebras" => {
370                        let base_url = secrets
371                            .base_url
372                            .clone()
373                            .unwrap_or_else(|| "https://api.cerebras.ai/v1".to_string());
374                        match openai::OpenAIProvider::with_base_url(api_key, base_url, "cerebras") {
375                            Ok(p) => registry.register(Arc::new(p)),
376                            Err(e) => tracing::warn!("Failed to init cerebras: {}", e),
377                        }
378                    }
379                    // MiniMax - OpenAI-compatible API
380                    "minimax" => {
381                        let base_url = secrets
382                            .base_url
383                            .clone()
384                            .unwrap_or_else(|| "https://api.minimax.chat/v1".to_string());
385                        match openai::OpenAIProvider::with_base_url(api_key, base_url, "minimax") {
386                            Ok(p) => registry.register(Arc::new(p)),
387                            Err(e) => tracing::warn!("Failed to init minimax: {}", e),
388                        }
389                    }
390                    // OpenAI-compatible providers (with custom base_url)
391                    "deepseek" | "groq" | "togetherai" | "fireworks-ai" | "mistral" | "nvidia"
392                    | "alibaba" | "openai" | "azure" | "novita" => {
393                        if let Some(base_url) = secrets.base_url.clone() {
394                            match openai::OpenAIProvider::with_base_url(
395                                api_key,
396                                base_url,
397                                &provider_id,
398                            ) {
399                                Ok(p) => registry.register(Arc::new(p)),
400                                Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
401                            }
402                        } else if provider_id == "openai" {
403                            // OpenAI doesn't need a custom base_url
404                            match openai::OpenAIProvider::new(api_key) {
405                                Ok(p) => registry.register(Arc::new(p)),
406                                Err(e) => tracing::warn!("Failed to init openai: {}", e),
407                            }
408                        } else if provider_id == "novita" {
409                            let base_url = "https://api.novita.ai/openai/v1".to_string();
410                            match openai::OpenAIProvider::with_base_url(
411                                api_key,
412                                base_url,
413                                &provider_id,
414                            ) {
415                                Ok(p) => registry.register(Arc::new(p)),
416                                Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
417                            }
418                        } else {
419                            // Try using the base_url from the models API
420                            if let Ok(catalog) = models::ModelCatalog::fetch().await {
421                                if let Some(provider_info) = catalog.get_provider(&provider_id) {
422                                    if let Some(api_url) = &provider_info.api {
423                                        match openai::OpenAIProvider::with_base_url(
424                                            api_key,
425                                            api_url.clone(),
426                                            &provider_id,
427                                        ) {
428                                            Ok(p) => registry.register(Arc::new(p)),
429                                            Err(e) => {
430                                                tracing::warn!(
431                                                    "Failed to init {}: {}",
432                                                    provider_id,
433                                                    e
434                                                )
435                                            }
436                                        }
437                                    }
438                                }
439                            }
440                        }
441                    }
442                    // Unknown providers - try as OpenAI-compatible with base_url from API
443                    other => {
444                        if let Some(base_url) = secrets.base_url {
445                            match openai::OpenAIProvider::with_base_url(api_key, base_url, other) {
446                                Ok(p) => registry.register(Arc::new(p)),
447                                Err(e) => tracing::warn!("Failed to init {}: {}", other, e),
448                            }
449                        } else {
450                            tracing::debug!(
451                                "Unknown provider {} without base_url, skipping",
452                                other
453                            );
454                        }
455                    }
456                }
457            }
458        } else {
459            tracing::warn!("Vault not configured, no providers will be available");
460        }
461
462        tracing::info!(
463            "Registered {} providers from Vault",
464            registry.providers.len()
465        );
466        Ok(registry)
467    }
468}
469
470impl Default for ProviderRegistry {
471    fn default() -> Self {
472        Self::new()
473    }
474}
475
476/// Parse a model string in the format "provider/model"
477pub fn parse_model_string(s: &str) -> (Option<&str>, &str) {
478    if let Some((provider, model)) = s.split_once('/') {
479        (Some(provider), model)
480    } else {
481        (None, s)
482    }
483}