Skip to main content

codetether_agent/provider/
mod.rs

1//! AI Provider abstraction layer
2//!
3//! Unified interface for multiple AI providers (OpenAI, Anthropic, Google, StepFun, etc.)
4
5pub mod anthropic;
6pub mod google;
7pub mod models;
8pub mod moonshot;
9pub mod openai;
10pub mod openrouter;
11pub mod stepfun;
12
13use anyhow::Result;
14use async_trait::async_trait;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18
19/// A message in a conversation
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Message {
22    pub role: Role,
23    pub content: Vec<ContentPart>,
24}
25
26#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
27#[serde(rename_all = "lowercase")]
28pub enum Role {
29    System,
30    User,
31    Assistant,
32    Tool,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum ContentPart {
38    Text { text: String },
39    Image { url: String, mime_type: Option<String> },
40    File { path: String, mime_type: Option<String> },
41    ToolCall { id: String, name: String, arguments: String },
42    ToolResult { tool_call_id: String, content: String },
43}
44
45/// Tool definition for the model
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolDefinition {
48    pub name: String,
49    pub description: String,
50    pub parameters: serde_json::Value, // JSON Schema
51}
52
53/// Request to generate a completion
54#[derive(Debug, Clone)]
55pub struct CompletionRequest {
56    pub messages: Vec<Message>,
57    pub tools: Vec<ToolDefinition>,
58    pub model: String,
59    pub temperature: Option<f32>,
60    pub top_p: Option<f32>,
61    pub max_tokens: Option<usize>,
62    pub stop: Vec<String>,
63}
64
65/// A streaming chunk from the model
66#[derive(Debug, Clone)]
67pub enum StreamChunk {
68    Text(String),
69    ToolCallStart { id: String, name: String },
70    ToolCallDelta { id: String, arguments_delta: String },
71    ToolCallEnd { id: String },
72    Done { usage: Option<Usage> },
73    Error(String),
74}
75
76/// Token usage information
77#[derive(Debug, Clone, Default, Serialize, Deserialize)]
78pub struct Usage {
79    pub prompt_tokens: usize,
80    pub completion_tokens: usize,
81    pub total_tokens: usize,
82    pub cache_read_tokens: Option<usize>,
83    pub cache_write_tokens: Option<usize>,
84}
85
86/// Response from a completion request
87#[derive(Debug, Clone)]
88pub struct CompletionResponse {
89    pub message: Message,
90    pub usage: Usage,
91    pub finish_reason: FinishReason,
92}
93
94#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
95#[serde(rename_all = "snake_case")]
96pub enum FinishReason {
97    Stop,
98    Length,
99    ToolCalls,
100    ContentFilter,
101    Error,
102}
103
104/// Provider trait that all AI providers must implement
105#[async_trait]
106pub trait Provider: Send + Sync {
107    /// Get the provider name
108    fn name(&self) -> &str;
109
110    /// List available models
111    async fn list_models(&self) -> Result<Vec<ModelInfo>>;
112
113    /// Generate a completion
114    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
115
116    /// Generate a streaming completion
117    async fn complete_stream(
118        &self,
119        request: CompletionRequest,
120    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>>;
121}
122
123/// Information about a model
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct ModelInfo {
126    pub id: String,
127    pub name: String,
128    pub provider: String,
129    pub context_window: usize,
130    pub max_output_tokens: Option<usize>,
131    pub supports_vision: bool,
132    pub supports_tools: bool,
133    pub supports_streaming: bool,
134    pub input_cost_per_million: Option<f64>,
135    pub output_cost_per_million: Option<f64>,
136}
137
138/// Registry of available providers
139pub struct ProviderRegistry {
140    providers: HashMap<String, Arc<dyn Provider>>,
141}
142
143impl std::fmt::Debug for ProviderRegistry {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        f.debug_struct("ProviderRegistry")
146            .field("provider_count", &self.providers.len())
147            .field("providers", &self.providers.keys().collect::<Vec<_>>())
148            .finish()
149    }
150}
151
152impl ProviderRegistry {
153    pub fn new() -> Self {
154        Self {
155            providers: HashMap::new(),
156        }
157    }
158
159    /// Register a provider
160    pub fn register(&mut self, provider: Arc<dyn Provider>) {
161        self.providers.insert(provider.name().to_string(), provider);
162    }
163
164    /// Get a provider by name
165    pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
166        self.providers.get(name).cloned()
167    }
168
169    /// List all registered providers
170    pub fn list(&self) -> Vec<&str> {
171        self.providers.keys().map(|s| s.as_str()).collect()
172    }
173
174    /// Initialize with default providers from config
175    pub async fn from_config(config: &crate::config::Config) -> Result<Self> {
176        let mut registry = Self::new();
177
178        // Always try to initialize OpenAI if key is available
179        if let Some(provider_config) = config.providers.get("openai") {
180            if let Some(api_key) = &provider_config.api_key {
181                registry.register(Arc::new(openai::OpenAIProvider::new(api_key.clone())?));
182            }
183        } else if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
184            registry.register(Arc::new(openai::OpenAIProvider::new(api_key)?));
185        }
186
187        // Initialize Anthropic
188        if let Some(provider_config) = config.providers.get("anthropic") {
189            if let Some(api_key) = &provider_config.api_key {
190                registry.register(Arc::new(anthropic::AnthropicProvider::new(api_key.clone())?));
191            }
192        } else if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
193            registry.register(Arc::new(anthropic::AnthropicProvider::new(api_key)?));
194        }
195
196        // Initialize Google
197        if let Some(provider_config) = config.providers.get("google") {
198            if let Some(api_key) = &provider_config.api_key {
199                registry.register(Arc::new(google::GoogleProvider::new(api_key.clone())?));
200            }
201        } else if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
202            registry.register(Arc::new(google::GoogleProvider::new(api_key)?));
203        }
204
205        Ok(registry)
206    }
207
208    /// Initialize providers from HashiCorp Vault
209    /// 
210    /// This loads API keys from Vault and creates providers dynamically.
211    /// Supports OpenAI-compatible providers via base_url.
212    pub async fn from_vault() -> Result<Self> {
213        let mut registry = Self::new();
214        
215        let manager = match crate::secrets::secrets_manager() {
216            Some(m) => m,
217            None => {
218                tracing::warn!("Vault not configured, no providers will be available");
219                return Ok(registry);
220            }
221        };
222
223        // List all configured providers from Vault
224        let providers = manager.list_configured_providers().await?;
225        tracing::info!("Found {} providers configured in Vault", providers.len());
226
227        for provider_id in providers {
228            let secrets = match manager.get_provider_secrets(&provider_id).await? {
229                Some(s) => s,
230                None => continue,
231            };
232
233            let api_key = match secrets.api_key {
234                Some(key) => key,
235                None => continue,
236            };
237
238            // Determine which provider implementation to use
239            match provider_id.as_str() {
240                // Native providers
241                "anthropic" | "anthropic-eu" | "anthropic-asia" => {
242                    match anthropic::AnthropicProvider::new(api_key) {
243                        Ok(p) => registry.register(Arc::new(p)),
244                        Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
245                    }
246                }
247                "google" | "google-vertex" => {
248                    match google::GoogleProvider::new(api_key) {
249                        Ok(p) => registry.register(Arc::new(p)),
250                        Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
251                    }
252                }
253                // StepFun - native provider (direct API, not via OpenRouter)
254                "stepfun" => {
255                    match stepfun::StepFunProvider::new(api_key) {
256                        Ok(p) => registry.register(Arc::new(p)),
257                        Err(e) => tracing::warn!("Failed to init stepfun: {}", e),
258                    }
259                }
260                // OpenRouter - native provider with support for extended response formats
261                "openrouter" => {
262                    match openrouter::OpenRouterProvider::new(api_key) {
263                        Ok(p) => registry.register(Arc::new(p)),
264                        Err(e) => tracing::warn!("Failed to init openrouter: {}", e),
265                    }
266                }
267                // Moonshot AI - native provider for Kimi models
268                "moonshotai" | "moonshotai-cn" => {
269                    match moonshot::MoonshotProvider::new(api_key) {
270                        Ok(p) => registry.register(Arc::new(p)),
271                        Err(e) => tracing::warn!("Failed to init moonshotai: {}", e),
272                    }
273                }
274                // OpenAI-compatible providers (with custom base_url)
275                "deepseek" | "groq" | "togetherai" 
276                | "fireworks-ai" | "mistral" | "nvidia" | "alibaba"
277                | "openai" | "azure" => {
278                    if let Some(base_url) = secrets.base_url {
279                        match openai::OpenAIProvider::with_base_url(api_key, base_url, &provider_id) {
280                            Ok(p) => registry.register(Arc::new(p)),
281                            Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
282                        }
283                    } else if provider_id == "openai" {
284                        // OpenAI doesn't need a custom base_url
285                        match openai::OpenAIProvider::new(api_key) {
286                            Ok(p) => registry.register(Arc::new(p)),
287                            Err(e) => tracing::warn!("Failed to init openai: {}", e),
288                        }
289                    } else {
290                        // Try using the base_url from the models API
291                        if let Ok(catalog) = models::ModelCatalog::fetch().await {
292                            if let Some(provider_info) = catalog.get_provider(&provider_id) {
293                                if let Some(api_url) = &provider_info.api {
294                                    match openai::OpenAIProvider::with_base_url(api_key, api_url.clone(), &provider_id) {
295                                        Ok(p) => registry.register(Arc::new(p)),
296                                        Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
297                                    }
298                                }
299                            }
300                        }
301                    }
302                }
303                // Unknown providers - try as OpenAI-compatible with base_url from API
304                other => {
305                    if let Some(base_url) = secrets.base_url {
306                        match openai::OpenAIProvider::with_base_url(api_key, base_url, other) {
307                            Ok(p) => registry.register(Arc::new(p)),
308                            Err(e) => tracing::warn!("Failed to init {}: {}", other, e),
309                        }
310                    } else {
311                        tracing::debug!("Unknown provider {} without base_url, skipping", other);
312                    }
313                }
314            }
315        }
316
317        tracing::info!("Registered {} providers from Vault", registry.providers.len());
318        Ok(registry)
319    }
320}
321
322impl Default for ProviderRegistry {
323    fn default() -> Self {
324        Self::new()
325    }
326}
327
328/// Parse a model string in the format "provider/model"
329pub fn parse_model_string(s: &str) -> (Option<&str>, &str) {
330    if let Some((provider, model)) = s.split_once('/') {
331        (Some(provider), model)
332    } else {
333        (None, s)
334    }
335}