lc/services/
proxy.rs

1use crate::{chat, config::Config, models_cache::ModelsCache, provider::ChatRequest};
2use anyhow::Result;
3use axum::{
4    extract::{Query, State},
5    http::{HeaderMap, StatusCode},
6    response::Json,
7    routing::{get, post},
8    Router,
9};
10use colored::Colorize;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use tower_http::cors::CorsLayer;
14
15#[derive(Clone)]
16pub struct ProxyState {
17    pub config: Config,
18    pub api_key: Option<String>,
19    pub provider_filter: Option<String>,
20    pub model_filter: Option<String>,
21}
22
23#[derive(Deserialize)]
24pub struct ProxyModelsQuery {
25    #[serde(default)]
26    pub provider: Option<String>,
27}
28
29#[derive(Serialize)]
30pub struct ProxyModelsResponse {
31    pub object: String,
32    pub data: Vec<ProxyModel>,
33}
34
35#[derive(Serialize)]
36pub struct ProxyModel {
37    pub id: String,
38    pub object: String,
39    pub created: u64,
40    pub owned_by: String,
41}
42
43#[derive(Deserialize)]
44pub struct ProxyChatRequest {
45    pub model: String,
46    pub messages: Vec<crate::provider::Message>,
47    pub max_tokens: Option<u32>,
48    pub temperature: Option<f32>,
49}
50
51#[derive(Serialize)]
52pub struct ProxyChatResponse {
53    pub id: String,
54    pub object: String,
55    pub created: u64,
56    pub model: String,
57    pub choices: Vec<ProxyChoice>,
58    pub usage: ProxyUsage,
59}
60
61#[derive(Serialize)]
62pub struct ProxyChoice {
63    pub index: u32,
64    pub message: crate::provider::Message,
65    pub finish_reason: String,
66}
67
68#[derive(Serialize)]
69pub struct ProxyUsage {
70    pub prompt_tokens: u32,
71    pub completion_tokens: u32,
72    pub total_tokens: u32,
73}
74
75pub async fn start_proxy_server(
76    host: String,
77    port: u16,
78    provider_filter: Option<String>,
79    model_filter: Option<String>,
80    api_key: Option<String>,
81) -> Result<()> {
82    let config = Config::load()?;
83
84    // Generate API key if requested
85    let final_api_key = if api_key.is_some() { api_key } else { None };
86
87    let state = ProxyState {
88        config,
89        api_key: final_api_key.clone(),
90        provider_filter,
91        model_filter,
92    };
93
94    let app = Router::new()
95        .route("/models", get(list_models))
96        .route("/v1/models", get(list_models))
97        .route("/chat/completions", post(chat_completions))
98        .route("/v1/chat/completions", post(chat_completions))
99        .layer(CorsLayer::permissive())
100        .with_state(Arc::new(state));
101
102    let addr = format!("{}:{}", host, port);
103    println!("{} Starting proxy server on {}", "🚀".blue(), addr.bold());
104
105    if let Some(ref key) = final_api_key {
106        println!(
107            "{} Authentication enabled with API key: {}",
108            "🔐".yellow(),
109            key
110        );
111    } else {
112        println!("{} No authentication required", "⚠️".yellow());
113    }
114
115    let listener = tokio::net::TcpListener::bind(&addr).await?;
116    println!("{} Server listening on http://{}", "✓".green(), addr);
117
118    axum::serve(listener, app).await?;
119
120    Ok(())
121}
122
123async fn authenticate(headers: &HeaderMap, state: &ProxyState) -> Result<(), StatusCode> {
124    if let Some(expected_key) = &state.api_key {
125        if let Some(auth_header) = headers.get("authorization") {
126            if let Ok(auth_str) = auth_header.to_str() {
127                if let Some(token) = auth_str.strip_prefix("Bearer ") {
128                    if token == expected_key {
129                        return Ok(());
130                    }
131                }
132            }
133        }
134        return Err(StatusCode::UNAUTHORIZED);
135    }
136    Ok(())
137}
138
139async fn list_models(
140    Query(query): Query<ProxyModelsQuery>,
141    State(state): State<Arc<ProxyState>>,
142    headers: HeaderMap,
143) -> Result<Json<ProxyModelsResponse>, StatusCode> {
144    // Authenticate if API key is configured
145    authenticate(&headers, &state).await?;
146
147    let mut models = Vec::new();
148    let current_time = std::time::SystemTime::now()
149        .duration_since(std::time::UNIX_EPOCH)
150        .unwrap_or(std::time::Duration::from_secs(0))
151        .as_secs();
152
153    // Use models cache for fast response
154    let cache = ModelsCache::load().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
155
156    // Check if cache needs refresh and refresh in background if needed
157    if cache.needs_refresh() {
158        // Refresh cache in background, but don't block the response
159        tokio::spawn(async {
160            if let Ok(mut bg_cache) = ModelsCache::load() {
161                let _ = bg_cache.refresh().await;
162            }
163        });
164    }
165
166    // Get cached models
167    let cached_models = cache.get_all_models();
168
169    for cached_model in cached_models {
170        let provider_name = &cached_model.provider;
171        let model_name = &cached_model.model;
172        let model_id = format!("{}:{}", provider_name, model_name);
173
174        // Apply provider filter if specified
175        if let Some(ref provider_filter) = state.provider_filter {
176            if provider_name != provider_filter {
177                continue;
178            }
179        }
180
181        // Apply query provider filter if specified
182        if let Some(ref query_provider) = query.provider {
183            if provider_name != query_provider {
184                continue;
185            }
186        }
187
188        // Apply model filter if specified
189        if let Some(ref model_filter) = state.model_filter {
190            if !model_id.contains(model_filter) && model_name != model_filter {
191                continue;
192            }
193        }
194
195        models.push(ProxyModel {
196            id: model_id,
197            object: "model".to_string(),
198            created: current_time,
199            owned_by: provider_name.clone(),
200        });
201    }
202
203    let response = ProxyModelsResponse {
204        object: "list".to_string(),
205        data: models,
206    };
207
208    Ok(Json(response))
209}
210
211async fn chat_completions(
212    State(state): State<Arc<ProxyState>>,
213    headers: HeaderMap,
214    Json(request): Json<ProxyChatRequest>,
215) -> Result<Json<ProxyChatResponse>, StatusCode> {
216    // Authenticate if API key is configured
217    authenticate(&headers, &state).await?;
218
219    // Parse the model to determine provider and model name
220    let (provider_name, model_name) =
221        parse_model_string(&request.model, &state.config).map_err(|_| StatusCode::BAD_REQUEST)?;
222
223    // Check if provider is allowed by filter
224    if let Some(ref provider_filter) = state.provider_filter {
225        if provider_name != *provider_filter {
226            return Err(StatusCode::BAD_REQUEST);
227        }
228    }
229
230    // Check if model is allowed by filter
231    if let Some(ref model_filter) = state.model_filter {
232        if !request.model.contains(model_filter) && model_name != *model_filter {
233            return Err(StatusCode::BAD_REQUEST);
234        }
235    }
236
237    // Create client for the provider
238    let mut config_mut = state.config.clone();
239    let client = chat::create_authenticated_client(&mut config_mut, &provider_name)
240        .await
241        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
242
243    // Convert to internal chat request format
244    let chat_request = ChatRequest {
245        model: model_name.clone(),
246        messages: request.messages,
247        max_tokens: request.max_tokens,
248        temperature: request.temperature,
249        tools: None,  // Proxy doesn't support tools yet
250        stream: None, // Proxy doesn't support streaming yet
251    };
252
253    // Send the request
254    let response_text = client
255        .chat(&chat_request)
256        .await
257        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
258
259    // Create response in OpenAI format
260    let current_time = std::time::SystemTime::now()
261        .duration_since(std::time::UNIX_EPOCH)
262        .unwrap_or(std::time::Duration::from_secs(0))
263        .as_secs();
264
265    let response = ProxyChatResponse {
266        id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
267        object: "chat.completion".to_string(),
268        created: current_time,
269        model: request.model,
270        choices: vec![ProxyChoice {
271            index: 0,
272            message: crate::provider::Message {
273                role: "assistant".to_string(),
274                content_type: crate::provider::MessageContent::Text {
275                    content: Some(response_text),
276                },
277                tool_calls: None,
278                tool_call_id: None,
279            },
280            finish_reason: "stop".to_string(),
281        }],
282        usage: ProxyUsage {
283            prompt_tokens: 0, // We don't track token usage currently
284            completion_tokens: 0,
285            total_tokens: 0,
286        },
287    };
288
289    Ok(Json(response))
290}
291
292pub fn parse_model_string(model: &str, config: &Config) -> Result<(String, String)> {
293    // Check if it's an alias first
294    if let Some(alias_target) = config.get_alias(model) {
295        if alias_target.contains(':') {
296            let parts: Vec<&str> = alias_target.splitn(2, ':').collect();
297            if parts.len() == 2 {
298                return Ok((parts[0].to_string(), parts[1].to_string()));
299            }
300        }
301        return Err(anyhow::anyhow!("Invalid alias target format"));
302    }
303
304    // Check if it contains provider:model format
305    if model.contains(':') {
306        let parts: Vec<&str> = model.splitn(2, ':').collect();
307        if parts.len() == 2 {
308            let provider_name = parts[0].to_string();
309            let model_name = parts[1].to_string();
310
311            // Validate provider exists
312            if config.has_provider(&provider_name) {
313                return Ok((provider_name, model_name));
314            }
315        }
316    }
317
318    // If no provider specified, use default provider
319    if let Some(default_provider) = &config.default_provider {
320        return Ok((default_provider.clone(), model.to_string()));
321    }
322
323    Err(anyhow::anyhow!(
324        "Could not determine provider for model: {}",
325        model
326    ))
327}
328
329pub fn generate_api_key() -> String {
330    use rand::Rng;
331    const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
332    let mut rng = rand::thread_rng();
333
334    let key: String = (0..32)
335        .map(|_| {
336            let idx = rng.gen_range(0..CHARSET.len());
337            CHARSET[idx] as char
338        })
339        .collect();
340
341    format!("sk-{}", key)
342}