1pub mod anthropic;
2pub mod compress;
3pub mod forward;
4pub mod google;
5pub mod introspect;
6pub mod metrics;
7pub mod openai;
8
9use std::net::SocketAddr;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12
13use axum::{
14 body::Body,
15 extract::State,
16 http::{Request, StatusCode},
17 response::{IntoResponse, Response},
18 routing::{any, get},
19 Router,
20};
21
22#[derive(Clone)]
23pub struct ProxyState {
24 pub client: reqwest::Client,
25 pub port: u16,
26 pub stats: Arc<ProxyStats>,
27 pub introspect: Arc<introspect::IntrospectState>,
28 pub anthropic_upstream: String,
29 pub openai_upstream: String,
30 pub gemini_upstream: String,
31}
32
33pub struct ProxyStats {
34 pub requests_total: AtomicU64,
35 pub requests_compressed: AtomicU64,
36 pub tokens_saved: AtomicU64,
37 pub bytes_original: AtomicU64,
38 pub bytes_compressed: AtomicU64,
39}
40
41impl Default for ProxyStats {
42 fn default() -> Self {
43 Self {
44 requests_total: AtomicU64::new(0),
45 requests_compressed: AtomicU64::new(0),
46 tokens_saved: AtomicU64::new(0),
47 bytes_original: AtomicU64::new(0),
48 bytes_compressed: AtomicU64::new(0),
49 }
50 }
51}
52
53impl ProxyStats {
54 pub fn record_request(&self) {
55 self.requests_total.fetch_add(1, Ordering::Relaxed);
56 }
57
58 pub fn record_compression(&self, original: usize, compressed: usize) {
59 self.requests_compressed.fetch_add(1, Ordering::Relaxed);
60 self.bytes_original
61 .fetch_add(original as u64, Ordering::Relaxed);
62 self.bytes_compressed
63 .fetch_add(compressed as u64, Ordering::Relaxed);
64 let saved_tokens = (original.saturating_sub(compressed) / 4) as u64;
65 self.tokens_saved.fetch_add(saved_tokens, Ordering::Relaxed);
66 }
67
68 pub fn compression_ratio(&self) -> f64 {
69 let original = self.bytes_original.load(Ordering::Relaxed);
70 if original == 0 {
71 return 0.0;
72 }
73 let compressed = self.bytes_compressed.load(Ordering::Relaxed);
74 (1.0 - compressed as f64 / original as f64) * 100.0
75 }
76}
77
78pub async fn start_proxy(port: u16) -> anyhow::Result<()> {
79 let token = crate::core::session_token::resolve_proxy_token("LEAN_CTX_PROXY_TOKEN");
80 start_proxy_with_token(port, Some(token)).await
81}
82
83pub async fn start_proxy_with_token(port: u16, auth_token: Option<String>) -> anyhow::Result<()> {
84 use crate::core::config::{Config, ProxyProvider};
85
86 let client = reqwest::Client::builder()
87 .timeout(std::time::Duration::from_mins(2))
88 .build()?;
89
90 let cfg = Config::load();
91 let anthropic_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Anthropic);
92 let openai_upstream = cfg.proxy.resolve_upstream(ProxyProvider::OpenAi);
93 let gemini_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Gemini);
94
95 let state = ProxyState {
96 client,
97 port,
98 stats: Arc::new(ProxyStats::default()),
99 introspect: Arc::new(introspect::IntrospectState::default()),
100 anthropic_upstream: anthropic_upstream.clone(),
101 openai_upstream: openai_upstream.clone(),
102 gemini_upstream: gemini_upstream.clone(),
103 };
104
105 let mut app = Router::new()
106 .route("/health", get(health))
107 .route("/status", get(status_handler))
108 .route("/v1/messages", any(anthropic::handler))
109 .route("/v1/chat/completions", any(openai::handler))
110 .route("/v1/references/{id}", get(v1_resolve_reference))
111 .fallback(fallback_router)
112 .layer(axum::middleware::from_fn(host_guard))
113 .with_state(state);
114
115 if let Some(ref token) = auth_token {
116 let expected = token.clone();
117 app = app.layer(axum::middleware::from_fn(move |req, next| {
118 let expected = expected.clone();
119 proxy_auth_guard(req, next, expected)
120 }));
121 }
122
123 let addr = SocketAddr::from(([127, 0, 0, 1], port));
124 if auth_token.is_some() {
125 println!("lean-ctx proxy listening on http://{addr} (token auth enabled)");
126 } else {
127 println!("lean-ctx proxy listening on http://{addr} (no auth — set LEAN_CTX_PROXY_TOKEN to enable)");
128 }
129 println!(" Anthropic: POST /v1/messages → {anthropic_upstream}");
130 println!(" OpenAI: POST /v1/chat/completions → {openai_upstream}");
131 println!(" Gemini: POST /v1beta/models/... → {gemini_upstream}");
132
133 let listener = tokio::net::TcpListener::bind(addr).await?;
134 axum::serve(listener, app)
135 .with_graceful_shutdown(shutdown_signal())
136 .await?;
137
138 println!("lean-ctx proxy shut down cleanly.");
139 Ok(())
140}
141
142async fn shutdown_signal() {
143 let ctrl_c = tokio::signal::ctrl_c();
144
145 #[cfg(unix)]
146 {
147 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
148 .expect("failed to install SIGTERM handler");
149 tokio::select! {
150 _ = ctrl_c => {},
151 _ = sigterm.recv() => {},
152 }
153 }
154
155 #[cfg(not(unix))]
156 {
157 ctrl_c.await.ok();
158 }
159
160 println!("lean-ctx proxy: received shutdown signal, draining…");
161}
162
163async fn health() -> impl IntoResponse {
164 let body = serde_json::json!({
165 "status": "ok",
166 "pid": std::process::id(),
167 });
168 (StatusCode::OK, axum::Json(body))
169}
170
171async fn v1_resolve_reference(
172 axum::extract::Path(id): axum::extract::Path<String>,
173) -> impl IntoResponse {
174 match crate::server::reference_store::resolve(&id) {
175 Some(content) => (StatusCode::OK, content),
176 None => (
177 StatusCode::NOT_FOUND,
178 "Reference expired or not found".to_string(),
179 ),
180 }
181}
182
183async fn status_handler(State(state): State<ProxyState>) -> impl IntoResponse {
184 use std::sync::atomic::Ordering::Relaxed;
185 let s = &state.stats;
186 let i = &state.introspect;
187
188 let last_breakdown = i
189 .last_breakdown
190 .lock()
191 .ok()
192 .and_then(|guard| guard.as_ref().map(|b| serde_json::to_value(b).ok()))
193 .flatten();
194
195 let body = serde_json::json!({
196 "status": "running",
197 "port": state.port,
198 "requests_total": s.requests_total.load(Relaxed),
199 "requests_compressed": s.requests_compressed.load(Relaxed),
200 "tokens_saved": s.tokens_saved.load(Relaxed),
201 "bytes_original": s.bytes_original.load(Relaxed),
202 "bytes_compressed": s.bytes_compressed.load(Relaxed),
203 "compression_ratio_pct": format!("{:.1}", s.compression_ratio()),
204 "introspect": {
205 "total_requests_analyzed": i.total_requests.load(Relaxed),
206 "total_system_prompt_tokens": i.total_system_prompt_tokens.load(Relaxed),
207 "last_breakdown": last_breakdown,
208 }
209 });
210 (StatusCode::OK, axum::Json(body))
211}
212
213async fn proxy_auth_guard(
214 req: axum::extract::Request,
215 next: axum::middleware::Next,
216 expected_token: String,
217) -> Result<Response, StatusCode> {
218 let path = req.uri().path();
219 if path == "/health" || path == "/status" {
220 return Ok(next.run(req).await);
221 }
222
223 if let Some(auth) = req
224 .headers()
225 .get("authorization")
226 .and_then(|v| v.to_str().ok())
227 {
228 if let Some(token) = auth.strip_prefix("Bearer ") {
229 if constant_time_eq(token.as_bytes(), expected_token.as_bytes()) {
230 return Ok(next.run(req).await);
231 }
232 }
233 }
234
235 Err(StatusCode::UNAUTHORIZED)
236}
237
238fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
239 use subtle::ConstantTimeEq;
240 if a.len() != b.len() {
241 return false;
242 }
243 bool::from(a.ct_eq(b))
244}
245
246async fn host_guard(
247 req: axum::extract::Request,
248 next: axum::middleware::Next,
249) -> Result<Response, StatusCode> {
250 if let Some(host) = req.headers().get("host").and_then(|v| v.to_str().ok()) {
251 let h = host.split(':').next().unwrap_or(host);
252 if matches!(h, "127.0.0.1" | "localhost" | "[::1]") {
253 return Ok(next.run(req).await);
254 }
255 }
256 Err(StatusCode::FORBIDDEN)
258}
259
260async fn fallback_router(State(state): State<ProxyState>, req: Request<Body>) -> Response {
261 let path = req.uri().path().to_string();
262
263 if path.starts_with("/v1beta/models/") || path.starts_with("/v1/models/") {
264 match google::handler(State(state), req).await {
265 Ok(resp) => resp,
266 Err(status) => Response::builder()
267 .status(status)
268 .body(Body::from("proxy error"))
269 .expect("BUG: building error response with valid status should never fail"),
270 }
271 } else {
272 let method = req.method().to_string();
273 eprintln!("lean-ctx proxy: unmatched {method} {path}");
274 Response::builder()
275 .status(StatusCode::NOT_FOUND)
276 .body(Body::from(format!(
277 "lean-ctx proxy: no handler for {method} {path}"
278 )))
279 .expect("BUG: building 404 response should never fail")
280 }
281}