1pub mod anthropic;
6pub mod bedrock;
7pub mod copilot;
8pub mod gemini_web;
9pub mod google;
10pub mod metrics;
11pub mod models;
12pub mod moonshot;
13pub mod openai;
14pub mod openai_codex;
15pub mod openrouter;
16pub mod stepfun;
17pub mod vertex_anthropic;
18pub mod vertex_glm;
19pub mod zai;
20
21use anyhow::Result;
22use async_trait::async_trait;
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::sync::Arc;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct Message {
30 pub role: Role,
31 pub content: Vec<ContentPart>,
32}
33
34#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
35#[serde(rename_all = "lowercase")]
36pub enum Role {
37 System,
38 User,
39 Assistant,
40 Tool,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44#[serde(tag = "type", rename_all = "snake_case")]
45pub enum ContentPart {
46 Text {
47 text: String,
48 },
49 Image {
50 url: String,
51 mime_type: Option<String>,
52 },
53 File {
54 path: String,
55 mime_type: Option<String>,
56 },
57 ToolCall {
58 id: String,
59 name: String,
60 arguments: String,
61 #[serde(skip_serializing_if = "Option::is_none")]
63 thought_signature: Option<String>,
64 },
65 ToolResult {
66 tool_call_id: String,
67 content: String,
68 },
69 Thinking {
70 text: String,
71 },
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ToolDefinition {
77 pub name: String,
78 pub description: String,
79 pub parameters: serde_json::Value, }
81
82#[derive(Debug, Clone)]
84pub struct CompletionRequest {
85 pub messages: Vec<Message>,
86 pub tools: Vec<ToolDefinition>,
87 pub model: String,
88 pub temperature: Option<f32>,
89 pub top_p: Option<f32>,
90 pub max_tokens: Option<usize>,
91 pub stop: Vec<String>,
92}
93
94#[derive(Debug, Clone)]
96pub enum StreamChunk {
97 Text(String),
98 ToolCallStart { id: String, name: String },
99 ToolCallDelta { id: String, arguments_delta: String },
100 ToolCallEnd { id: String },
101 Done { usage: Option<Usage> },
102 Error(String),
103}
104
105#[derive(Debug, Clone, Default, Serialize, Deserialize)]
107pub struct Usage {
108 pub prompt_tokens: usize,
109 pub completion_tokens: usize,
110 pub total_tokens: usize,
111 pub cache_read_tokens: Option<usize>,
112 pub cache_write_tokens: Option<usize>,
113}
114
115#[derive(Debug, Clone)]
117pub struct CompletionResponse {
118 pub message: Message,
119 pub usage: Usage,
120 pub finish_reason: FinishReason,
121}
122
123#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
124#[serde(rename_all = "snake_case")]
125pub enum FinishReason {
126 Stop,
127 Length,
128 ToolCalls,
129 ContentFilter,
130 Error,
131}
132
133#[async_trait]
135pub trait Provider: Send + Sync {
136 fn name(&self) -> &str;
138
139 async fn list_models(&self) -> Result<Vec<ModelInfo>>;
141
142 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
144
145 async fn complete_stream(
147 &self,
148 request: CompletionRequest,
149 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>>;
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ModelInfo {
155 pub id: String,
156 pub name: String,
157 pub provider: String,
158 pub context_window: usize,
159 pub max_output_tokens: Option<usize>,
160 pub supports_vision: bool,
161 pub supports_tools: bool,
162 pub supports_streaming: bool,
163 pub input_cost_per_million: Option<f64>,
164 pub output_cost_per_million: Option<f64>,
165}
166
167pub struct ProviderRegistry {
169 providers: HashMap<String, Arc<dyn Provider>>,
170}
171
172impl std::fmt::Debug for ProviderRegistry {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 f.debug_struct("ProviderRegistry")
175 .field("provider_count", &self.providers.len())
176 .field("providers", &self.providers.keys().collect::<Vec<_>>())
177 .finish()
178 }
179}
180
181impl ProviderRegistry {
182 pub fn new() -> Self {
183 Self {
184 providers: HashMap::new(),
185 }
186 }
187
188 pub fn register(&mut self, provider: Arc<dyn Provider>) {
190 let name = provider.name().to_string();
191 let wrapped = metrics::MetricsProvider::wrap(provider);
192 self.providers.insert(name, wrapped);
193 }
194
195 pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
197 self.providers.get(name).cloned()
198 }
199
200 pub fn list(&self) -> Vec<&str> {
202 self.providers.keys().map(|s| s.as_str()).collect()
203 }
204
205 pub async fn from_config(config: &crate::config::Config) -> Result<Self> {
207 let mut registry = Self::new();
208
209 if let Some(provider_config) = config.providers.get("openai") {
211 if let Some(api_key) = &provider_config.api_key {
212 registry.register(Arc::new(openai::OpenAIProvider::new(api_key.clone())?));
213 }
214 } else if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
215 registry.register(Arc::new(openai::OpenAIProvider::new(api_key)?));
216 }
217
218 if let Some(provider_config) = config.providers.get("anthropic") {
220 if let Some(api_key) = &provider_config.api_key {
221 let provider = if let Some(base_url) = provider_config.base_url.clone() {
222 anthropic::AnthropicProvider::with_base_url(
223 api_key.clone(),
224 base_url,
225 "anthropic",
226 )?
227 } else {
228 anthropic::AnthropicProvider::new(api_key.clone())?
229 };
230 registry.register(Arc::new(provider));
231 }
232 } else if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
233 registry.register(Arc::new(anthropic::AnthropicProvider::new(api_key)?));
234 }
235
236 if let Some(provider_config) = config.providers.get("google") {
238 if let Some(api_key) = &provider_config.api_key {
239 registry.register(Arc::new(google::GoogleProvider::new(api_key.clone())?));
240 }
241 } else if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
242 registry.register(Arc::new(google::GoogleProvider::new(api_key)?));
243 }
244
245 if let Some(provider_config) = config.providers.get("novita") {
247 if let Some(api_key) = &provider_config.api_key {
248 let base_url = provider_config
249 .base_url
250 .clone()
251 .unwrap_or_else(|| "https://api.novita.ai/openai/v1".to_string());
252 registry.register(Arc::new(openai::OpenAIProvider::with_base_url(
253 api_key.clone(),
254 base_url,
255 "novita",
256 )?));
257 }
258 }
259
260 if let Some(creds) = bedrock::AwsCredentials::from_environment() {
262 let region = bedrock::AwsCredentials::detect_region()
263 .unwrap_or_else(|| bedrock::DEFAULT_REGION.to_string());
264 match bedrock::BedrockProvider::with_credentials(creds, region) {
265 Ok(p) => registry.register(Arc::new(p)),
266 Err(e) => tracing::warn!("Failed to init bedrock from AWS credentials: {}", e),
267 }
268 }
269
270 Ok(registry)
271 }
272
273 pub async fn from_vault() -> Result<Self> {
278 let mut registry = Self::new();
279
280 if let Some(manager) = crate::secrets::secrets_manager() {
281 let providers = manager.list_configured_providers().await?;
283 tracing::info!("Found {} providers configured in Vault", providers.len());
284
285 for provider_id in providers {
286 let secrets = match manager.get_provider_secrets(&provider_id).await? {
287 Some(s) => s,
288 None => continue,
289 };
290
291 if matches!(provider_id.as_str(), "bedrock" | "aws-bedrock") {
294 let region = secrets
295 .extra
296 .get("region")
297 .and_then(|v| v.as_str())
298 .unwrap_or("us-east-1")
299 .to_string();
300
301 let aws_key_id = secrets
303 .extra
304 .get("aws_access_key_id")
305 .and_then(|v| v.as_str());
306 let aws_secret = secrets
307 .extra
308 .get("aws_secret_access_key")
309 .and_then(|v| v.as_str());
310
311 let result = if let (Some(key_id), Some(secret)) = (aws_key_id, aws_secret) {
312 let creds = bedrock::AwsCredentials {
313 access_key_id: key_id.to_string(),
314 secret_access_key: secret.to_string(),
315 session_token: secrets
316 .extra
317 .get("aws_session_token")
318 .and_then(|v| v.as_str())
319 .map(|s| s.to_string()),
320 };
321 bedrock::BedrockProvider::with_credentials(creds, region)
322 } else if let Some(ref key) = secrets.api_key {
323 bedrock::BedrockProvider::with_region(key.clone(), region)
324 } else {
325 if let Some(creds) = bedrock::AwsCredentials::from_environment() {
327 bedrock::BedrockProvider::with_credentials(creds, region)
328 } else {
329 Err(anyhow::anyhow!(
330 "No AWS credentials or API key found for Bedrock"
331 ))
332 }
333 };
334
335 match result {
336 Ok(p) => registry.register(Arc::new(p)),
337 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
338 }
339 continue;
340 }
341
342 if matches!(provider_id.as_str(), "vertex-glm" | "vertex-ai" | "gcp-glm") {
345 let sa_json = secrets
346 .extra
347 .get("service_account_json")
348 .and_then(|v| v.as_str());
349
350 if let Some(sa_json) = sa_json {
351 let project_id = secrets
352 .extra
353 .get("project_id")
354 .and_then(|v| v.as_str())
355 .or_else(|| secrets.extra.get("projectId").and_then(|v| v.as_str()))
356 .map(|s| s.to_string());
357
358 match vertex_glm::VertexGlmProvider::new(sa_json, project_id) {
359 Ok(p) => registry.register(Arc::new(p)),
360 Err(e) => tracing::warn!("Failed to init vertex-glm: {e}"),
361 }
362 } else {
363 tracing::warn!(
364 "vertex-glm provider requires service_account_json in Vault secrets"
365 );
366 }
367 continue;
368 }
369
370 if matches!(
373 provider_id.as_str(),
374 "vertex-anthropic" | "vertex-claude" | "gcp-anthropic"
375 ) {
376 let sa_json = secrets
377 .extra
378 .get("service_account_json")
379 .and_then(|v| v.as_str());
380
381 if let Some(sa_json) = sa_json {
382 let project_id = secrets
383 .extra
384 .get("project_id")
385 .and_then(|v| v.as_str())
386 .or_else(|| secrets.extra.get("projectId").and_then(|v| v.as_str()))
387 .map(|s| s.to_string());
388
389 match vertex_anthropic::VertexAnthropicProvider::new(sa_json, project_id) {
390 Ok(p) => registry.register(Arc::new(p)),
391 Err(e) => tracing::warn!("Failed to init vertex-anthropic: {e}"),
392 }
393 } else {
394 tracing::warn!(
395 "vertex-anthropic provider requires service_account_json in Vault secrets"
396 );
397 }
398 continue;
399 }
400
401 if matches!(provider_id.as_str(), "openai-codex" | "codex" | "chatgpt") {
404 let access_token = secrets.extra.get("access_token").and_then(|v| v.as_str());
405 let refresh_token = secrets.extra.get("refresh_token").and_then(|v| v.as_str());
406 let expires_at = secrets.extra.get("expires_at").and_then(|v| v.as_u64());
407
408 match (access_token, refresh_token, expires_at) {
409 (Some(access), Some(refresh), Some(expires)) => {
410 let creds = openai_codex::OAuthCredentials {
411 access_token: access.to_string(),
412 refresh_token: refresh.to_string(),
413 expires_at: expires,
414 };
415 let provider =
416 openai_codex::OpenAiCodexProvider::from_credentials(creds);
417 registry.register(Arc::new(provider));
418 }
419 _ => {
420 tracing::warn!(
421 "openai-codex provider requires access_token, refresh_token, and expires_at in Vault secrets"
422 );
423 }
424 }
425 continue;
426 }
427
428 if matches!(provider_id.as_str(), "gemini-web") {
430 let cookies = secrets
431 .extra
432 .get("cookies")
433 .and_then(|v| v.as_str());
434
435 if let Some(cookies) = cookies {
436 match gemini_web::GeminiWebProvider::new(cookies.to_string()) {
437 Ok(p) => registry.register(Arc::new(p)),
438 Err(e) => tracing::warn!("Failed to init gemini-web: {e}"),
439 }
440 } else {
441 tracing::warn!(
442 "gemini-web provider requires \"cookies\" field in Vault secrets \
443 (tab-separated Cookie-Editor export)"
444 );
445 }
446 continue;
447 }
448
449 let api_key = match secrets.api_key {
450 Some(key) => key,
451 None => continue,
452 };
453
454 match provider_id.as_str() {
456 "anthropic" | "anthropic-eu" | "anthropic-asia" => {
458 let base_url = secrets
459 .base_url
460 .clone()
461 .unwrap_or_else(|| "https://api.anthropic.com".to_string());
462 match anthropic::AnthropicProvider::with_base_url(
463 api_key,
464 base_url,
465 &provider_id,
466 ) {
467 Ok(p) => registry.register(Arc::new(p)),
468 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
469 }
470 }
471 "google" | "google-vertex" => match google::GoogleProvider::new(api_key) {
472 Ok(p) => registry.register(Arc::new(p)),
473 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
474 },
475 "stepfun" => match stepfun::StepFunProvider::new(api_key) {
477 Ok(p) => registry.register(Arc::new(p)),
478 Err(e) => tracing::warn!("Failed to init stepfun: {}", e),
479 },
480 "openrouter" => match openrouter::OpenRouterProvider::new(api_key) {
482 Ok(p) => registry.register(Arc::new(p)),
483 Err(e) => tracing::warn!("Failed to init openrouter: {}", e),
484 },
485 "moonshotai" | "moonshotai-cn" => {
487 match moonshot::MoonshotProvider::new(api_key) {
488 Ok(p) => registry.register(Arc::new(p)),
489 Err(e) => tracing::warn!("Failed to init moonshotai: {}", e),
490 }
491 }
492 "github-copilot" => {
494 let result = if let Some(base_url) = secrets.base_url.clone() {
495 copilot::CopilotProvider::with_base_url(
496 api_key,
497 base_url,
498 "github-copilot",
499 )
500 } else {
501 copilot::CopilotProvider::new(api_key)
502 };
503
504 match result {
505 Ok(p) => registry.register(Arc::new(p)),
506 Err(e) => tracing::warn!("Failed to init github-copilot: {}", e),
507 }
508 }
509 "github-copilot-enterprise" => {
510 let enterprise_url = secrets
511 .extra
512 .get("enterpriseUrl")
513 .and_then(|v| v.as_str())
514 .or_else(|| {
515 secrets.extra.get("enterprise_url").and_then(|v| v.as_str())
516 });
517
518 let result = if let Some(base_url) = secrets.base_url.clone() {
519 copilot::CopilotProvider::with_base_url(
520 api_key,
521 base_url,
522 "github-copilot-enterprise",
523 )
524 } else if let Some(url) = enterprise_url {
525 copilot::CopilotProvider::enterprise(api_key, url.to_string())
526 } else {
527 copilot::CopilotProvider::with_base_url(
528 api_key,
529 "https://api.githubcopilot.com".to_string(),
530 "github-copilot-enterprise",
531 )
532 };
533
534 match result {
535 Ok(p) => registry.register(Arc::new(p)),
536 Err(e) => {
537 tracing::warn!("Failed to init github-copilot-enterprise: {}", e)
538 }
539 }
540 }
541 "zhipuai" | "zai" => {
543 let base_url = secrets
544 .base_url
545 .clone()
546 .unwrap_or_else(|| "https://api.z.ai/api/paas/v4".to_string());
547 match zai::ZaiProvider::with_base_url(api_key, base_url) {
548 Ok(p) => registry.register(Arc::new(p)),
549 Err(e) => tracing::warn!("Failed to init zai: {}", e),
550 }
551 }
552
553 "cerebras" => {
555 let base_url = secrets
556 .base_url
557 .clone()
558 .unwrap_or_else(|| "https://api.cerebras.ai/v1".to_string());
559 match openai::OpenAIProvider::with_base_url(api_key, base_url, "cerebras") {
560 Ok(p) => registry.register(Arc::new(p)),
561 Err(e) => tracing::warn!("Failed to init cerebras: {}", e),
562 }
563 }
564 "minimax" | "minimax-credits" => {
568 let base_url = secrets
569 .base_url
570 .clone()
571 .unwrap_or_else(|| "https://api.minimax.io/anthropic".to_string());
572 let base_url = normalize_minimax_anthropic_base_url(&base_url);
573 match anthropic::AnthropicProvider::with_base_url(
574 api_key,
575 base_url,
576 &provider_id,
577 ) {
578 Ok(p) => registry.register(Arc::new(p)),
579 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
580 }
581 }
582 "deepseek" | "groq" | "togetherai" | "fireworks-ai" | "mistral" | "nvidia"
584 | "alibaba" | "openai" | "azure" | "novita" => {
585 if let Some(base_url) = secrets.base_url.clone() {
586 match openai::OpenAIProvider::with_base_url(
587 api_key,
588 base_url,
589 &provider_id,
590 ) {
591 Ok(p) => registry.register(Arc::new(p)),
592 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
593 }
594 } else if provider_id == "openai" {
595 match openai::OpenAIProvider::new(api_key) {
597 Ok(p) => registry.register(Arc::new(p)),
598 Err(e) => tracing::warn!("Failed to init openai: {}", e),
599 }
600 } else if provider_id == "novita" {
601 let base_url = "https://api.novita.ai/openai/v1".to_string();
602 match openai::OpenAIProvider::with_base_url(
603 api_key,
604 base_url,
605 &provider_id,
606 ) {
607 Ok(p) => registry.register(Arc::new(p)),
608 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
609 }
610 } else {
611 tracing::warn!(
612 "Provider {} has no built-in base_url; set base_url in Vault secrets",
613 provider_id
614 );
615 }
616 }
617 other => {
619 if let Some(base_url) = secrets.base_url {
620 match openai::OpenAIProvider::with_base_url(api_key, base_url, other) {
621 Ok(p) => registry.register(Arc::new(p)),
622 Err(e) => tracing::warn!("Failed to init {}: {}", other, e),
623 }
624 } else {
625 tracing::debug!(
626 "Unknown provider {} without base_url, skipping",
627 other
628 );
629 }
630 }
631 }
632 }
633 } else {
634 tracing::warn!("Vault not configured, no providers will be available from Vault");
635 }
636
637 if !registry.providers.contains_key("bedrock") {
639 if let Some(creds) = bedrock::AwsCredentials::from_environment() {
640 let region = bedrock::AwsCredentials::detect_region()
641 .unwrap_or_else(|| "us-east-1".to_string());
642 match bedrock::BedrockProvider::with_credentials(creds, region) {
643 Ok(p) => {
644 tracing::info!("Registered Bedrock provider from local AWS credentials");
645 registry.register(Arc::new(p));
646 }
647 Err(e) => tracing::warn!("Failed to init bedrock from AWS credentials: {}", e),
648 }
649 }
650 }
651
652 Self::register_env_fallbacks(&mut registry);
654
655 tracing::info!(
656 "Registered {} providers (Vault + env fallback)",
657 registry.providers.len()
658 );
659 Ok(registry)
660 }
661
662 fn register_env_fallbacks(registry: &mut Self) {
664 let fallbacks: &[(&str, &str, fn(String) -> Result<Arc<dyn Provider>>)] = &[
665 ("openai", "OPENAI_API_KEY", |key| Ok(Arc::new(openai::OpenAIProvider::new(key)?))),
666 ("anthropic", "ANTHROPIC_API_KEY", |key| Ok(Arc::new(anthropic::AnthropicProvider::new(key)?))),
667 ("google", "GOOGLE_API_KEY", |key| Ok(Arc::new(google::GoogleProvider::new(key)?))),
668 ("openrouter", "OPENROUTER_API_KEY", |key| Ok(Arc::new(openrouter::OpenRouterProvider::new(key)?))),
669 ];
670
671 for (provider_id, env_var, constructor) in fallbacks {
672 if !registry.providers.contains_key(*provider_id) {
673 if let Ok(api_key) = std::env::var(env_var) {
674 match constructor(api_key) {
675 Ok(p) => {
676 tracing::info!("Registered {} provider from {} env var", provider_id, env_var);
677 registry.register(p);
678 }
679 Err(e) => tracing::warn!("Failed to init {} from env: {}", provider_id, e),
680 }
681 }
682 }
683 }
684 }
685}
686
687fn normalize_minimax_anthropic_base_url(base_url: &str) -> String {
688 let trimmed = base_url.trim().trim_end_matches('/');
689 if trimmed.eq_ignore_ascii_case("https://api.minimax.io/v1") {
690 "https://api.minimax.io/anthropic".to_string()
691 } else {
692 trimmed.to_string()
693 }
694}
695
696impl Default for ProviderRegistry {
697 fn default() -> Self {
698 Self::new()
699 }
700}
701
702pub fn parse_model_string(s: &str) -> (Option<&str>, &str) {
704 if let Some((provider, model)) = s.split_once('/') {
705 (Some(provider), model)
706 } else {
707 (None, s)
708 }
709}