1pub mod anthropic;
6pub mod google;
7pub mod models;
8pub mod moonshot;
9pub mod openai;
10pub mod openrouter;
11pub mod stepfun;
12
13use anyhow::Result;
14use async_trait::async_trait;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Message {
22 pub role: Role,
23 pub content: Vec<ContentPart>,
24}
25
26#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
27#[serde(rename_all = "lowercase")]
28pub enum Role {
29 System,
30 User,
31 Assistant,
32 Tool,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum ContentPart {
38 Text {
39 text: String,
40 },
41 Image {
42 url: String,
43 mime_type: Option<String>,
44 },
45 File {
46 path: String,
47 mime_type: Option<String>,
48 },
49 ToolCall {
50 id: String,
51 name: String,
52 arguments: String,
53 },
54 ToolResult {
55 tool_call_id: String,
56 content: String,
57 },
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct ToolDefinition {
63 pub name: String,
64 pub description: String,
65 pub parameters: serde_json::Value, }
67
68#[derive(Debug, Clone)]
70pub struct CompletionRequest {
71 pub messages: Vec<Message>,
72 pub tools: Vec<ToolDefinition>,
73 pub model: String,
74 pub temperature: Option<f32>,
75 pub top_p: Option<f32>,
76 pub max_tokens: Option<usize>,
77 pub stop: Vec<String>,
78}
79
80#[derive(Debug, Clone)]
82pub enum StreamChunk {
83 Text(String),
84 ToolCallStart { id: String, name: String },
85 ToolCallDelta { id: String, arguments_delta: String },
86 ToolCallEnd { id: String },
87 Done { usage: Option<Usage> },
88 Error(String),
89}
90
91#[derive(Debug, Clone, Default, Serialize, Deserialize)]
93pub struct Usage {
94 pub prompt_tokens: usize,
95 pub completion_tokens: usize,
96 pub total_tokens: usize,
97 pub cache_read_tokens: Option<usize>,
98 pub cache_write_tokens: Option<usize>,
99}
100
101#[derive(Debug, Clone)]
103pub struct CompletionResponse {
104 pub message: Message,
105 pub usage: Usage,
106 pub finish_reason: FinishReason,
107}
108
109#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
110#[serde(rename_all = "snake_case")]
111pub enum FinishReason {
112 Stop,
113 Length,
114 ToolCalls,
115 ContentFilter,
116 Error,
117}
118
119#[async_trait]
121pub trait Provider: Send + Sync {
122 fn name(&self) -> &str;
124
125 async fn list_models(&self) -> Result<Vec<ModelInfo>>;
127
128 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
130
131 async fn complete_stream(
133 &self,
134 request: CompletionRequest,
135 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>>;
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ModelInfo {
141 pub id: String,
142 pub name: String,
143 pub provider: String,
144 pub context_window: usize,
145 pub max_output_tokens: Option<usize>,
146 pub supports_vision: bool,
147 pub supports_tools: bool,
148 pub supports_streaming: bool,
149 pub input_cost_per_million: Option<f64>,
150 pub output_cost_per_million: Option<f64>,
151}
152
153pub struct ProviderRegistry {
155 providers: HashMap<String, Arc<dyn Provider>>,
156}
157
158impl std::fmt::Debug for ProviderRegistry {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("ProviderRegistry")
161 .field("provider_count", &self.providers.len())
162 .field("providers", &self.providers.keys().collect::<Vec<_>>())
163 .finish()
164 }
165}
166
167impl ProviderRegistry {
168 pub fn new() -> Self {
169 Self {
170 providers: HashMap::new(),
171 }
172 }
173
174 pub fn register(&mut self, provider: Arc<dyn Provider>) {
176 self.providers.insert(provider.name().to_string(), provider);
177 }
178
179 pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
181 self.providers.get(name).cloned()
182 }
183
184 pub fn list(&self) -> Vec<&str> {
186 self.providers.keys().map(|s| s.as_str()).collect()
187 }
188
189 pub async fn from_config(config: &crate::config::Config) -> Result<Self> {
191 let mut registry = Self::new();
192
193 if let Some(provider_config) = config.providers.get("openai") {
195 if let Some(api_key) = &provider_config.api_key {
196 registry.register(Arc::new(openai::OpenAIProvider::new(api_key.clone())?));
197 }
198 } else if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
199 registry.register(Arc::new(openai::OpenAIProvider::new(api_key)?));
200 }
201
202 if let Some(provider_config) = config.providers.get("anthropic") {
204 if let Some(api_key) = &provider_config.api_key {
205 registry.register(Arc::new(anthropic::AnthropicProvider::new(
206 api_key.clone(),
207 )?));
208 }
209 } else if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
210 registry.register(Arc::new(anthropic::AnthropicProvider::new(api_key)?));
211 }
212
213 if let Some(provider_config) = config.providers.get("google") {
215 if let Some(api_key) = &provider_config.api_key {
216 registry.register(Arc::new(google::GoogleProvider::new(api_key.clone())?));
217 }
218 } else if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
219 registry.register(Arc::new(google::GoogleProvider::new(api_key)?));
220 }
221
222 if let Some(provider_config) = config.providers.get("novita") {
224 if let Some(api_key) = &provider_config.api_key {
225 let base_url = provider_config
226 .base_url
227 .clone()
228 .unwrap_or_else(|| "https://api.novita.ai/openai/v1".to_string());
229 registry.register(Arc::new(openai::OpenAIProvider::with_base_url(
230 api_key.clone(),
231 base_url,
232 "novita",
233 )?));
234 }
235 }
236
237 Ok(registry)
238 }
239
240 pub async fn from_vault() -> Result<Self> {
245 let mut registry = Self::new();
246
247 let manager = match crate::secrets::secrets_manager() {
248 Some(m) => m,
249 None => {
250 tracing::warn!("Vault not configured, no providers will be available");
251 return Ok(registry);
252 }
253 };
254
255 let providers = manager.list_configured_providers().await?;
257 tracing::info!("Found {} providers configured in Vault", providers.len());
258
259 for provider_id in providers {
260 let secrets = match manager.get_provider_secrets(&provider_id).await? {
261 Some(s) => s,
262 None => continue,
263 };
264
265 let api_key = match secrets.api_key {
266 Some(key) => key,
267 None => continue,
268 };
269
270 match provider_id.as_str() {
272 "anthropic" | "anthropic-eu" | "anthropic-asia" => {
274 match anthropic::AnthropicProvider::new(api_key) {
275 Ok(p) => registry.register(Arc::new(p)),
276 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
277 }
278 }
279 "google" | "google-vertex" => match google::GoogleProvider::new(api_key) {
280 Ok(p) => registry.register(Arc::new(p)),
281 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
282 },
283 "stepfun" => match stepfun::StepFunProvider::new(api_key) {
285 Ok(p) => registry.register(Arc::new(p)),
286 Err(e) => tracing::warn!("Failed to init stepfun: {}", e),
287 },
288 "openrouter" => match openrouter::OpenRouterProvider::new(api_key) {
290 Ok(p) => registry.register(Arc::new(p)),
291 Err(e) => tracing::warn!("Failed to init openrouter: {}", e),
292 },
293 "moonshotai" | "moonshotai-cn" => match moonshot::MoonshotProvider::new(api_key) {
295 Ok(p) => registry.register(Arc::new(p)),
296 Err(e) => tracing::warn!("Failed to init moonshotai: {}", e),
297 },
298 "zhipuai" => {
300 let base_url = secrets
301 .base_url
302 .clone()
303 .unwrap_or_else(|| "https://api.z.ai/api/coding/paas/v4".to_string());
304 match openai::OpenAIProvider::with_base_url(api_key, base_url, "zhipuai") {
305 Ok(p) => registry.register(Arc::new(p)),
306 Err(e) => tracing::warn!("Failed to init zhipuai: {}", e),
307 }
308 }
309 "deepseek" | "groq" | "togetherai" | "fireworks-ai" | "mistral" | "nvidia"
311 | "alibaba" | "openai" | "azure" | "novita" => {
312 if let Some(base_url) = secrets.base_url {
313 match openai::OpenAIProvider::with_base_url(api_key, base_url, &provider_id)
314 {
315 Ok(p) => registry.register(Arc::new(p)),
316 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
317 }
318 } else if provider_id == "openai" {
319 match openai::OpenAIProvider::new(api_key) {
321 Ok(p) => registry.register(Arc::new(p)),
322 Err(e) => tracing::warn!("Failed to init openai: {}", e),
323 }
324 } else if provider_id == "novita" {
325 let base_url = "https://api.novita.ai/openai/v1".to_string();
326 match openai::OpenAIProvider::with_base_url(api_key, base_url, &provider_id)
327 {
328 Ok(p) => registry.register(Arc::new(p)),
329 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
330 }
331 } else {
332 if let Ok(catalog) = models::ModelCatalog::fetch().await {
334 if let Some(provider_info) = catalog.get_provider(&provider_id) {
335 if let Some(api_url) = &provider_info.api {
336 match openai::OpenAIProvider::with_base_url(
337 api_key,
338 api_url.clone(),
339 &provider_id,
340 ) {
341 Ok(p) => registry.register(Arc::new(p)),
342 Err(e) => {
343 tracing::warn!("Failed to init {}: {}", provider_id, e)
344 }
345 }
346 }
347 }
348 }
349 }
350 }
351 other => {
353 if let Some(base_url) = secrets.base_url {
354 match openai::OpenAIProvider::with_base_url(api_key, base_url, other) {
355 Ok(p) => registry.register(Arc::new(p)),
356 Err(e) => tracing::warn!("Failed to init {}: {}", other, e),
357 }
358 } else {
359 tracing::debug!("Unknown provider {} without base_url, skipping", other);
360 }
361 }
362 }
363 }
364
365 tracing::info!(
366 "Registered {} providers from Vault",
367 registry.providers.len()
368 );
369 Ok(registry)
370 }
371}
372
373impl Default for ProviderRegistry {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379pub fn parse_model_string(s: &str) -> (Option<&str>, &str) {
381 if let Some((provider, model)) = s.split_once('/') {
382 (Some(provider), model)
383 } else {
384 (None, s)
385 }
386}