1use std::net::SocketAddr;
15use std::path::PathBuf;
16use std::sync::{Arc, Mutex};
17
18use anyhow::Result;
19use axum::body::Body;
20use axum::extract::State;
21use axum::http::{header, Method, StatusCode};
22use axum::response::Response;
23use axum::routing::any;
24use axum::Router;
25use futures_util::StreamExt;
26use serde_json::Value;
27
28mod optimize;
29pub use optimize::{optimize_request, Optimization, OptimizeOpts};
30
31const MAX_BODY: usize = 64 * 1024 * 1024;
33const USAGE_SCAN_CAP: usize = 256 * 1024;
35
36type Db = Arc<Mutex<rusqlite::Connection>>;
37
38#[derive(Clone)]
39pub struct ProxyConfig {
40 pub bind: SocketAddr,
41 pub upstream: String,
42 pub dry_run: bool,
43 pub cache: bool,
44 pub dedup: bool,
45 pub prune: bool,
46 pub db_path: Option<PathBuf>,
47}
48
49struct AppState {
50 cfg: ProxyConfig,
51 client: reqwest::Client,
52 db: Option<Db>,
53}
54
55pub async fn serve(cfg: ProxyConfig) -> Result<()> {
57 let db = cfg
58 .db_path
59 .as_ref()
60 .and_then(|p| match mimir_core::db::open(p) {
61 Ok(c) => Some(Arc::new(Mutex::new(c))),
62 Err(e) => {
63 tracing::warn!("proxy: savings ledger disabled ({e})");
64 None
65 }
66 });
67 let client = reqwest::Client::builder().build()?;
68 let bind = cfg.bind;
69 let upstream = cfg.upstream.clone();
70 let dry = cfg.dry_run;
71
72 if let Some(db) = db.clone() {
80 tokio::spawn(async move {
81 let period = std::time::Duration::from_secs(120);
82 let mut tick = tokio::time::interval_at(tokio::time::Instant::now() + period, period);
83 loop {
84 tick.tick().await;
85 let db = db.clone();
86 let _ = tokio::task::spawn_blocking(move || {
87 if let Ok(conn) = db.lock() {
88 if let Err(e) = mimir_core::db::checkpoint_truncate(&conn) {
89 tracing::warn!("proxy: wal checkpoint failed ({e})");
90 }
91 }
92 })
93 .await;
94 }
95 });
96 }
97
98 let state = Arc::new(AppState { cfg, client, db });
99 let app = Router::new().fallback(any(handle)).with_state(state);
100
101 let listener = tokio::net::TcpListener::bind(bind).await?;
102 eprintln!("mimir proxy listening on http://{bind} → {upstream}");
103 eprintln!(" point your client at it: export ANTHROPIC_BASE_URL=http://{bind}");
104 if dry {
105 eprintln!(" dry-run: measuring only, request bodies forwarded unchanged");
106 }
107 axum::serve(listener, app).await?;
108 Ok(())
109}
110
111async fn handle(State(state): State<Arc<AppState>>, req: axum::extract::Request) -> Response {
112 match proxy(&state, req).await {
113 Ok(resp) => resp,
114 Err(e) => {
115 tracing::error!("proxy error: {e}");
116 Response::builder()
117 .status(StatusCode::BAD_GATEWAY)
118 .body(Body::from(format!("mimir proxy error: {e}")))
119 .unwrap()
120 }
121 }
122}
123
124async fn proxy(state: &AppState, req: axum::extract::Request) -> Result<Response> {
125 let (parts, body) = req.into_parts();
126 let path_q = parts
127 .uri
128 .path_and_query()
129 .map(|p| p.as_str())
130 .unwrap_or("/")
131 .to_string();
132 let url = format!("{}{}", state.cfg.upstream.trim_end_matches('/'), path_q);
133 let body_bytes = axum::body::to_bytes(body, MAX_BODY).await?;
134
135 let is_messages = parts.method == Method::POST && parts.uri.path().ends_with("/v1/messages");
136 let (send_bytes, we_cached) = if is_messages {
139 optimize_messages_body(state, &body_bytes)
140 } else {
141 (body_bytes.to_vec(), false)
142 };
143
144 let mut rb = state.client.request(parts.method.clone(), &url);
146 for (name, value) in parts.headers.iter() {
147 if is_hop_by_hop(name.as_str()) || name == header::HOST || name == header::CONTENT_LENGTH {
148 continue;
149 }
150 rb = rb.header(name.clone(), value.clone());
151 }
152 let upstream = rb.body(send_bytes).send().await?;
153
154 let status = upstream.status();
155 let mut builder = Response::builder().status(status);
156 for (name, value) in upstream.headers().iter() {
157 if is_hop_by_hop(name.as_str())
158 || name == header::CONTENT_LENGTH
159 || name == header::TRANSFER_ENCODING
160 {
161 continue;
162 }
163 builder = builder.header(name.clone(), value.clone());
164 }
165
166 let body = if we_cached {
169 let mut scanner = UsageScanner::new(state.db.clone());
170 let tapped = upstream.bytes_stream().inspect(move |chunk| {
171 if let Ok(bytes) = chunk {
172 scanner.feed(bytes);
173 }
174 });
175 Body::from_stream(tapped)
176 } else {
177 Body::from_stream(upstream.bytes_stream())
178 };
179 Ok(builder.body(body)?)
180}
181
182fn optimize_messages_body(state: &AppState, body_bytes: &[u8]) -> (Vec<u8>, bool) {
185 let Ok(json) = serde_json::from_slice::<Value>(body_bytes) else {
186 return (body_bytes.to_vec(), false); };
188 let we_cached = state.cfg.cache && !optimize::has_cache_control(&json);
189 let (new_json, opt) = optimize_request(
190 json,
191 OptimizeOpts {
192 cache: state.cfg.cache,
193 dedup: state.cfg.dedup,
194 prune: state.cfg.prune,
195 },
196 );
197 if state.cfg.dry_run {
198 if opt.deduped > 0 || opt.pruned > 0 {
199 tracing::info!(
200 deduped = opt.deduped,
201 pruned = opt.pruned,
202 "proxy dry-run: would optimize"
203 );
204 }
205 return (body_bytes.to_vec(), false);
206 }
207 record_request_savings(&state.db, &opt);
208 let bytes = serde_json::to_vec(&new_json).unwrap_or_else(|_| body_bytes.to_vec());
209 (bytes, we_cached)
210}
211
212fn is_hop_by_hop(name: &str) -> bool {
213 matches!(
214 name.to_ascii_lowercase().as_str(),
215 "connection"
216 | "keep-alive"
217 | "proxy-authenticate"
218 | "proxy-authorization"
219 | "te"
220 | "trailers"
221 | "transfer-encoding"
222 | "upgrade"
223 )
224}
225
226fn record_request_savings(db: &Option<Db>, opt: &Optimization) {
227 let Some(db) = db else { return };
228 let Ok(conn) = db.lock() else { return };
229 if opt.deduped > 0 {
230 let _ = mimir_core::savings::record(
231 &conn,
232 None,
233 mimir_core::savings::source::PROXY_DEDUP,
234 opt.deduped,
235 0,
236 serde_json::json!({}),
237 );
238 }
239 if opt.pruned > 0 {
240 let _ = mimir_core::savings::record(
241 &conn,
242 None,
243 mimir_core::savings::source::PROXY_PRUNE,
244 opt.pruned,
245 0,
246 serde_json::json!({}),
247 );
248 }
249}
250
251struct UsageScanner {
257 db: Option<Db>,
258 buf: String,
259 done: bool,
260}
261
262impl UsageScanner {
263 fn new(db: Option<Db>) -> Self {
264 UsageScanner {
265 db,
266 buf: String::new(),
267 done: false,
268 }
269 }
270
271 fn feed(&mut self, bytes: &[u8]) {
272 if self.done {
273 return;
274 }
275 self.buf.push_str(&String::from_utf8_lossy(bytes));
276 if let Some(cache_read) = field_u64(&self.buf, "\"cache_read_input_tokens\"") {
277 self.record(cache_read);
278 self.done = true;
279 self.buf = String::new();
280 } else if self.buf.len() > USAGE_SCAN_CAP {
281 self.done = true;
282 self.buf = String::new();
283 }
284 }
285
286 fn record(&self, cache_read: u64) {
287 if cache_read == 0 {
288 return;
289 }
290 let Some(db) = &self.db else { return };
291 let Ok(conn) = db.lock() else { return };
292 let before = cache_read as usize;
294 let after = before / 10;
295 let _ = mimir_core::savings::record(
296 &conn,
297 None,
298 mimir_core::savings::source::PROXY_CACHE,
299 before,
300 after,
301 serde_json::json!({ "measured": true }),
302 );
303 }
304}
305
306fn field_u64(s: &str, key: &str) -> Option<u64> {
308 let i = s.find(key)?;
309 let rest = s[i + key.len()..].trim_start_matches([':', ' ']);
310 let end = rest
311 .find(|c: char| !c.is_ascii_digit())
312 .unwrap_or(rest.len());
313 rest.get(..end).and_then(|d| d.parse().ok())
314}
315
316#[cfg(test)]
317mod tests {
318 use super::field_u64;
319
320 #[test]
321 fn parses_usage_field() {
322 let sse = r#"event: message_start
323data: {"type":"message_start","message":{"usage":{"input_tokens":12,"cache_creation_input_tokens":0,"cache_read_input_tokens":2048}}}"#;
324 assert_eq!(field_u64(sse, "\"cache_read_input_tokens\""), Some(2048));
325 assert_eq!(field_u64(sse, "\"input_tokens\""), Some(12));
326 assert_eq!(field_u64(sse, "\"missing\""), None);
327 }
328}