1pub mod anthropic;
2pub mod compress;
3pub mod forward;
4pub mod google;
5pub mod history_prune;
6pub mod introspect;
7pub mod metrics;
8pub mod openai;
9
10use std::net::SocketAddr;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13
14use axum::{
15 body::Body,
16 extract::State,
17 http::{Request, StatusCode},
18 response::{IntoResponse, Response},
19 routing::{any, get},
20 Router,
21};
22
23#[derive(Clone)]
24pub struct ProxyState {
25 pub client: reqwest::Client,
26 pub port: u16,
27 pub stats: Arc<ProxyStats>,
28 pub introspect: Arc<introspect::IntrospectState>,
29 pub anthropic_upstream: String,
30 pub openai_upstream: String,
31 pub gemini_upstream: String,
32}
33
34pub struct ProxyStats {
35 pub requests_total: AtomicU64,
36 pub requests_compressed: AtomicU64,
37 pub tokens_saved: AtomicU64,
38 pub bytes_original: AtomicU64,
39 pub bytes_compressed: AtomicU64,
40}
41
42impl Default for ProxyStats {
43 fn default() -> Self {
44 Self {
45 requests_total: AtomicU64::new(0),
46 requests_compressed: AtomicU64::new(0),
47 tokens_saved: AtomicU64::new(0),
48 bytes_original: AtomicU64::new(0),
49 bytes_compressed: AtomicU64::new(0),
50 }
51 }
52}
53
54impl ProxyStats {
55 pub fn record_request(&self) {
56 self.requests_total.fetch_add(1, Ordering::Relaxed);
57 }
58
59 pub fn record_compression(&self, original: usize, compressed: usize) {
60 self.requests_compressed.fetch_add(1, Ordering::Relaxed);
61 self.bytes_original
62 .fetch_add(original as u64, Ordering::Relaxed);
63 self.bytes_compressed
64 .fetch_add(compressed as u64, Ordering::Relaxed);
65 let saved_tokens = (original.saturating_sub(compressed) / 4) as u64;
66 self.tokens_saved.fetch_add(saved_tokens, Ordering::Relaxed);
67 }
68
69 pub fn compression_ratio(&self) -> f64 {
70 let original = self.bytes_original.load(Ordering::Relaxed);
71 if original == 0 {
72 return 0.0;
73 }
74 let compressed = self.bytes_compressed.load(Ordering::Relaxed);
75 (1.0 - compressed as f64 / original as f64) * 100.0
76 }
77}
78
79pub async fn start_proxy(port: u16) -> anyhow::Result<()> {
80 let token = crate::core::session_token::resolve_proxy_token("LEAN_CTX_PROXY_TOKEN");
81 start_proxy_with_token(port, Some(token)).await
82}
83
84pub async fn start_proxy_with_token(port: u16, auth_token: Option<String>) -> anyhow::Result<()> {
85 use crate::core::config::{Config, ProxyProvider};
86
87 let client = reqwest::Client::builder()
88 .timeout(std::time::Duration::from_mins(2))
89 .build()?;
90
91 let cfg = Config::load();
92 let anthropic_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Anthropic);
93 let openai_upstream = cfg.proxy.resolve_upstream(ProxyProvider::OpenAi);
94 let gemini_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Gemini);
95
96 let state = ProxyState {
97 client,
98 port,
99 stats: Arc::new(ProxyStats::default()),
100 introspect: Arc::new(introspect::IntrospectState::default()),
101 anthropic_upstream: anthropic_upstream.clone(),
102 openai_upstream: openai_upstream.clone(),
103 gemini_upstream: gemini_upstream.clone(),
104 };
105
106 let mut app = Router::new()
107 .route("/health", get(health))
108 .route("/status", get(status_handler))
109 .route("/v1/messages", any(anthropic::handler))
110 .route("/v1/chat/completions", any(openai::handler))
111 .route("/v1/references/{id}", get(v1_resolve_reference))
112 .fallback(fallback_router)
113 .layer(axum::middleware::from_fn(host_guard))
114 .with_state(state);
115
116 if let Some(ref token) = auth_token {
117 let expected = token.clone();
118 app = app.layer(axum::middleware::from_fn(move |req, next| {
119 let expected = expected.clone();
120 proxy_auth_guard(req, next, expected)
121 }));
122 }
123
124 let addr = SocketAddr::from(([127, 0, 0, 1], port));
125 if auth_token.is_some() {
126 println!("lean-ctx proxy listening on http://{addr} (token auth enabled)");
127 } else {
128 println!("lean-ctx proxy listening on http://{addr} (no auth — set LEAN_CTX_PROXY_TOKEN to enable)");
129 }
130 println!(" Anthropic: POST /v1/messages → {anthropic_upstream}");
131 println!(" OpenAI: POST /v1/chat/completions → {openai_upstream}");
132 println!(" Gemini: POST /v1beta/models/... → {gemini_upstream}");
133
134 let listener = tokio::net::TcpListener::bind(addr).await?;
135 axum::serve(listener, app)
136 .with_graceful_shutdown(shutdown_signal())
137 .await?;
138
139 println!("lean-ctx proxy shut down cleanly.");
140 Ok(())
141}
142
143async fn shutdown_signal() {
144 let ctrl_c = tokio::signal::ctrl_c();
145
146 #[cfg(unix)]
147 {
148 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
151 Ok(mut sigterm) => {
152 tokio::select! {
153 _ = ctrl_c => {},
154 _ = sigterm.recv() => {},
155 }
156 }
157 Err(e) => {
158 tracing::warn!("lean-ctx proxy: SIGTERM handler unavailable ({e}); Ctrl-C only");
159 ctrl_c.await.ok();
160 }
161 }
162 }
163
164 #[cfg(not(unix))]
165 {
166 ctrl_c.await.ok();
167 }
168
169 println!("lean-ctx proxy: received shutdown signal, draining…");
170}
171
172async fn health() -> impl IntoResponse {
173 let body = serde_json::json!({
174 "status": "ok",
175 "pid": std::process::id(),
176 });
177 (StatusCode::OK, axum::Json(body))
178}
179
180async fn v1_resolve_reference(
181 axum::extract::Path(id): axum::extract::Path<String>,
182) -> impl IntoResponse {
183 match crate::server::reference_store::resolve(&id) {
184 Some(content) => (StatusCode::OK, content),
185 None => (
186 StatusCode::NOT_FOUND,
187 "Reference expired or not found".to_string(),
188 ),
189 }
190}
191
192async fn status_handler(State(state): State<ProxyState>) -> impl IntoResponse {
193 use std::sync::atomic::Ordering::Relaxed;
194 let s = &state.stats;
195 let i = &state.introspect;
196
197 let last_breakdown = i
198 .last_breakdown
199 .lock()
200 .ok()
201 .and_then(|guard| guard.as_ref().map(|b| serde_json::to_value(b).ok()))
202 .flatten();
203
204 let body = serde_json::json!({
205 "status": "running",
206 "port": state.port,
207 "requests_total": s.requests_total.load(Relaxed),
208 "requests_compressed": s.requests_compressed.load(Relaxed),
209 "tokens_saved": s.tokens_saved.load(Relaxed),
210 "bytes_original": s.bytes_original.load(Relaxed),
211 "bytes_compressed": s.bytes_compressed.load(Relaxed),
212 "compression_ratio_pct": format!("{:.1}", s.compression_ratio()),
213 "introspect": {
214 "total_requests_analyzed": i.total_requests.load(Relaxed),
215 "total_system_prompt_tokens": i.total_system_prompt_tokens.load(Relaxed),
216 "last_breakdown": last_breakdown,
217 }
218 });
219 (StatusCode::OK, axum::Json(body))
220}
221
222async fn proxy_auth_guard(
223 req: axum::extract::Request,
224 next: axum::middleware::Next,
225 expected_token: String,
226) -> Result<Response, Response> {
227 let path = req.uri().path();
228 if path == "/health" {
229 return Ok(next.run(req).await);
230 }
231
232 if let Some(auth) = req
234 .headers()
235 .get("authorization")
236 .and_then(|v| v.to_str().ok())
237 {
238 if let Some(token) = auth.strip_prefix("Bearer ") {
239 if constant_time_eq(token.as_bytes(), expected_token.as_bytes()) {
240 return Ok(next.run(req).await);
241 }
242 }
243 }
244
245 if has_provider_api_key(&req) && is_provider_route(path) {
250 return Ok(next.run(req).await);
251 }
252
253 let cfg = crate::core::config::Config::load();
254 let hint = match cfg.proxy_enabled {
255 Some(true) => "lean-ctx proxy requires authentication. Use a Bearer token (LEAN_CTX_PROXY_TOKEN) or configure your AI tool's API key.",
256 Some(false) => "lean-ctx proxy is disabled but still running. Run: lean-ctx proxy cleanup",
257 None => "lean-ctx proxy is not configured. Your AI tool's ANTHROPIC_BASE_URL may be pointing here by mistake. Fix: lean-ctx proxy cleanup OR lean-ctx proxy enable",
258 };
259
260 let body = serde_json::json!({
261 "type": "error",
262 "error": {
263 "type": "authentication_error",
264 "message": format!("401 Unauthorized — {hint}")
265 }
266 });
267
268 Err((StatusCode::UNAUTHORIZED, axum::Json(body)).into_response())
269}
270
271fn has_provider_api_key(req: &axum::extract::Request) -> bool {
272 let headers = req.headers();
273 for key in ["x-api-key", "x-goog-api-key", "api-key"] {
274 if headers
275 .get(key)
276 .and_then(|v| v.to_str().ok())
277 .is_some_and(|v| !v.trim().is_empty())
278 {
279 return true;
280 }
281 }
282 if let Some(auth) = headers.get("authorization").and_then(|v| v.to_str().ok()) {
283 if auth.starts_with("Bearer sk-") || auth.starts_with("Bearer gsk_") {
284 return true;
285 }
286 }
287 false
288}
289
290fn is_provider_route(path: &str) -> bool {
291 path.starts_with("/v1/")
292 || path.starts_with("/v1beta/")
293 || path.starts_with("/chat/completions")
294}
295
296fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
297 use subtle::ConstantTimeEq;
298 if a.len() != b.len() {
299 return false;
300 }
301 bool::from(a.ct_eq(b))
302}
303
304async fn host_guard(
305 req: axum::extract::Request,
306 next: axum::middleware::Next,
307) -> Result<Response, StatusCode> {
308 if let Some(host) = req.headers().get("host").and_then(|v| v.to_str().ok()) {
309 let h = host.split(':').next().unwrap_or(host);
310 if matches!(h, "127.0.0.1" | "localhost" | "[::1]") {
311 return Ok(next.run(req).await);
312 }
313 }
314 Err(StatusCode::FORBIDDEN)
315}
316
317async fn fallback_router(State(state): State<ProxyState>, req: Request<Body>) -> Response {
318 let path = req.uri().path().to_string();
319
320 if path.starts_with("/v1beta/models/") || path.starts_with("/v1/models/") {
321 match google::handler(State(state), req).await {
322 Ok(resp) => resp,
323 Err(status) => Response::builder()
324 .status(status)
325 .body(Body::from("proxy error"))
326 .expect("BUG: building error response with valid status should never fail"),
327 }
328 } else {
329 let method = req.method().to_string();
330 eprintln!("lean-ctx proxy: unmatched {method} {path}");
331 Response::builder()
332 .status(StatusCode::NOT_FOUND)
333 .body(Body::from(format!(
334 "lean-ctx proxy: no handler for {method} {path}"
335 )))
336 .expect("BUG: building 404 response should never fail")
337 }
338}
339
340#[cfg(test)]
341mod auth_tests {
342 use super::*;
343
344 #[test]
345 fn is_provider_route_v1() {
346 assert!(is_provider_route("/v1/chat/completions"));
347 assert!(is_provider_route("/v1/messages"));
348 assert!(is_provider_route("/v1/completions"));
349 }
350
351 #[test]
352 fn is_provider_route_v1beta() {
353 assert!(is_provider_route("/v1beta/models"));
354 }
355
356 #[test]
357 fn is_provider_route_chat() {
358 assert!(is_provider_route("/chat/completions"));
359 }
360
361 #[test]
362 fn is_provider_route_rejects_non_provider() {
363 assert!(!is_provider_route("/health"));
364 assert!(!is_provider_route("/api/v2/test"));
365 assert!(!is_provider_route("/"));
366 }
367
368 fn build_request(headers: &[(&str, &str)], path: &str) -> axum::extract::Request {
369 let mut builder = axum::http::Request::builder().uri(path);
370 for (k, v) in headers {
371 builder = builder.header(*k, *v);
372 }
373 builder.body(axum::body::Body::empty()).unwrap()
374 }
375
376 #[test]
377 fn has_provider_api_key_x_api_key() {
378 let req = build_request(&[("x-api-key", "sk-ant-abc123")], "/v1/messages");
379 assert!(has_provider_api_key(&req));
380 }
381
382 #[test]
383 fn has_provider_api_key_x_goog() {
384 let req = build_request(&[("x-goog-api-key", "AIzaSyAbc")], "/v1beta/models");
385 assert!(has_provider_api_key(&req));
386 }
387
388 #[test]
389 fn has_provider_api_key_azure() {
390 let req = build_request(&[("api-key", "deadbeef")], "/v1/completions");
391 assert!(has_provider_api_key(&req));
392 }
393
394 #[test]
395 fn has_provider_api_key_bearer_sk() {
396 let req = build_request(
397 &[("authorization", "Bearer sk-proj-abc123")],
398 "/v1/chat/completions",
399 );
400 assert!(has_provider_api_key(&req));
401 }
402
403 #[test]
404 fn has_provider_api_key_empty_rejected() {
405 let req = build_request(&[("x-api-key", " ")], "/v1/messages");
406 assert!(!has_provider_api_key(&req));
407 }
408
409 #[test]
410 fn has_provider_api_key_no_headers() {
411 let req = build_request(&[], "/v1/messages");
412 assert!(!has_provider_api_key(&req));
413 }
414
415 #[test]
416 fn has_provider_api_key_regular_bearer_rejected() {
417 let req = build_request(
418 &[("authorization", "Bearer some-session-token")],
419 "/v1/chat",
420 );
421 assert!(
422 !has_provider_api_key(&req),
423 "non-sk/gsk Bearer should not count as provider key"
424 );
425 }
426}