Skip to main content

lean_ctx/proxy/
mod.rs

1pub mod anthropic;
2pub mod compress;
3pub mod forward;
4pub mod google;
5pub mod history_prune;
6pub mod introspect;
7pub mod metrics;
8pub mod openai;
9
10use std::net::SocketAddr;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13
14use axum::{
15    body::Body,
16    extract::State,
17    http::{Request, StatusCode},
18    response::{IntoResponse, Response},
19    routing::{any, get},
20    Router,
21};
22
23#[derive(Clone)]
24pub struct ProxyState {
25    pub client: reqwest::Client,
26    pub port: u16,
27    pub stats: Arc<ProxyStats>,
28    pub introspect: Arc<introspect::IntrospectState>,
29    pub anthropic_upstream: String,
30    pub openai_upstream: String,
31    pub gemini_upstream: String,
32}
33
34pub struct ProxyStats {
35    pub requests_total: AtomicU64,
36    pub requests_compressed: AtomicU64,
37    pub tokens_saved: AtomicU64,
38    pub bytes_original: AtomicU64,
39    pub bytes_compressed: AtomicU64,
40}
41
42impl Default for ProxyStats {
43    fn default() -> Self {
44        Self {
45            requests_total: AtomicU64::new(0),
46            requests_compressed: AtomicU64::new(0),
47            tokens_saved: AtomicU64::new(0),
48            bytes_original: AtomicU64::new(0),
49            bytes_compressed: AtomicU64::new(0),
50        }
51    }
52}
53
54impl ProxyStats {
55    pub fn record_request(&self) {
56        self.requests_total.fetch_add(1, Ordering::Relaxed);
57    }
58
59    pub fn record_compression(&self, original: usize, compressed: usize) {
60        self.requests_compressed.fetch_add(1, Ordering::Relaxed);
61        self.bytes_original
62            .fetch_add(original as u64, Ordering::Relaxed);
63        self.bytes_compressed
64            .fetch_add(compressed as u64, Ordering::Relaxed);
65        let saved_tokens = (original.saturating_sub(compressed) / 4) as u64;
66        self.tokens_saved.fetch_add(saved_tokens, Ordering::Relaxed);
67    }
68
69    pub fn compression_ratio(&self) -> f64 {
70        let original = self.bytes_original.load(Ordering::Relaxed);
71        if original == 0 {
72            return 0.0;
73        }
74        let compressed = self.bytes_compressed.load(Ordering::Relaxed);
75        (1.0 - compressed as f64 / original as f64) * 100.0
76    }
77}
78
79pub async fn start_proxy(port: u16) -> anyhow::Result<()> {
80    let token = crate::core::session_token::resolve_proxy_token("LEAN_CTX_PROXY_TOKEN");
81    start_proxy_with_token(port, Some(token)).await
82}
83
84pub async fn start_proxy_with_token(port: u16, auth_token: Option<String>) -> anyhow::Result<()> {
85    use crate::core::config::{Config, ProxyProvider};
86
87    let client = reqwest::Client::builder()
88        .timeout(std::time::Duration::from_mins(2))
89        .build()?;
90
91    let cfg = Config::load();
92    let anthropic_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Anthropic);
93    let openai_upstream = cfg.proxy.resolve_upstream(ProxyProvider::OpenAi);
94    let gemini_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Gemini);
95
96    let state = ProxyState {
97        client,
98        port,
99        stats: Arc::new(ProxyStats::default()),
100        introspect: Arc::new(introspect::IntrospectState::default()),
101        anthropic_upstream: anthropic_upstream.clone(),
102        openai_upstream: openai_upstream.clone(),
103        gemini_upstream: gemini_upstream.clone(),
104    };
105
106    let mut app = Router::new()
107        .route("/health", get(health))
108        .route("/status", get(status_handler))
109        .route("/v1/messages", any(anthropic::handler))
110        .route("/v1/messages/{*rest}", any(anthropic::handler))
111        .route("/v1/chat/completions", any(openai::handler))
112        .route("/v1/references/{id}", get(v1_resolve_reference))
113        .fallback(fallback_router)
114        .layer(axum::middleware::from_fn(host_guard))
115        .with_state(state);
116
117    if let Some(ref token) = auth_token {
118        let expected = token.clone();
119        app = app.layer(axum::middleware::from_fn(move |req, next| {
120            let expected = expected.clone();
121            proxy_auth_guard(req, next, expected)
122        }));
123    }
124
125    let addr = SocketAddr::from(([127, 0, 0, 1], port));
126    if auth_token.is_some() {
127        println!("lean-ctx proxy listening on http://{addr} (token auth enabled)");
128    } else {
129        println!("lean-ctx proxy listening on http://{addr} (no auth — set LEAN_CTX_PROXY_TOKEN to enable)");
130    }
131    println!("  Anthropic: POST /v1/messages → {anthropic_upstream}");
132    println!("  OpenAI:    POST /v1/chat/completions → {openai_upstream}");
133    println!("  Gemini:    POST /v1beta/models/... → {gemini_upstream}");
134
135    let listener = tokio::net::TcpListener::bind(addr).await?;
136    axum::serve(listener, app)
137        .with_graceful_shutdown(shutdown_signal())
138        .await?;
139
140    println!("lean-ctx proxy shut down cleanly.");
141    Ok(())
142}
143
144async fn shutdown_signal() {
145    let ctrl_c = tokio::signal::ctrl_c();
146
147    #[cfg(unix)]
148    {
149        // Fall back to Ctrl-C only if the SIGTERM handler cannot be installed,
150        // rather than panicking the proxy on startup.
151        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
152            Ok(mut sigterm) => {
153                tokio::select! {
154                    _ = ctrl_c => {},
155                    _ = sigterm.recv() => {},
156                }
157            }
158            Err(e) => {
159                tracing::warn!("lean-ctx proxy: SIGTERM handler unavailable ({e}); Ctrl-C only");
160                ctrl_c.await.ok();
161            }
162        }
163    }
164
165    #[cfg(not(unix))]
166    {
167        ctrl_c.await.ok();
168    }
169
170    println!("lean-ctx proxy: received shutdown signal, draining…");
171}
172
173async fn health() -> impl IntoResponse {
174    let body = serde_json::json!({
175        "status": "ok",
176        "pid": std::process::id(),
177    });
178    (StatusCode::OK, axum::Json(body))
179}
180
181async fn v1_resolve_reference(
182    axum::extract::Path(id): axum::extract::Path<String>,
183) -> impl IntoResponse {
184    match crate::server::reference_store::resolve(&id) {
185        Some(content) => (StatusCode::OK, content),
186        None => (
187            StatusCode::NOT_FOUND,
188            "Reference expired or not found".to_string(),
189        ),
190    }
191}
192
193async fn status_handler(State(state): State<ProxyState>) -> impl IntoResponse {
194    use std::sync::atomic::Ordering::Relaxed;
195    let s = &state.stats;
196    let i = &state.introspect;
197
198    let last_breakdown = i
199        .last_breakdown
200        .lock()
201        .ok()
202        .and_then(|guard| guard.as_ref().map(|b| serde_json::to_value(b).ok()))
203        .flatten();
204
205    let body = serde_json::json!({
206        "status": "running",
207        "port": state.port,
208        "requests_total": s.requests_total.load(Relaxed),
209        "requests_compressed": s.requests_compressed.load(Relaxed),
210        "tokens_saved": s.tokens_saved.load(Relaxed),
211        "bytes_original": s.bytes_original.load(Relaxed),
212        "bytes_compressed": s.bytes_compressed.load(Relaxed),
213        "compression_ratio_pct": format!("{:.1}", s.compression_ratio()),
214        "introspect": {
215            "total_requests_analyzed": i.total_requests.load(Relaxed),
216            "total_system_prompt_tokens": i.total_system_prompt_tokens.load(Relaxed),
217            "last_breakdown": last_breakdown,
218        }
219    });
220    (StatusCode::OK, axum::Json(body))
221}
222
223async fn proxy_auth_guard(
224    req: axum::extract::Request,
225    next: axum::middleware::Next,
226    expected_token: String,
227) -> Result<Response, Response> {
228    let path = req.uri().path();
229    if path == "/health" {
230        return Ok(next.run(req).await);
231    }
232
233    // Accept Bearer token (lean-ctx session token)
234    if let Some(auth) = req
235        .headers()
236        .get("authorization")
237        .and_then(|v| v.to_str().ok())
238    {
239        if let Some(token) = auth.strip_prefix("Bearer ") {
240            if constant_time_eq(token.as_bytes(), expected_token.as_bytes()) {
241                return Ok(next.run(req).await);
242            }
243        }
244    }
245
246    // Accept provider API keys on provider routes (loopback-only, host_guard runs first).
247    // AI tools like Claude Code send x-api-key, not Bearer tokens. Since the proxy
248    // only binds to 127.0.0.1, the presence of a valid API key header is sufficient
249    // to authenticate the request as coming from a local AI tool.
250    if has_provider_api_key(&req) && is_provider_route(path) {
251        return Ok(next.run(req).await);
252    }
253
254    let cfg = crate::core::config::Config::load();
255    let hint = match cfg.proxy_enabled {
256        Some(true) => "lean-ctx proxy requires authentication. Use a Bearer token (LEAN_CTX_PROXY_TOKEN) or configure your AI tool's API key.",
257        Some(false) => "lean-ctx proxy is disabled but still running. Run: lean-ctx proxy cleanup",
258        None => "lean-ctx proxy is not configured. Your AI tool's ANTHROPIC_BASE_URL may be pointing here by mistake. Fix: lean-ctx proxy cleanup  OR  lean-ctx proxy enable",
259    };
260
261    let body = serde_json::json!({
262        "type": "error",
263        "error": {
264            "type": "authentication_error",
265            "message": format!("401 Unauthorized — {hint}")
266        }
267    });
268
269    Err((StatusCode::UNAUTHORIZED, axum::Json(body)).into_response())
270}
271
272fn has_provider_api_key(req: &axum::extract::Request) -> bool {
273    let headers = req.headers();
274    for key in ["x-api-key", "x-goog-api-key", "api-key"] {
275        if headers
276            .get(key)
277            .and_then(|v| v.to_str().ok())
278            .is_some_and(|v| !v.trim().is_empty())
279        {
280            return true;
281        }
282    }
283    if let Some(auth) = headers.get("authorization").and_then(|v| v.to_str().ok()) {
284        if auth.starts_with("Bearer sk-") || auth.starts_with("Bearer gsk_") {
285            return true;
286        }
287    }
288    false
289}
290
291fn is_provider_route(path: &str) -> bool {
292    path.starts_with("/v1/")
293        || path.starts_with("/v1beta/")
294        || path.starts_with("/chat/completions")
295}
296
297fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
298    use subtle::ConstantTimeEq;
299    if a.len() != b.len() {
300        return false;
301    }
302    bool::from(a.ct_eq(b))
303}
304
305async fn host_guard(
306    req: axum::extract::Request,
307    next: axum::middleware::Next,
308) -> Result<Response, StatusCode> {
309    if let Some(host) = req.headers().get("host").and_then(|v| v.to_str().ok()) {
310        let h = host.split(':').next().unwrap_or(host);
311        if matches!(h, "127.0.0.1" | "localhost" | "[::1]") {
312            return Ok(next.run(req).await);
313        }
314    }
315    Err(StatusCode::FORBIDDEN)
316}
317
318async fn fallback_router(State(state): State<ProxyState>, req: Request<Body>) -> Response {
319    let path = req.uri().path().to_string();
320
321    if path.starts_with("/v1beta/models/") || path.starts_with("/v1/models/") {
322        match google::handler(State(state), req).await {
323            Ok(resp) => resp,
324            Err(status) => Response::builder()
325                .status(status)
326                .body(Body::from("proxy error"))
327                .expect("BUG: building error response with valid status should never fail"),
328        }
329    } else {
330        let method = req.method().to_string();
331        eprintln!("lean-ctx proxy: unmatched {method} {path}");
332        Response::builder()
333            .status(StatusCode::NOT_FOUND)
334            .body(Body::from(format!(
335                "lean-ctx proxy: no handler for {method} {path}"
336            )))
337            .expect("BUG: building 404 response should never fail")
338    }
339}
340
341#[cfg(test)]
342mod auth_tests {
343    use super::*;
344
345    #[test]
346    fn is_provider_route_v1() {
347        assert!(is_provider_route("/v1/chat/completions"));
348        assert!(is_provider_route("/v1/messages"));
349        assert!(is_provider_route("/v1/completions"));
350    }
351
352    #[test]
353    fn is_provider_route_anthropic_subpaths() {
354        assert!(is_provider_route("/v1/messages/count_tokens"));
355        assert!(is_provider_route("/v1/messages/batches"));
356        assert!(is_provider_route("/v1/messages/batches/batch_123"));
357    }
358
359    #[test]
360    fn is_provider_route_v1beta() {
361        assert!(is_provider_route("/v1beta/models"));
362    }
363
364    #[test]
365    fn is_provider_route_chat() {
366        assert!(is_provider_route("/chat/completions"));
367    }
368
369    #[test]
370    fn is_provider_route_rejects_non_provider() {
371        assert!(!is_provider_route("/health"));
372        assert!(!is_provider_route("/api/v2/test"));
373        assert!(!is_provider_route("/"));
374    }
375
376    fn build_request(headers: &[(&str, &str)], path: &str) -> axum::extract::Request {
377        let mut builder = axum::http::Request::builder().uri(path);
378        for (k, v) in headers {
379            builder = builder.header(*k, *v);
380        }
381        builder.body(axum::body::Body::empty()).unwrap()
382    }
383
384    #[test]
385    fn has_provider_api_key_x_api_key() {
386        let req = build_request(&[("x-api-key", "sk-ant-abc123")], "/v1/messages");
387        assert!(has_provider_api_key(&req));
388    }
389
390    #[test]
391    fn has_provider_api_key_x_goog() {
392        let req = build_request(&[("x-goog-api-key", "AIzaSyAbc")], "/v1beta/models");
393        assert!(has_provider_api_key(&req));
394    }
395
396    #[test]
397    fn has_provider_api_key_azure() {
398        let req = build_request(&[("api-key", "deadbeef")], "/v1/completions");
399        assert!(has_provider_api_key(&req));
400    }
401
402    #[test]
403    fn has_provider_api_key_bearer_sk() {
404        let req = build_request(
405            &[("authorization", "Bearer sk-proj-abc123")],
406            "/v1/chat/completions",
407        );
408        assert!(has_provider_api_key(&req));
409    }
410
411    #[test]
412    fn has_provider_api_key_empty_rejected() {
413        let req = build_request(&[("x-api-key", "  ")], "/v1/messages");
414        assert!(!has_provider_api_key(&req));
415    }
416
417    #[test]
418    fn has_provider_api_key_no_headers() {
419        let req = build_request(&[], "/v1/messages");
420        assert!(!has_provider_api_key(&req));
421    }
422
423    #[test]
424    fn has_provider_api_key_regular_bearer_rejected() {
425        let req = build_request(
426            &[("authorization", "Bearer some-session-token")],
427            "/v1/chat",
428        );
429        assert!(
430            !has_provider_api_key(&req),
431            "non-sk/gsk Bearer should not count as provider key"
432        );
433    }
434}