Skip to main content

mimir_proxy/
lib.rs

1//! Mimir's optional local API proxy ("Headroom"-style, our own implementation).
2//!
3//! It sits between the agent and `api.anthropic.com`: forwards every request
4//! verbatim (streaming SSE back untouched) EXCEPT it may rewrite the body of
5//! `POST /v1/messages` to add prompt-cache breakpoints, dedup repeated blocks,
6//! and (optionally) prune stale tool results. Cache savings are MEASURED from
7//! the response's `usage.cache_read_input_tokens`; dedup/prune are measured at
8//! the request.
9//!
10//! It is OFF by default (you start it with `mimir proxy`) and it is a
11//! man-in-the-middle on your prompts/completions — see docs/proxy.md. Auth
12//! headers pass through untouched; the proxy never reads your API key.
13
14use std::net::SocketAddr;
15use std::path::PathBuf;
16use std::sync::{Arc, Mutex};
17
18use anyhow::Result;
19use axum::body::Body;
20use axum::extract::State;
21use axum::http::{header, Method, StatusCode};
22use axum::response::Response;
23use axum::routing::any;
24use axum::Router;
25use futures_util::StreamExt;
26use serde_json::Value;
27
28mod optimize;
29pub use optimize::{optimize_request, Optimization, OptimizeOpts};
30
31/// Upper bound on a buffered request body (the messages JSON).
32const MAX_BODY: usize = 64 * 1024 * 1024;
33/// How far into the response we scan for the usage block before giving up.
34const USAGE_SCAN_CAP: usize = 256 * 1024;
35
36type Db = Arc<Mutex<rusqlite::Connection>>;
37
38#[derive(Clone)]
39pub struct ProxyConfig {
40    pub bind: SocketAddr,
41    pub upstream: String,
42    pub dry_run: bool,
43    pub cache: bool,
44    pub dedup: bool,
45    pub prune: bool,
46    pub db_path: Option<PathBuf>,
47}
48
49struct AppState {
50    cfg: ProxyConfig,
51    client: reqwest::Client,
52    db: Option<Db>,
53}
54
55/// Run the proxy until the process is killed.
56pub async fn serve(cfg: ProxyConfig) -> Result<()> {
57    let db = cfg
58        .db_path
59        .as_ref()
60        .and_then(|p| match mimir_core::db::open(p) {
61            Ok(c) => Some(Arc::new(Mutex::new(c))),
62            Err(e) => {
63                tracing::warn!("proxy: savings ledger disabled ({e})");
64                None
65            }
66        });
67    let client = reqwest::Client::builder().build()?;
68    let bind = cfg.bind;
69    let upstream = cfg.upstream.clone();
70    let dry = cfg.dry_run;
71    let state = Arc::new(AppState { cfg, client, db });
72    let app = Router::new().fallback(any(handle)).with_state(state);
73
74    let listener = tokio::net::TcpListener::bind(bind).await?;
75    eprintln!("mimir proxy listening on http://{bind} → {upstream}");
76    eprintln!("  point your client at it:  export ANTHROPIC_BASE_URL=http://{bind}");
77    if dry {
78        eprintln!("  dry-run: measuring only, request bodies forwarded unchanged");
79    }
80    axum::serve(listener, app).await?;
81    Ok(())
82}
83
84async fn handle(State(state): State<Arc<AppState>>, req: axum::extract::Request) -> Response {
85    match proxy(&state, req).await {
86        Ok(resp) => resp,
87        Err(e) => {
88            tracing::error!("proxy error: {e}");
89            Response::builder()
90                .status(StatusCode::BAD_GATEWAY)
91                .body(Body::from(format!("mimir proxy error: {e}")))
92                .unwrap()
93        }
94    }
95}
96
97async fn proxy(state: &AppState, req: axum::extract::Request) -> Result<Response> {
98    let (parts, body) = req.into_parts();
99    let path_q = parts
100        .uri
101        .path_and_query()
102        .map(|p| p.as_str())
103        .unwrap_or("/")
104        .to_string();
105    let url = format!("{}{}", state.cfg.upstream.trim_end_matches('/'), path_q);
106    let body_bytes = axum::body::to_bytes(body, MAX_BODY).await?;
107
108    let is_messages = parts.method == Method::POST && parts.uri.path().ends_with("/v1/messages");
109    // we_cached: true only on requests where WE introduced caching (so the
110    // measured cache reads downstream are ours to claim).
111    let (send_bytes, we_cached) = if is_messages {
112        optimize_messages_body(state, &body_bytes)
113    } else {
114        (body_bytes.to_vec(), false)
115    };
116
117    // Forward upstream, copying headers (minus hop-by-hop, host, content-length).
118    let mut rb = state.client.request(parts.method.clone(), &url);
119    for (name, value) in parts.headers.iter() {
120        if is_hop_by_hop(name.as_str()) || name == header::HOST || name == header::CONTENT_LENGTH {
121            continue;
122        }
123        rb = rb.header(name.clone(), value.clone());
124    }
125    let upstream = rb.body(send_bytes).send().await?;
126
127    let status = upstream.status();
128    let mut builder = Response::builder().status(status);
129    for (name, value) in upstream.headers().iter() {
130        if is_hop_by_hop(name.as_str())
131            || name == header::CONTENT_LENGTH
132            || name == header::TRANSFER_ENCODING
133        {
134            continue;
135        }
136        builder = builder.header(name.clone(), value.clone());
137    }
138
139    // Tap the stream to record real cache savings, but only when we added the
140    // breakpoints (otherwise the cache hits aren't ours to claim).
141    let body = if we_cached {
142        let mut scanner = UsageScanner::new(state.db.clone());
143        let tapped = upstream.bytes_stream().inspect(move |chunk| {
144            if let Ok(bytes) = chunk {
145                scanner.feed(bytes);
146            }
147        });
148        Body::from_stream(tapped)
149    } else {
150        Body::from_stream(upstream.bytes_stream())
151    };
152    Ok(builder.body(body)?)
153}
154
155/// Optimize a `/v1/messages` body. Returns the bytes to forward and whether we
156/// introduced caching (and so should measure the response's cache reads).
157fn optimize_messages_body(state: &AppState, body_bytes: &[u8]) -> (Vec<u8>, bool) {
158    let Ok(json) = serde_json::from_slice::<Value>(body_bytes) else {
159        return (body_bytes.to_vec(), false); // not JSON we understand — pass through
160    };
161    let we_cached = state.cfg.cache && !optimize::has_cache_control(&json);
162    let (new_json, opt) = optimize_request(
163        json,
164        OptimizeOpts {
165            cache: state.cfg.cache,
166            dedup: state.cfg.dedup,
167            prune: state.cfg.prune,
168        },
169    );
170    if state.cfg.dry_run {
171        if opt.deduped > 0 || opt.pruned > 0 {
172            tracing::info!(
173                deduped = opt.deduped,
174                pruned = opt.pruned,
175                "proxy dry-run: would optimize"
176            );
177        }
178        return (body_bytes.to_vec(), false);
179    }
180    record_request_savings(&state.db, &opt);
181    let bytes = serde_json::to_vec(&new_json).unwrap_or_else(|_| body_bytes.to_vec());
182    (bytes, we_cached)
183}
184
185fn is_hop_by_hop(name: &str) -> bool {
186    matches!(
187        name.to_ascii_lowercase().as_str(),
188        "connection"
189            | "keep-alive"
190            | "proxy-authenticate"
191            | "proxy-authorization"
192            | "te"
193            | "trailers"
194            | "transfer-encoding"
195            | "upgrade"
196    )
197}
198
199fn record_request_savings(db: &Option<Db>, opt: &Optimization) {
200    let Some(db) = db else { return };
201    let Ok(conn) = db.lock() else { return };
202    if opt.deduped > 0 {
203        let _ = mimir_core::savings::record(
204            &conn,
205            None,
206            mimir_core::savings::source::PROXY_DEDUP,
207            opt.deduped,
208            0,
209            serde_json::json!({}),
210        );
211    }
212    if opt.pruned > 0 {
213        let _ = mimir_core::savings::record(
214            &conn,
215            None,
216            mimir_core::savings::source::PROXY_PRUNE,
217            opt.pruned,
218            0,
219            serde_json::json!({}),
220        );
221    }
222}
223
224/// Scans the start of a response for the `usage` block and records the REAL
225/// cache savings once. Forwards nothing — it's a passive tap.
226///
227/// Assumes the Anthropic SSE/JSON shape (a `"cache_read_input_tokens": N` field
228/// near the start). On any other shape it simply finds nothing and records 0.
229struct UsageScanner {
230    db: Option<Db>,
231    buf: String,
232    done: bool,
233}
234
235impl UsageScanner {
236    fn new(db: Option<Db>) -> Self {
237        UsageScanner {
238            db,
239            buf: String::new(),
240            done: false,
241        }
242    }
243
244    fn feed(&mut self, bytes: &[u8]) {
245        if self.done {
246            return;
247        }
248        self.buf.push_str(&String::from_utf8_lossy(bytes));
249        if let Some(cache_read) = field_u64(&self.buf, "\"cache_read_input_tokens\"") {
250            self.record(cache_read);
251            self.done = true;
252            self.buf = String::new();
253        } else if self.buf.len() > USAGE_SCAN_CAP {
254            self.done = true;
255            self.buf = String::new();
256        }
257    }
258
259    fn record(&self, cache_read: u64) {
260        if cache_read == 0 {
261            return;
262        }
263        let Some(db) = &self.db else { return };
264        let Ok(conn) = db.lock() else { return };
265        // Cached input bills at ~10% → ~90% saved on these tokens.
266        let before = cache_read as usize;
267        let after = before / 10;
268        let _ = mimir_core::savings::record(
269            &conn,
270            None,
271            mimir_core::savings::source::PROXY_CACHE,
272            before,
273            after,
274            serde_json::json!({ "measured": true }),
275        );
276    }
277}
278
279/// Parse `"key": <digits>` out of a JSON fragment.
280fn field_u64(s: &str, key: &str) -> Option<u64> {
281    let i = s.find(key)?;
282    let rest = s[i + key.len()..].trim_start_matches([':', ' ']);
283    let end = rest
284        .find(|c: char| !c.is_ascii_digit())
285        .unwrap_or(rest.len());
286    rest.get(..end).and_then(|d| d.parse().ok())
287}
288
289#[cfg(test)]
290mod tests {
291    use super::field_u64;
292
293    #[test]
294    fn parses_usage_field() {
295        let sse = r#"event: message_start
296data: {"type":"message_start","message":{"usage":{"input_tokens":12,"cache_creation_input_tokens":0,"cache_read_input_tokens":2048}}}"#;
297        assert_eq!(field_u64(sse, "\"cache_read_input_tokens\""), Some(2048));
298        assert_eq!(field_u64(sse, "\"input_tokens\""), Some(12));
299        assert_eq!(field_u64(sse, "\"missing\""), None);
300    }
301}