Skip to main content

lean_ctx/proxy/
mod.rs

1pub mod anthropic;
2pub mod compress;
3pub mod forward;
4pub mod google;
5pub mod introspect;
6pub mod openai;
7
8use std::net::SocketAddr;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11
12use axum::{
13    body::Body,
14    extract::State,
15    http::{Request, StatusCode},
16    response::{IntoResponse, Response},
17    routing::{any, get},
18    Router,
19};
20
21#[derive(Clone)]
22pub struct ProxyState {
23    pub client: reqwest::Client,
24    pub port: u16,
25    pub stats: Arc<ProxyStats>,
26    pub introspect: Arc<introspect::IntrospectState>,
27    pub anthropic_upstream: String,
28    pub openai_upstream: String,
29    pub gemini_upstream: String,
30}
31
32pub struct ProxyStats {
33    pub requests_total: AtomicU64,
34    pub requests_compressed: AtomicU64,
35    pub tokens_saved: AtomicU64,
36    pub bytes_original: AtomicU64,
37    pub bytes_compressed: AtomicU64,
38}
39
40impl Default for ProxyStats {
41    fn default() -> Self {
42        Self {
43            requests_total: AtomicU64::new(0),
44            requests_compressed: AtomicU64::new(0),
45            tokens_saved: AtomicU64::new(0),
46            bytes_original: AtomicU64::new(0),
47            bytes_compressed: AtomicU64::new(0),
48        }
49    }
50}
51
52impl ProxyStats {
53    pub fn record_request(&self) {
54        self.requests_total.fetch_add(1, Ordering::Relaxed);
55    }
56
57    pub fn record_compression(&self, original: usize, compressed: usize) {
58        self.requests_compressed.fetch_add(1, Ordering::Relaxed);
59        self.bytes_original
60            .fetch_add(original as u64, Ordering::Relaxed);
61        self.bytes_compressed
62            .fetch_add(compressed as u64, Ordering::Relaxed);
63        let saved_tokens = (original.saturating_sub(compressed) / 4) as u64;
64        self.tokens_saved.fetch_add(saved_tokens, Ordering::Relaxed);
65    }
66
67    pub fn compression_ratio(&self) -> f64 {
68        let original = self.bytes_original.load(Ordering::Relaxed);
69        if original == 0 {
70            return 0.0;
71        }
72        let compressed = self.bytes_compressed.load(Ordering::Relaxed);
73        (1.0 - compressed as f64 / original as f64) * 100.0
74    }
75}
76
77pub async fn start_proxy(port: u16) -> anyhow::Result<()> {
78    start_proxy_with_token(port, std::env::var("LEAN_CTX_PROXY_TOKEN").ok()).await
79}
80
81pub async fn start_proxy_with_token(port: u16, auth_token: Option<String>) -> anyhow::Result<()> {
82    use crate::core::config::{Config, ProxyProvider};
83
84    let client = reqwest::Client::builder()
85        .timeout(std::time::Duration::from_mins(2))
86        .build()?;
87
88    let cfg = Config::load();
89    let anthropic_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Anthropic);
90    let openai_upstream = cfg.proxy.resolve_upstream(ProxyProvider::OpenAi);
91    let gemini_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Gemini);
92
93    let state = ProxyState {
94        client,
95        port,
96        stats: Arc::new(ProxyStats::default()),
97        introspect: Arc::new(introspect::IntrospectState::default()),
98        anthropic_upstream: anthropic_upstream.clone(),
99        openai_upstream: openai_upstream.clone(),
100        gemini_upstream: gemini_upstream.clone(),
101    };
102
103    let mut app = Router::new()
104        .route("/health", get(health))
105        .route("/status", get(status_handler))
106        .route("/v1/messages", any(anthropic::handler))
107        .route("/v1/chat/completions", any(openai::handler))
108        .fallback(fallback_router)
109        .layer(axum::middleware::from_fn(host_guard))
110        .with_state(state);
111
112    if let Some(ref token) = auth_token {
113        let expected = token.clone();
114        app = app.layer(axum::middleware::from_fn(move |req, next| {
115            let expected = expected.clone();
116            proxy_auth_guard(req, next, expected)
117        }));
118    }
119
120    let addr = SocketAddr::from(([127, 0, 0, 1], port));
121    if auth_token.is_some() {
122        println!("lean-ctx proxy listening on http://{addr} (token auth enabled)");
123    } else {
124        println!("lean-ctx proxy listening on http://{addr} (no auth — set LEAN_CTX_PROXY_TOKEN to enable)");
125    }
126    println!("  Anthropic: POST /v1/messages → {anthropic_upstream}");
127    println!("  OpenAI:    POST /v1/chat/completions → {openai_upstream}");
128    println!("  Gemini:    POST /v1beta/models/... → {gemini_upstream}");
129
130    let listener = tokio::net::TcpListener::bind(addr).await?;
131    axum::serve(listener, app)
132        .with_graceful_shutdown(shutdown_signal())
133        .await?;
134
135    println!("lean-ctx proxy shut down cleanly.");
136    Ok(())
137}
138
139async fn shutdown_signal() {
140    let ctrl_c = tokio::signal::ctrl_c();
141
142    #[cfg(unix)]
143    {
144        let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
145            .expect("failed to install SIGTERM handler");
146        tokio::select! {
147            _ = ctrl_c => {},
148            _ = sigterm.recv() => {},
149        }
150    }
151
152    #[cfg(not(unix))]
153    {
154        ctrl_c.await.ok();
155    }
156
157    println!("lean-ctx proxy: received shutdown signal, draining…");
158}
159
160async fn health() -> impl IntoResponse {
161    let body = serde_json::json!({
162        "status": "ok",
163        "pid": std::process::id(),
164    });
165    (StatusCode::OK, axum::Json(body))
166}
167
168async fn status_handler(State(state): State<ProxyState>) -> impl IntoResponse {
169    use std::sync::atomic::Ordering::Relaxed;
170    let s = &state.stats;
171    let i = &state.introspect;
172
173    let last_breakdown = i
174        .last_breakdown
175        .lock()
176        .ok()
177        .and_then(|guard| guard.as_ref().map(|b| serde_json::to_value(b).ok()))
178        .flatten();
179
180    let body = serde_json::json!({
181        "status": "running",
182        "port": state.port,
183        "requests_total": s.requests_total.load(Relaxed),
184        "requests_compressed": s.requests_compressed.load(Relaxed),
185        "tokens_saved": s.tokens_saved.load(Relaxed),
186        "bytes_original": s.bytes_original.load(Relaxed),
187        "bytes_compressed": s.bytes_compressed.load(Relaxed),
188        "compression_ratio_pct": format!("{:.1}", s.compression_ratio()),
189        "introspect": {
190            "total_requests_analyzed": i.total_requests.load(Relaxed),
191            "total_system_prompt_tokens": i.total_system_prompt_tokens.load(Relaxed),
192            "last_breakdown": last_breakdown,
193        }
194    });
195    (StatusCode::OK, axum::Json(body))
196}
197
198async fn proxy_auth_guard(
199    req: axum::extract::Request,
200    next: axum::middleware::Next,
201    expected_token: String,
202) -> Result<Response, StatusCode> {
203    let path = req.uri().path();
204    if path == "/health" || path == "/status" {
205        return Ok(next.run(req).await);
206    }
207
208    if let Some(auth) = req
209        .headers()
210        .get("authorization")
211        .and_then(|v| v.to_str().ok())
212    {
213        if let Some(token) = auth.strip_prefix("Bearer ") {
214            if token == expected_token {
215                return Ok(next.run(req).await);
216            }
217        }
218    }
219
220    Err(StatusCode::UNAUTHORIZED)
221}
222
223async fn host_guard(
224    req: axum::extract::Request,
225    next: axum::middleware::Next,
226) -> Result<Response, StatusCode> {
227    if let Some(host) = req.headers().get("host").and_then(|v| v.to_str().ok()) {
228        let h = host.split(':').next().unwrap_or(host);
229        if matches!(h, "127.0.0.1" | "localhost" | "[::1]") {
230            return Ok(next.run(req).await);
231        }
232    } else {
233        return Ok(next.run(req).await);
234    }
235    Err(StatusCode::FORBIDDEN)
236}
237
238async fn fallback_router(State(state): State<ProxyState>, req: Request<Body>) -> Response {
239    let path = req.uri().path().to_string();
240
241    if path.starts_with("/v1beta/models/") || path.starts_with("/v1/models/") {
242        match google::handler(State(state), req).await {
243            Ok(resp) => resp,
244            Err(status) => Response::builder()
245                .status(status)
246                .body(Body::from("proxy error"))
247                .expect("BUG: building error response with valid status should never fail"),
248        }
249    } else {
250        let method = req.method().to_string();
251        eprintln!("lean-ctx proxy: unmatched {method} {path}");
252        Response::builder()
253            .status(StatusCode::NOT_FOUND)
254            .body(Body::from(format!(
255                "lean-ctx proxy: no handler for {method} {path}"
256            )))
257            .expect("BUG: building 404 response should never fail")
258    }
259}