Skip to main content

minion_engine/sandbox/
proxy.rs

1//! Secure API Key Proxy
2//!
3//! A lightweight HTTP reverse proxy that runs on the host and intercepts
4//! API requests from the Docker sandbox container. Instead of passing
5//! sensitive API keys as environment variables into the container, the
6//! proxy injects authentication headers on the fly.
7//!
8//! Architecture:
9//! ```text
10//! Container                           Host Proxy                  Upstream
11//! ─────────                           ──────────                  ────────
12//! ANTHROPIC_BASE_URL=                 127.0.0.1:<port>
13//!   http://host.docker.internal:PORT  /v1/messages ──────────►  api.anthropic.com
14//!                                     + x-api-key: sk-ant-...    (HTTPS)
15//! ```
16//!
17//! The container never sees the actual API key.
18
19use std::sync::Arc;
20
21use anyhow::{Context, Result};
22use reqwest::Client;
23use tokio::net::TcpListener;
24use tokio::sync::oneshot;
25use tokio::task::JoinHandle;
26
27/// Secrets to inject into proxied requests
28struct ProxySecrets {
29    anthropic_api_key: String,
30}
31
32/// A running API proxy instance
33pub struct ApiProxy {
34    port: u16,
35    shutdown_tx: Option<oneshot::Sender<()>>,
36    join_handle: Option<JoinHandle<()>>,
37}
38
39impl ApiProxy {
40    /// Start the proxy on a random available port.
41    ///
42    /// The proxy will forward requests to `api.anthropic.com` and inject
43    /// the `x-api-key` header with the provided API key.
44    pub async fn start(anthropic_api_key: String) -> Result<Self> {
45        let listener = TcpListener::bind("127.0.0.1:0")
46            .await
47            .context("Failed to bind proxy to localhost")?;
48
49        let port = listener.local_addr()?.port();
50
51        let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
52
53        let secrets = Arc::new(ProxySecrets { anthropic_api_key });
54        let client = Client::builder()
55            .timeout(std::time::Duration::from_secs(300))
56            .build()
57            .context("Failed to create HTTP client for proxy")?;
58
59        let handle = tokio::spawn(run_proxy(listener, secrets, client, shutdown_rx));
60
61        tracing::info!(port, "API key proxy started — secrets never enter the container");
62
63        Ok(Self {
64            port,
65            shutdown_tx: Some(shutdown_tx),
66            join_handle: Some(handle),
67        })
68    }
69
70    /// The port the proxy is listening on.
71    pub fn port(&self) -> u16 {
72        self.port
73    }
74
75    /// Gracefully stop the proxy.
76    pub async fn stop(mut self) {
77        if let Some(tx) = self.shutdown_tx.take() {
78            let _ = tx.send(());
79        }
80        if let Some(handle) = self.join_handle.take() {
81            let _ = handle.await;
82        }
83        tracing::info!("API key proxy stopped");
84    }
85}
86
87impl Drop for ApiProxy {
88    fn drop(&mut self) {
89        // Best-effort shutdown if stop() was not called explicitly
90        if let Some(tx) = self.shutdown_tx.take() {
91            let _ = tx.send(());
92        }
93    }
94}
95
96/// Main proxy loop: accept connections and forward requests.
97async fn run_proxy(
98    listener: TcpListener,
99    secrets: Arc<ProxySecrets>,
100    client: Client,
101    mut shutdown_rx: oneshot::Receiver<()>,
102) {
103    loop {
104        tokio::select! {
105            accept = listener.accept() => {
106                match accept {
107                    Ok((stream, _addr)) => {
108                        let secrets = Arc::clone(&secrets);
109                        let client = client.clone();
110                        tokio::spawn(handle_connection(stream, secrets, client));
111                    }
112                    Err(e) => {
113                        tracing::warn!(error = %e, "Proxy accept error");
114                    }
115                }
116            }
117            _ = &mut shutdown_rx => {
118                tracing::debug!("Proxy received shutdown signal");
119                break;
120            }
121        }
122    }
123}
124
125/// Handle a single HTTP connection (may have multiple requests via keep-alive,
126/// but we use a simple one-request-per-connection model for simplicity).
127async fn handle_connection(
128    mut stream: tokio::net::TcpStream,
129    secrets: Arc<ProxySecrets>,
130    client: Client,
131) {
132    use tokio::io::{AsyncReadExt, AsyncWriteExt};
133
134    // Read the full request (up to 10MB for large prompts)
135    let mut buf = Vec::with_capacity(8192);
136    let mut tmp = [0u8; 8192];
137
138    // Read until we have the full headers + body
139    loop {
140        match stream.read(&mut tmp).await {
141            Ok(0) => return, // connection closed
142            Ok(n) => {
143                buf.extend_from_slice(&tmp[..n]);
144                // Check if we have the complete request
145                if let Some(body_start) = find_header_end(&buf) {
146                    // Parse Content-Length to know how much body to expect
147                    let headers_str = String::from_utf8_lossy(&buf[..body_start]);
148                    let content_length = parse_content_length(&headers_str);
149                    let body_received = buf.len() - body_start;
150                    if body_received >= content_length {
151                        break; // Full request received
152                    }
153                }
154                // Safety: don't read more than 10MB
155                if buf.len() > 10 * 1024 * 1024 {
156                    let resp = b"HTTP/1.1 413 Payload Too Large\r\nContent-Length: 0\r\n\r\n";
157                    let _ = stream.write_all(resp).await;
158                    return;
159                }
160            }
161            Err(_) => return,
162        }
163    }
164
165    // Parse the request
166    let header_end = match find_header_end(&buf) {
167        Some(pos) => pos,
168        None => return,
169    };
170
171    let headers_str = String::from_utf8_lossy(&buf[..header_end]).to_string();
172    let body = &buf[header_end..];
173
174    // Parse first line: METHOD PATH HTTP/1.x
175    let first_line = headers_str.lines().next().unwrap_or("");
176    let parts: Vec<&str> = first_line.split_whitespace().collect();
177    if parts.len() < 3 {
178        let resp = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n";
179        let _ = stream.write_all(resp).await;
180        return;
181    }
182
183    let method = parts[0];
184    let path = parts[1];
185
186    // Build upstream URL — all paths go to api.anthropic.com
187    let upstream_url = format!("https://api.anthropic.com{path}");
188
189    // Build the upstream request
190    let req_method = match method {
191        "GET" => reqwest::Method::GET,
192        "POST" => reqwest::Method::POST,
193        "PUT" => reqwest::Method::PUT,
194        "DELETE" => reqwest::Method::DELETE,
195        "PATCH" => reqwest::Method::PATCH,
196        "OPTIONS" => reqwest::Method::OPTIONS,
197        "HEAD" => reqwest::Method::HEAD,
198        _ => {
199            let resp = b"HTTP/1.1 405 Method Not Allowed\r\nContent-Length: 0\r\n\r\n";
200            let _ = stream.write_all(resp).await;
201            return;
202        }
203    };
204
205    let mut upstream_req = client.request(req_method, &upstream_url);
206
207    // Forward relevant headers (skip Host, Connection, and any existing auth)
208    for line in headers_str.lines().skip(1) {
209        if line.is_empty() {
210            break;
211        }
212        if let Some((key, value)) = line.split_once(':') {
213            let key_lower = key.trim().to_lowercase();
214            // Skip hop-by-hop headers and auth (we inject our own)
215            if matches!(
216                key_lower.as_str(),
217                "host" | "connection" | "x-api-key" | "authorization"
218            ) {
219                continue;
220            }
221            upstream_req = upstream_req.header(key.trim(), value.trim());
222        }
223    }
224
225    // Inject the API key
226    upstream_req = upstream_req.header("x-api-key", &secrets.anthropic_api_key);
227
228    // Add body if present
229    if !body.is_empty() {
230        upstream_req = upstream_req.body(body.to_vec());
231    }
232
233    // Send upstream request
234    let upstream_resp = match upstream_req.send().await {
235        Ok(resp) => resp,
236        Err(e) => {
237            tracing::warn!(error = %e, url = %upstream_url, "Proxy upstream request failed");
238            let error_body = format!("Proxy error: {e}");
239            let resp = format!(
240                "HTTP/1.1 502 Bad Gateway\r\nContent-Length: {}\r\n\r\n{error_body}",
241                error_body.len()
242            );
243            let _ = stream.write_all(resp.as_bytes()).await;
244            return;
245        }
246    };
247
248    // Build response back to client
249    let status = upstream_resp.status();
250    let resp_headers = upstream_resp.headers().clone();
251    let resp_body = match upstream_resp.bytes().await {
252        Ok(b) => b,
253        Err(e) => {
254            tracing::warn!(error = %e, "Failed to read upstream response body");
255            let _ = stream
256                .write_all(b"HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\n\r\n")
257                .await;
258            return;
259        }
260    };
261
262    // Write HTTP response
263    let mut response = format!("HTTP/1.1 {}\r\n", status);
264    for (key, value) in resp_headers.iter() {
265        let key_lower = key.as_str().to_lowercase();
266        // Skip hop-by-hop headers
267        if matches!(key_lower.as_str(), "connection" | "transfer-encoding") {
268            continue;
269        }
270        if let Ok(v) = value.to_str() {
271            response.push_str(&format!("{}: {}\r\n", key, v));
272        }
273    }
274    // Ensure Content-Length is set
275    response.push_str(&format!("Content-Length: {}\r\n", resp_body.len()));
276    response.push_str("Connection: close\r\n");
277    response.push_str("\r\n");
278
279    let _ = stream.write_all(response.as_bytes()).await;
280    let _ = stream.write_all(&resp_body).await;
281}
282
283/// Find the end of HTTP headers (double CRLF)
284fn find_header_end(buf: &[u8]) -> Option<usize> {
285    for i in 0..buf.len().saturating_sub(3) {
286        if &buf[i..i + 4] == b"\r\n\r\n" {
287            return Some(i + 4);
288        }
289    }
290    None
291}
292
293/// Parse Content-Length from raw headers string
294fn parse_content_length(headers: &str) -> usize {
295    for line in headers.lines() {
296        if let Some((key, value)) = line.split_once(':') {
297            if key.trim().eq_ignore_ascii_case("content-length") {
298                return value.trim().parse().unwrap_or(0);
299            }
300        }
301    }
302    0
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn find_header_end_works() {
311        let data = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\nbody";
312        let pos = find_header_end(data);
313        assert_eq!(pos, Some(37));
314    }
315
316    #[test]
317    fn find_header_end_returns_none_when_incomplete() {
318        let data = b"GET / HTTP/1.1\r\nHost: example.com\r\n";
319        let pos = find_header_end(data);
320        assert_eq!(pos, None);
321    }
322
323    #[test]
324    fn parse_content_length_works() {
325        let headers = "POST /v1/messages HTTP/1.1\r\nContent-Type: application/json\r\nContent-Length: 1234\r\n";
326        assert_eq!(parse_content_length(headers), 1234);
327    }
328
329    #[test]
330    fn parse_content_length_missing_returns_zero() {
331        let headers = "GET /v1/models HTTP/1.1\r\nHost: api.anthropic.com\r\n";
332        assert_eq!(parse_content_length(headers), 0);
333    }
334
335    #[tokio::test]
336    async fn proxy_starts_and_stops() {
337        let proxy = ApiProxy::start("test-key-123".to_string()).await.unwrap();
338        assert!(proxy.port() > 0);
339        proxy.stop().await;
340    }
341
342    #[tokio::test]
343    async fn proxy_responds_to_requests() {
344        let proxy = ApiProxy::start("test-key-abc".to_string()).await.unwrap();
345        let port = proxy.port();
346
347        // Send a request to the proxy (it will fail upstream but we can verify it's listening)
348        let client = reqwest::Client::new();
349        let resp = client
350            .get(format!("http://127.0.0.1:{port}/v1/models"))
351            .send()
352            .await;
353
354        // The proxy should respond (even if upstream fails, it should return 502 or forward the error)
355        assert!(resp.is_ok());
356
357        proxy.stop().await;
358    }
359}