codetether_agent/provider/
mod.rs1pub mod anthropic;
6pub mod google;
7pub mod models;
8pub mod moonshot;
9pub mod openai;
10pub mod openrouter;
11pub mod stepfun;
12
13use anyhow::Result;
14use async_trait::async_trait;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Message {
22 pub role: Role,
23 pub content: Vec<ContentPart>,
24}
25
26#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
27#[serde(rename_all = "lowercase")]
28pub enum Role {
29 System,
30 User,
31 Assistant,
32 Tool,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum ContentPart {
38 Text { text: String },
39 Image { url: String, mime_type: Option<String> },
40 File { path: String, mime_type: Option<String> },
41 ToolCall { id: String, name: String, arguments: String },
42 ToolResult { tool_call_id: String, content: String },
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolDefinition {
48 pub name: String,
49 pub description: String,
50 pub parameters: serde_json::Value, }
52
53#[derive(Debug, Clone)]
55pub struct CompletionRequest {
56 pub messages: Vec<Message>,
57 pub tools: Vec<ToolDefinition>,
58 pub model: String,
59 pub temperature: Option<f32>,
60 pub top_p: Option<f32>,
61 pub max_tokens: Option<usize>,
62 pub stop: Vec<String>,
63}
64
65#[derive(Debug, Clone)]
67pub enum StreamChunk {
68 Text(String),
69 ToolCallStart { id: String, name: String },
70 ToolCallDelta { id: String, arguments_delta: String },
71 ToolCallEnd { id: String },
72 Done { usage: Option<Usage> },
73 Error(String),
74}
75
76#[derive(Debug, Clone, Default, Serialize, Deserialize)]
78pub struct Usage {
79 pub prompt_tokens: usize,
80 pub completion_tokens: usize,
81 pub total_tokens: usize,
82 pub cache_read_tokens: Option<usize>,
83 pub cache_write_tokens: Option<usize>,
84}
85
86#[derive(Debug, Clone)]
88pub struct CompletionResponse {
89 pub message: Message,
90 pub usage: Usage,
91 pub finish_reason: FinishReason,
92}
93
94#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
95#[serde(rename_all = "snake_case")]
96pub enum FinishReason {
97 Stop,
98 Length,
99 ToolCalls,
100 ContentFilter,
101 Error,
102}
103
104#[async_trait]
106pub trait Provider: Send + Sync {
107 fn name(&self) -> &str;
109
110 async fn list_models(&self) -> Result<Vec<ModelInfo>>;
112
113 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
115
116 async fn complete_stream(
118 &self,
119 request: CompletionRequest,
120 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>>;
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct ModelInfo {
126 pub id: String,
127 pub name: String,
128 pub provider: String,
129 pub context_window: usize,
130 pub max_output_tokens: Option<usize>,
131 pub supports_vision: bool,
132 pub supports_tools: bool,
133 pub supports_streaming: bool,
134 pub input_cost_per_million: Option<f64>,
135 pub output_cost_per_million: Option<f64>,
136}
137
138pub struct ProviderRegistry {
140 providers: HashMap<String, Arc<dyn Provider>>,
141}
142
143impl std::fmt::Debug for ProviderRegistry {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 f.debug_struct("ProviderRegistry")
146 .field("provider_count", &self.providers.len())
147 .field("providers", &self.providers.keys().collect::<Vec<_>>())
148 .finish()
149 }
150}
151
152impl ProviderRegistry {
153 pub fn new() -> Self {
154 Self {
155 providers: HashMap::new(),
156 }
157 }
158
159 pub fn register(&mut self, provider: Arc<dyn Provider>) {
161 self.providers.insert(provider.name().to_string(), provider);
162 }
163
164 pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
166 self.providers.get(name).cloned()
167 }
168
169 pub fn list(&self) -> Vec<&str> {
171 self.providers.keys().map(|s| s.as_str()).collect()
172 }
173
174 pub async fn from_config(config: &crate::config::Config) -> Result<Self> {
176 let mut registry = Self::new();
177
178 if let Some(provider_config) = config.providers.get("openai") {
180 if let Some(api_key) = &provider_config.api_key {
181 registry.register(Arc::new(openai::OpenAIProvider::new(api_key.clone())?));
182 }
183 } else if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
184 registry.register(Arc::new(openai::OpenAIProvider::new(api_key)?));
185 }
186
187 if let Some(provider_config) = config.providers.get("anthropic") {
189 if let Some(api_key) = &provider_config.api_key {
190 registry.register(Arc::new(anthropic::AnthropicProvider::new(api_key.clone())?));
191 }
192 } else if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
193 registry.register(Arc::new(anthropic::AnthropicProvider::new(api_key)?));
194 }
195
196 if let Some(provider_config) = config.providers.get("google") {
198 if let Some(api_key) = &provider_config.api_key {
199 registry.register(Arc::new(google::GoogleProvider::new(api_key.clone())?));
200 }
201 } else if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
202 registry.register(Arc::new(google::GoogleProvider::new(api_key)?));
203 }
204
205 Ok(registry)
206 }
207
208 pub async fn from_vault() -> Result<Self> {
213 let mut registry = Self::new();
214
215 let manager = match crate::secrets::secrets_manager() {
216 Some(m) => m,
217 None => {
218 tracing::warn!("Vault not configured, no providers will be available");
219 return Ok(registry);
220 }
221 };
222
223 let providers = manager.list_configured_providers().await?;
225 tracing::info!("Found {} providers configured in Vault", providers.len());
226
227 for provider_id in providers {
228 let secrets = match manager.get_provider_secrets(&provider_id).await? {
229 Some(s) => s,
230 None => continue,
231 };
232
233 let api_key = match secrets.api_key {
234 Some(key) => key,
235 None => continue,
236 };
237
238 match provider_id.as_str() {
240 "anthropic" | "anthropic-eu" | "anthropic-asia" => {
242 match anthropic::AnthropicProvider::new(api_key) {
243 Ok(p) => registry.register(Arc::new(p)),
244 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
245 }
246 }
247 "google" | "google-vertex" => {
248 match google::GoogleProvider::new(api_key) {
249 Ok(p) => registry.register(Arc::new(p)),
250 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
251 }
252 }
253 "stepfun" => {
255 match stepfun::StepFunProvider::new(api_key) {
256 Ok(p) => registry.register(Arc::new(p)),
257 Err(e) => tracing::warn!("Failed to init stepfun: {}", e),
258 }
259 }
260 "openrouter" => {
262 match openrouter::OpenRouterProvider::new(api_key) {
263 Ok(p) => registry.register(Arc::new(p)),
264 Err(e) => tracing::warn!("Failed to init openrouter: {}", e),
265 }
266 }
267 "moonshotai" | "moonshotai-cn" => {
269 match moonshot::MoonshotProvider::new(api_key) {
270 Ok(p) => registry.register(Arc::new(p)),
271 Err(e) => tracing::warn!("Failed to init moonshotai: {}", e),
272 }
273 }
274 "deepseek" | "groq" | "togetherai"
276 | "fireworks-ai" | "mistral" | "nvidia" | "alibaba"
277 | "openai" | "azure" => {
278 if let Some(base_url) = secrets.base_url {
279 match openai::OpenAIProvider::with_base_url(api_key, base_url, &provider_id) {
280 Ok(p) => registry.register(Arc::new(p)),
281 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
282 }
283 } else if provider_id == "openai" {
284 match openai::OpenAIProvider::new(api_key) {
286 Ok(p) => registry.register(Arc::new(p)),
287 Err(e) => tracing::warn!("Failed to init openai: {}", e),
288 }
289 } else {
290 if let Ok(catalog) = models::ModelCatalog::fetch().await {
292 if let Some(provider_info) = catalog.get_provider(&provider_id) {
293 if let Some(api_url) = &provider_info.api {
294 match openai::OpenAIProvider::with_base_url(api_key, api_url.clone(), &provider_id) {
295 Ok(p) => registry.register(Arc::new(p)),
296 Err(e) => tracing::warn!("Failed to init {}: {}", provider_id, e),
297 }
298 }
299 }
300 }
301 }
302 }
303 other => {
305 if let Some(base_url) = secrets.base_url {
306 match openai::OpenAIProvider::with_base_url(api_key, base_url, other) {
307 Ok(p) => registry.register(Arc::new(p)),
308 Err(e) => tracing::warn!("Failed to init {}: {}", other, e),
309 }
310 } else {
311 tracing::debug!("Unknown provider {} without base_url, skipping", other);
312 }
313 }
314 }
315 }
316
317 tracing::info!("Registered {} providers from Vault", registry.providers.len());
318 Ok(registry)
319 }
320}
321
322impl Default for ProviderRegistry {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328pub fn parse_model_string(s: &str) -> (Option<&str>, &str) {
330 if let Some((provider, model)) = s.split_once('/') {
331 (Some(provider), model)
332 } else {
333 (None, s)
334 }
335}