Skip to main content

lean_ctx/proxy/
mod.rs

1pub mod anthropic;
2pub mod compress;
3pub mod cost;
4pub mod forward;
5pub mod google;
6pub mod history_prune;
7pub mod introspect;
8pub mod metrics;
9pub mod openai;
10pub mod openai_responses;
11pub mod tool_kind;
12
13use std::net::SocketAddr;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::Arc;
16
17use axum::{
18    body::Body,
19    extract::State,
20    http::{Request, StatusCode},
21    response::{IntoResponse, Response},
22    routing::{any, get},
23    Router,
24};
25
26#[derive(Clone)]
27pub struct ProxyState {
28    pub client: reqwest::Client,
29    pub port: u16,
30    pub stats: Arc<ProxyStats>,
31    pub introspect: Arc<introspect::IntrospectState>,
32    pub anthropic_upstream: String,
33    pub openai_upstream: String,
34    pub gemini_upstream: String,
35}
36
37pub struct ProxyStats {
38    pub requests_total: AtomicU64,
39    pub requests_compressed: AtomicU64,
40    pub tokens_saved: AtomicU64,
41    pub bytes_original: AtomicU64,
42    pub bytes_compressed: AtomicU64,
43}
44
45impl Default for ProxyStats {
46    fn default() -> Self {
47        Self {
48            requests_total: AtomicU64::new(0),
49            requests_compressed: AtomicU64::new(0),
50            tokens_saved: AtomicU64::new(0),
51            bytes_original: AtomicU64::new(0),
52            bytes_compressed: AtomicU64::new(0),
53        }
54    }
55}
56
57impl ProxyStats {
58    pub fn record_request(&self) {
59        self.requests_total.fetch_add(1, Ordering::Relaxed);
60    }
61
62    pub fn record_compression(&self, original: usize, compressed: usize) {
63        self.requests_compressed.fetch_add(1, Ordering::Relaxed);
64        self.bytes_original
65            .fetch_add(original as u64, Ordering::Relaxed);
66        self.bytes_compressed
67            .fetch_add(compressed as u64, Ordering::Relaxed);
68        let saved_tokens = (original.saturating_sub(compressed) / 4) as u64;
69        self.tokens_saved.fetch_add(saved_tokens, Ordering::Relaxed);
70    }
71
72    pub fn compression_ratio(&self) -> f64 {
73        let original = self.bytes_original.load(Ordering::Relaxed);
74        if original == 0 {
75            return 0.0;
76        }
77        let compressed = self.bytes_compressed.load(Ordering::Relaxed);
78        (1.0 - compressed as f64 / original as f64) * 100.0
79    }
80}
81
82/// TCP connect timeout (seconds). Configurable via `LEAN_CTX_PROXY_CONNECT_TIMEOUT_SECS`.
83fn connect_timeout_secs() -> u64 {
84    std::env::var("LEAN_CTX_PROXY_CONNECT_TIMEOUT_SECS")
85        .ok()
86        .and_then(|v| v.trim().parse::<u64>().ok())
87        .filter(|s| *s > 0)
88        .unwrap_or(15)
89}
90
91/// Idle read timeout (seconds) between bytes from upstream. Generous by default
92/// so long extended-thinking phases (which still emit SSE keepalives) are never
93/// cut, while a truly dead connection eventually fails. Configurable via
94/// `LEAN_CTX_PROXY_READ_TIMEOUT_SECS`.
95fn read_idle_timeout_secs() -> u64 {
96    std::env::var("LEAN_CTX_PROXY_READ_TIMEOUT_SECS")
97        .ok()
98        .and_then(|v| v.trim().parse::<u64>().ok())
99        .filter(|s| *s > 0)
100        .unwrap_or(300)
101}
102
103pub async fn start_proxy(port: u16) -> anyhow::Result<()> {
104    let token = crate::core::session_token::resolve_proxy_token("LEAN_CTX_PROXY_TOKEN");
105    start_proxy_with_token(port, Some(token)).await
106}
107
108pub async fn start_proxy_with_token(port: u16, auth_token: Option<String>) -> anyhow::Result<()> {
109    use crate::core::config::{Config, ProxyProvider};
110
111    // A single total timeout aborts long streaming generations (e.g. Opus doing
112    // a big refactor) mid-response. Use a connect timeout plus a read (idle)
113    // timeout instead: a genuinely hung upstream still fails, but a slow-but-
114    // alive stream is never cut off. Both are configurable for edge networks.
115    let client = reqwest::Client::builder()
116        .connect_timeout(std::time::Duration::from_secs(connect_timeout_secs()))
117        .read_timeout(std::time::Duration::from_secs(read_idle_timeout_secs()))
118        .build()?;
119
120    let cfg = Config::load();
121    let anthropic_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Anthropic);
122    let openai_upstream = cfg.proxy.resolve_upstream(ProxyProvider::OpenAi);
123    let gemini_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Gemini);
124
125    let state = ProxyState {
126        client,
127        port,
128        stats: Arc::new(ProxyStats::default()),
129        introspect: Arc::new(introspect::IntrospectState::default()),
130        anthropic_upstream: anthropic_upstream.clone(),
131        openai_upstream: openai_upstream.clone(),
132        gemini_upstream: gemini_upstream.clone(),
133    };
134
135    let mut app = Router::new()
136        .route("/health", get(health))
137        .route("/status", get(status_handler))
138        .route("/v1/messages", any(anthropic::handler))
139        .route("/v1/messages/{*rest}", any(anthropic::handler))
140        .route("/v1/chat/completions", any(openai::handler))
141        .route("/v1/responses", any(openai_responses::handler))
142        .route("/v1/responses/{*rest}", any(openai_responses::handler))
143        // Bare provider endpoints (no `/v1` prefix). Clients whose base URL points
144        // at the proxy root — notably OpenCode via `@ai-sdk/openai`, whose
145        // Responses-API requests hit `/responses` — dispatch here. The
146        // `normalize_provider_path` layer rewrites the URI to its canonical
147        // `/v1/...` form before the handler forwards upstream (#353).
148        .route("/messages", any(anthropic::handler))
149        .route("/messages/{*rest}", any(anthropic::handler))
150        .route("/chat/completions", any(openai::handler))
151        .route("/responses", any(openai_responses::handler))
152        .route("/responses/{*rest}", any(openai_responses::handler))
153        .route("/v1/references/{id}", get(v1_resolve_reference))
154        .fallback(fallback_router)
155        .layer(axum::middleware::from_fn(host_guard))
156        .with_state(state);
157
158    if let Some(ref token) = auth_token {
159        let expected = token.clone();
160        app = app.layer(axum::middleware::from_fn(move |req, next| {
161            let expected = expected.clone();
162            proxy_auth_guard(req, next, expected)
163        }));
164    }
165
166    // Outermost layer (runs first): normalize bare provider endpoints to their
167    // canonical `/v1/...` form so auth, routing and upstream forwarding all agree,
168    // regardless of whether the client's base URL includes `/v1` (#353).
169    app = app.layer(axum::middleware::from_fn(normalize_provider_path));
170
171    let addr = SocketAddr::from(([127, 0, 0, 1], port));
172    if auth_token.is_some() {
173        println!("lean-ctx proxy listening on http://{addr} (token auth enabled)");
174    } else {
175        println!("lean-ctx proxy listening on http://{addr} (no auth — set LEAN_CTX_PROXY_TOKEN to enable)");
176    }
177    println!("  Anthropic: POST /v1/messages → {anthropic_upstream}");
178    println!("  OpenAI:    POST /v1/chat/completions → {openai_upstream}");
179    println!("  OpenAI:    POST /v1/responses → {openai_upstream}");
180    println!("  Gemini:    POST /v1beta/models/... → {gemini_upstream}");
181
182    let listener = tokio::net::TcpListener::bind(addr).await?;
183    axum::serve(listener, app)
184        .with_graceful_shutdown(shutdown_signal())
185        .await?;
186
187    println!("lean-ctx proxy shut down cleanly.");
188    Ok(())
189}
190
191async fn shutdown_signal() {
192    let ctrl_c = tokio::signal::ctrl_c();
193
194    #[cfg(unix)]
195    {
196        // Fall back to Ctrl-C only if the SIGTERM handler cannot be installed,
197        // rather than panicking the proxy on startup.
198        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
199            Ok(mut sigterm) => {
200                tokio::select! {
201                    _ = ctrl_c => {},
202                    _ = sigterm.recv() => {},
203                }
204            }
205            Err(e) => {
206                tracing::warn!("lean-ctx proxy: SIGTERM handler unavailable ({e}); Ctrl-C only");
207                ctrl_c.await.ok();
208            }
209        }
210    }
211
212    #[cfg(not(unix))]
213    {
214        ctrl_c.await.ok();
215    }
216
217    println!("lean-ctx proxy: received shutdown signal, draining…");
218}
219
220async fn health() -> impl IntoResponse {
221    let body = serde_json::json!({
222        "status": "ok",
223        "pid": std::process::id(),
224    });
225    (StatusCode::OK, axum::Json(body))
226}
227
228async fn v1_resolve_reference(
229    axum::extract::Path(id): axum::extract::Path<String>,
230) -> impl IntoResponse {
231    match crate::server::reference_store::resolve(&id) {
232        Some(content) => (StatusCode::OK, content),
233        None => (
234            StatusCode::NOT_FOUND,
235            "Reference expired or not found".to_string(),
236        ),
237    }
238}
239
240async fn status_handler(State(state): State<ProxyState>) -> impl IntoResponse {
241    use std::sync::atomic::Ordering::Relaxed;
242    let s = &state.stats;
243    let i = &state.introspect;
244
245    let last_breakdown = i
246        .last_breakdown
247        .lock()
248        .ok()
249        .and_then(|guard| guard.as_ref().map(|b| serde_json::to_value(b).ok()))
250        .flatten();
251
252    let body = serde_json::json!({
253        "status": "running",
254        "port": state.port,
255        "requests_total": s.requests_total.load(Relaxed),
256        "requests_compressed": s.requests_compressed.load(Relaxed),
257        "tokens_saved": s.tokens_saved.load(Relaxed),
258        "tokens_saved_estimated": true,
259        "bytes_original": s.bytes_original.load(Relaxed),
260        "bytes_compressed": s.bytes_compressed.load(Relaxed),
261        "compression_ratio_pct": format!("{:.1}", s.compression_ratio()),
262        "per_model": cost::snapshot(),
263        "note": "Savings are request-side (tokens removed before forwarding); they do not subtract any re-reads the agent performs. Token figures are estimates; USD uses the shared model price table.",
264        "introspect": {
265            "total_requests_analyzed": i.total_requests.load(Relaxed),
266            "total_system_prompt_tokens": i.total_system_prompt_tokens.load(Relaxed),
267            "last_breakdown": last_breakdown,
268        }
269    });
270    (StatusCode::OK, axum::Json(body))
271}
272
273async fn proxy_auth_guard(
274    req: axum::extract::Request,
275    next: axum::middleware::Next,
276    expected_token: String,
277) -> Result<Response, Response> {
278    let path = req.uri().path();
279    if path == "/health" {
280        return Ok(next.run(req).await);
281    }
282
283    // Accept Bearer token (lean-ctx session token)
284    if let Some(auth) = req
285        .headers()
286        .get("authorization")
287        .and_then(|v| v.to_str().ok())
288    {
289        if let Some(token) = auth.strip_prefix("Bearer ") {
290            if constant_time_eq(token.as_bytes(), expected_token.as_bytes()) {
291                return Ok(next.run(req).await);
292            }
293        }
294    }
295
296    // Accept provider API keys on provider routes (loopback-only, host_guard runs first).
297    // AI tools like Claude Code send x-api-key, not Bearer tokens. Since the proxy
298    // only binds to 127.0.0.1, the presence of a valid API key header is sufficient
299    // to authenticate the request as coming from a local AI tool.
300    if has_provider_api_key(&req) && is_provider_route(path) {
301        return Ok(next.run(req).await);
302    }
303
304    let cfg = crate::core::config::Config::load();
305    let hint = match cfg.proxy_enabled {
306        Some(true) => "lean-ctx proxy requires authentication. Use a Bearer token (LEAN_CTX_PROXY_TOKEN) or configure your AI tool's API key.",
307        Some(false) => "lean-ctx proxy is disabled but still running. Run: lean-ctx proxy cleanup",
308        None => "lean-ctx proxy is not configured. Your AI tool's ANTHROPIC_BASE_URL may be pointing here by mistake. Fix: lean-ctx proxy cleanup  OR  lean-ctx proxy enable",
309    };
310
311    let body = serde_json::json!({
312        "type": "error",
313        "error": {
314            "type": "authentication_error",
315            "message": format!("401 Unauthorized — {hint}")
316        }
317    });
318
319    Err((StatusCode::UNAUTHORIZED, axum::Json(body)).into_response())
320}
321
322fn has_provider_api_key(req: &axum::extract::Request) -> bool {
323    let headers = req.headers();
324    for key in ["x-api-key", "x-goog-api-key", "api-key"] {
325        if headers
326            .get(key)
327            .and_then(|v| v.to_str().ok())
328            .is_some_and(|v| !v.trim().is_empty())
329        {
330            return true;
331        }
332    }
333    if let Some(auth) = headers.get("authorization").and_then(|v| v.to_str().ok()) {
334        if auth.starts_with("Bearer sk-") || auth.starts_with("Bearer gsk_") {
335            return true;
336        }
337    }
338    false
339}
340
341fn is_provider_route(path: &str) -> bool {
342    path.starts_with("/v1/")
343        || path.starts_with("/v1beta/")
344        || path.starts_with("/chat/completions")
345        || path.starts_with("/responses")
346        || path.starts_with("/messages")
347}
348
349/// Maps a bare provider endpoint to its canonical `/v1/...` form, preserving any
350/// sub-path. Returns `None` when the path is already canonical or not a known
351/// provider endpoint.
352///
353/// Some OpenAI-compatible clients treat the configured base URL as the API root
354/// and append the bare endpoint, so they send `POST /responses` or
355/// `/chat/completions` instead of `/v1/responses` — notably OpenCode via
356/// `@ai-sdk/openai`, whose Responses-API requests land on `/responses`. The proxy
357/// and every upstream only know the `/v1/...` paths, so an un-prefixed request
358/// would 401 (not a provider route) and then 404 (no handler). (#353)
359fn canonical_provider_path(path: &str) -> Option<String> {
360    const BARE_TO_CANONICAL: &[(&str, &str)] = &[
361        ("/responses", "/v1/responses"),
362        ("/chat/completions", "/v1/chat/completions"),
363        ("/messages", "/v1/messages"),
364    ];
365    for (bare, canonical) in BARE_TO_CANONICAL {
366        if path == *bare {
367            return Some((*canonical).to_string());
368        }
369        if let Some(rest) = path.strip_prefix(&format!("{bare}/")) {
370            return Some(format!("{canonical}/{rest}"));
371        }
372    }
373    None
374}
375
376/// Returns the canonicalized URI for a bare provider endpoint (query preserved),
377/// or `None` when no rewrite is needed. Pure, so the rewrite is unit-testable
378/// without constructing axum middleware plumbing.
379fn normalized_provider_uri(uri: &axum::http::Uri) -> Option<axum::http::Uri> {
380    let canonical = canonical_provider_path(uri.path())?;
381    let new_path_and_query = match uri.query() {
382        Some(q) => format!("{canonical}?{q}"),
383        None => canonical,
384    };
385    new_path_and_query.parse::<axum::http::Uri>().ok()
386}
387
388/// Rewrites the request URI in place when it targets a bare provider endpoint, so
389/// downstream auth (`is_provider_route`), routing and upstream forwarding all see
390/// the canonical `/v1/...` path. (#353)
391async fn normalize_provider_path(
392    mut req: axum::extract::Request,
393    next: axum::middleware::Next,
394) -> Response {
395    if let Some(uri) = normalized_provider_uri(req.uri()) {
396        *req.uri_mut() = uri;
397    }
398    next.run(req).await
399}
400
401fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
402    use subtle::ConstantTimeEq;
403    if a.len() != b.len() {
404        return false;
405    }
406    bool::from(a.ct_eq(b))
407}
408
409async fn host_guard(
410    req: axum::extract::Request,
411    next: axum::middleware::Next,
412) -> Result<Response, StatusCode> {
413    if let Some(host) = req.headers().get("host").and_then(|v| v.to_str().ok()) {
414        let h = host.split(':').next().unwrap_or(host);
415        if matches!(h, "127.0.0.1" | "localhost" | "[::1]") {
416            return Ok(next.run(req).await);
417        }
418    }
419    Err(StatusCode::FORBIDDEN)
420}
421
422async fn fallback_router(State(state): State<ProxyState>, req: Request<Body>) -> Response {
423    let path = req.uri().path().to_string();
424
425    if path.starts_with("/v1beta/models/") || path.starts_with("/v1/models/") {
426        match google::handler(State(state), req).await {
427            Ok(resp) => resp,
428            Err(status) => Response::builder()
429                .status(status)
430                .body(Body::from("proxy error"))
431                .expect("BUG: building error response with valid status should never fail"),
432        }
433    } else {
434        let method = req.method().to_string();
435        eprintln!("lean-ctx proxy: unmatched {method} {path}");
436        Response::builder()
437            .status(StatusCode::NOT_FOUND)
438            .body(Body::from(format!(
439                "lean-ctx proxy: no handler for {method} {path}"
440            )))
441            .expect("BUG: building 404 response should never fail")
442    }
443}
444
445#[cfg(test)]
446mod auth_tests {
447    use super::*;
448
449    #[test]
450    fn is_provider_route_v1() {
451        assert!(is_provider_route("/v1/chat/completions"));
452        assert!(is_provider_route("/v1/messages"));
453        assert!(is_provider_route("/v1/completions"));
454    }
455
456    #[test]
457    fn is_provider_route_anthropic_subpaths() {
458        assert!(is_provider_route("/v1/messages/count_tokens"));
459        assert!(is_provider_route("/v1/messages/batches"));
460        assert!(is_provider_route("/v1/messages/batches/batch_123"));
461    }
462
463    #[test]
464    fn is_provider_route_v1beta() {
465        assert!(is_provider_route("/v1beta/models"));
466    }
467
468    #[test]
469    fn is_provider_route_chat() {
470        assert!(is_provider_route("/chat/completions"));
471    }
472
473    #[test]
474    fn is_provider_route_rejects_non_provider() {
475        assert!(!is_provider_route("/health"));
476        assert!(!is_provider_route("/api/v2/test"));
477        assert!(!is_provider_route("/"));
478    }
479
480    fn build_request(headers: &[(&str, &str)], path: &str) -> axum::extract::Request {
481        let mut builder = axum::http::Request::builder().uri(path);
482        for (k, v) in headers {
483            builder = builder.header(*k, *v);
484        }
485        builder.body(axum::body::Body::empty()).unwrap()
486    }
487
488    #[test]
489    fn has_provider_api_key_x_api_key() {
490        let req = build_request(&[("x-api-key", "sk-ant-abc123")], "/v1/messages");
491        assert!(has_provider_api_key(&req));
492    }
493
494    #[test]
495    fn has_provider_api_key_x_goog() {
496        let req = build_request(&[("x-goog-api-key", "AIzaSyAbc")], "/v1beta/models");
497        assert!(has_provider_api_key(&req));
498    }
499
500    #[test]
501    fn has_provider_api_key_azure() {
502        let req = build_request(&[("api-key", "deadbeef")], "/v1/completions");
503        assert!(has_provider_api_key(&req));
504    }
505
506    #[test]
507    fn has_provider_api_key_bearer_sk() {
508        let req = build_request(
509            &[("authorization", "Bearer sk-proj-abc123")],
510            "/v1/chat/completions",
511        );
512        assert!(has_provider_api_key(&req));
513    }
514
515    #[test]
516    fn has_provider_api_key_empty_rejected() {
517        let req = build_request(&[("x-api-key", "  ")], "/v1/messages");
518        assert!(!has_provider_api_key(&req));
519    }
520
521    #[test]
522    fn has_provider_api_key_no_headers() {
523        let req = build_request(&[], "/v1/messages");
524        assert!(!has_provider_api_key(&req));
525    }
526
527    #[test]
528    fn has_provider_api_key_regular_bearer_rejected() {
529        let req = build_request(
530            &[("authorization", "Bearer some-session-token")],
531            "/v1/chat",
532        );
533        assert!(
534            !has_provider_api_key(&req),
535            "non-sk/gsk Bearer should not count as provider key"
536        );
537    }
538
539    // --- #353: bare provider endpoints (OpenCode / @ai-sdk/openai) ---
540
541    #[test]
542    fn is_provider_route_bare_responses_and_messages() {
543        // Clients that point their base URL at the proxy root (no `/v1`) send the
544        // bare endpoint; auth must still recognise it as a provider route.
545        assert!(is_provider_route("/responses"));
546        assert!(is_provider_route("/responses/resp_123/input_items"));
547        assert!(is_provider_route("/messages"));
548    }
549
550    #[test]
551    fn canonical_provider_path_rewrites_bare_endpoints() {
552        assert_eq!(
553            canonical_provider_path("/responses").as_deref(),
554            Some("/v1/responses")
555        );
556        assert_eq!(
557            canonical_provider_path("/chat/completions").as_deref(),
558            Some("/v1/chat/completions")
559        );
560        assert_eq!(
561            canonical_provider_path("/messages").as_deref(),
562            Some("/v1/messages")
563        );
564    }
565
566    #[test]
567    fn canonical_provider_path_preserves_subpaths() {
568        assert_eq!(
569            canonical_provider_path("/responses/resp_abc/cancel").as_deref(),
570            Some("/v1/responses/resp_abc/cancel")
571        );
572        assert_eq!(
573            canonical_provider_path("/messages/batches/batch_1").as_deref(),
574            Some("/v1/messages/batches/batch_1")
575        );
576    }
577
578    #[test]
579    fn canonical_provider_path_ignores_already_canonical_and_unknown() {
580        // Already canonical → no rewrite (avoids `/v1/v1/...`).
581        assert_eq!(canonical_provider_path("/v1/responses"), None);
582        assert_eq!(canonical_provider_path("/v1/chat/completions"), None);
583        // Unrelated paths are untouched.
584        assert_eq!(canonical_provider_path("/health"), None);
585        assert_eq!(canonical_provider_path("/responsesx"), None);
586        assert_eq!(canonical_provider_path("/"), None);
587    }
588
589    #[test]
590    fn normalized_provider_uri_rewrites_path_and_preserves_query() {
591        use axum::http::Uri;
592        let uri: Uri = "/responses?stream=true".parse().unwrap();
593        let rewritten = normalized_provider_uri(&uri).expect("bare /responses must rewrite");
594        assert_eq!(rewritten.path(), "/v1/responses");
595        assert_eq!(rewritten.query(), Some("stream=true"));
596        assert_eq!(
597            rewritten
598                .path_and_query()
599                .map(axum::http::uri::PathAndQuery::as_str),
600            Some("/v1/responses?stream=true")
601        );
602    }
603
604    #[test]
605    fn normalized_provider_uri_noop_for_canonical() {
606        use axum::http::Uri;
607        let uri: Uri = "/v1/responses".parse().unwrap();
608        assert!(normalized_provider_uri(&uri).is_none());
609    }
610}