Skip to main content

lean_ctx/proxy/
mod.rs

1pub mod anthropic;
2pub mod compress;
3pub mod forward;
4pub mod google;
5pub mod introspect;
6pub mod metrics;
7pub mod openai;
8
9use std::net::SocketAddr;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12
13use axum::{
14    body::Body,
15    extract::State,
16    http::{Request, StatusCode},
17    response::{IntoResponse, Response},
18    routing::{any, get},
19    Router,
20};
21
22#[derive(Clone)]
23pub struct ProxyState {
24    pub client: reqwest::Client,
25    pub port: u16,
26    pub stats: Arc<ProxyStats>,
27    pub introspect: Arc<introspect::IntrospectState>,
28    pub anthropic_upstream: String,
29    pub openai_upstream: String,
30    pub gemini_upstream: String,
31}
32
33pub struct ProxyStats {
34    pub requests_total: AtomicU64,
35    pub requests_compressed: AtomicU64,
36    pub tokens_saved: AtomicU64,
37    pub bytes_original: AtomicU64,
38    pub bytes_compressed: AtomicU64,
39}
40
41impl Default for ProxyStats {
42    fn default() -> Self {
43        Self {
44            requests_total: AtomicU64::new(0),
45            requests_compressed: AtomicU64::new(0),
46            tokens_saved: AtomicU64::new(0),
47            bytes_original: AtomicU64::new(0),
48            bytes_compressed: AtomicU64::new(0),
49        }
50    }
51}
52
53impl ProxyStats {
54    pub fn record_request(&self) {
55        self.requests_total.fetch_add(1, Ordering::Relaxed);
56    }
57
58    pub fn record_compression(&self, original: usize, compressed: usize) {
59        self.requests_compressed.fetch_add(1, Ordering::Relaxed);
60        self.bytes_original
61            .fetch_add(original as u64, Ordering::Relaxed);
62        self.bytes_compressed
63            .fetch_add(compressed as u64, Ordering::Relaxed);
64        let saved_tokens = (original.saturating_sub(compressed) / 4) as u64;
65        self.tokens_saved.fetch_add(saved_tokens, Ordering::Relaxed);
66    }
67
68    pub fn compression_ratio(&self) -> f64 {
69        let original = self.bytes_original.load(Ordering::Relaxed);
70        if original == 0 {
71            return 0.0;
72        }
73        let compressed = self.bytes_compressed.load(Ordering::Relaxed);
74        (1.0 - compressed as f64 / original as f64) * 100.0
75    }
76}
77
78pub async fn start_proxy(port: u16) -> anyhow::Result<()> {
79    let token = crate::core::session_token::resolve_proxy_token("LEAN_CTX_PROXY_TOKEN");
80    start_proxy_with_token(port, Some(token)).await
81}
82
83pub async fn start_proxy_with_token(port: u16, auth_token: Option<String>) -> anyhow::Result<()> {
84    use crate::core::config::{Config, ProxyProvider};
85
86    let client = reqwest::Client::builder()
87        .timeout(std::time::Duration::from_mins(2))
88        .build()?;
89
90    let cfg = Config::load();
91    let anthropic_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Anthropic);
92    let openai_upstream = cfg.proxy.resolve_upstream(ProxyProvider::OpenAi);
93    let gemini_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Gemini);
94
95    let state = ProxyState {
96        client,
97        port,
98        stats: Arc::new(ProxyStats::default()),
99        introspect: Arc::new(introspect::IntrospectState::default()),
100        anthropic_upstream: anthropic_upstream.clone(),
101        openai_upstream: openai_upstream.clone(),
102        gemini_upstream: gemini_upstream.clone(),
103    };
104
105    let mut app = Router::new()
106        .route("/health", get(health))
107        .route("/status", get(status_handler))
108        .route("/v1/messages", any(anthropic::handler))
109        .route("/v1/chat/completions", any(openai::handler))
110        .route("/v1/references/{id}", get(v1_resolve_reference))
111        .fallback(fallback_router)
112        .layer(axum::middleware::from_fn(host_guard))
113        .with_state(state);
114
115    if let Some(ref token) = auth_token {
116        let expected = token.clone();
117        app = app.layer(axum::middleware::from_fn(move |req, next| {
118            let expected = expected.clone();
119            proxy_auth_guard(req, next, expected)
120        }));
121    }
122
123    let addr = SocketAddr::from(([127, 0, 0, 1], port));
124    if auth_token.is_some() {
125        println!("lean-ctx proxy listening on http://{addr} (token auth enabled)");
126    } else {
127        println!("lean-ctx proxy listening on http://{addr} (no auth — set LEAN_CTX_PROXY_TOKEN to enable)");
128    }
129    println!("  Anthropic: POST /v1/messages → {anthropic_upstream}");
130    println!("  OpenAI:    POST /v1/chat/completions → {openai_upstream}");
131    println!("  Gemini:    POST /v1beta/models/... → {gemini_upstream}");
132
133    let listener = tokio::net::TcpListener::bind(addr).await?;
134    axum::serve(listener, app)
135        .with_graceful_shutdown(shutdown_signal())
136        .await?;
137
138    println!("lean-ctx proxy shut down cleanly.");
139    Ok(())
140}
141
142async fn shutdown_signal() {
143    let ctrl_c = tokio::signal::ctrl_c();
144
145    #[cfg(unix)]
146    {
147        let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
148            .expect("failed to install SIGTERM handler");
149        tokio::select! {
150            _ = ctrl_c => {},
151            _ = sigterm.recv() => {},
152        }
153    }
154
155    #[cfg(not(unix))]
156    {
157        ctrl_c.await.ok();
158    }
159
160    println!("lean-ctx proxy: received shutdown signal, draining…");
161}
162
163async fn health() -> impl IntoResponse {
164    let body = serde_json::json!({
165        "status": "ok",
166        "pid": std::process::id(),
167    });
168    (StatusCode::OK, axum::Json(body))
169}
170
171async fn v1_resolve_reference(
172    axum::extract::Path(id): axum::extract::Path<String>,
173) -> impl IntoResponse {
174    match crate::server::reference_store::resolve(&id) {
175        Some(content) => (StatusCode::OK, content),
176        None => (
177            StatusCode::NOT_FOUND,
178            "Reference expired or not found".to_string(),
179        ),
180    }
181}
182
183async fn status_handler(State(state): State<ProxyState>) -> impl IntoResponse {
184    use std::sync::atomic::Ordering::Relaxed;
185    let s = &state.stats;
186    let i = &state.introspect;
187
188    let last_breakdown = i
189        .last_breakdown
190        .lock()
191        .ok()
192        .and_then(|guard| guard.as_ref().map(|b| serde_json::to_value(b).ok()))
193        .flatten();
194
195    let body = serde_json::json!({
196        "status": "running",
197        "port": state.port,
198        "requests_total": s.requests_total.load(Relaxed),
199        "requests_compressed": s.requests_compressed.load(Relaxed),
200        "tokens_saved": s.tokens_saved.load(Relaxed),
201        "bytes_original": s.bytes_original.load(Relaxed),
202        "bytes_compressed": s.bytes_compressed.load(Relaxed),
203        "compression_ratio_pct": format!("{:.1}", s.compression_ratio()),
204        "introspect": {
205            "total_requests_analyzed": i.total_requests.load(Relaxed),
206            "total_system_prompt_tokens": i.total_system_prompt_tokens.load(Relaxed),
207            "last_breakdown": last_breakdown,
208        }
209    });
210    (StatusCode::OK, axum::Json(body))
211}
212
213async fn proxy_auth_guard(
214    req: axum::extract::Request,
215    next: axum::middleware::Next,
216    expected_token: String,
217) -> Result<Response, Response> {
218    let path = req.uri().path();
219    if path == "/health" {
220        return Ok(next.run(req).await);
221    }
222
223    // Accept Bearer token (lean-ctx session token)
224    if let Some(auth) = req
225        .headers()
226        .get("authorization")
227        .and_then(|v| v.to_str().ok())
228    {
229        if let Some(token) = auth.strip_prefix("Bearer ") {
230            if constant_time_eq(token.as_bytes(), expected_token.as_bytes()) {
231                return Ok(next.run(req).await);
232            }
233        }
234    }
235
236    // Accept provider API keys on provider routes (loopback-only, host_guard runs first).
237    // AI tools like Claude Code send x-api-key, not Bearer tokens. Since the proxy
238    // only binds to 127.0.0.1, the presence of a valid API key header is sufficient
239    // to authenticate the request as coming from a local AI tool.
240    if has_provider_api_key(&req) && is_provider_route(path) {
241        return Ok(next.run(req).await);
242    }
243
244    let cfg = crate::core::config::Config::load();
245    let hint = match cfg.proxy_enabled {
246        Some(true) => "lean-ctx proxy requires authentication. Use a Bearer token (LEAN_CTX_PROXY_TOKEN) or configure your AI tool's API key.",
247        Some(false) => "lean-ctx proxy is disabled but still running. Run: lean-ctx proxy cleanup",
248        None => "lean-ctx proxy is not configured. Your AI tool's ANTHROPIC_BASE_URL may be pointing here by mistake. Fix: lean-ctx proxy cleanup  OR  lean-ctx proxy enable",
249    };
250
251    let body = serde_json::json!({
252        "type": "error",
253        "error": {
254            "type": "authentication_error",
255            "message": format!("401 Unauthorized — {hint}")
256        }
257    });
258
259    Err((StatusCode::UNAUTHORIZED, axum::Json(body)).into_response())
260}
261
262fn has_provider_api_key(req: &axum::extract::Request) -> bool {
263    let headers = req.headers();
264    for key in ["x-api-key", "x-goog-api-key", "api-key"] {
265        if headers
266            .get(key)
267            .and_then(|v| v.to_str().ok())
268            .is_some_and(|v| !v.trim().is_empty())
269        {
270            return true;
271        }
272    }
273    if let Some(auth) = headers.get("authorization").and_then(|v| v.to_str().ok()) {
274        if auth.starts_with("Bearer sk-") || auth.starts_with("Bearer gsk_") {
275            return true;
276        }
277    }
278    false
279}
280
281fn is_provider_route(path: &str) -> bool {
282    path.starts_with("/v1/")
283        || path.starts_with("/v1beta/")
284        || path.starts_with("/chat/completions")
285}
286
287fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
288    use subtle::ConstantTimeEq;
289    if a.len() != b.len() {
290        return false;
291    }
292    bool::from(a.ct_eq(b))
293}
294
295async fn host_guard(
296    req: axum::extract::Request,
297    next: axum::middleware::Next,
298) -> Result<Response, StatusCode> {
299    if let Some(host) = req.headers().get("host").and_then(|v| v.to_str().ok()) {
300        let h = host.split(':').next().unwrap_or(host);
301        if matches!(h, "127.0.0.1" | "localhost" | "[::1]") {
302            return Ok(next.run(req).await);
303        }
304    }
305    Err(StatusCode::FORBIDDEN)
306}
307
308async fn fallback_router(State(state): State<ProxyState>, req: Request<Body>) -> Response {
309    let path = req.uri().path().to_string();
310
311    if path.starts_with("/v1beta/models/") || path.starts_with("/v1/models/") {
312        match google::handler(State(state), req).await {
313            Ok(resp) => resp,
314            Err(status) => Response::builder()
315                .status(status)
316                .body(Body::from("proxy error"))
317                .expect("BUG: building error response with valid status should never fail"),
318        }
319    } else {
320        let method = req.method().to_string();
321        eprintln!("lean-ctx proxy: unmatched {method} {path}");
322        Response::builder()
323            .status(StatusCode::NOT_FOUND)
324            .body(Body::from(format!(
325                "lean-ctx proxy: no handler for {method} {path}"
326            )))
327            .expect("BUG: building 404 response should never fail")
328    }
329}
330
331#[cfg(test)]
332mod auth_tests {
333    use super::*;
334
335    #[test]
336    fn is_provider_route_v1() {
337        assert!(is_provider_route("/v1/chat/completions"));
338        assert!(is_provider_route("/v1/messages"));
339        assert!(is_provider_route("/v1/completions"));
340    }
341
342    #[test]
343    fn is_provider_route_v1beta() {
344        assert!(is_provider_route("/v1beta/models"));
345    }
346
347    #[test]
348    fn is_provider_route_chat() {
349        assert!(is_provider_route("/chat/completions"));
350    }
351
352    #[test]
353    fn is_provider_route_rejects_non_provider() {
354        assert!(!is_provider_route("/health"));
355        assert!(!is_provider_route("/api/v2/test"));
356        assert!(!is_provider_route("/"));
357    }
358
359    fn build_request(headers: &[(&str, &str)], path: &str) -> axum::extract::Request {
360        let mut builder = axum::http::Request::builder().uri(path);
361        for (k, v) in headers {
362            builder = builder.header(*k, *v);
363        }
364        builder.body(axum::body::Body::empty()).unwrap()
365    }
366
367    #[test]
368    fn has_provider_api_key_x_api_key() {
369        let req = build_request(&[("x-api-key", "sk-ant-abc123")], "/v1/messages");
370        assert!(has_provider_api_key(&req));
371    }
372
373    #[test]
374    fn has_provider_api_key_x_goog() {
375        let req = build_request(&[("x-goog-api-key", "AIzaSyAbc")], "/v1beta/models");
376        assert!(has_provider_api_key(&req));
377    }
378
379    #[test]
380    fn has_provider_api_key_azure() {
381        let req = build_request(&[("api-key", "deadbeef")], "/v1/completions");
382        assert!(has_provider_api_key(&req));
383    }
384
385    #[test]
386    fn has_provider_api_key_bearer_sk() {
387        let req = build_request(
388            &[("authorization", "Bearer sk-proj-abc123")],
389            "/v1/chat/completions",
390        );
391        assert!(has_provider_api_key(&req));
392    }
393
394    #[test]
395    fn has_provider_api_key_empty_rejected() {
396        let req = build_request(&[("x-api-key", "  ")], "/v1/messages");
397        assert!(!has_provider_api_key(&req));
398    }
399
400    #[test]
401    fn has_provider_api_key_no_headers() {
402        let req = build_request(&[], "/v1/messages");
403        assert!(!has_provider_api_key(&req));
404    }
405
406    #[test]
407    fn has_provider_api_key_regular_bearer_rejected() {
408        let req = build_request(
409            &[("authorization", "Bearer some-session-token")],
410            "/v1/chat",
411        );
412        assert!(
413            !has_provider_api_key(&req),
414            "non-sk/gsk Bearer should not count as provider key"
415        );
416    }
417}