1pub mod anthropic;
6pub mod bedrock;
7pub mod copilot;
8pub mod google;
9pub mod models;
10pub mod moonshot;
11pub mod openai;
12pub mod openrouter;
13pub mod stepfun;
14pub mod zai;
15
16use anyhow::Result;
17use async_trait::async_trait;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::Arc;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Message {
25 pub role: Role,
26 pub content: Vec<ContentPart>,
27}
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
30#[serde(rename_all = "lowercase")]
31pub enum Role {
32 System,
33 User,
34 Assistant,
35 Tool,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type", rename_all = "snake_case")]
40pub enum ContentPart {
41 Text {
42 text: String,
43 },
44 Image {
45 url: String,
46 mime_type: Option<String>,
47 },
48 File {
49 path: String,
50 mime_type: Option<String>,
51 },
52 ToolCall {
53 id: String,
54 name: String,
55 arguments: String,
56 },
57 ToolResult {
58 tool_call_id: String,
59 content: String,
60 },
61 Thinking {
62 text: String,
63 },
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ToolDefinition {
69 pub name: String,
70 pub description: String,
71 pub parameters: serde_json::Value, }
73
74#[derive(Debug, Clone)]
76pub struct CompletionRequest {
77 pub messages: Vec<Message>,
78 pub tools: Vec<ToolDefinition>,
79 pub model: String,
80 pub temperature: Option<f32>,
81 pub top_p: Option<f32>,
82 pub max_tokens: Option<usize>,
83 pub stop: Vec<String>,
84}
85
86#[derive(Debug, Clone)]
88pub enum StreamChunk {
89 Text(String),
90 ToolCallStart { id: String, name: String },
91 ToolCallDelta { id: String, arguments_delta: String },
92 ToolCallEnd { id: String },
93 Done { usage: Option<Usage> },
94 Error(String),
95}
96
97#[derive(Debug, Clone, Default, Serialize, Deserialize)]
99pub struct Usage {
100 pub prompt_tokens: usize,
101 pub completion_tokens: usize,
102 pub total_tokens: usize,
103 pub cache_read_tokens: Option<usize>,
104 pub cache_write_tokens: Option<usize>,
105}
106
107#[derive(Debug, Clone)]
109pub struct CompletionResponse {
110 pub message: Message,
111 pub usage: Usage,
112 pub finish_reason: FinishReason,
113}
114
115#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
116#[serde(rename_all = "snake_case")]
117pub enum FinishReason {
118 Stop,
119 Length,
120 ToolCalls,
121 ContentFilter,
122 Error,
123}
124
125#[async_trait]
127pub trait Provider: Send + Sync {
128 fn name(&self) -> &str;
130
131 async fn list_models(&self) -> Result<Vec<ModelInfo>>;
133
134 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
136
137 async fn complete_stream(
139 &self,
140 request: CompletionRequest,
141 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>>;
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct ModelInfo {
147 pub id: String,
148 pub name: String,
149 pub provider: String,
150 pub context_window: usize,
151 pub max_output_tokens: Option<usize>,
152 pub supports_vision: bool,
153 pub supports_tools: bool,
154 pub supports_streaming: bool,
155 pub input_cost_per_million: Option<f64>,
156 pub output_cost_per_million: Option<f64>,
157}
158
159pub struct ProviderRegistry {
161 providers: HashMap<String, Arc<dyn Provider>>,
162}
163
164impl std::fmt::Debug for ProviderRegistry {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 f.debug_struct("ProviderRegistry")
167 .field("provider_count", &self.providers.len())
168 .field("providers", &self.providers.keys().collect::<Vec<_>>())
169 .finish()
170 }
171}
172
173impl ProviderRegistry {
174 pub fn new() -> Self {
175 Self {
176 providers: HashMap::new(),
177 }
178 }
179
180 pub fn register(&mut self, provider: Arc<dyn Provider>) {
182 self.providers.insert(provider.name().to_string(), provider);
183 }
184
185 pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
187 self.providers.get(name).cloned()
188 }
189
190 pub fn list(&self) -> Vec<&str> {
192 self.providers.keys().map(|s| s.as_str()).collect()
193 }
194
195 pub async fn from_config(config: &crate::config::Config) -> Result<Self> {
197 let mut registry = Self::new();
198
199 if let Some(provider_config) = config.providers.get("openai") {
201 if let Some(api_key) = &provider_config.api_key {
202 registry.register(Arc::new(openai::OpenAIProvider::new(api_key.clone())?));
203 }
204 } else if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
205 registry.register(Arc::new(openai::OpenAIProvider::new(api_key)?));
206 }
207
208 if let Some(provider_config) = config.providers.get("anthropic") {
210 if let Some(api_key) = &provider_config.api_key {
211 let provider = if let Some(base_url) = provider_config.base_url.clone() {
212 anthropic::AnthropicProvider::with_base_url(
213 api_key.clone(),
214 base_url,
215 "anthropic",
216 )?
217 } else {
218 anthropic::AnthropicProvider::new(api_key.clone())?
219 };
220 registry.register(Arc::new(provider));
221 }
222 } else if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
223 registry.register(Arc::new(anthropic::AnthropicProvider::new(api_key)?));
224 }
225
226 if let Some(provider_config) = config.providers.get("google") {
228 if let Some(api_key) = &provider_config.api_key {
229 registry.register(Arc::new(google::GoogleProvider::new(api_key.clone())?));
230 }
231 } else if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
232 registry.register(Arc::new(google::GoogleProvider::new(api_key)?));
233 }
234
235 if let Some(provider_config) = config.providers.get("novita") {
237 if let Some(api_key) = &provider_config.api_key {
238 let base_url = provider_config
239 .base_url
240 .clone()
241 .unwrap_or_else(|| "https://api.novita.ai/openai/v1".to_string());
242 registry.register(Arc::new(openai::OpenAIProvider::with_base_url(
243 api_key.clone(),
244 base_url,
245 "novita",
246 )?));
247 }
248 }
249
250 if let Some(creds) = bedrock::AwsCredentials::from_environment() {
252 let region = bedrock::AwsCredentials::detect_region()
253 .unwrap_or_else(|| bedrock::DEFAULT_REGION.to_string());
254 match bedrock::BedrockProvider::with_credentials(creds, region) {
255 Ok(p) => registry.register(Arc::new(p)),
256 Err(e) => tracing::warn!("Failed to init bedrock from AWS credentials: {}", e),
257 }
258 }
259
260 Ok(registry)
261 }
262
263 pub async fn from_vault() -> Result<Self> {
268 let mut registry = Self::new();
269
270 if let Some(manager) = crate::secrets::secrets_manager() {
271 let providers = manager.list_configured_providers().await?;
273 tracing::info!("Found {} providers configured in Vault", providers.len());
274
275 for provider_id in providers {
276 let secrets = match manager.get_provider_secrets(&provider_id).await? {
277 Some(s) => s,
278 None => continue,
279 };
280
281 if matches!(provider_id.as_str(), "bedrock" | "aws-bedrock") {
284 let region = secrets
285 .extra
286 .get("region")
287 .and_then(|v| v.as_str())
288 .unwrap_or("us-east-1")
289 .to_string();
290
291 let aws_key_id = secrets
293 .extra
294 .get("aws_access_key_id")
295 .and_then(|v| v.as_str());
296 let aws_secret = secrets
297 .extra
298 .get("aws_secret_access_key")
299 .and_then(|v| v.as_str());
300
301 let result = if let (Some(key_id), Some(secret)) = (aws_key_id, aws_secret) {
302 let creds = bedrock::AwsCredentials {
303 access_key_id: key_id.to_string(),
304 secret_access_key: secret.to_string(),
305 session_token: secrets
306 .extra
307 .get("aws_session_token")
308 .and_then(|v| v.as_str())
309 .map(|s| s.to_string()),
310 };
311 bedrock::BedrockProvider::with_credentials(creds, region)
312 } else if let Some(ref key) = secrets.api_key {
313 bedrock::BedrockProvider::with_region(key.clone(), region)
314 } else {
315 if let Some(creds) = bedrock::AwsCredentials::from_environment() {
317 bedrock::BedrockProvider::with_credentials(creds, region)
318 } else {
319 Err(anyhow::anyhow!(
320 "No AWS credentials or API key found for Bedrock"
321 ))
322 }
323 };
324
325 match result {
326 Ok(p) => registry.register(Arc::new(p)),
327 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
328 }
329 continue;
330 }
331
332 let api_key = match secrets.api_key {
333 Some(key) => key,
334 None => continue,
335 };
336
337 match provider_id.as_str() {
339 "anthropic" | "anthropic-eu" | "anthropic-asia" => {
341 let base_url = secrets
342 .base_url
343 .clone()
344 .unwrap_or_else(|| "https://api.anthropic.com".to_string());
345 match anthropic::AnthropicProvider::with_base_url(
346 api_key,
347 base_url,
348 &provider_id,
349 ) {
350 Ok(p) => registry.register(Arc::new(p)),
351 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
352 }
353 }
354 "google" | "google-vertex" => match google::GoogleProvider::new(api_key) {
355 Ok(p) => registry.register(Arc::new(p)),
356 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
357 },
358 "stepfun" => match stepfun::StepFunProvider::new(api_key) {
360 Ok(p) => registry.register(Arc::new(p)),
361 Err(e) => tracing::warn!("Failed to init stepfun: {}", e),
362 },
363 "openrouter" => match openrouter::OpenRouterProvider::new(api_key) {
365 Ok(p) => registry.register(Arc::new(p)),
366 Err(e) => tracing::warn!("Failed to init openrouter: {}", e),
367 },
368 "moonshotai" | "moonshotai-cn" => {
370 match moonshot::MoonshotProvider::new(api_key) {
371 Ok(p) => registry.register(Arc::new(p)),
372 Err(e) => tracing::warn!("Failed to init moonshotai: {}", e),
373 }
374 }
375 "github-copilot" => {
377 let result = if let Some(base_url) = secrets.base_url.clone() {
378 copilot::CopilotProvider::with_base_url(
379 api_key,
380 base_url,
381 "github-copilot",
382 )
383 } else {
384 copilot::CopilotProvider::new(api_key)
385 };
386
387 match result {
388 Ok(p) => registry.register(Arc::new(p)),
389 Err(e) => tracing::warn!("Failed to init github-copilot: {}", e),
390 }
391 }
392 "github-copilot-enterprise" => {
393 let enterprise_url = secrets
394 .extra
395 .get("enterpriseUrl")
396 .and_then(|v| v.as_str())
397 .or_else(|| {
398 secrets.extra.get("enterprise_url").and_then(|v| v.as_str())
399 });
400
401 let result = if let Some(base_url) = secrets.base_url.clone() {
402 copilot::CopilotProvider::with_base_url(
403 api_key,
404 base_url,
405 "github-copilot-enterprise",
406 )
407 } else if let Some(url) = enterprise_url {
408 copilot::CopilotProvider::enterprise(api_key, url.to_string())
409 } else {
410 copilot::CopilotProvider::with_base_url(
411 api_key,
412 "https://api.githubcopilot.com".to_string(),
413 "github-copilot-enterprise",
414 )
415 };
416
417 match result {
418 Ok(p) => registry.register(Arc::new(p)),
419 Err(e) => {
420 tracing::warn!("Failed to init github-copilot-enterprise: {}", e)
421 }
422 }
423 }
424 "zhipuai" | "zai" => {
426 let base_url = secrets
427 .base_url
428 .clone()
429 .unwrap_or_else(|| "https://api.z.ai/api/paas/v4".to_string());
430 match zai::ZaiProvider::with_base_url(api_key, base_url) {
431 Ok(p) => registry.register(Arc::new(p)),
432 Err(e) => tracing::warn!("Failed to init zai: {}", e),
433 }
434 }
435 "cerebras" => {
437 let base_url = secrets
438 .base_url
439 .clone()
440 .unwrap_or_else(|| "https://api.cerebras.ai/v1".to_string());
441 match openai::OpenAIProvider::with_base_url(api_key, base_url, "cerebras") {
442 Ok(p) => registry.register(Arc::new(p)),
443 Err(e) => tracing::warn!("Failed to init cerebras: {}", e),
444 }
445 }
446 "minimax" => {
448 let base_url = secrets
449 .base_url
450 .clone()
451 .unwrap_or_else(|| "https://api.minimax.io/anthropic".to_string());
452 let base_url = normalize_minimax_anthropic_base_url(&base_url);
453 match anthropic::AnthropicProvider::with_base_url(
454 api_key, base_url, "minimax",
455 ) {
456 Ok(p) => registry.register(Arc::new(p)),
457 Err(e) => tracing::warn!("Failed to init minimax: {}", e),
458 }
459 }
460 "deepseek" | "groq" | "togetherai" | "fireworks-ai" | "mistral" | "nvidia"
462 | "alibaba" | "openai" | "azure" | "novita" => {
463 if let Some(base_url) = secrets.base_url.clone() {
464 match openai::OpenAIProvider::with_base_url(
465 api_key,
466 base_url,
467 &provider_id,
468 ) {
469 Ok(p) => registry.register(Arc::new(p)),
470 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
471 }
472 } else if provider_id == "openai" {
473 match openai::OpenAIProvider::new(api_key) {
475 Ok(p) => registry.register(Arc::new(p)),
476 Err(e) => tracing::warn!("Failed to init openai: {}", e),
477 }
478 } else if provider_id == "novita" {
479 let base_url = "https://api.novita.ai/openai/v1".to_string();
480 match openai::OpenAIProvider::with_base_url(
481 api_key,
482 base_url,
483 &provider_id,
484 ) {
485 Ok(p) => registry.register(Arc::new(p)),
486 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
487 }
488 } else {
489 if let Ok(catalog) = models::ModelCatalog::fetch().await {
491 if let Some(provider_info) = catalog.get_provider(&provider_id) {
492 if let Some(api_url) = &provider_info.api {
493 match openai::OpenAIProvider::with_base_url(
494 api_key,
495 api_url.clone(),
496 &provider_id,
497 ) {
498 Ok(p) => registry.register(Arc::new(p)),
499 Err(e) => {
500 tracing::warn!(
501 "Failed to init {}: {}",
502 provider_id,
503 e
504 )
505 }
506 }
507 }
508 }
509 }
510 }
511 }
512 other => {
514 if let Some(base_url) = secrets.base_url {
515 match openai::OpenAIProvider::with_base_url(api_key, base_url, other) {
516 Ok(p) => registry.register(Arc::new(p)),
517 Err(e) => tracing::warn!("Failed to init {}: {}", other, e),
518 }
519 } else {
520 tracing::debug!(
521 "Unknown provider {} without base_url, skipping",
522 other
523 );
524 }
525 }
526 }
527 }
528 } else {
529 tracing::warn!("Vault not configured, no providers will be available from Vault");
530 }
531
532 if !registry.providers.contains_key("bedrock") {
534 if let Some(creds) = bedrock::AwsCredentials::from_environment() {
535 let region = bedrock::AwsCredentials::detect_region()
536 .unwrap_or_else(|| "us-east-1".to_string());
537 match bedrock::BedrockProvider::with_credentials(creds, region) {
538 Ok(p) => {
539 tracing::info!("Registered Bedrock provider from local AWS credentials");
540 registry.register(Arc::new(p));
541 }
542 Err(e) => tracing::warn!("Failed to init bedrock from AWS credentials: {}", e),
543 }
544 }
545 }
546
547 tracing::info!(
548 "Registered {} providers from Vault",
549 registry.providers.len()
550 );
551 Ok(registry)
552 }
553}
554
555fn normalize_minimax_anthropic_base_url(base_url: &str) -> String {
556 let trimmed = base_url.trim().trim_end_matches('/');
557 if trimmed.eq_ignore_ascii_case("https://api.minimax.io/v1") {
558 "https://api.minimax.io/anthropic".to_string()
559 } else {
560 trimmed.to_string()
561 }
562}
563
564impl Default for ProviderRegistry {
565 fn default() -> Self {
566 Self::new()
567 }
568}
569
570pub fn parse_model_string(s: &str) -> (Option<&str>, &str) {
572 if let Some((provider, model)) = s.split_once('/') {
573 (Some(provider), model)
574 } else {
575 (None, s)
576 }
577}