Skip to main content

lean_ctx/proxy/
mod.rs

1pub mod anthropic;
2pub mod compress;
3pub mod cost;
4pub mod forward;
5pub mod google;
6pub mod history_prune;
7pub mod introspect;
8pub mod metrics;
9pub mod openai;
10pub mod openai_responses;
11pub mod tool_kind;
12
13use std::net::SocketAddr;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::Arc;
16
17use axum::{
18    body::Body,
19    extract::State,
20    http::{Request, StatusCode},
21    response::{IntoResponse, Response},
22    routing::{any, get},
23    Router,
24};
25
26#[derive(Clone)]
27pub struct ProxyState {
28    pub client: reqwest::Client,
29    pub port: u16,
30    pub stats: Arc<ProxyStats>,
31    pub introspect: Arc<introspect::IntrospectState>,
32    pub anthropic_upstream: String,
33    pub openai_upstream: String,
34    pub gemini_upstream: String,
35}
36
37pub struct ProxyStats {
38    pub requests_total: AtomicU64,
39    pub requests_compressed: AtomicU64,
40    pub tokens_saved: AtomicU64,
41    pub bytes_original: AtomicU64,
42    pub bytes_compressed: AtomicU64,
43}
44
45impl Default for ProxyStats {
46    fn default() -> Self {
47        Self {
48            requests_total: AtomicU64::new(0),
49            requests_compressed: AtomicU64::new(0),
50            tokens_saved: AtomicU64::new(0),
51            bytes_original: AtomicU64::new(0),
52            bytes_compressed: AtomicU64::new(0),
53        }
54    }
55}
56
57impl ProxyStats {
58    pub fn record_request(&self) {
59        self.requests_total.fetch_add(1, Ordering::Relaxed);
60    }
61
62    pub fn record_compression(&self, original: usize, compressed: usize) {
63        self.requests_compressed.fetch_add(1, Ordering::Relaxed);
64        self.bytes_original
65            .fetch_add(original as u64, Ordering::Relaxed);
66        self.bytes_compressed
67            .fetch_add(compressed as u64, Ordering::Relaxed);
68        let saved_tokens = (original.saturating_sub(compressed) / 4) as u64;
69        self.tokens_saved.fetch_add(saved_tokens, Ordering::Relaxed);
70    }
71
72    pub fn compression_ratio(&self) -> f64 {
73        let original = self.bytes_original.load(Ordering::Relaxed);
74        if original == 0 {
75            return 0.0;
76        }
77        let compressed = self.bytes_compressed.load(Ordering::Relaxed);
78        (1.0 - compressed as f64 / original as f64) * 100.0
79    }
80}
81
82/// TCP connect timeout (seconds). Configurable via `LEAN_CTX_PROXY_CONNECT_TIMEOUT_SECS`.
83fn connect_timeout_secs() -> u64 {
84    std::env::var("LEAN_CTX_PROXY_CONNECT_TIMEOUT_SECS")
85        .ok()
86        .and_then(|v| v.trim().parse::<u64>().ok())
87        .filter(|s| *s > 0)
88        .unwrap_or(15)
89}
90
91/// Idle read timeout (seconds) between bytes from upstream. Generous by default
92/// so long extended-thinking phases (which still emit SSE keepalives) are never
93/// cut, while a truly dead connection eventually fails. Configurable via
94/// `LEAN_CTX_PROXY_READ_TIMEOUT_SECS`.
95fn read_idle_timeout_secs() -> u64 {
96    std::env::var("LEAN_CTX_PROXY_READ_TIMEOUT_SECS")
97        .ok()
98        .and_then(|v| v.trim().parse::<u64>().ok())
99        .filter(|s| *s > 0)
100        .unwrap_or(300)
101}
102
103pub async fn start_proxy(port: u16) -> anyhow::Result<()> {
104    let token = crate::core::session_token::resolve_proxy_token("LEAN_CTX_PROXY_TOKEN");
105    start_proxy_with_token(port, Some(token)).await
106}
107
108pub async fn start_proxy_with_token(port: u16, auth_token: Option<String>) -> anyhow::Result<()> {
109    use crate::core::config::{Config, ProxyProvider};
110
111    // A single total timeout aborts long streaming generations (e.g. Opus doing
112    // a big refactor) mid-response. Use a connect timeout plus a read (idle)
113    // timeout instead: a genuinely hung upstream still fails, but a slow-but-
114    // alive stream is never cut off. Both are configurable for edge networks.
115    let client = reqwest::Client::builder()
116        .connect_timeout(std::time::Duration::from_secs(connect_timeout_secs()))
117        .read_timeout(std::time::Duration::from_secs(read_idle_timeout_secs()))
118        .build()?;
119
120    let cfg = Config::load();
121    let anthropic_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Anthropic);
122    let openai_upstream = cfg.proxy.resolve_upstream(ProxyProvider::OpenAi);
123    let gemini_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Gemini);
124
125    let state = ProxyState {
126        client,
127        port,
128        stats: Arc::new(ProxyStats::default()),
129        introspect: Arc::new(introspect::IntrospectState::default()),
130        anthropic_upstream: anthropic_upstream.clone(),
131        openai_upstream: openai_upstream.clone(),
132        gemini_upstream: gemini_upstream.clone(),
133    };
134
135    let mut app = Router::new()
136        .route("/health", get(health))
137        .route("/status", get(status_handler))
138        .route("/v1/messages", any(anthropic::handler))
139        .route("/v1/messages/{*rest}", any(anthropic::handler))
140        .route("/v1/chat/completions", any(openai::handler))
141        .route("/v1/responses", any(openai_responses::handler))
142        .route("/v1/responses/{*rest}", any(openai_responses::handler))
143        .route("/v1/references/{id}", get(v1_resolve_reference))
144        .fallback(fallback_router)
145        .layer(axum::middleware::from_fn(host_guard))
146        .with_state(state);
147
148    if let Some(ref token) = auth_token {
149        let expected = token.clone();
150        app = app.layer(axum::middleware::from_fn(move |req, next| {
151            let expected = expected.clone();
152            proxy_auth_guard(req, next, expected)
153        }));
154    }
155
156    let addr = SocketAddr::from(([127, 0, 0, 1], port));
157    if auth_token.is_some() {
158        println!("lean-ctx proxy listening on http://{addr} (token auth enabled)");
159    } else {
160        println!("lean-ctx proxy listening on http://{addr} (no auth — set LEAN_CTX_PROXY_TOKEN to enable)");
161    }
162    println!("  Anthropic: POST /v1/messages → {anthropic_upstream}");
163    println!("  OpenAI:    POST /v1/chat/completions → {openai_upstream}");
164    println!("  OpenAI:    POST /v1/responses → {openai_upstream}");
165    println!("  Gemini:    POST /v1beta/models/... → {gemini_upstream}");
166
167    let listener = tokio::net::TcpListener::bind(addr).await?;
168    axum::serve(listener, app)
169        .with_graceful_shutdown(shutdown_signal())
170        .await?;
171
172    println!("lean-ctx proxy shut down cleanly.");
173    Ok(())
174}
175
176async fn shutdown_signal() {
177    let ctrl_c = tokio::signal::ctrl_c();
178
179    #[cfg(unix)]
180    {
181        // Fall back to Ctrl-C only if the SIGTERM handler cannot be installed,
182        // rather than panicking the proxy on startup.
183        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
184            Ok(mut sigterm) => {
185                tokio::select! {
186                    _ = ctrl_c => {},
187                    _ = sigterm.recv() => {},
188                }
189            }
190            Err(e) => {
191                tracing::warn!("lean-ctx proxy: SIGTERM handler unavailable ({e}); Ctrl-C only");
192                ctrl_c.await.ok();
193            }
194        }
195    }
196
197    #[cfg(not(unix))]
198    {
199        ctrl_c.await.ok();
200    }
201
202    println!("lean-ctx proxy: received shutdown signal, draining…");
203}
204
205async fn health() -> impl IntoResponse {
206    let body = serde_json::json!({
207        "status": "ok",
208        "pid": std::process::id(),
209    });
210    (StatusCode::OK, axum::Json(body))
211}
212
213async fn v1_resolve_reference(
214    axum::extract::Path(id): axum::extract::Path<String>,
215) -> impl IntoResponse {
216    match crate::server::reference_store::resolve(&id) {
217        Some(content) => (StatusCode::OK, content),
218        None => (
219            StatusCode::NOT_FOUND,
220            "Reference expired or not found".to_string(),
221        ),
222    }
223}
224
225async fn status_handler(State(state): State<ProxyState>) -> impl IntoResponse {
226    use std::sync::atomic::Ordering::Relaxed;
227    let s = &state.stats;
228    let i = &state.introspect;
229
230    let last_breakdown = i
231        .last_breakdown
232        .lock()
233        .ok()
234        .and_then(|guard| guard.as_ref().map(|b| serde_json::to_value(b).ok()))
235        .flatten();
236
237    let body = serde_json::json!({
238        "status": "running",
239        "port": state.port,
240        "requests_total": s.requests_total.load(Relaxed),
241        "requests_compressed": s.requests_compressed.load(Relaxed),
242        "tokens_saved": s.tokens_saved.load(Relaxed),
243        "tokens_saved_estimated": true,
244        "bytes_original": s.bytes_original.load(Relaxed),
245        "bytes_compressed": s.bytes_compressed.load(Relaxed),
246        "compression_ratio_pct": format!("{:.1}", s.compression_ratio()),
247        "per_model": cost::snapshot(),
248        "note": "Savings are request-side (tokens removed before forwarding); they do not subtract any re-reads the agent performs. Token figures are estimates; USD uses the shared model price table.",
249        "introspect": {
250            "total_requests_analyzed": i.total_requests.load(Relaxed),
251            "total_system_prompt_tokens": i.total_system_prompt_tokens.load(Relaxed),
252            "last_breakdown": last_breakdown,
253        }
254    });
255    (StatusCode::OK, axum::Json(body))
256}
257
258async fn proxy_auth_guard(
259    req: axum::extract::Request,
260    next: axum::middleware::Next,
261    expected_token: String,
262) -> Result<Response, Response> {
263    let path = req.uri().path();
264    if path == "/health" {
265        return Ok(next.run(req).await);
266    }
267
268    // Accept Bearer token (lean-ctx session token)
269    if let Some(auth) = req
270        .headers()
271        .get("authorization")
272        .and_then(|v| v.to_str().ok())
273    {
274        if let Some(token) = auth.strip_prefix("Bearer ") {
275            if constant_time_eq(token.as_bytes(), expected_token.as_bytes()) {
276                return Ok(next.run(req).await);
277            }
278        }
279    }
280
281    // Accept provider API keys on provider routes (loopback-only, host_guard runs first).
282    // AI tools like Claude Code send x-api-key, not Bearer tokens. Since the proxy
283    // only binds to 127.0.0.1, the presence of a valid API key header is sufficient
284    // to authenticate the request as coming from a local AI tool.
285    if has_provider_api_key(&req) && is_provider_route(path) {
286        return Ok(next.run(req).await);
287    }
288
289    let cfg = crate::core::config::Config::load();
290    let hint = match cfg.proxy_enabled {
291        Some(true) => "lean-ctx proxy requires authentication. Use a Bearer token (LEAN_CTX_PROXY_TOKEN) or configure your AI tool's API key.",
292        Some(false) => "lean-ctx proxy is disabled but still running. Run: lean-ctx proxy cleanup",
293        None => "lean-ctx proxy is not configured. Your AI tool's ANTHROPIC_BASE_URL may be pointing here by mistake. Fix: lean-ctx proxy cleanup  OR  lean-ctx proxy enable",
294    };
295
296    let body = serde_json::json!({
297        "type": "error",
298        "error": {
299            "type": "authentication_error",
300            "message": format!("401 Unauthorized — {hint}")
301        }
302    });
303
304    Err((StatusCode::UNAUTHORIZED, axum::Json(body)).into_response())
305}
306
307fn has_provider_api_key(req: &axum::extract::Request) -> bool {
308    let headers = req.headers();
309    for key in ["x-api-key", "x-goog-api-key", "api-key"] {
310        if headers
311            .get(key)
312            .and_then(|v| v.to_str().ok())
313            .is_some_and(|v| !v.trim().is_empty())
314        {
315            return true;
316        }
317    }
318    if let Some(auth) = headers.get("authorization").and_then(|v| v.to_str().ok()) {
319        if auth.starts_with("Bearer sk-") || auth.starts_with("Bearer gsk_") {
320            return true;
321        }
322    }
323    false
324}
325
326fn is_provider_route(path: &str) -> bool {
327    path.starts_with("/v1/")
328        || path.starts_with("/v1beta/")
329        || path.starts_with("/chat/completions")
330}
331
332fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
333    use subtle::ConstantTimeEq;
334    if a.len() != b.len() {
335        return false;
336    }
337    bool::from(a.ct_eq(b))
338}
339
340async fn host_guard(
341    req: axum::extract::Request,
342    next: axum::middleware::Next,
343) -> Result<Response, StatusCode> {
344    if let Some(host) = req.headers().get("host").and_then(|v| v.to_str().ok()) {
345        let h = host.split(':').next().unwrap_or(host);
346        if matches!(h, "127.0.0.1" | "localhost" | "[::1]") {
347            return Ok(next.run(req).await);
348        }
349    }
350    Err(StatusCode::FORBIDDEN)
351}
352
353async fn fallback_router(State(state): State<ProxyState>, req: Request<Body>) -> Response {
354    let path = req.uri().path().to_string();
355
356    if path.starts_with("/v1beta/models/") || path.starts_with("/v1/models/") {
357        match google::handler(State(state), req).await {
358            Ok(resp) => resp,
359            Err(status) => Response::builder()
360                .status(status)
361                .body(Body::from("proxy error"))
362                .expect("BUG: building error response with valid status should never fail"),
363        }
364    } else {
365        let method = req.method().to_string();
366        eprintln!("lean-ctx proxy: unmatched {method} {path}");
367        Response::builder()
368            .status(StatusCode::NOT_FOUND)
369            .body(Body::from(format!(
370                "lean-ctx proxy: no handler for {method} {path}"
371            )))
372            .expect("BUG: building 404 response should never fail")
373    }
374}
375
376#[cfg(test)]
377mod auth_tests {
378    use super::*;
379
380    #[test]
381    fn is_provider_route_v1() {
382        assert!(is_provider_route("/v1/chat/completions"));
383        assert!(is_provider_route("/v1/messages"));
384        assert!(is_provider_route("/v1/completions"));
385    }
386
387    #[test]
388    fn is_provider_route_anthropic_subpaths() {
389        assert!(is_provider_route("/v1/messages/count_tokens"));
390        assert!(is_provider_route("/v1/messages/batches"));
391        assert!(is_provider_route("/v1/messages/batches/batch_123"));
392    }
393
394    #[test]
395    fn is_provider_route_v1beta() {
396        assert!(is_provider_route("/v1beta/models"));
397    }
398
399    #[test]
400    fn is_provider_route_chat() {
401        assert!(is_provider_route("/chat/completions"));
402    }
403
404    #[test]
405    fn is_provider_route_rejects_non_provider() {
406        assert!(!is_provider_route("/health"));
407        assert!(!is_provider_route("/api/v2/test"));
408        assert!(!is_provider_route("/"));
409    }
410
411    fn build_request(headers: &[(&str, &str)], path: &str) -> axum::extract::Request {
412        let mut builder = axum::http::Request::builder().uri(path);
413        for (k, v) in headers {
414            builder = builder.header(*k, *v);
415        }
416        builder.body(axum::body::Body::empty()).unwrap()
417    }
418
419    #[test]
420    fn has_provider_api_key_x_api_key() {
421        let req = build_request(&[("x-api-key", "sk-ant-abc123")], "/v1/messages");
422        assert!(has_provider_api_key(&req));
423    }
424
425    #[test]
426    fn has_provider_api_key_x_goog() {
427        let req = build_request(&[("x-goog-api-key", "AIzaSyAbc")], "/v1beta/models");
428        assert!(has_provider_api_key(&req));
429    }
430
431    #[test]
432    fn has_provider_api_key_azure() {
433        let req = build_request(&[("api-key", "deadbeef")], "/v1/completions");
434        assert!(has_provider_api_key(&req));
435    }
436
437    #[test]
438    fn has_provider_api_key_bearer_sk() {
439        let req = build_request(
440            &[("authorization", "Bearer sk-proj-abc123")],
441            "/v1/chat/completions",
442        );
443        assert!(has_provider_api_key(&req));
444    }
445
446    #[test]
447    fn has_provider_api_key_empty_rejected() {
448        let req = build_request(&[("x-api-key", "  ")], "/v1/messages");
449        assert!(!has_provider_api_key(&req));
450    }
451
452    #[test]
453    fn has_provider_api_key_no_headers() {
454        let req = build_request(&[], "/v1/messages");
455        assert!(!has_provider_api_key(&req));
456    }
457
458    #[test]
459    fn has_provider_api_key_regular_bearer_rejected() {
460        let req = build_request(
461            &[("authorization", "Bearer some-session-token")],
462            "/v1/chat",
463        );
464        assert!(
465            !has_provider_api_key(&req),
466            "non-sk/gsk Bearer should not count as provider key"
467        );
468    }
469}