Skip to main content

lean_ctx/proxy/
mod.rs

1pub mod anthropic;
2pub mod compress;
3pub mod forward;
4pub mod google;
5pub mod openai;
6
7use std::net::SocketAddr;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10
11use axum::{
12    body::Body,
13    extract::State,
14    http::{Request, StatusCode},
15    response::{IntoResponse, Response},
16    routing::{any, get},
17    Router,
18};
19
20#[derive(Clone)]
21pub struct ProxyState {
22    pub client: reqwest::Client,
23    pub port: u16,
24    pub stats: Arc<ProxyStats>,
25    pub anthropic_upstream: String,
26    pub openai_upstream: String,
27    pub gemini_upstream: String,
28}
29
30pub struct ProxyStats {
31    pub requests_total: AtomicU64,
32    pub requests_compressed: AtomicU64,
33    pub tokens_saved: AtomicU64,
34    pub bytes_original: AtomicU64,
35    pub bytes_compressed: AtomicU64,
36}
37
38impl Default for ProxyStats {
39    fn default() -> Self {
40        Self {
41            requests_total: AtomicU64::new(0),
42            requests_compressed: AtomicU64::new(0),
43            tokens_saved: AtomicU64::new(0),
44            bytes_original: AtomicU64::new(0),
45            bytes_compressed: AtomicU64::new(0),
46        }
47    }
48}
49
50impl ProxyStats {
51    pub fn record_request(&self) {
52        self.requests_total.fetch_add(1, Ordering::Relaxed);
53    }
54
55    pub fn record_compression(&self, original: usize, compressed: usize) {
56        self.requests_compressed.fetch_add(1, Ordering::Relaxed);
57        self.bytes_original
58            .fetch_add(original as u64, Ordering::Relaxed);
59        self.bytes_compressed
60            .fetch_add(compressed as u64, Ordering::Relaxed);
61        let saved_tokens = (original.saturating_sub(compressed) / 4) as u64;
62        self.tokens_saved.fetch_add(saved_tokens, Ordering::Relaxed);
63    }
64
65    pub fn compression_ratio(&self) -> f64 {
66        let original = self.bytes_original.load(Ordering::Relaxed);
67        if original == 0 {
68            return 0.0;
69        }
70        let compressed = self.bytes_compressed.load(Ordering::Relaxed);
71        (1.0 - compressed as f64 / original as f64) * 100.0
72    }
73}
74
75pub async fn start_proxy(port: u16) -> anyhow::Result<()> {
76    start_proxy_with_token(port, std::env::var("LEAN_CTX_PROXY_TOKEN").ok()).await
77}
78
79pub async fn start_proxy_with_token(port: u16, auth_token: Option<String>) -> anyhow::Result<()> {
80    use crate::core::config::{Config, ProxyProvider};
81
82    let client = reqwest::Client::builder()
83        .timeout(std::time::Duration::from_mins(2))
84        .build()?;
85
86    let cfg = Config::load();
87    let anthropic_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Anthropic);
88    let openai_upstream = cfg.proxy.resolve_upstream(ProxyProvider::OpenAi);
89    let gemini_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Gemini);
90
91    let state = ProxyState {
92        client,
93        port,
94        stats: Arc::new(ProxyStats::default()),
95        anthropic_upstream: anthropic_upstream.clone(),
96        openai_upstream: openai_upstream.clone(),
97        gemini_upstream: gemini_upstream.clone(),
98    };
99
100    let mut app = Router::new()
101        .route("/health", get(health))
102        .route("/status", get(status_handler))
103        .route("/v1/messages", any(anthropic::handler))
104        .route("/v1/chat/completions", any(openai::handler))
105        .fallback(fallback_router)
106        .layer(axum::middleware::from_fn(host_guard))
107        .with_state(state);
108
109    if let Some(ref token) = auth_token {
110        let expected = token.clone();
111        app = app.layer(axum::middleware::from_fn(move |req, next| {
112            let expected = expected.clone();
113            proxy_auth_guard(req, next, expected)
114        }));
115    }
116
117    let addr = SocketAddr::from(([127, 0, 0, 1], port));
118    if auth_token.is_some() {
119        println!("lean-ctx proxy listening on http://{addr} (token auth enabled)");
120    } else {
121        println!("lean-ctx proxy listening on http://{addr} (no auth — set LEAN_CTX_PROXY_TOKEN to enable)");
122    }
123    println!("  Anthropic: POST /v1/messages → {anthropic_upstream}");
124    println!("  OpenAI:    POST /v1/chat/completions → {openai_upstream}");
125    println!("  Gemini:    POST /v1beta/models/... → {gemini_upstream}");
126
127    let listener = tokio::net::TcpListener::bind(addr).await?;
128    axum::serve(listener, app).await?;
129
130    Ok(())
131}
132
133async fn health() -> impl IntoResponse {
134    (StatusCode::OK, "ok")
135}
136
137async fn status_handler(State(state): State<ProxyState>) -> impl IntoResponse {
138    use std::sync::atomic::Ordering::Relaxed;
139    let s = &state.stats;
140    let body = serde_json::json!({
141        "status": "running",
142        "port": state.port,
143        "requests_total": s.requests_total.load(Relaxed),
144        "requests_compressed": s.requests_compressed.load(Relaxed),
145        "tokens_saved": s.tokens_saved.load(Relaxed),
146        "bytes_original": s.bytes_original.load(Relaxed),
147        "bytes_compressed": s.bytes_compressed.load(Relaxed),
148        "compression_ratio_pct": format!("{:.1}", s.compression_ratio()),
149    });
150    (StatusCode::OK, axum::Json(body))
151}
152
153async fn proxy_auth_guard(
154    req: axum::extract::Request,
155    next: axum::middleware::Next,
156    expected_token: String,
157) -> Result<Response, StatusCode> {
158    let path = req.uri().path();
159    if path == "/health" || path == "/status" {
160        return Ok(next.run(req).await);
161    }
162
163    if let Some(auth) = req
164        .headers()
165        .get("authorization")
166        .and_then(|v| v.to_str().ok())
167    {
168        if let Some(token) = auth.strip_prefix("Bearer ") {
169            if token == expected_token {
170                return Ok(next.run(req).await);
171            }
172        }
173    }
174
175    Err(StatusCode::UNAUTHORIZED)
176}
177
178async fn host_guard(
179    req: axum::extract::Request,
180    next: axum::middleware::Next,
181) -> Result<Response, StatusCode> {
182    if let Some(host) = req.headers().get("host").and_then(|v| v.to_str().ok()) {
183        let h = host.split(':').next().unwrap_or(host);
184        if matches!(h, "127.0.0.1" | "localhost" | "[::1]") {
185            return Ok(next.run(req).await);
186        }
187    } else {
188        return Ok(next.run(req).await);
189    }
190    Err(StatusCode::FORBIDDEN)
191}
192
193async fn fallback_router(State(state): State<ProxyState>, req: Request<Body>) -> Response {
194    let path = req.uri().path().to_string();
195
196    if path.starts_with("/v1beta/models/") || path.starts_with("/v1/models/") {
197        match google::handler(State(state), req).await {
198            Ok(resp) => resp,
199            Err(status) => Response::builder()
200                .status(status)
201                .body(Body::from("proxy error"))
202                .expect("BUG: building error response with valid status should never fail"),
203        }
204    } else {
205        let method = req.method().to_string();
206        eprintln!("lean-ctx proxy: unmatched {method} {path}");
207        Response::builder()
208            .status(StatusCode::NOT_FOUND)
209            .body(Body::from(format!(
210                "lean-ctx proxy: no handler for {method} {path}"
211            )))
212            .expect("BUG: building 404 response should never fail")
213    }
214}