1pub mod anthropic;
2pub mod compress;
3pub mod forward;
4pub mod google;
5pub mod introspect;
6pub mod openai;
7
8use std::net::SocketAddr;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11
12use axum::{
13 body::Body,
14 extract::State,
15 http::{Request, StatusCode},
16 response::{IntoResponse, Response},
17 routing::{any, get},
18 Router,
19};
20
21#[derive(Clone)]
22pub struct ProxyState {
23 pub client: reqwest::Client,
24 pub port: u16,
25 pub stats: Arc<ProxyStats>,
26 pub introspect: Arc<introspect::IntrospectState>,
27 pub anthropic_upstream: String,
28 pub openai_upstream: String,
29 pub gemini_upstream: String,
30}
31
32pub struct ProxyStats {
33 pub requests_total: AtomicU64,
34 pub requests_compressed: AtomicU64,
35 pub tokens_saved: AtomicU64,
36 pub bytes_original: AtomicU64,
37 pub bytes_compressed: AtomicU64,
38}
39
40impl Default for ProxyStats {
41 fn default() -> Self {
42 Self {
43 requests_total: AtomicU64::new(0),
44 requests_compressed: AtomicU64::new(0),
45 tokens_saved: AtomicU64::new(0),
46 bytes_original: AtomicU64::new(0),
47 bytes_compressed: AtomicU64::new(0),
48 }
49 }
50}
51
52impl ProxyStats {
53 pub fn record_request(&self) {
54 self.requests_total.fetch_add(1, Ordering::Relaxed);
55 }
56
57 pub fn record_compression(&self, original: usize, compressed: usize) {
58 self.requests_compressed.fetch_add(1, Ordering::Relaxed);
59 self.bytes_original
60 .fetch_add(original as u64, Ordering::Relaxed);
61 self.bytes_compressed
62 .fetch_add(compressed as u64, Ordering::Relaxed);
63 let saved_tokens = (original.saturating_sub(compressed) / 4) as u64;
64 self.tokens_saved.fetch_add(saved_tokens, Ordering::Relaxed);
65 }
66
67 pub fn compression_ratio(&self) -> f64 {
68 let original = self.bytes_original.load(Ordering::Relaxed);
69 if original == 0 {
70 return 0.0;
71 }
72 let compressed = self.bytes_compressed.load(Ordering::Relaxed);
73 (1.0 - compressed as f64 / original as f64) * 100.0
74 }
75}
76
77pub async fn start_proxy(port: u16) -> anyhow::Result<()> {
78 start_proxy_with_token(port, std::env::var("LEAN_CTX_PROXY_TOKEN").ok()).await
79}
80
81pub async fn start_proxy_with_token(port: u16, auth_token: Option<String>) -> anyhow::Result<()> {
82 use crate::core::config::{Config, ProxyProvider};
83
84 let client = reqwest::Client::builder()
85 .timeout(std::time::Duration::from_mins(2))
86 .build()?;
87
88 let cfg = Config::load();
89 let anthropic_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Anthropic);
90 let openai_upstream = cfg.proxy.resolve_upstream(ProxyProvider::OpenAi);
91 let gemini_upstream = cfg.proxy.resolve_upstream(ProxyProvider::Gemini);
92
93 let state = ProxyState {
94 client,
95 port,
96 stats: Arc::new(ProxyStats::default()),
97 introspect: Arc::new(introspect::IntrospectState::default()),
98 anthropic_upstream: anthropic_upstream.clone(),
99 openai_upstream: openai_upstream.clone(),
100 gemini_upstream: gemini_upstream.clone(),
101 };
102
103 let mut app = Router::new()
104 .route("/health", get(health))
105 .route("/status", get(status_handler))
106 .route("/v1/messages", any(anthropic::handler))
107 .route("/v1/chat/completions", any(openai::handler))
108 .fallback(fallback_router)
109 .layer(axum::middleware::from_fn(host_guard))
110 .with_state(state);
111
112 if let Some(ref token) = auth_token {
113 let expected = token.clone();
114 app = app.layer(axum::middleware::from_fn(move |req, next| {
115 let expected = expected.clone();
116 proxy_auth_guard(req, next, expected)
117 }));
118 }
119
120 let addr = SocketAddr::from(([127, 0, 0, 1], port));
121 if auth_token.is_some() {
122 println!("lean-ctx proxy listening on http://{addr} (token auth enabled)");
123 } else {
124 println!("lean-ctx proxy listening on http://{addr} (no auth — set LEAN_CTX_PROXY_TOKEN to enable)");
125 }
126 println!(" Anthropic: POST /v1/messages → {anthropic_upstream}");
127 println!(" OpenAI: POST /v1/chat/completions → {openai_upstream}");
128 println!(" Gemini: POST /v1beta/models/... → {gemini_upstream}");
129
130 let listener = tokio::net::TcpListener::bind(addr).await?;
131 axum::serve(listener, app).await?;
132
133 Ok(())
134}
135
136async fn health() -> impl IntoResponse {
137 (StatusCode::OK, "ok")
138}
139
140async fn status_handler(State(state): State<ProxyState>) -> impl IntoResponse {
141 use std::sync::atomic::Ordering::Relaxed;
142 let s = &state.stats;
143 let i = &state.introspect;
144
145 let last_breakdown = i
146 .last_breakdown
147 .lock()
148 .ok()
149 .and_then(|guard| guard.as_ref().map(|b| serde_json::to_value(b).ok()))
150 .flatten();
151
152 let body = serde_json::json!({
153 "status": "running",
154 "port": state.port,
155 "requests_total": s.requests_total.load(Relaxed),
156 "requests_compressed": s.requests_compressed.load(Relaxed),
157 "tokens_saved": s.tokens_saved.load(Relaxed),
158 "bytes_original": s.bytes_original.load(Relaxed),
159 "bytes_compressed": s.bytes_compressed.load(Relaxed),
160 "compression_ratio_pct": format!("{:.1}", s.compression_ratio()),
161 "introspect": {
162 "total_requests_analyzed": i.total_requests.load(Relaxed),
163 "total_system_prompt_tokens": i.total_system_prompt_tokens.load(Relaxed),
164 "last_breakdown": last_breakdown,
165 }
166 });
167 (StatusCode::OK, axum::Json(body))
168}
169
170async fn proxy_auth_guard(
171 req: axum::extract::Request,
172 next: axum::middleware::Next,
173 expected_token: String,
174) -> Result<Response, StatusCode> {
175 let path = req.uri().path();
176 if path == "/health" || path == "/status" {
177 return Ok(next.run(req).await);
178 }
179
180 if let Some(auth) = req
181 .headers()
182 .get("authorization")
183 .and_then(|v| v.to_str().ok())
184 {
185 if let Some(token) = auth.strip_prefix("Bearer ") {
186 if token == expected_token {
187 return Ok(next.run(req).await);
188 }
189 }
190 }
191
192 Err(StatusCode::UNAUTHORIZED)
193}
194
195async fn host_guard(
196 req: axum::extract::Request,
197 next: axum::middleware::Next,
198) -> Result<Response, StatusCode> {
199 if let Some(host) = req.headers().get("host").and_then(|v| v.to_str().ok()) {
200 let h = host.split(':').next().unwrap_or(host);
201 if matches!(h, "127.0.0.1" | "localhost" | "[::1]") {
202 return Ok(next.run(req).await);
203 }
204 } else {
205 return Ok(next.run(req).await);
206 }
207 Err(StatusCode::FORBIDDEN)
208}
209
210async fn fallback_router(State(state): State<ProxyState>, req: Request<Body>) -> Response {
211 let path = req.uri().path().to_string();
212
213 if path.starts_with("/v1beta/models/") || path.starts_with("/v1/models/") {
214 match google::handler(State(state), req).await {
215 Ok(resp) => resp,
216 Err(status) => Response::builder()
217 .status(status)
218 .body(Body::from("proxy error"))
219 .expect("BUG: building error response with valid status should never fail"),
220 }
221 } else {
222 let method = req.method().to_string();
223 eprintln!("lean-ctx proxy: unmatched {method} {path}");
224 Response::builder()
225 .status(StatusCode::NOT_FOUND)
226 .body(Body::from(format!(
227 "lean-ctx proxy: no handler for {method} {path}"
228 )))
229 .expect("BUG: building 404 response should never fail")
230 }
231}