Skip to main content

lean_ctx/proxy/
mod.rs

1pub mod anthropic;
2pub mod compress;
3pub mod forward;
4pub mod google;
5pub mod introspect;
6pub mod metrics;
7pub mod openai;
8
9use std::net::SocketAddr;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12
13use axum::{
14    body::Body,
15    extract::State,
16    http::{Request, StatusCode},
17    response::{IntoResponse, Response},
18    routing::{any, get},
19    Router,
20};
21
22#[derive(Clone)]
23pub struct ProxyState {
24    pub client: reqwest::Client,
25    pub port: u16,
26    pub stats: Arc<ProxyStats>,
27    pub introspect: Arc<introspect::IntrospectState>,
28    pub anthropic_upstream: String,
29    pub openai_upstream: String,
30    pub gemini_upstream: String,
31}
32
33pub struct ProxyStats {
34    pub requests_total: AtomicU64,
35    pub requests_compressed: AtomicU64,
36    pub tokens_saved: AtomicU64,
37    pub bytes_original: AtomicU64,
38    pub bytes_compressed: AtomicU64,
39}
40
41impl Default for ProxyStats {
42    fn default() -> Self {
43        Self {
44            requests_total: AtomicU64::new(0),
45            requests_compressed: AtomicU64::new(0),
46            tokens_saved: AtomicU64::new(0),
47            bytes_original: AtomicU64::new(0),
48            bytes_compressed: AtomicU64::new(0),
49        }
50    }
51}
52
53impl ProxyStats {
54    pub fn record_request(&self) {
55        self.requests_total.fetch_add(1, Ordering::Relaxed);
56    }
57
58    pub fn record_compression(&self, original: usize, compressed: usize) {
59        self.requests_compressed.fetch_add(1, Ordering::Relaxed);
60        self.bytes_original
61            .fetch_add(original as u64, Ordering::Relaxed);
62        self.bytes_compressed
63            .fetch_add(compressed as u64, Ordering::Relaxed);
64        let saved_tokens = (original.saturating_sub(compressed) / 4) as u64;
65        self.tokens_saved.fetch_add(saved_tokens, Ordering::Relaxed);
66    }
67
68    pub fn compression_ratio(&self) -> f64 {
69        let original = self.bytes_original.load(Ordering::Relaxed);
70        if original == 0 {
71            return 0.0;
72        }
73        let compressed = self.bytes_compressed.load(Ordering::Relaxed);
74        (1.0 - compressed as f64 / original as f64) * 100.0
75    }
76}
77
78pub async fn start_proxy(port: u16) -> anyhow::Result<()> {
79    start_proxy_with_token(port, std::env::var("LEAN_CTX_PROXY_TOKEN").ok()).await
80}
81
82pub async fn start_proxy_with_token(port: u16, auth_token: Option<String>) -> anyhow::Result<()> {
83    use crate::core::config::{Config, ProxyProvider};
84
85    let client = reqwest::Client::builder()
86        .timeout(std::time::Duration::from_mins(2))
87        .build()?;
88
89    let cfg = Config::load();
90    let anthropic_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Anthropic);
91    let openai_upstream = cfg.proxy.resolve_upstream(ProxyProvider::OpenAi);
92    let gemini_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Gemini);
93
94    let state = ProxyState {
95        client,
96        port,
97        stats: Arc::new(ProxyStats::default()),
98        introspect: Arc::new(introspect::IntrospectState::default()),
99        anthropic_upstream: anthropic_upstream.clone(),
100        openai_upstream: openai_upstream.clone(),
101        gemini_upstream: gemini_upstream.clone(),
102    };
103
104    let mut app = Router::new()
105        .route("/health", get(health))
106        .route("/status", get(status_handler))
107        .route("/v1/messages", any(anthropic::handler))
108        .route("/v1/chat/completions", any(openai::handler))
109        .route("/v1/references/{id}", get(v1_resolve_reference))
110        .fallback(fallback_router)
111        .layer(axum::middleware::from_fn(host_guard))
112        .with_state(state);
113
114    if let Some(ref token) = auth_token {
115        let expected = token.clone();
116        app = app.layer(axum::middleware::from_fn(move |req, next| {
117            let expected = expected.clone();
118            proxy_auth_guard(req, next, expected)
119        }));
120    }
121
122    let addr = SocketAddr::from(([127, 0, 0, 1], port));
123    if auth_token.is_some() {
124        println!("lean-ctx proxy listening on http://{addr} (token auth enabled)");
125    } else {
126        println!("lean-ctx proxy listening on http://{addr} (no auth — set LEAN_CTX_PROXY_TOKEN to enable)");
127    }
128    println!("  Anthropic: POST /v1/messages → {anthropic_upstream}");
129    println!("  OpenAI:    POST /v1/chat/completions → {openai_upstream}");
130    println!("  Gemini:    POST /v1beta/models/... → {gemini_upstream}");
131
132    let listener = tokio::net::TcpListener::bind(addr).await?;
133    axum::serve(listener, app)
134        .with_graceful_shutdown(shutdown_signal())
135        .await?;
136
137    println!("lean-ctx proxy shut down cleanly.");
138    Ok(())
139}
140
141async fn shutdown_signal() {
142    let ctrl_c = tokio::signal::ctrl_c();
143
144    #[cfg(unix)]
145    {
146        let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
147            .expect("failed to install SIGTERM handler");
148        tokio::select! {
149            _ = ctrl_c => {},
150            _ = sigterm.recv() => {},
151        }
152    }
153
154    #[cfg(not(unix))]
155    {
156        ctrl_c.await.ok();
157    }
158
159    println!("lean-ctx proxy: received shutdown signal, draining…");
160}
161
162async fn health() -> impl IntoResponse {
163    let body = serde_json::json!({
164        "status": "ok",
165        "pid": std::process::id(),
166    });
167    (StatusCode::OK, axum::Json(body))
168}
169
170async fn v1_resolve_reference(
171    axum::extract::Path(id): axum::extract::Path<String>,
172) -> impl IntoResponse {
173    match crate::server::reference_store::resolve(&id) {
174        Some(content) => (StatusCode::OK, content),
175        None => (
176            StatusCode::NOT_FOUND,
177            "Reference expired or not found".to_string(),
178        ),
179    }
180}
181
182async fn status_handler(State(state): State<ProxyState>) -> impl IntoResponse {
183    use std::sync::atomic::Ordering::Relaxed;
184    let s = &state.stats;
185    let i = &state.introspect;
186
187    let last_breakdown = i
188        .last_breakdown
189        .lock()
190        .ok()
191        .and_then(|guard| guard.as_ref().map(|b| serde_json::to_value(b).ok()))
192        .flatten();
193
194    let body = serde_json::json!({
195        "status": "running",
196        "port": state.port,
197        "requests_total": s.requests_total.load(Relaxed),
198        "requests_compressed": s.requests_compressed.load(Relaxed),
199        "tokens_saved": s.tokens_saved.load(Relaxed),
200        "bytes_original": s.bytes_original.load(Relaxed),
201        "bytes_compressed": s.bytes_compressed.load(Relaxed),
202        "compression_ratio_pct": format!("{:.1}", s.compression_ratio()),
203        "introspect": {
204            "total_requests_analyzed": i.total_requests.load(Relaxed),
205            "total_system_prompt_tokens": i.total_system_prompt_tokens.load(Relaxed),
206            "last_breakdown": last_breakdown,
207        }
208    });
209    (StatusCode::OK, axum::Json(body))
210}
211
212async fn proxy_auth_guard(
213    req: axum::extract::Request,
214    next: axum::middleware::Next,
215    expected_token: String,
216) -> Result<Response, StatusCode> {
217    let path = req.uri().path();
218    if path == "/health" || path == "/status" {
219        return Ok(next.run(req).await);
220    }
221
222    if let Some(auth) = req
223        .headers()
224        .get("authorization")
225        .and_then(|v| v.to_str().ok())
226    {
227        if let Some(token) = auth.strip_prefix("Bearer ") {
228            if constant_time_eq(token.as_bytes(), expected_token.as_bytes()) {
229                return Ok(next.run(req).await);
230            }
231        }
232    }
233
234    Err(StatusCode::UNAUTHORIZED)
235}
236
237fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
238    use subtle::ConstantTimeEq;
239    if a.len() != b.len() {
240        return false;
241    }
242    bool::from(a.ct_eq(b))
243}
244
245async fn host_guard(
246    req: axum::extract::Request,
247    next: axum::middleware::Next,
248) -> Result<Response, StatusCode> {
249    if let Some(host) = req.headers().get("host").and_then(|v| v.to_str().ok()) {
250        let h = host.split(':').next().unwrap_or(host);
251        if matches!(h, "127.0.0.1" | "localhost" | "[::1]") {
252            return Ok(next.run(req).await);
253        }
254    } else {
255        return Ok(next.run(req).await);
256    }
257    Err(StatusCode::FORBIDDEN)
258}
259
260async fn fallback_router(State(state): State<ProxyState>, req: Request<Body>) -> Response {
261    let path = req.uri().path().to_string();
262
263    if path.starts_with("/v1beta/models/") || path.starts_with("/v1/models/") {
264        match google::handler(State(state), req).await {
265            Ok(resp) => resp,
266            Err(status) => Response::builder()
267                .status(status)
268                .body(Body::from("proxy error"))
269                .expect("BUG: building error response with valid status should never fail"),
270        }
271    } else {
272        let method = req.method().to_string();
273        eprintln!("lean-ctx proxy: unmatched {method} {path}");
274        Response::builder()
275            .status(StatusCode::NOT_FOUND)
276            .body(Body::from(format!(
277                "lean-ctx proxy: no handler for {method} {path}"
278            )))
279            .expect("BUG: building 404 response should never fail")
280    }
281}