1use crate::{chat, config::Config, models_cache::ModelsCache, provider::ChatRequest};
2use anyhow::Result;
3use axum::{
4 extract::{Query, State},
5 http::{HeaderMap, StatusCode},
6 response::Json,
7 routing::{get, post},
8 Router,
9};
10use colored::Colorize;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use tower_http::cors::CorsLayer;
14
15#[derive(Clone)]
16pub struct ProxyState {
17 pub config: Config,
18 pub api_key: Option<String>,
19 pub provider_filter: Option<String>,
20 pub model_filter: Option<String>,
21}
22
23#[derive(Deserialize)]
24pub struct ProxyModelsQuery {
25 #[serde(default)]
26 pub provider: Option<String>,
27}
28
29#[derive(Serialize)]
30pub struct ProxyModelsResponse {
31 pub object: String,
32 pub data: Vec<ProxyModel>,
33}
34
35#[derive(Serialize)]
36pub struct ProxyModel {
37 pub id: String,
38 pub object: String,
39 pub created: u64,
40 pub owned_by: String,
41}
42
43#[derive(Deserialize)]
44pub struct ProxyChatRequest {
45 pub model: String,
46 pub messages: Vec<crate::provider::Message>,
47 pub max_tokens: Option<u32>,
48 pub temperature: Option<f32>,
49}
50
51#[derive(Serialize)]
52pub struct ProxyChatResponse {
53 pub id: String,
54 pub object: String,
55 pub created: u64,
56 pub model: String,
57 pub choices: Vec<ProxyChoice>,
58 pub usage: ProxyUsage,
59}
60
61#[derive(Serialize)]
62pub struct ProxyChoice {
63 pub index: u32,
64 pub message: crate::provider::Message,
65 pub finish_reason: String,
66}
67
68#[derive(Serialize)]
69pub struct ProxyUsage {
70 pub prompt_tokens: u32,
71 pub completion_tokens: u32,
72 pub total_tokens: u32,
73}
74
75pub async fn start_proxy_server(
76 host: String,
77 port: u16,
78 provider_filter: Option<String>,
79 model_filter: Option<String>,
80 api_key: Option<String>,
81) -> Result<()> {
82 let config = Config::load()?;
83
84 let final_api_key = if api_key.is_some() { api_key } else { None };
86
87 let state = ProxyState {
88 config,
89 api_key: final_api_key.clone(),
90 provider_filter,
91 model_filter,
92 };
93
94 let app = Router::new()
95 .route("/models", get(list_models))
96 .route("/v1/models", get(list_models))
97 .route("/chat/completions", post(chat_completions))
98 .route("/v1/chat/completions", post(chat_completions))
99 .layer(CorsLayer::permissive())
100 .with_state(Arc::new(state));
101
102 let addr = format!("{}:{}", host, port);
103 println!("{} Starting proxy server on {}", "🚀".blue(), addr.bold());
104
105 if let Some(ref key) = final_api_key {
106 println!(
107 "{} Authentication enabled with API key: {}",
108 "🔐".yellow(),
109 key
110 );
111 } else {
112 println!("{} No authentication required", "⚠️".yellow());
113 }
114
115 let listener = tokio::net::TcpListener::bind(&addr).await?;
116 println!("{} Server listening on http://{}", "✓".green(), addr);
117
118 axum::serve(listener, app).await?;
119
120 Ok(())
121}
122
123async fn authenticate(headers: &HeaderMap, state: &ProxyState) -> Result<(), StatusCode> {
124 if let Some(expected_key) = &state.api_key {
125 if let Some(auth_header) = headers.get("authorization") {
126 if let Ok(auth_str) = auth_header.to_str() {
127 if let Some(token) = auth_str.strip_prefix("Bearer ") {
128 if token == expected_key {
129 return Ok(());
130 }
131 }
132 }
133 }
134 return Err(StatusCode::UNAUTHORIZED);
135 }
136 Ok(())
137}
138
139async fn list_models(
140 Query(query): Query<ProxyModelsQuery>,
141 State(state): State<Arc<ProxyState>>,
142 headers: HeaderMap,
143) -> Result<Json<ProxyModelsResponse>, StatusCode> {
144 authenticate(&headers, &state).await?;
146
147 let mut models = Vec::new();
148 let current_time = std::time::SystemTime::now()
149 .duration_since(std::time::UNIX_EPOCH)
150 .unwrap_or(std::time::Duration::from_secs(0))
151 .as_secs();
152
153 let cache = ModelsCache::load().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
155
156 if cache.needs_refresh() {
158 tokio::spawn(async {
160 if let Ok(mut bg_cache) = ModelsCache::load() {
161 let _ = bg_cache.refresh().await;
162 }
163 });
164 }
165
166 let cached_models = cache.get_all_models();
168
169 for cached_model in cached_models {
170 let provider_name = &cached_model.provider;
171 let model_name = &cached_model.model;
172 let model_id = format!("{}:{}", provider_name, model_name);
173
174 if let Some(ref provider_filter) = state.provider_filter {
176 if provider_name != provider_filter {
177 continue;
178 }
179 }
180
181 if let Some(ref query_provider) = query.provider {
183 if provider_name != query_provider {
184 continue;
185 }
186 }
187
188 if let Some(ref model_filter) = state.model_filter {
190 if !model_id.contains(model_filter) && model_name != model_filter {
191 continue;
192 }
193 }
194
195 models.push(ProxyModel {
196 id: model_id,
197 object: "model".to_string(),
198 created: current_time,
199 owned_by: provider_name.clone(),
200 });
201 }
202
203 let response = ProxyModelsResponse {
204 object: "list".to_string(),
205 data: models,
206 };
207
208 Ok(Json(response))
209}
210
211async fn chat_completions(
212 State(state): State<Arc<ProxyState>>,
213 headers: HeaderMap,
214 Json(request): Json<ProxyChatRequest>,
215) -> Result<Json<ProxyChatResponse>, StatusCode> {
216 authenticate(&headers, &state).await?;
218
219 let (provider_name, model_name) =
221 parse_model_string(&request.model, &state.config).map_err(|_| StatusCode::BAD_REQUEST)?;
222
223 if let Some(ref provider_filter) = state.provider_filter {
225 if provider_name != *provider_filter {
226 return Err(StatusCode::BAD_REQUEST);
227 }
228 }
229
230 if let Some(ref model_filter) = state.model_filter {
232 if !request.model.contains(model_filter) && model_name != *model_filter {
233 return Err(StatusCode::BAD_REQUEST);
234 }
235 }
236
237 let mut config_mut = state.config.clone();
239 let client = chat::create_authenticated_client(&mut config_mut, &provider_name)
240 .await
241 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
242
243 let chat_request = ChatRequest {
245 model: model_name.clone(),
246 messages: request.messages,
247 max_tokens: request.max_tokens,
248 temperature: request.temperature,
249 tools: None, stream: None, };
252
253 let response_text = client
255 .chat(&chat_request)
256 .await
257 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
258
259 let current_time = std::time::SystemTime::now()
261 .duration_since(std::time::UNIX_EPOCH)
262 .unwrap_or(std::time::Duration::from_secs(0))
263 .as_secs();
264
265 let response = ProxyChatResponse {
266 id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
267 object: "chat.completion".to_string(),
268 created: current_time,
269 model: request.model,
270 choices: vec![ProxyChoice {
271 index: 0,
272 message: crate::provider::Message {
273 role: "assistant".to_string(),
274 content_type: crate::provider::MessageContent::Text {
275 content: Some(response_text),
276 },
277 tool_calls: None,
278 tool_call_id: None,
279 },
280 finish_reason: "stop".to_string(),
281 }],
282 usage: ProxyUsage {
283 prompt_tokens: 0, completion_tokens: 0,
285 total_tokens: 0,
286 },
287 };
288
289 Ok(Json(response))
290}
291
292pub fn parse_model_string(model: &str, config: &Config) -> Result<(String, String)> {
293 if let Some(alias_target) = config.get_alias(model) {
295 if alias_target.contains(':') {
296 let parts: Vec<&str> = alias_target.splitn(2, ':').collect();
297 if parts.len() == 2 {
298 return Ok((parts[0].to_string(), parts[1].to_string()));
299 }
300 }
301 return Err(anyhow::anyhow!("Invalid alias target format"));
302 }
303
304 if model.contains(':') {
306 let parts: Vec<&str> = model.splitn(2, ':').collect();
307 if parts.len() == 2 {
308 let provider_name = parts[0].to_string();
309 let model_name = parts[1].to_string();
310
311 if config.has_provider(&provider_name) {
313 return Ok((provider_name, model_name));
314 }
315 }
316 }
317
318 if let Some(default_provider) = &config.default_provider {
320 return Ok((default_provider.clone(), model.to_string()));
321 }
322
323 Err(anyhow::anyhow!(
324 "Could not determine provider for model: {}",
325 model
326 ))
327}
328
329pub fn generate_api_key() -> String {
330 use rand::Rng;
331 const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
332 let mut rng = rand::thread_rng();
333
334 let key: String = (0..32)
335 .map(|_| {
336 let idx = rng.gen_range(0..CHARSET.len());
337 CHARSET[idx] as char
338 })
339 .collect();
340
341 format!("sk-{}", key)
342}