1use std::net::SocketAddr;
15use std::path::PathBuf;
16use std::sync::{Arc, Mutex};
17
18use anyhow::Result;
19use axum::body::Body;
20use axum::extract::State;
21use axum::http::{header, Method, StatusCode};
22use axum::response::Response;
23use axum::routing::any;
24use axum::Router;
25use futures_util::StreamExt;
26use serde_json::Value;
27
28mod optimize;
29pub use optimize::{optimize_request, Optimization, OptimizeOpts};
30
31const MAX_BODY: usize = 64 * 1024 * 1024;
33const USAGE_SCAN_CAP: usize = 256 * 1024;
35
36type Db = Arc<Mutex<rusqlite::Connection>>;
37
38#[derive(Clone)]
39pub struct ProxyConfig {
40 pub bind: SocketAddr,
41 pub upstream: String,
42 pub dry_run: bool,
43 pub cache: bool,
44 pub dedup: bool,
45 pub prune: bool,
46 pub db_path: Option<PathBuf>,
47}
48
49struct AppState {
50 cfg: ProxyConfig,
51 client: reqwest::Client,
52 db: Option<Db>,
53}
54
55pub async fn serve(cfg: ProxyConfig) -> Result<()> {
57 let db = cfg
58 .db_path
59 .as_ref()
60 .and_then(|p| match mimir_core::db::open(p) {
61 Ok(c) => Some(Arc::new(Mutex::new(c))),
62 Err(e) => {
63 tracing::warn!("proxy: savings ledger disabled ({e})");
64 None
65 }
66 });
67 let client = reqwest::Client::builder().build()?;
68 let bind = cfg.bind;
69 let upstream = cfg.upstream.clone();
70 let dry = cfg.dry_run;
71 let state = Arc::new(AppState { cfg, client, db });
72 let app = Router::new().fallback(any(handle)).with_state(state);
73
74 let listener = tokio::net::TcpListener::bind(bind).await?;
75 eprintln!("mimir proxy listening on http://{bind} → {upstream}");
76 eprintln!(" point your client at it: export ANTHROPIC_BASE_URL=http://{bind}");
77 if dry {
78 eprintln!(" dry-run: measuring only, request bodies forwarded unchanged");
79 }
80 axum::serve(listener, app).await?;
81 Ok(())
82}
83
84async fn handle(State(state): State<Arc<AppState>>, req: axum::extract::Request) -> Response {
85 match proxy(&state, req).await {
86 Ok(resp) => resp,
87 Err(e) => {
88 tracing::error!("proxy error: {e}");
89 Response::builder()
90 .status(StatusCode::BAD_GATEWAY)
91 .body(Body::from(format!("mimir proxy error: {e}")))
92 .unwrap()
93 }
94 }
95}
96
97async fn proxy(state: &AppState, req: axum::extract::Request) -> Result<Response> {
98 let (parts, body) = req.into_parts();
99 let path_q = parts
100 .uri
101 .path_and_query()
102 .map(|p| p.as_str())
103 .unwrap_or("/")
104 .to_string();
105 let url = format!("{}{}", state.cfg.upstream.trim_end_matches('/'), path_q);
106 let body_bytes = axum::body::to_bytes(body, MAX_BODY).await?;
107
108 let is_messages = parts.method == Method::POST && parts.uri.path().ends_with("/v1/messages");
109 let (send_bytes, we_cached) = if is_messages {
112 optimize_messages_body(state, &body_bytes)
113 } else {
114 (body_bytes.to_vec(), false)
115 };
116
117 let mut rb = state.client.request(parts.method.clone(), &url);
119 for (name, value) in parts.headers.iter() {
120 if is_hop_by_hop(name.as_str()) || name == header::HOST || name == header::CONTENT_LENGTH {
121 continue;
122 }
123 rb = rb.header(name.clone(), value.clone());
124 }
125 let upstream = rb.body(send_bytes).send().await?;
126
127 let status = upstream.status();
128 let mut builder = Response::builder().status(status);
129 for (name, value) in upstream.headers().iter() {
130 if is_hop_by_hop(name.as_str())
131 || name == header::CONTENT_LENGTH
132 || name == header::TRANSFER_ENCODING
133 {
134 continue;
135 }
136 builder = builder.header(name.clone(), value.clone());
137 }
138
139 let body = if we_cached {
142 let mut scanner = UsageScanner::new(state.db.clone());
143 let tapped = upstream.bytes_stream().inspect(move |chunk| {
144 if let Ok(bytes) = chunk {
145 scanner.feed(bytes);
146 }
147 });
148 Body::from_stream(tapped)
149 } else {
150 Body::from_stream(upstream.bytes_stream())
151 };
152 Ok(builder.body(body)?)
153}
154
155fn optimize_messages_body(state: &AppState, body_bytes: &[u8]) -> (Vec<u8>, bool) {
158 let Ok(json) = serde_json::from_slice::<Value>(body_bytes) else {
159 return (body_bytes.to_vec(), false); };
161 let we_cached = state.cfg.cache && !optimize::has_cache_control(&json);
162 let (new_json, opt) = optimize_request(
163 json,
164 OptimizeOpts {
165 cache: state.cfg.cache,
166 dedup: state.cfg.dedup,
167 prune: state.cfg.prune,
168 },
169 );
170 if state.cfg.dry_run {
171 if opt.deduped > 0 || opt.pruned > 0 {
172 tracing::info!(
173 deduped = opt.deduped,
174 pruned = opt.pruned,
175 "proxy dry-run: would optimize"
176 );
177 }
178 return (body_bytes.to_vec(), false);
179 }
180 record_request_savings(&state.db, &opt);
181 let bytes = serde_json::to_vec(&new_json).unwrap_or_else(|_| body_bytes.to_vec());
182 (bytes, we_cached)
183}
184
185fn is_hop_by_hop(name: &str) -> bool {
186 matches!(
187 name.to_ascii_lowercase().as_str(),
188 "connection"
189 | "keep-alive"
190 | "proxy-authenticate"
191 | "proxy-authorization"
192 | "te"
193 | "trailers"
194 | "transfer-encoding"
195 | "upgrade"
196 )
197}
198
199fn record_request_savings(db: &Option<Db>, opt: &Optimization) {
200 let Some(db) = db else { return };
201 let Ok(conn) = db.lock() else { return };
202 if opt.deduped > 0 {
203 let _ = mimir_core::savings::record(
204 &conn,
205 None,
206 mimir_core::savings::source::PROXY_DEDUP,
207 opt.deduped,
208 0,
209 serde_json::json!({}),
210 );
211 }
212 if opt.pruned > 0 {
213 let _ = mimir_core::savings::record(
214 &conn,
215 None,
216 mimir_core::savings::source::PROXY_PRUNE,
217 opt.pruned,
218 0,
219 serde_json::json!({}),
220 );
221 }
222}
223
224struct UsageScanner {
230 db: Option<Db>,
231 buf: String,
232 done: bool,
233}
234
235impl UsageScanner {
236 fn new(db: Option<Db>) -> Self {
237 UsageScanner {
238 db,
239 buf: String::new(),
240 done: false,
241 }
242 }
243
244 fn feed(&mut self, bytes: &[u8]) {
245 if self.done {
246 return;
247 }
248 self.buf.push_str(&String::from_utf8_lossy(bytes));
249 if let Some(cache_read) = field_u64(&self.buf, "\"cache_read_input_tokens\"") {
250 self.record(cache_read);
251 self.done = true;
252 self.buf = String::new();
253 } else if self.buf.len() > USAGE_SCAN_CAP {
254 self.done = true;
255 self.buf = String::new();
256 }
257 }
258
259 fn record(&self, cache_read: u64) {
260 if cache_read == 0 {
261 return;
262 }
263 let Some(db) = &self.db else { return };
264 let Ok(conn) = db.lock() else { return };
265 let before = cache_read as usize;
267 let after = before / 10;
268 let _ = mimir_core::savings::record(
269 &conn,
270 None,
271 mimir_core::savings::source::PROXY_CACHE,
272 before,
273 after,
274 serde_json::json!({ "measured": true }),
275 );
276 }
277}
278
279fn field_u64(s: &str, key: &str) -> Option<u64> {
281 let i = s.find(key)?;
282 let rest = s[i + key.len()..].trim_start_matches([':', ' ']);
283 let end = rest
284 .find(|c: char| !c.is_ascii_digit())
285 .unwrap_or(rest.len());
286 rest.get(..end).and_then(|d| d.parse().ok())
287}
288
289#[cfg(test)]
290mod tests {
291 use super::field_u64;
292
293 #[test]
294 fn parses_usage_field() {
295 let sse = r#"event: message_start
296data: {"type":"message_start","message":{"usage":{"input_tokens":12,"cache_creation_input_tokens":0,"cache_read_input_tokens":2048}}}"#;
297 assert_eq!(field_u64(sse, "\"cache_read_input_tokens\""), Some(2048));
298 assert_eq!(field_u64(sse, "\"input_tokens\""), Some(12));
299 assert_eq!(field_u64(sse, "\"missing\""), None);
300 }
301}