Skip to main content

mimir_proxy/
lib.rs

1//! Mimir's optional local API proxy ("Headroom"-style, our own implementation).
2//!
3//! It sits between the agent and `api.anthropic.com`: forwards every request
4//! verbatim (streaming SSE back untouched) EXCEPT it may rewrite the body of
5//! `POST /v1/messages` to add prompt-cache breakpoints, dedup repeated blocks,
6//! and (optionally) prune stale tool results. Cache savings are MEASURED from
7//! the response's `usage.cache_read_input_tokens`; dedup/prune are measured at
8//! the request.
9//!
10//! It is OFF by default (you start it with `mimir proxy`) and it is a
11//! man-in-the-middle on your prompts/completions — see docs/proxy.md. Auth
12//! headers pass through untouched; the proxy never reads your API key.
13
14use std::net::SocketAddr;
15use std::path::PathBuf;
16use std::sync::{Arc, Mutex};
17
18use anyhow::Result;
19use axum::body::Body;
20use axum::extract::State;
21use axum::http::{header, Method, StatusCode};
22use axum::response::Response;
23use axum::routing::any;
24use axum::Router;
25use futures_util::StreamExt;
26use serde_json::Value;
27
28mod optimize;
29pub use optimize::{optimize_request, Optimization, OptimizeOpts};
30
31/// Upper bound on a buffered request body (the messages JSON).
32const MAX_BODY: usize = 64 * 1024 * 1024;
33/// How far into the response we scan for the usage block before giving up.
34const USAGE_SCAN_CAP: usize = 256 * 1024;
35
36type Db = Arc<Mutex<rusqlite::Connection>>;
37
38#[derive(Clone)]
39pub struct ProxyConfig {
40    pub bind: SocketAddr,
41    pub upstream: String,
42    pub dry_run: bool,
43    pub cache: bool,
44    pub dedup: bool,
45    pub prune: bool,
46    pub db_path: Option<PathBuf>,
47}
48
49struct AppState {
50    cfg: ProxyConfig,
51    client: reqwest::Client,
52    db: Option<Db>,
53}
54
55/// Run the proxy until the process is killed.
56pub async fn serve(cfg: ProxyConfig) -> Result<()> {
57    let db = cfg
58        .db_path
59        .as_ref()
60        .and_then(|p| match mimir_core::db::open(p) {
61            Ok(c) => Some(Arc::new(Mutex::new(c))),
62            Err(e) => {
63                tracing::warn!("proxy: savings ledger disabled ({e})");
64                None
65            }
66        });
67    let client = reqwest::Client::builder().build()?;
68    let bind = cfg.bind;
69    let upstream = cfg.upstream.clone();
70    let dry = cfg.dry_run;
71
72    // Reclaim the WAL on an idle timer. The proxy is long-lived and its
73    // per-request savings writes leave a -wal that PASSIVE autocheckpoints may
74    // never fully reset under a steady reader. A periodic TRUNCATE checkpoint
75    // keeps it bounded; best-effort and non-fatal. Unlike the sync daemons this
76    // reuses the shared savings connection (autocommit between requests, so no
77    // read mark blocks it), and runs the blocking checkpoint on spawn_blocking
78    // so a large TRUNCATE never stalls a runtime worker.
79    if let Some(db) = db.clone() {
80        tokio::spawn(async move {
81            let period = std::time::Duration::from_secs(120);
82            let mut tick = tokio::time::interval_at(tokio::time::Instant::now() + period, period);
83            loop {
84                tick.tick().await;
85                let db = db.clone();
86                let _ = tokio::task::spawn_blocking(move || {
87                    if let Ok(conn) = db.lock() {
88                        if let Err(e) = mimir_core::db::checkpoint_truncate(&conn) {
89                            tracing::warn!("proxy: wal checkpoint failed ({e})");
90                        }
91                    }
92                })
93                .await;
94            }
95        });
96    }
97
98    let state = Arc::new(AppState { cfg, client, db });
99    let app = Router::new().fallback(any(handle)).with_state(state);
100
101    let listener = tokio::net::TcpListener::bind(bind).await?;
102    eprintln!("mimir proxy listening on http://{bind} → {upstream}");
103    eprintln!("  point your client at it:  export ANTHROPIC_BASE_URL=http://{bind}");
104    if dry {
105        eprintln!("  dry-run: measuring only, request bodies forwarded unchanged");
106    }
107    axum::serve(listener, app).await?;
108    Ok(())
109}
110
111async fn handle(State(state): State<Arc<AppState>>, req: axum::extract::Request) -> Response {
112    match proxy(&state, req).await {
113        Ok(resp) => resp,
114        Err(e) => {
115            tracing::error!("proxy error: {e}");
116            Response::builder()
117                .status(StatusCode::BAD_GATEWAY)
118                .body(Body::from(format!("mimir proxy error: {e}")))
119                .unwrap()
120        }
121    }
122}
123
124async fn proxy(state: &AppState, req: axum::extract::Request) -> Result<Response> {
125    let (parts, body) = req.into_parts();
126    let path_q = parts
127        .uri
128        .path_and_query()
129        .map(|p| p.as_str())
130        .unwrap_or("/")
131        .to_string();
132    let url = format!("{}{}", state.cfg.upstream.trim_end_matches('/'), path_q);
133    let body_bytes = axum::body::to_bytes(body, MAX_BODY).await?;
134
135    let is_messages = parts.method == Method::POST && parts.uri.path().ends_with("/v1/messages");
136    // we_cached: true only on requests where WE introduced caching (so the
137    // measured cache reads downstream are ours to claim).
138    let (send_bytes, we_cached) = if is_messages {
139        optimize_messages_body(state, &body_bytes)
140    } else {
141        (body_bytes.to_vec(), false)
142    };
143
144    // Forward upstream, copying headers (minus hop-by-hop, host, content-length).
145    let mut rb = state.client.request(parts.method.clone(), &url);
146    for (name, value) in parts.headers.iter() {
147        if is_hop_by_hop(name.as_str()) || name == header::HOST || name == header::CONTENT_LENGTH {
148            continue;
149        }
150        rb = rb.header(name.clone(), value.clone());
151    }
152    let upstream = rb.body(send_bytes).send().await?;
153
154    let status = upstream.status();
155    let mut builder = Response::builder().status(status);
156    for (name, value) in upstream.headers().iter() {
157        if is_hop_by_hop(name.as_str())
158            || name == header::CONTENT_LENGTH
159            || name == header::TRANSFER_ENCODING
160        {
161            continue;
162        }
163        builder = builder.header(name.clone(), value.clone());
164    }
165
166    // Tap the stream to record real cache savings, but only when we added the
167    // breakpoints (otherwise the cache hits aren't ours to claim).
168    let body = if we_cached {
169        let mut scanner = UsageScanner::new(state.db.clone());
170        let tapped = upstream.bytes_stream().inspect(move |chunk| {
171            if let Ok(bytes) = chunk {
172                scanner.feed(bytes);
173            }
174        });
175        Body::from_stream(tapped)
176    } else {
177        Body::from_stream(upstream.bytes_stream())
178    };
179    Ok(builder.body(body)?)
180}
181
182/// Optimize a `/v1/messages` body. Returns the bytes to forward and whether we
183/// introduced caching (and so should measure the response's cache reads).
184fn optimize_messages_body(state: &AppState, body_bytes: &[u8]) -> (Vec<u8>, bool) {
185    let Ok(json) = serde_json::from_slice::<Value>(body_bytes) else {
186        return (body_bytes.to_vec(), false); // not JSON we understand — pass through
187    };
188    let we_cached = state.cfg.cache && !optimize::has_cache_control(&json);
189    let (new_json, opt) = optimize_request(
190        json,
191        OptimizeOpts {
192            cache: state.cfg.cache,
193            dedup: state.cfg.dedup,
194            prune: state.cfg.prune,
195        },
196    );
197    if state.cfg.dry_run {
198        if opt.deduped > 0 || opt.pruned > 0 {
199            tracing::info!(
200                deduped = opt.deduped,
201                pruned = opt.pruned,
202                "proxy dry-run: would optimize"
203            );
204        }
205        return (body_bytes.to_vec(), false);
206    }
207    record_request_savings(&state.db, &opt);
208    let bytes = serde_json::to_vec(&new_json).unwrap_or_else(|_| body_bytes.to_vec());
209    (bytes, we_cached)
210}
211
212fn is_hop_by_hop(name: &str) -> bool {
213    matches!(
214        name.to_ascii_lowercase().as_str(),
215        "connection"
216            | "keep-alive"
217            | "proxy-authenticate"
218            | "proxy-authorization"
219            | "te"
220            | "trailers"
221            | "transfer-encoding"
222            | "upgrade"
223    )
224}
225
226fn record_request_savings(db: &Option<Db>, opt: &Optimization) {
227    let Some(db) = db else { return };
228    let Ok(conn) = db.lock() else { return };
229    if opt.deduped > 0 {
230        let _ = mimir_core::savings::record(
231            &conn,
232            None,
233            mimir_core::savings::source::PROXY_DEDUP,
234            opt.deduped,
235            0,
236            serde_json::json!({}),
237        );
238    }
239    if opt.pruned > 0 {
240        let _ = mimir_core::savings::record(
241            &conn,
242            None,
243            mimir_core::savings::source::PROXY_PRUNE,
244            opt.pruned,
245            0,
246            serde_json::json!({}),
247        );
248    }
249}
250
251/// Scans the start of a response for the `usage` block and records the REAL
252/// cache savings once. Forwards nothing — it's a passive tap.
253///
254/// Assumes the Anthropic SSE/JSON shape (a `"cache_read_input_tokens": N` field
255/// near the start). On any other shape it simply finds nothing and records 0.
256struct UsageScanner {
257    db: Option<Db>,
258    buf: String,
259    done: bool,
260}
261
262impl UsageScanner {
263    fn new(db: Option<Db>) -> Self {
264        UsageScanner {
265            db,
266            buf: String::new(),
267            done: false,
268        }
269    }
270
271    fn feed(&mut self, bytes: &[u8]) {
272        if self.done {
273            return;
274        }
275        self.buf.push_str(&String::from_utf8_lossy(bytes));
276        if let Some(cache_read) = field_u64(&self.buf, "\"cache_read_input_tokens\"") {
277            self.record(cache_read);
278            self.done = true;
279            self.buf = String::new();
280        } else if self.buf.len() > USAGE_SCAN_CAP {
281            self.done = true;
282            self.buf = String::new();
283        }
284    }
285
286    fn record(&self, cache_read: u64) {
287        if cache_read == 0 {
288            return;
289        }
290        let Some(db) = &self.db else { return };
291        let Ok(conn) = db.lock() else { return };
292        // Cached input bills at ~10% → ~90% saved on these tokens.
293        let before = cache_read as usize;
294        let after = before / 10;
295        let _ = mimir_core::savings::record(
296            &conn,
297            None,
298            mimir_core::savings::source::PROXY_CACHE,
299            before,
300            after,
301            serde_json::json!({ "measured": true }),
302        );
303    }
304}
305
306/// Parse `"key": <digits>` out of a JSON fragment.
307fn field_u64(s: &str, key: &str) -> Option<u64> {
308    let i = s.find(key)?;
309    let rest = s[i + key.len()..].trim_start_matches([':', ' ']);
310    let end = rest
311        .find(|c: char| !c.is_ascii_digit())
312        .unwrap_or(rest.len());
313    rest.get(..end).and_then(|d| d.parse().ok())
314}
315
316#[cfg(test)]
317mod tests {
318    use super::field_u64;
319
320    #[test]
321    fn parses_usage_field() {
322        let sse = r#"event: message_start
323data: {"type":"message_start","message":{"usage":{"input_tokens":12,"cache_creation_input_tokens":0,"cache_read_input_tokens":2048}}}"#;
324        assert_eq!(field_u64(sse, "\"cache_read_input_tokens\""), Some(2048));
325        assert_eq!(field_u64(sse, "\"input_tokens\""), Some(12));
326        assert_eq!(field_u64(sse, "\"missing\""), None);
327    }
328}