Skip to main content

codetether_agent/provider/
mod.rs

1//! AI Provider abstraction layer
2//!
3//! Unified interface for multiple AI providers (OpenAI, Anthropic, Google, StepFun, etc.)
4
5pub mod anthropic;
6pub mod google;
7pub mod models;
8pub mod moonshot;
9pub mod openai;
10pub mod openrouter;
11pub mod stepfun;
12
13use anyhow::Result;
14use async_trait::async_trait;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18
19/// A message in a conversation
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Message {
22    pub role: Role,
23    pub content: Vec<ContentPart>,
24}
25
26#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
27#[serde(rename_all = "lowercase")]
28pub enum Role {
29    System,
30    User,
31    Assistant,
32    Tool,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum ContentPart {
38    Text {
39        text: String,
40    },
41    Image {
42        url: String,
43        mime_type: Option<String>,
44    },
45    File {
46        path: String,
47        mime_type: Option<String>,
48    },
49    ToolCall {
50        id: String,
51        name: String,
52        arguments: String,
53    },
54    ToolResult {
55        tool_call_id: String,
56        content: String,
57    },
58}
59
60/// Tool definition for the model
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct ToolDefinition {
63    pub name: String,
64    pub description: String,
65    pub parameters: serde_json::Value, // JSON Schema
66}
67
68/// Request to generate a completion
69#[derive(Debug, Clone)]
70pub struct CompletionRequest {
71    pub messages: Vec<Message>,
72    pub tools: Vec<ToolDefinition>,
73    pub model: String,
74    pub temperature: Option<f32>,
75    pub top_p: Option<f32>,
76    pub max_tokens: Option<usize>,
77    pub stop: Vec<String>,
78}
79
80/// A streaming chunk from the model
81#[derive(Debug, Clone)]
82pub enum StreamChunk {
83    Text(String),
84    ToolCallStart { id: String, name: String },
85    ToolCallDelta { id: String, arguments_delta: String },
86    ToolCallEnd { id: String },
87    Done { usage: Option<Usage> },
88    Error(String),
89}
90
91/// Token usage information
92#[derive(Debug, Clone, Default, Serialize, Deserialize)]
93pub struct Usage {
94    pub prompt_tokens: usize,
95    pub completion_tokens: usize,
96    pub total_tokens: usize,
97    pub cache_read_tokens: Option<usize>,
98    pub cache_write_tokens: Option<usize>,
99}
100
101/// Response from a completion request
102#[derive(Debug, Clone)]
103pub struct CompletionResponse {
104    pub message: Message,
105    pub usage: Usage,
106    pub finish_reason: FinishReason,
107}
108
109#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
110#[serde(rename_all = "snake_case")]
111pub enum FinishReason {
112    Stop,
113    Length,
114    ToolCalls,
115    ContentFilter,
116    Error,
117}
118
119/// Provider trait that all AI providers must implement
120#[async_trait]
121pub trait Provider: Send + Sync {
122    /// Get the provider name
123    fn name(&self) -> &str;
124
125    /// List available models
126    async fn list_models(&self) -> Result<Vec<ModelInfo>>;
127
128    /// Generate a completion
129    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
130
131    /// Generate a streaming completion
132    async fn complete_stream(
133        &self,
134        request: CompletionRequest,
135    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>>;
136}
137
138/// Information about a model
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ModelInfo {
141    pub id: String,
142    pub name: String,
143    pub provider: String,
144    pub context_window: usize,
145    pub max_output_tokens: Option<usize>,
146    pub supports_vision: bool,
147    pub supports_tools: bool,
148    pub supports_streaming: bool,
149    pub input_cost_per_million: Option<f64>,
150    pub output_cost_per_million: Option<f64>,
151}
152
153/// Registry of available providers
154pub struct ProviderRegistry {
155    providers: HashMap<String, Arc<dyn Provider>>,
156}
157
158impl std::fmt::Debug for ProviderRegistry {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        f.debug_struct("ProviderRegistry")
161            .field("provider_count", &self.providers.len())
162            .field("providers", &self.providers.keys().collect::<Vec<_>>())
163            .finish()
164    }
165}
166
167impl ProviderRegistry {
168    pub fn new() -> Self {
169        Self {
170            providers: HashMap::new(),
171        }
172    }
173
174    /// Register a provider
175    pub fn register(&mut self, provider: Arc<dyn Provider>) {
176        self.providers.insert(provider.name().to_string(), provider);
177    }
178
179    /// Get a provider by name
180    pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
181        self.providers.get(name).cloned()
182    }
183
184    /// List all registered providers
185    pub fn list(&self) -> Vec<&str> {
186        self.providers.keys().map(|s| s.as_str()).collect()
187    }
188
189    /// Initialize with default providers from config
190    pub async fn from_config(config: &crate::config::Config) -> Result<Self> {
191        let mut registry = Self::new();
192
193        // Always try to initialize OpenAI if key is available
194        if let Some(provider_config) = config.providers.get("openai") {
195            if let Some(api_key) = &provider_config.api_key {
196                registry.register(Arc::new(openai::OpenAIProvider::new(api_key.clone())?));
197            }
198        } else if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
199            registry.register(Arc::new(openai::OpenAIProvider::new(api_key)?));
200        }
201
202        // Initialize Anthropic
203        if let Some(provider_config) = config.providers.get("anthropic") {
204            if let Some(api_key) = &provider_config.api_key {
205                registry.register(Arc::new(anthropic::AnthropicProvider::new(
206                    api_key.clone(),
207                )?));
208            }
209        } else if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
210            registry.register(Arc::new(anthropic::AnthropicProvider::new(api_key)?));
211        }
212
213        // Initialize Google
214        if let Some(provider_config) = config.providers.get("google") {
215            if let Some(api_key) = &provider_config.api_key {
216                registry.register(Arc::new(google::GoogleProvider::new(api_key.clone())?));
217            }
218        } else if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
219            registry.register(Arc::new(google::GoogleProvider::new(api_key)?));
220        }
221
222        // Initialize Novita (OpenAI-compatible)
223        if let Some(provider_config) = config.providers.get("novita") {
224            if let Some(api_key) = &provider_config.api_key {
225                let base_url = provider_config
226                    .base_url
227                    .clone()
228                    .unwrap_or_else(|| "https://api.novita.ai/openai/v1".to_string());
229                registry.register(Arc::new(openai::OpenAIProvider::with_base_url(
230                    api_key.clone(),
231                    base_url,
232                    "novita",
233                )?));
234            }
235        }
236
237        Ok(registry)
238    }
239
240    /// Initialize providers from HashiCorp Vault
241    ///
242    /// This loads API keys from Vault and creates providers dynamically.
243    /// Supports OpenAI-compatible providers via base_url.
244    pub async fn from_vault() -> Result<Self> {
245        let mut registry = Self::new();
246
247        let manager = match crate::secrets::secrets_manager() {
248            Some(m) => m,
249            None => {
250                tracing::warn!("Vault not configured, no providers will be available");
251                return Ok(registry);
252            }
253        };
254
255        // List all configured providers from Vault
256        let providers = manager.list_configured_providers().await?;
257        tracing::info!("Found {} providers configured in Vault", providers.len());
258
259        for provider_id in providers {
260            let secrets = match manager.get_provider_secrets(&provider_id).await? {
261                Some(s) => s,
262                None => continue,
263            };
264
265            let api_key = match secrets.api_key {
266                Some(key) => key,
267                None => continue,
268            };
269
270            // Determine which provider implementation to use
271            match provider_id.as_str() {
272                // Native providers
273                "anthropic" | "anthropic-eu" | "anthropic-asia" => {
274                    match anthropic::AnthropicProvider::new(api_key) {
275                        Ok(p) => registry.register(Arc::new(p)),
276                        Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
277                    }
278                }
279                "google" | "google-vertex" => match google::GoogleProvider::new(api_key) {
280                    Ok(p) => registry.register(Arc::new(p)),
281                    Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
282                },
283                // StepFun - native provider (direct API, not via OpenRouter)
284                "stepfun" => match stepfun::StepFunProvider::new(api_key) {
285                    Ok(p) => registry.register(Arc::new(p)),
286                    Err(e) => tracing::warn!("Failed to init stepfun: {}", e),
287                },
288                // OpenRouter - native provider with support for extended response formats
289                "openrouter" => match openrouter::OpenRouterProvider::new(api_key) {
290                    Ok(p) => registry.register(Arc::new(p)),
291                    Err(e) => tracing::warn!("Failed to init openrouter: {}", e),
292                },
293                // Moonshot AI - native provider for Kimi models
294                "moonshotai" | "moonshotai-cn" => match moonshot::MoonshotProvider::new(api_key) {
295                    Ok(p) => registry.register(Arc::new(p)),
296                    Err(e) => tracing::warn!("Failed to init moonshotai: {}", e),
297                },
298                // ZhipuAI - OpenAI-compatible coding API
299                "zhipuai" => {
300                    let base_url = secrets
301                        .base_url
302                        .clone()
303                        .unwrap_or_else(|| "https://api.z.ai/api/coding/paas/v4".to_string());
304                    match openai::OpenAIProvider::with_base_url(api_key, base_url, "zhipuai") {
305                        Ok(p) => registry.register(Arc::new(p)),
306                        Err(e) => tracing::warn!("Failed to init zhipuai: {}", e),
307                    }
308                }
309                // OpenAI-compatible providers (with custom base_url)
310                "deepseek" | "groq" | "togetherai" | "fireworks-ai" | "mistral" | "nvidia"
311                | "alibaba" | "openai" | "azure" | "novita" => {
312                    if let Some(base_url) = secrets.base_url {
313                        match openai::OpenAIProvider::with_base_url(api_key, base_url, &provider_id)
314                        {
315                            Ok(p) => registry.register(Arc::new(p)),
316                            Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
317                        }
318                    } else if provider_id == "openai" {
319                        // OpenAI doesn't need a custom base_url
320                        match openai::OpenAIProvider::new(api_key) {
321                            Ok(p) => registry.register(Arc::new(p)),
322                            Err(e) => tracing::warn!("Failed to init openai: {}", e),
323                        }
324                    } else if provider_id == "novita" {
325                        let base_url = "https://api.novita.ai/openai/v1".to_string();
326                        match openai::OpenAIProvider::with_base_url(api_key, base_url, &provider_id)
327                        {
328                            Ok(p) => registry.register(Arc::new(p)),
329                            Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
330                        }
331                    } else {
332                        // Try using the base_url from the models API
333                        if let Ok(catalog) = models::ModelCatalog::fetch().await {
334                            if let Some(provider_info) = catalog.get_provider(&provider_id) {
335                                if let Some(api_url) = &provider_info.api {
336                                    match openai::OpenAIProvider::with_base_url(
337                                        api_key,
338                                        api_url.clone(),
339                                        &provider_id,
340                                    ) {
341                                        Ok(p) => registry.register(Arc::new(p)),
342                                        Err(e) => {
343                                            tracing::warn!("Failed to init {}: {}", provider_id, e)
344                                        }
345                                    }
346                                }
347                            }
348                        }
349                    }
350                }
351                // Unknown providers - try as OpenAI-compatible with base_url from API
352                other => {
353                    if let Some(base_url) = secrets.base_url {
354                        match openai::OpenAIProvider::with_base_url(api_key, base_url, other) {
355                            Ok(p) => registry.register(Arc::new(p)),
356                            Err(e) => tracing::warn!("Failed to init {}: {}", other, e),
357                        }
358                    } else {
359                        tracing::debug!("Unknown provider {} without base_url, skipping", other);
360                    }
361                }
362            }
363        }
364
365        tracing::info!(
366            "Registered {} providers from Vault",
367            registry.providers.len()
368        );
369        Ok(registry)
370    }
371}
372
373impl Default for ProviderRegistry {
374    fn default() -> Self {
375        Self::new()
376    }
377}
378
379/// Parse a model string in the format "provider/model"
380pub fn parse_model_string(s: &str) -> (Option<&str>, &str) {
381    if let Some((provider, model)) = s.split_once('/') {
382        (Some(provider), model)
383    } else {
384        (None, s)
385    }
386}