1pub mod anthropic;
2pub mod compress;
3pub mod cost;
4pub mod forward;
5pub mod google;
6pub mod history_prune;
7pub mod introspect;
8pub mod metrics;
9pub mod openai;
10pub mod openai_responses;
11pub mod tool_kind;
12
13use std::net::SocketAddr;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::Arc;
16
17use axum::{
18 body::Body,
19 extract::State,
20 http::{Request, StatusCode},
21 response::{IntoResponse, Response},
22 routing::{any, get},
23 Router,
24};
25
26#[derive(Clone)]
27pub struct ProxyState {
28 pub client: reqwest::Client,
29 pub port: u16,
30 pub stats: Arc<ProxyStats>,
31 pub introspect: Arc<introspect::IntrospectState>,
32 pub anthropic_upstream: String,
33 pub openai_upstream: String,
34 pub gemini_upstream: String,
35}
36
37pub struct ProxyStats {
38 pub requests_total: AtomicU64,
39 pub requests_compressed: AtomicU64,
40 pub tokens_saved: AtomicU64,
41 pub bytes_original: AtomicU64,
42 pub bytes_compressed: AtomicU64,
43}
44
45impl Default for ProxyStats {
46 fn default() -> Self {
47 Self {
48 requests_total: AtomicU64::new(0),
49 requests_compressed: AtomicU64::new(0),
50 tokens_saved: AtomicU64::new(0),
51 bytes_original: AtomicU64::new(0),
52 bytes_compressed: AtomicU64::new(0),
53 }
54 }
55}
56
57impl ProxyStats {
58 pub fn record_request(&self) {
59 self.requests_total.fetch_add(1, Ordering::Relaxed);
60 }
61
62 pub fn record_compression(&self, original: usize, compressed: usize) {
63 self.requests_compressed.fetch_add(1, Ordering::Relaxed);
64 self.bytes_original
65 .fetch_add(original as u64, Ordering::Relaxed);
66 self.bytes_compressed
67 .fetch_add(compressed as u64, Ordering::Relaxed);
68 let saved_tokens = (original.saturating_sub(compressed) / 4) as u64;
69 self.tokens_saved.fetch_add(saved_tokens, Ordering::Relaxed);
70 }
71
72 pub fn compression_ratio(&self) -> f64 {
73 let original = self.bytes_original.load(Ordering::Relaxed);
74 if original == 0 {
75 return 0.0;
76 }
77 let compressed = self.bytes_compressed.load(Ordering::Relaxed);
78 (1.0 - compressed as f64 / original as f64) * 100.0
79 }
80}
81
82fn connect_timeout_secs() -> u64 {
84 std::env::var("LEAN_CTX_PROXY_CONNECT_TIMEOUT_SECS")
85 .ok()
86 .and_then(|v| v.trim().parse::<u64>().ok())
87 .filter(|s| *s > 0)
88 .unwrap_or(15)
89}
90
91fn read_idle_timeout_secs() -> u64 {
96 std::env::var("LEAN_CTX_PROXY_READ_TIMEOUT_SECS")
97 .ok()
98 .and_then(|v| v.trim().parse::<u64>().ok())
99 .filter(|s| *s > 0)
100 .unwrap_or(300)
101}
102
103pub async fn start_proxy(port: u16) -> anyhow::Result<()> {
104 let token = crate::core::session_token::resolve_proxy_token("LEAN_CTX_PROXY_TOKEN");
105 start_proxy_with_token(port, Some(token)).await
106}
107
108pub async fn start_proxy_with_token(port: u16, auth_token: Option<String>) -> anyhow::Result<()> {
109 use crate::core::config::{Config, ProxyProvider};
110
111 let client = reqwest::Client::builder()
116 .connect_timeout(std::time::Duration::from_secs(connect_timeout_secs()))
117 .read_timeout(std::time::Duration::from_secs(read_idle_timeout_secs()))
118 .build()?;
119
120 let cfg = Config::load();
121 let anthropic_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Anthropic);
122 let openai_upstream = cfg.proxy.resolve_upstream(ProxyProvider::OpenAi);
123 let gemini_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Gemini);
124
125 let state = ProxyState {
126 client,
127 port,
128 stats: Arc::new(ProxyStats::default()),
129 introspect: Arc::new(introspect::IntrospectState::default()),
130 anthropic_upstream: anthropic_upstream.clone(),
131 openai_upstream: openai_upstream.clone(),
132 gemini_upstream: gemini_upstream.clone(),
133 };
134
135 let mut app = Router::new()
136 .route("/health", get(health))
137 .route("/status", get(status_handler))
138 .route("/v1/messages", any(anthropic::handler))
139 .route("/v1/messages/{*rest}", any(anthropic::handler))
140 .route("/v1/chat/completions", any(openai::handler))
141 .route("/v1/responses", any(openai_responses::handler))
142 .route("/v1/responses/{*rest}", any(openai_responses::handler))
143 .route("/v1/references/{id}", get(v1_resolve_reference))
144 .fallback(fallback_router)
145 .layer(axum::middleware::from_fn(host_guard))
146 .with_state(state);
147
148 if let Some(ref token) = auth_token {
149 let expected = token.clone();
150 app = app.layer(axum::middleware::from_fn(move |req, next| {
151 let expected = expected.clone();
152 proxy_auth_guard(req, next, expected)
153 }));
154 }
155
156 let addr = SocketAddr::from(([127, 0, 0, 1], port));
157 if auth_token.is_some() {
158 println!("lean-ctx proxy listening on http://{addr} (token auth enabled)");
159 } else {
160 println!("lean-ctx proxy listening on http://{addr} (no auth — set LEAN_CTX_PROXY_TOKEN to enable)");
161 }
162 println!(" Anthropic: POST /v1/messages → {anthropic_upstream}");
163 println!(" OpenAI: POST /v1/chat/completions → {openai_upstream}");
164 println!(" OpenAI: POST /v1/responses → {openai_upstream}");
165 println!(" Gemini: POST /v1beta/models/... → {gemini_upstream}");
166
167 let listener = tokio::net::TcpListener::bind(addr).await?;
168 axum::serve(listener, app)
169 .with_graceful_shutdown(shutdown_signal())
170 .await?;
171
172 println!("lean-ctx proxy shut down cleanly.");
173 Ok(())
174}
175
176async fn shutdown_signal() {
177 let ctrl_c = tokio::signal::ctrl_c();
178
179 #[cfg(unix)]
180 {
181 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
184 Ok(mut sigterm) => {
185 tokio::select! {
186 _ = ctrl_c => {},
187 _ = sigterm.recv() => {},
188 }
189 }
190 Err(e) => {
191 tracing::warn!("lean-ctx proxy: SIGTERM handler unavailable ({e}); Ctrl-C only");
192 ctrl_c.await.ok();
193 }
194 }
195 }
196
197 #[cfg(not(unix))]
198 {
199 ctrl_c.await.ok();
200 }
201
202 println!("lean-ctx proxy: received shutdown signal, draining…");
203}
204
205async fn health() -> impl IntoResponse {
206 let body = serde_json::json!({
207 "status": "ok",
208 "pid": std::process::id(),
209 });
210 (StatusCode::OK, axum::Json(body))
211}
212
213async fn v1_resolve_reference(
214 axum::extract::Path(id): axum::extract::Path<String>,
215) -> impl IntoResponse {
216 match crate::server::reference_store::resolve(&id) {
217 Some(content) => (StatusCode::OK, content),
218 None => (
219 StatusCode::NOT_FOUND,
220 "Reference expired or not found".to_string(),
221 ),
222 }
223}
224
225async fn status_handler(State(state): State<ProxyState>) -> impl IntoResponse {
226 use std::sync::atomic::Ordering::Relaxed;
227 let s = &state.stats;
228 let i = &state.introspect;
229
230 let last_breakdown = i
231 .last_breakdown
232 .lock()
233 .ok()
234 .and_then(|guard| guard.as_ref().map(|b| serde_json::to_value(b).ok()))
235 .flatten();
236
237 let body = serde_json::json!({
238 "status": "running",
239 "port": state.port,
240 "requests_total": s.requests_total.load(Relaxed),
241 "requests_compressed": s.requests_compressed.load(Relaxed),
242 "tokens_saved": s.tokens_saved.load(Relaxed),
243 "tokens_saved_estimated": true,
244 "bytes_original": s.bytes_original.load(Relaxed),
245 "bytes_compressed": s.bytes_compressed.load(Relaxed),
246 "compression_ratio_pct": format!("{:.1}", s.compression_ratio()),
247 "per_model": cost::snapshot(),
248 "note": "Savings are request-side (tokens removed before forwarding); they do not subtract any re-reads the agent performs. Token figures are estimates; USD uses the shared model price table.",
249 "introspect": {
250 "total_requests_analyzed": i.total_requests.load(Relaxed),
251 "total_system_prompt_tokens": i.total_system_prompt_tokens.load(Relaxed),
252 "last_breakdown": last_breakdown,
253 }
254 });
255 (StatusCode::OK, axum::Json(body))
256}
257
258async fn proxy_auth_guard(
259 req: axum::extract::Request,
260 next: axum::middleware::Next,
261 expected_token: String,
262) -> Result<Response, Response> {
263 let path = req.uri().path();
264 if path == "/health" {
265 return Ok(next.run(req).await);
266 }
267
268 if let Some(auth) = req
270 .headers()
271 .get("authorization")
272 .and_then(|v| v.to_str().ok())
273 {
274 if let Some(token) = auth.strip_prefix("Bearer ") {
275 if constant_time_eq(token.as_bytes(), expected_token.as_bytes()) {
276 return Ok(next.run(req).await);
277 }
278 }
279 }
280
281 if has_provider_api_key(&req) && is_provider_route(path) {
286 return Ok(next.run(req).await);
287 }
288
289 let cfg = crate::core::config::Config::load();
290 let hint = match cfg.proxy_enabled {
291 Some(true) => "lean-ctx proxy requires authentication. Use a Bearer token (LEAN_CTX_PROXY_TOKEN) or configure your AI tool's API key.",
292 Some(false) => "lean-ctx proxy is disabled but still running. Run: lean-ctx proxy cleanup",
293 None => "lean-ctx proxy is not configured. Your AI tool's ANTHROPIC_BASE_URL may be pointing here by mistake. Fix: lean-ctx proxy cleanup OR lean-ctx proxy enable",
294 };
295
296 let body = serde_json::json!({
297 "type": "error",
298 "error": {
299 "type": "authentication_error",
300 "message": format!("401 Unauthorized — {hint}")
301 }
302 });
303
304 Err((StatusCode::UNAUTHORIZED, axum::Json(body)).into_response())
305}
306
307fn has_provider_api_key(req: &axum::extract::Request) -> bool {
308 let headers = req.headers();
309 for key in ["x-api-key", "x-goog-api-key", "api-key"] {
310 if headers
311 .get(key)
312 .and_then(|v| v.to_str().ok())
313 .is_some_and(|v| !v.trim().is_empty())
314 {
315 return true;
316 }
317 }
318 if let Some(auth) = headers.get("authorization").and_then(|v| v.to_str().ok()) {
319 if auth.starts_with("Bearer sk-") || auth.starts_with("Bearer gsk_") {
320 return true;
321 }
322 }
323 false
324}
325
326fn is_provider_route(path: &str) -> bool {
327 path.starts_with("/v1/")
328 || path.starts_with("/v1beta/")
329 || path.starts_with("/chat/completions")
330}
331
332fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
333 use subtle::ConstantTimeEq;
334 if a.len() != b.len() {
335 return false;
336 }
337 bool::from(a.ct_eq(b))
338}
339
340async fn host_guard(
341 req: axum::extract::Request,
342 next: axum::middleware::Next,
343) -> Result<Response, StatusCode> {
344 if let Some(host) = req.headers().get("host").and_then(|v| v.to_str().ok()) {
345 let h = host.split(':').next().unwrap_or(host);
346 if matches!(h, "127.0.0.1" | "localhost" | "[::1]") {
347 return Ok(next.run(req).await);
348 }
349 }
350 Err(StatusCode::FORBIDDEN)
351}
352
353async fn fallback_router(State(state): State<ProxyState>, req: Request<Body>) -> Response {
354 let path = req.uri().path().to_string();
355
356 if path.starts_with("/v1beta/models/") || path.starts_with("/v1/models/") {
357 match google::handler(State(state), req).await {
358 Ok(resp) => resp,
359 Err(status) => Response::builder()
360 .status(status)
361 .body(Body::from("proxy error"))
362 .expect("BUG: building error response with valid status should never fail"),
363 }
364 } else {
365 let method = req.method().to_string();
366 eprintln!("lean-ctx proxy: unmatched {method} {path}");
367 Response::builder()
368 .status(StatusCode::NOT_FOUND)
369 .body(Body::from(format!(
370 "lean-ctx proxy: no handler for {method} {path}"
371 )))
372 .expect("BUG: building 404 response should never fail")
373 }
374}
375
376#[cfg(test)]
377mod auth_tests {
378 use super::*;
379
380 #[test]
381 fn is_provider_route_v1() {
382 assert!(is_provider_route("/v1/chat/completions"));
383 assert!(is_provider_route("/v1/messages"));
384 assert!(is_provider_route("/v1/completions"));
385 }
386
387 #[test]
388 fn is_provider_route_anthropic_subpaths() {
389 assert!(is_provider_route("/v1/messages/count_tokens"));
390 assert!(is_provider_route("/v1/messages/batches"));
391 assert!(is_provider_route("/v1/messages/batches/batch_123"));
392 }
393
394 #[test]
395 fn is_provider_route_v1beta() {
396 assert!(is_provider_route("/v1beta/models"));
397 }
398
399 #[test]
400 fn is_provider_route_chat() {
401 assert!(is_provider_route("/chat/completions"));
402 }
403
404 #[test]
405 fn is_provider_route_rejects_non_provider() {
406 assert!(!is_provider_route("/health"));
407 assert!(!is_provider_route("/api/v2/test"));
408 assert!(!is_provider_route("/"));
409 }
410
411 fn build_request(headers: &[(&str, &str)], path: &str) -> axum::extract::Request {
412 let mut builder = axum::http::Request::builder().uri(path);
413 for (k, v) in headers {
414 builder = builder.header(*k, *v);
415 }
416 builder.body(axum::body::Body::empty()).unwrap()
417 }
418
419 #[test]
420 fn has_provider_api_key_x_api_key() {
421 let req = build_request(&[("x-api-key", "sk-ant-abc123")], "/v1/messages");
422 assert!(has_provider_api_key(&req));
423 }
424
425 #[test]
426 fn has_provider_api_key_x_goog() {
427 let req = build_request(&[("x-goog-api-key", "AIzaSyAbc")], "/v1beta/models");
428 assert!(has_provider_api_key(&req));
429 }
430
431 #[test]
432 fn has_provider_api_key_azure() {
433 let req = build_request(&[("api-key", "deadbeef")], "/v1/completions");
434 assert!(has_provider_api_key(&req));
435 }
436
437 #[test]
438 fn has_provider_api_key_bearer_sk() {
439 let req = build_request(
440 &[("authorization", "Bearer sk-proj-abc123")],
441 "/v1/chat/completions",
442 );
443 assert!(has_provider_api_key(&req));
444 }
445
446 #[test]
447 fn has_provider_api_key_empty_rejected() {
448 let req = build_request(&[("x-api-key", " ")], "/v1/messages");
449 assert!(!has_provider_api_key(&req));
450 }
451
452 #[test]
453 fn has_provider_api_key_no_headers() {
454 let req = build_request(&[], "/v1/messages");
455 assert!(!has_provider_api_key(&req));
456 }
457
458 #[test]
459 fn has_provider_api_key_regular_bearer_rejected() {
460 let req = build_request(
461 &[("authorization", "Bearer some-session-token")],
462 "/v1/chat",
463 );
464 assert!(
465 !has_provider_api_key(&req),
466 "non-sk/gsk Bearer should not count as provider key"
467 );
468 }
469}