1pub mod anthropic;
6pub mod bedrock;
7pub mod copilot;
8pub mod google;
9pub mod models;
10pub mod moonshot;
11pub mod openai;
12pub mod openrouter;
13pub mod stepfun;
14pub mod zai;
15
16use anyhow::Result;
17use async_trait::async_trait;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::Arc;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Message {
25 pub role: Role,
26 pub content: Vec<ContentPart>,
27}
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
30#[serde(rename_all = "lowercase")]
31pub enum Role {
32 System,
33 User,
34 Assistant,
35 Tool,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type", rename_all = "snake_case")]
40pub enum ContentPart {
41 Text {
42 text: String,
43 },
44 Image {
45 url: String,
46 mime_type: Option<String>,
47 },
48 File {
49 path: String,
50 mime_type: Option<String>,
51 },
52 ToolCall {
53 id: String,
54 name: String,
55 arguments: String,
56 },
57 ToolResult {
58 tool_call_id: String,
59 content: String,
60 },
61 Thinking {
62 text: String,
63 },
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ToolDefinition {
69 pub name: String,
70 pub description: String,
71 pub parameters: serde_json::Value, }
73
74#[derive(Debug, Clone)]
76pub struct CompletionRequest {
77 pub messages: Vec<Message>,
78 pub tools: Vec<ToolDefinition>,
79 pub model: String,
80 pub temperature: Option<f32>,
81 pub top_p: Option<f32>,
82 pub max_tokens: Option<usize>,
83 pub stop: Vec<String>,
84}
85
86#[derive(Debug, Clone)]
88pub enum StreamChunk {
89 Text(String),
90 ToolCallStart { id: String, name: String },
91 ToolCallDelta { id: String, arguments_delta: String },
92 ToolCallEnd { id: String },
93 Done { usage: Option<Usage> },
94 Error(String),
95}
96
97#[derive(Debug, Clone, Default, Serialize, Deserialize)]
99pub struct Usage {
100 pub prompt_tokens: usize,
101 pub completion_tokens: usize,
102 pub total_tokens: usize,
103 pub cache_read_tokens: Option<usize>,
104 pub cache_write_tokens: Option<usize>,
105}
106
107#[derive(Debug, Clone)]
109pub struct CompletionResponse {
110 pub message: Message,
111 pub usage: Usage,
112 pub finish_reason: FinishReason,
113}
114
115#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
116#[serde(rename_all = "snake_case")]
117pub enum FinishReason {
118 Stop,
119 Length,
120 ToolCalls,
121 ContentFilter,
122 Error,
123}
124
125#[async_trait]
127pub trait Provider: Send + Sync {
128 fn name(&self) -> &str;
130
131 async fn list_models(&self) -> Result<Vec<ModelInfo>>;
133
134 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
136
137 async fn complete_stream(
139 &self,
140 request: CompletionRequest,
141 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>>;
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct ModelInfo {
147 pub id: String,
148 pub name: String,
149 pub provider: String,
150 pub context_window: usize,
151 pub max_output_tokens: Option<usize>,
152 pub supports_vision: bool,
153 pub supports_tools: bool,
154 pub supports_streaming: bool,
155 pub input_cost_per_million: Option<f64>,
156 pub output_cost_per_million: Option<f64>,
157}
158
159pub struct ProviderRegistry {
161 providers: HashMap<String, Arc<dyn Provider>>,
162}
163
164impl std::fmt::Debug for ProviderRegistry {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 f.debug_struct("ProviderRegistry")
167 .field("provider_count", &self.providers.len())
168 .field("providers", &self.providers.keys().collect::<Vec<_>>())
169 .finish()
170 }
171}
172
173impl ProviderRegistry {
174 pub fn new() -> Self {
175 Self {
176 providers: HashMap::new(),
177 }
178 }
179
180 pub fn register(&mut self, provider: Arc<dyn Provider>) {
182 self.providers.insert(provider.name().to_string(), provider);
183 }
184
185 pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
187 self.providers.get(name).cloned()
188 }
189
190 pub fn list(&self) -> Vec<&str> {
192 self.providers.keys().map(|s| s.as_str()).collect()
193 }
194
195 pub async fn from_config(config: &crate::config::Config) -> Result<Self> {
197 let mut registry = Self::new();
198
199 if let Some(provider_config) = config.providers.get("openai") {
201 if let Some(api_key) = &provider_config.api_key {
202 registry.register(Arc::new(openai::OpenAIProvider::new(api_key.clone())?));
203 }
204 } else if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
205 registry.register(Arc::new(openai::OpenAIProvider::new(api_key)?));
206 }
207
208 if let Some(provider_config) = config.providers.get("anthropic") {
210 if let Some(api_key) = &provider_config.api_key {
211 registry.register(Arc::new(anthropic::AnthropicProvider::new(
212 api_key.clone(),
213 )?));
214 }
215 } else if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
216 registry.register(Arc::new(anthropic::AnthropicProvider::new(api_key)?));
217 }
218
219 if let Some(provider_config) = config.providers.get("google") {
221 if let Some(api_key) = &provider_config.api_key {
222 registry.register(Arc::new(google::GoogleProvider::new(api_key.clone())?));
223 }
224 } else if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
225 registry.register(Arc::new(google::GoogleProvider::new(api_key)?));
226 }
227
228 if let Some(provider_config) = config.providers.get("novita") {
230 if let Some(api_key) = &provider_config.api_key {
231 let base_url = provider_config
232 .base_url
233 .clone()
234 .unwrap_or_else(|| "https://api.novita.ai/openai/v1".to_string());
235 registry.register(Arc::new(openai::OpenAIProvider::with_base_url(
236 api_key.clone(),
237 base_url,
238 "novita",
239 )?));
240 }
241 }
242
243 if let Some(creds) = bedrock::AwsCredentials::from_environment() {
245 let region = bedrock::AwsCredentials::detect_region()
246 .unwrap_or_else(|| bedrock::DEFAULT_REGION.to_string());
247 match bedrock::BedrockProvider::with_credentials(creds, region) {
248 Ok(p) => registry.register(Arc::new(p)),
249 Err(e) => tracing::warn!("Failed to init bedrock from AWS credentials: {}", e),
250 }
251 }
252
253 Ok(registry)
254 }
255
256 pub async fn from_vault() -> Result<Self> {
261 let mut registry = Self::new();
262
263 if let Some(manager) = crate::secrets::secrets_manager() {
264 let providers = manager.list_configured_providers().await?;
266 tracing::info!("Found {} providers configured in Vault", providers.len());
267
268 for provider_id in providers {
269 let secrets = match manager.get_provider_secrets(&provider_id).await? {
270 Some(s) => s,
271 None => continue,
272 };
273
274 if matches!(provider_id.as_str(), "bedrock" | "aws-bedrock") {
277 let region = secrets
278 .extra
279 .get("region")
280 .and_then(|v| v.as_str())
281 .unwrap_or("us-east-1")
282 .to_string();
283
284 let aws_key_id = secrets
286 .extra
287 .get("aws_access_key_id")
288 .and_then(|v| v.as_str());
289 let aws_secret = secrets
290 .extra
291 .get("aws_secret_access_key")
292 .and_then(|v| v.as_str());
293
294 let result = if let (Some(key_id), Some(secret)) = (aws_key_id, aws_secret) {
295 let creds = bedrock::AwsCredentials {
296 access_key_id: key_id.to_string(),
297 secret_access_key: secret.to_string(),
298 session_token: secrets
299 .extra
300 .get("aws_session_token")
301 .and_then(|v| v.as_str())
302 .map(|s| s.to_string()),
303 };
304 bedrock::BedrockProvider::with_credentials(creds, region)
305 } else if let Some(ref key) = secrets.api_key {
306 bedrock::BedrockProvider::with_region(key.clone(), region)
307 } else {
308 if let Some(creds) = bedrock::AwsCredentials::from_environment() {
310 bedrock::BedrockProvider::with_credentials(creds, region)
311 } else {
312 Err(anyhow::anyhow!(
313 "No AWS credentials or API key found for Bedrock"
314 ))
315 }
316 };
317
318 match result {
319 Ok(p) => registry.register(Arc::new(p)),
320 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
321 }
322 continue;
323 }
324
325 let api_key = match secrets.api_key {
326 Some(key) => key,
327 None => continue,
328 };
329
330 match provider_id.as_str() {
332 "anthropic" | "anthropic-eu" | "anthropic-asia" => {
334 match anthropic::AnthropicProvider::new(api_key) {
335 Ok(p) => registry.register(Arc::new(p)),
336 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
337 }
338 }
339 "google" | "google-vertex" => match google::GoogleProvider::new(api_key) {
340 Ok(p) => registry.register(Arc::new(p)),
341 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
342 },
343 "stepfun" => match stepfun::StepFunProvider::new(api_key) {
345 Ok(p) => registry.register(Arc::new(p)),
346 Err(e) => tracing::warn!("Failed to init stepfun: {}", e),
347 },
348 "openrouter" => match openrouter::OpenRouterProvider::new(api_key) {
350 Ok(p) => registry.register(Arc::new(p)),
351 Err(e) => tracing::warn!("Failed to init openrouter: {}", e),
352 },
353 "moonshotai" | "moonshotai-cn" => {
355 match moonshot::MoonshotProvider::new(api_key) {
356 Ok(p) => registry.register(Arc::new(p)),
357 Err(e) => tracing::warn!("Failed to init moonshotai: {}", e),
358 }
359 }
360 "github-copilot" => {
362 let result = if let Some(base_url) = secrets.base_url.clone() {
363 copilot::CopilotProvider::with_base_url(
364 api_key,
365 base_url,
366 "github-copilot",
367 )
368 } else {
369 copilot::CopilotProvider::new(api_key)
370 };
371
372 match result {
373 Ok(p) => registry.register(Arc::new(p)),
374 Err(e) => tracing::warn!("Failed to init github-copilot: {}", e),
375 }
376 }
377 "github-copilot-enterprise" => {
378 let enterprise_url = secrets
379 .extra
380 .get("enterpriseUrl")
381 .and_then(|v| v.as_str())
382 .or_else(|| {
383 secrets.extra.get("enterprise_url").and_then(|v| v.as_str())
384 });
385
386 let result = if let Some(base_url) = secrets.base_url.clone() {
387 copilot::CopilotProvider::with_base_url(
388 api_key,
389 base_url,
390 "github-copilot-enterprise",
391 )
392 } else if let Some(url) = enterprise_url {
393 copilot::CopilotProvider::enterprise(api_key, url.to_string())
394 } else {
395 copilot::CopilotProvider::with_base_url(
396 api_key,
397 "https://api.githubcopilot.com".to_string(),
398 "github-copilot-enterprise",
399 )
400 };
401
402 match result {
403 Ok(p) => registry.register(Arc::new(p)),
404 Err(e) => {
405 tracing::warn!("Failed to init github-copilot-enterprise: {}", e)
406 }
407 }
408 }
409 "zhipuai" | "zai" => {
411 let base_url = secrets
412 .base_url
413 .clone()
414 .unwrap_or_else(|| "https://api.z.ai/api/coding/paas/v4".to_string());
415 match zai::ZaiProvider::with_base_url(api_key, base_url) {
416 Ok(p) => registry.register(Arc::new(p)),
417 Err(e) => tracing::warn!("Failed to init zai: {}", e),
418 }
419 }
420 "cerebras" => {
422 let base_url = secrets
423 .base_url
424 .clone()
425 .unwrap_or_else(|| "https://api.cerebras.ai/v1".to_string());
426 match openai::OpenAIProvider::with_base_url(api_key, base_url, "cerebras") {
427 Ok(p) => registry.register(Arc::new(p)),
428 Err(e) => tracing::warn!("Failed to init cerebras: {}", e),
429 }
430 }
431 "minimax" => {
433 let base_url = secrets
434 .base_url
435 .clone()
436 .unwrap_or_else(|| "https://api.minimax.chat/v1".to_string());
437 match openai::OpenAIProvider::with_base_url(api_key, base_url, "minimax") {
438 Ok(p) => registry.register(Arc::new(p)),
439 Err(e) => tracing::warn!("Failed to init minimax: {}", e),
440 }
441 }
442 "deepseek" | "groq" | "togetherai" | "fireworks-ai" | "mistral" | "nvidia"
444 | "alibaba" | "openai" | "azure" | "novita" => {
445 if let Some(base_url) = secrets.base_url.clone() {
446 match openai::OpenAIProvider::with_base_url(
447 api_key,
448 base_url,
449 &provider_id,
450 ) {
451 Ok(p) => registry.register(Arc::new(p)),
452 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
453 }
454 } else if provider_id == "openai" {
455 match openai::OpenAIProvider::new(api_key) {
457 Ok(p) => registry.register(Arc::new(p)),
458 Err(e) => tracing::warn!("Failed to init openai: {}", e),
459 }
460 } else if provider_id == "novita" {
461 let base_url = "https://api.novita.ai/openai/v1".to_string();
462 match openai::OpenAIProvider::with_base_url(
463 api_key,
464 base_url,
465 &provider_id,
466 ) {
467 Ok(p) => registry.register(Arc::new(p)),
468 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
469 }
470 } else {
471 if let Ok(catalog) = models::ModelCatalog::fetch().await {
473 if let Some(provider_info) = catalog.get_provider(&provider_id) {
474 if let Some(api_url) = &provider_info.api {
475 match openai::OpenAIProvider::with_base_url(
476 api_key,
477 api_url.clone(),
478 &provider_id,
479 ) {
480 Ok(p) => registry.register(Arc::new(p)),
481 Err(e) => {
482 tracing::warn!(
483 "Failed to init {}: {}",
484 provider_id,
485 e
486 )
487 }
488 }
489 }
490 }
491 }
492 }
493 }
494 other => {
496 if let Some(base_url) = secrets.base_url {
497 match openai::OpenAIProvider::with_base_url(api_key, base_url, other) {
498 Ok(p) => registry.register(Arc::new(p)),
499 Err(e) => tracing::warn!("Failed to init {}: {}", other, e),
500 }
501 } else {
502 tracing::debug!(
503 "Unknown provider {} without base_url, skipping",
504 other
505 );
506 }
507 }
508 }
509 }
510 } else {
511 tracing::warn!("Vault not configured, no providers will be available from Vault");
512 }
513
514 if !registry.providers.contains_key("bedrock") {
516 if let Some(creds) = bedrock::AwsCredentials::from_environment() {
517 let region = bedrock::AwsCredentials::detect_region()
518 .unwrap_or_else(|| "us-east-1".to_string());
519 match bedrock::BedrockProvider::with_credentials(creds, region) {
520 Ok(p) => {
521 tracing::info!("Registered Bedrock provider from local AWS credentials");
522 registry.register(Arc::new(p));
523 }
524 Err(e) => tracing::warn!("Failed to init bedrock from AWS credentials: {}", e),
525 }
526 }
527 }
528
529 tracing::info!(
530 "Registered {} providers from Vault",
531 registry.providers.len()
532 );
533 Ok(registry)
534 }
535}
536
537impl Default for ProviderRegistry {
538 fn default() -> Self {
539 Self::new()
540 }
541}
542
543pub fn parse_model_string(s: &str) -> (Option<&str>, &str) {
545 if let Some((provider, model)) = s.split_once('/') {
546 (Some(provider), model)
547 } else {
548 (None, s)
549 }
550}