1pub mod anthropic;
6pub mod bedrock;
7pub mod copilot;
8pub mod google;
9pub mod models;
10pub mod moonshot;
11pub mod openai;
12pub mod openrouter;
13pub mod stepfun;
14
15use anyhow::Result;
16use async_trait::async_trait;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::Arc;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct Message {
24 pub role: Role,
25 pub content: Vec<ContentPart>,
26}
27
28#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
29#[serde(rename_all = "lowercase")]
30pub enum Role {
31 System,
32 User,
33 Assistant,
34 Tool,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(tag = "type", rename_all = "snake_case")]
39pub enum ContentPart {
40 Text {
41 text: String,
42 },
43 Image {
44 url: String,
45 mime_type: Option<String>,
46 },
47 File {
48 path: String,
49 mime_type: Option<String>,
50 },
51 ToolCall {
52 id: String,
53 name: String,
54 arguments: String,
55 },
56 ToolResult {
57 tool_call_id: String,
58 content: String,
59 },
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ToolDefinition {
65 pub name: String,
66 pub description: String,
67 pub parameters: serde_json::Value, }
69
70#[derive(Debug, Clone)]
72pub struct CompletionRequest {
73 pub messages: Vec<Message>,
74 pub tools: Vec<ToolDefinition>,
75 pub model: String,
76 pub temperature: Option<f32>,
77 pub top_p: Option<f32>,
78 pub max_tokens: Option<usize>,
79 pub stop: Vec<String>,
80}
81
82#[derive(Debug, Clone)]
84pub enum StreamChunk {
85 Text(String),
86 ToolCallStart { id: String, name: String },
87 ToolCallDelta { id: String, arguments_delta: String },
88 ToolCallEnd { id: String },
89 Done { usage: Option<Usage> },
90 Error(String),
91}
92
93#[derive(Debug, Clone, Default, Serialize, Deserialize)]
95pub struct Usage {
96 pub prompt_tokens: usize,
97 pub completion_tokens: usize,
98 pub total_tokens: usize,
99 pub cache_read_tokens: Option<usize>,
100 pub cache_write_tokens: Option<usize>,
101}
102
103#[derive(Debug, Clone)]
105pub struct CompletionResponse {
106 pub message: Message,
107 pub usage: Usage,
108 pub finish_reason: FinishReason,
109}
110
111#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
112#[serde(rename_all = "snake_case")]
113pub enum FinishReason {
114 Stop,
115 Length,
116 ToolCalls,
117 ContentFilter,
118 Error,
119}
120
121#[async_trait]
123pub trait Provider: Send + Sync {
124 fn name(&self) -> &str;
126
127 async fn list_models(&self) -> Result<Vec<ModelInfo>>;
129
130 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
132
133 async fn complete_stream(
135 &self,
136 request: CompletionRequest,
137 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>>;
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct ModelInfo {
143 pub id: String,
144 pub name: String,
145 pub provider: String,
146 pub context_window: usize,
147 pub max_output_tokens: Option<usize>,
148 pub supports_vision: bool,
149 pub supports_tools: bool,
150 pub supports_streaming: bool,
151 pub input_cost_per_million: Option<f64>,
152 pub output_cost_per_million: Option<f64>,
153}
154
155pub struct ProviderRegistry {
157 providers: HashMap<String, Arc<dyn Provider>>,
158}
159
160impl std::fmt::Debug for ProviderRegistry {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 f.debug_struct("ProviderRegistry")
163 .field("provider_count", &self.providers.len())
164 .field("providers", &self.providers.keys().collect::<Vec<_>>())
165 .finish()
166 }
167}
168
169impl ProviderRegistry {
170 pub fn new() -> Self {
171 Self {
172 providers: HashMap::new(),
173 }
174 }
175
176 pub fn register(&mut self, provider: Arc<dyn Provider>) {
178 self.providers.insert(provider.name().to_string(), provider);
179 }
180
181 pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
183 self.providers.get(name).cloned()
184 }
185
186 pub fn list(&self) -> Vec<&str> {
188 self.providers.keys().map(|s| s.as_str()).collect()
189 }
190
191 pub async fn from_config(config: &crate::config::Config) -> Result<Self> {
193 let mut registry = Self::new();
194
195 if let Some(provider_config) = config.providers.get("openai") {
197 if let Some(api_key) = &provider_config.api_key {
198 registry.register(Arc::new(openai::OpenAIProvider::new(api_key.clone())?));
199 }
200 } else if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
201 registry.register(Arc::new(openai::OpenAIProvider::new(api_key)?));
202 }
203
204 if let Some(provider_config) = config.providers.get("anthropic") {
206 if let Some(api_key) = &provider_config.api_key {
207 registry.register(Arc::new(anthropic::AnthropicProvider::new(
208 api_key.clone(),
209 )?));
210 }
211 } else if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
212 registry.register(Arc::new(anthropic::AnthropicProvider::new(api_key)?));
213 }
214
215 if let Some(provider_config) = config.providers.get("google") {
217 if let Some(api_key) = &provider_config.api_key {
218 registry.register(Arc::new(google::GoogleProvider::new(api_key.clone())?));
219 }
220 } else if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
221 registry.register(Arc::new(google::GoogleProvider::new(api_key)?));
222 }
223
224 if let Some(provider_config) = config.providers.get("novita") {
226 if let Some(api_key) = &provider_config.api_key {
227 let base_url = provider_config
228 .base_url
229 .clone()
230 .unwrap_or_else(|| "https://api.novita.ai/openai/v1".to_string());
231 registry.register(Arc::new(openai::OpenAIProvider::with_base_url(
232 api_key.clone(),
233 base_url,
234 "novita",
235 )?));
236 }
237 }
238
239 Ok(registry)
240 }
241
242 pub async fn from_vault() -> Result<Self> {
247 let mut registry = Self::new();
248
249 if let Some(manager) = crate::secrets::secrets_manager() {
250 let providers = manager.list_configured_providers().await?;
252 tracing::info!("Found {} providers configured in Vault", providers.len());
253
254 for provider_id in providers {
255 let secrets = match manager.get_provider_secrets(&provider_id).await? {
256 Some(s) => s,
257 None => continue,
258 };
259
260 let api_key = match secrets.api_key {
261 Some(key) => key,
262 None => continue,
263 };
264
265 match provider_id.as_str() {
267 "bedrock" | "aws-bedrock" => {
269 let region = secrets
270 .extra
271 .get("region")
272 .and_then(|v| v.as_str())
273 .unwrap_or("us-east-1")
274 .to_string();
275 match bedrock::BedrockProvider::with_region(api_key, region) {
276 Ok(p) => registry.register(Arc::new(p)),
277 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
278 }
279 }
280 "anthropic" | "anthropic-eu" | "anthropic-asia" => {
282 match anthropic::AnthropicProvider::new(api_key) {
283 Ok(p) => registry.register(Arc::new(p)),
284 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
285 }
286 }
287 "google" | "google-vertex" => match google::GoogleProvider::new(api_key) {
288 Ok(p) => registry.register(Arc::new(p)),
289 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
290 },
291 "stepfun" => match stepfun::StepFunProvider::new(api_key) {
293 Ok(p) => registry.register(Arc::new(p)),
294 Err(e) => tracing::warn!("Failed to init stepfun: {}", e),
295 },
296 "openrouter" => match openrouter::OpenRouterProvider::new(api_key) {
298 Ok(p) => registry.register(Arc::new(p)),
299 Err(e) => tracing::warn!("Failed to init openrouter: {}", e),
300 },
301 "moonshotai" | "moonshotai-cn" => {
303 match moonshot::MoonshotProvider::new(api_key) {
304 Ok(p) => registry.register(Arc::new(p)),
305 Err(e) => tracing::warn!("Failed to init moonshotai: {}", e),
306 }
307 }
308 "github-copilot" => {
310 let result = if let Some(base_url) = secrets.base_url.clone() {
311 copilot::CopilotProvider::with_base_url(
312 api_key,
313 base_url,
314 "github-copilot",
315 )
316 } else {
317 copilot::CopilotProvider::new(api_key)
318 };
319
320 match result {
321 Ok(p) => registry.register(Arc::new(p)),
322 Err(e) => tracing::warn!("Failed to init github-copilot: {}", e),
323 }
324 }
325 "github-copilot-enterprise" => {
326 let enterprise_url = secrets
327 .extra
328 .get("enterpriseUrl")
329 .and_then(|v| v.as_str())
330 .or_else(|| {
331 secrets.extra.get("enterprise_url").and_then(|v| v.as_str())
332 });
333
334 let result = if let Some(base_url) = secrets.base_url.clone() {
335 copilot::CopilotProvider::with_base_url(
336 api_key,
337 base_url,
338 "github-copilot-enterprise",
339 )
340 } else if let Some(url) = enterprise_url {
341 copilot::CopilotProvider::enterprise(api_key, url.to_string())
342 } else {
343 copilot::CopilotProvider::with_base_url(
344 api_key,
345 "https://api.githubcopilot.com".to_string(),
346 "github-copilot-enterprise",
347 )
348 };
349
350 match result {
351 Ok(p) => registry.register(Arc::new(p)),
352 Err(e) => {
353 tracing::warn!("Failed to init github-copilot-enterprise: {}", e)
354 }
355 }
356 }
357 "zhipuai" => {
359 let base_url = secrets
360 .base_url
361 .clone()
362 .unwrap_or_else(|| "https://api.z.ai/api/coding/paas/v4".to_string());
363 match openai::OpenAIProvider::with_base_url(api_key, base_url, "zhipuai") {
364 Ok(p) => registry.register(Arc::new(p)),
365 Err(e) => tracing::warn!("Failed to init zhipuai: {}", e),
366 }
367 }
368 "cerebras" => {
370 let base_url = secrets
371 .base_url
372 .clone()
373 .unwrap_or_else(|| "https://api.cerebras.ai/v1".to_string());
374 match openai::OpenAIProvider::with_base_url(api_key, base_url, "cerebras") {
375 Ok(p) => registry.register(Arc::new(p)),
376 Err(e) => tracing::warn!("Failed to init cerebras: {}", e),
377 }
378 }
379 "minimax" => {
381 let base_url = secrets
382 .base_url
383 .clone()
384 .unwrap_or_else(|| "https://api.minimax.chat/v1".to_string());
385 match openai::OpenAIProvider::with_base_url(api_key, base_url, "minimax") {
386 Ok(p) => registry.register(Arc::new(p)),
387 Err(e) => tracing::warn!("Failed to init minimax: {}", e),
388 }
389 }
390 "deepseek" | "groq" | "togetherai" | "fireworks-ai" | "mistral" | "nvidia"
392 | "alibaba" | "openai" | "azure" | "novita" => {
393 if let Some(base_url) = secrets.base_url.clone() {
394 match openai::OpenAIProvider::with_base_url(
395 api_key,
396 base_url,
397 &provider_id,
398 ) {
399 Ok(p) => registry.register(Arc::new(p)),
400 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
401 }
402 } else if provider_id == "openai" {
403 match openai::OpenAIProvider::new(api_key) {
405 Ok(p) => registry.register(Arc::new(p)),
406 Err(e) => tracing::warn!("Failed to init openai: {}", e),
407 }
408 } else if provider_id == "novita" {
409 let base_url = "https://api.novita.ai/openai/v1".to_string();
410 match openai::OpenAIProvider::with_base_url(
411 api_key,
412 base_url,
413 &provider_id,
414 ) {
415 Ok(p) => registry.register(Arc::new(p)),
416 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
417 }
418 } else {
419 if let Ok(catalog) = models::ModelCatalog::fetch().await {
421 if let Some(provider_info) = catalog.get_provider(&provider_id) {
422 if let Some(api_url) = &provider_info.api {
423 match openai::OpenAIProvider::with_base_url(
424 api_key,
425 api_url.clone(),
426 &provider_id,
427 ) {
428 Ok(p) => registry.register(Arc::new(p)),
429 Err(e) => {
430 tracing::warn!(
431 "Failed to init {}: {}",
432 provider_id,
433 e
434 )
435 }
436 }
437 }
438 }
439 }
440 }
441 }
442 other => {
444 if let Some(base_url) = secrets.base_url {
445 match openai::OpenAIProvider::with_base_url(api_key, base_url, other) {
446 Ok(p) => registry.register(Arc::new(p)),
447 Err(e) => tracing::warn!("Failed to init {}: {}", other, e),
448 }
449 } else {
450 tracing::debug!(
451 "Unknown provider {} without base_url, skipping",
452 other
453 );
454 }
455 }
456 }
457 }
458 } else {
459 tracing::warn!("Vault not configured, no providers will be available");
460 }
461
462 tracing::info!(
463 "Registered {} providers from Vault",
464 registry.providers.len()
465 );
466 Ok(registry)
467 }
468}
469
470impl Default for ProviderRegistry {
471 fn default() -> Self {
472 Self::new()
473 }
474}
475
476pub fn parse_model_string(s: &str) -> (Option<&str>, &str) {
478 if let Some((provider, model)) = s.split_once('/') {
479 (Some(provider), model)
480 } else {
481 (None, s)
482 }
483}