Skip to main content

claude_api_test/
recorder.rs

1//! Live recording proxy for `claude-api-test`.
2//!
3//! [`Recorder`] runs an in-process HTTP forwarder on `127.0.0.1` that
4//! captures every request a `claude_api::Client` sends through it,
5//! tees the exchange to a JSONL cassette file, and returns the upstream
6//! response unchanged. Pair with [`mount_cassette`](crate::mount_cassette)
7//! and [`Cassette::from_path`](crate::Cassette::from_path) for replay.
8//!
9//! ```ignore
10//! let recorder = Recorder::start(RecorderConfig {
11//!     upstream: "https://api.anthropic.com".into(),
12//!     cassette_path: "./cassette.jsonl".into(),
13//!     ..Default::default()
14//! }).await?;
15//!
16//! let client = claude_api::Client::builder()
17//!     .api_key(env!("ANTHROPIC_API_KEY"))
18//!     .base_url(recorder.url())
19//!     .build()?;
20//!
21//! // ... drive the client; every request lands in cassette.jsonl ...
22//!
23//! recorder.shutdown().await?;
24//! ```
25
26use std::convert::Infallible;
27use std::path::PathBuf;
28use std::sync::Arc;
29
30use bytes::Bytes;
31use http::HeaderMap;
32use http_body_util::{BodyExt, Full};
33use hyper::body::Incoming;
34use hyper::service::service_fn;
35use hyper::{Request, Response};
36use hyper_util::rt::TokioIo;
37use tokio::io::AsyncWriteExt;
38use tokio::net::TcpListener;
39use tokio::sync::{Mutex, oneshot};
40use tokio::task::JoinHandle;
41
42use crate::RecordedExchange;
43
44/// Headers redacted from saved cassette entries by default.
45///
46/// Keeps API keys and bearer tokens out of files that get committed.
47pub const DEFAULT_REDACT_HEADERS: &[&str] = &["x-api-key", "authorization"];
48
49/// Configuration for [`Recorder::start`].
50#[derive(Debug, Clone)]
51pub struct RecorderConfig {
52    /// Upstream base URL the recorder forwards to. Trailing slashes are
53    /// trimmed. Example: `"https://api.anthropic.com"`.
54    pub upstream: String,
55    /// Filesystem path for the JSONL cassette. The file is created if
56    /// missing, truncated on start, and exchanges are written as they complete.
57    pub cassette_path: PathBuf,
58    /// Header names (lowercase) whose values are dropped before being
59    /// recorded to disk. Defaults to [`DEFAULT_REDACT_HEADERS`]. Body
60    /// contents are *not* redacted -- callers should ensure prompts
61    /// don't contain secrets they don't want in the cassette.
62    pub redact_headers: Vec<String>,
63}
64
65impl Default for RecorderConfig {
66    fn default() -> Self {
67        Self {
68            upstream: "https://api.anthropic.com".into(),
69            cassette_path: PathBuf::from("./cassette.jsonl"),
70            redact_headers: DEFAULT_REDACT_HEADERS
71                .iter()
72                .map(|s| (*s).to_owned())
73                .collect(),
74        }
75    }
76}
77
78/// Live in-process recording proxy. Drop the value to begin shutdown,
79/// or call [`Self::shutdown`] for an awaitable clean exit.
80pub struct Recorder {
81    url: String,
82    shutdown: Option<oneshot::Sender<()>>,
83    handle: Option<JoinHandle<()>>,
84}
85
86impl Recorder {
87    /// Bind to `127.0.0.1:0`, spawn a forwarder task, and return a
88    /// handle whose [`Self::url`] points at the proxy. The cassette
89    /// file at `config.cassette_path` is **truncated** on start --
90    /// each recording run produces a fresh cassette. To accumulate
91    /// across runs, copy the file off between sessions.
92    pub async fn start(config: RecorderConfig) -> std::io::Result<Self> {
93        let upstream = config.upstream.trim_end_matches('/').to_owned();
94        let listener = TcpListener::bind(("127.0.0.1", 0)).await?;
95        let local_addr = listener.local_addr()?;
96        let url = format!("http://{local_addr}");
97
98        let file = tokio::fs::OpenOptions::new()
99            .create(true)
100            .write(true)
101            .truncate(true)
102            .open(&config.cassette_path)
103            .await?;
104        let writer = Arc::new(Mutex::new(file));
105
106        // reqwest client used to forward requests upstream. Does NOT
107        // share a connection pool with whatever the user's Client has;
108        // that's fine for tests.
109        let forwarder = reqwest::Client::builder()
110            .build()
111            .map_err(std::io::Error::other)?;
112
113        let redact: Arc<Vec<String>> = Arc::new(
114            config
115                .redact_headers
116                .iter()
117                .map(|s| s.to_lowercase())
118                .collect(),
119        );
120
121        let (tx, rx) = oneshot::channel::<()>();
122
123        let handle = tokio::spawn(async move {
124            tokio::pin!(rx);
125            loop {
126                tokio::select! {
127                    _ = &mut rx => break,
128                    accept = listener.accept() => {
129                        let Ok((stream, _peer)) = accept else { continue };
130                        let upstream = upstream.clone();
131                        let writer = Arc::clone(&writer);
132                        let forwarder = forwarder.clone();
133                        let redact = Arc::clone(&redact);
134                        tokio::spawn(async move {
135                            let io = TokioIo::new(stream);
136                            let svc = service_fn(move |req| {
137                                let upstream = upstream.clone();
138                                let writer = Arc::clone(&writer);
139                                let forwarder = forwarder.clone();
140                                let redact = Arc::clone(&redact);
141                                async move {
142                                    handle_request(req, &upstream, &forwarder, writer, redact)
143                                        .await
144                                }
145                            });
146                            let _ = hyper::server::conn::http1::Builder::new()
147                                .serve_connection(io, svc)
148                                .await;
149                        });
150                    }
151                }
152            }
153        });
154
155        Ok(Self {
156            url,
157            shutdown: Some(tx),
158            handle: Some(handle),
159        })
160    }
161
162    /// Proxy URL the user should pass to
163    /// `Client::builder().base_url(...)`.
164    #[must_use]
165    pub fn url(&self) -> &str {
166        &self.url
167    }
168
169    /// Signal the forwarder to stop accepting new connections, then
170    /// await its task. Returns once the recorder has fully exited.
171    pub async fn shutdown(mut self) -> std::io::Result<()> {
172        if let Some(tx) = self.shutdown.take() {
173            let _ = tx.send(());
174        }
175        if let Some(handle) = self.handle.take() {
176            let _ = handle.await;
177        }
178        Ok(())
179    }
180}
181
182impl Drop for Recorder {
183    fn drop(&mut self) {
184        if let Some(tx) = self.shutdown.take() {
185            let _ = tx.send(());
186        }
187    }
188}
189
190async fn handle_request(
191    req: Request<Incoming>,
192    upstream: &str,
193    forwarder: &reqwest::Client,
194    writer: Arc<Mutex<tokio::fs::File>>,
195    redact: Arc<Vec<String>>,
196) -> Result<Response<Full<Bytes>>, Infallible> {
197    let method = req.method().clone();
198    let path_and_query = req
199        .uri()
200        .path_and_query()
201        .map_or_else(|| req.uri().path().to_owned(), ToString::to_string);
202    let path_only = req.uri().path().to_owned();
203    let headers = req.headers().clone();
204
205    let body_bytes = match req.into_body().collect().await {
206        Ok(b) => b.to_bytes(),
207        Err(_) => {
208            return Ok(error_response(
209                http::StatusCode::BAD_GATEWAY,
210                "recorder: failed to read request body",
211            ));
212        }
213    };
214
215    // Forward upstream.
216    let url = format!("{upstream}{path_and_query}");
217    let mut fwd = forwarder.request(method.clone(), &url);
218    for (name, value) in &headers {
219        // Hop-by-hop and host headers are unsafe to forward verbatim.
220        if matches!(name.as_str(), "host" | "content-length") {
221            continue;
222        }
223        fwd = fwd.header(name, value);
224    }
225    if !body_bytes.is_empty() {
226        fwd = fwd.body(body_bytes.to_vec());
227    }
228    let upstream_resp = match fwd.send().await {
229        Ok(r) => r,
230        Err(e) => {
231            return Ok(error_response(
232                http::StatusCode::BAD_GATEWAY,
233                &format!("recorder: upstream request failed: {e}"),
234            ));
235        }
236    };
237    let status = upstream_resp.status();
238    let upstream_headers = upstream_resp.headers().clone();
239    let resp_bytes = upstream_resp.bytes().await.unwrap_or_default();
240
241    // Capture the exchange.
242    let exchange = build_exchange(
243        method.as_str(),
244        &path_only,
245        status.as_u16(),
246        &body_bytes,
247        &upstream_headers,
248        &resp_bytes,
249        &redact,
250    );
251    // Suppress unused_variables warning for `headers` -- we kept it
252    // bound for symmetry with the response side, and to leave a hook
253    // for redaction-policy expansion (e.g. recording the Authorization
254    // *presence* without its value).
255    let _ = &headers;
256    if let Ok(line) = serde_json::to_string(&exchange) {
257        let mut guard = writer.lock().await;
258        let _ = guard.write_all(line.as_bytes()).await;
259        let _ = guard.write_all(b"\n").await;
260        let _ = guard.flush().await;
261    }
262
263    // Build the response we send back to the client.
264    let mut builder = Response::builder().status(status);
265    for (name, value) in &upstream_headers {
266        builder = builder.header(name, value);
267    }
268    let response = builder
269        .body(Full::new(resp_bytes))
270        .unwrap_or_else(|_| error_response(http::StatusCode::BAD_GATEWAY, "recorder: build error"));
271    Ok(response)
272}
273
274fn build_exchange(
275    method: &str,
276    path: &str,
277    status: u16,
278    request_body: &[u8],
279    response_headers: &HeaderMap,
280    response_body: &[u8],
281    redact: &[String],
282) -> RecordedExchange {
283    // Decode bodies as JSON when possible; bare-bytes payloads (e.g.
284    // multipart uploads) fall back to a base64-ish stand-in -- but in
285    // practice the API surface is JSON, and this recorder is scoped to
286    // claude-api whose endpoints are all JSON or SSE.
287    let request_value = if request_body.is_empty() {
288        None
289    } else {
290        Some(
291            serde_json::from_slice::<serde_json::Value>(request_body).unwrap_or_else(|_| {
292                serde_json::Value::String(format!("<{} bytes>", request_body.len()))
293            }),
294        )
295    };
296
297    // SSE responses arrive as `text/event-stream` and are not valid JSON.
298    // Store the raw wire text as a JSON string so the cassette can replay
299    // it verbatim; `mount_cassette` detects the content-type and serves
300    // the body as text rather than JSON.
301    let is_sse = response_headers
302        .get(http::header::CONTENT_TYPE)
303        .and_then(|v| v.to_str().ok())
304        .is_some_and(|ct| ct.contains("text/event-stream"));
305
306    let response_value = if is_sse {
307        // Preserve SSE wire format as a plain string value.
308        let text = String::from_utf8_lossy(response_body).into_owned();
309        serde_json::Value::String(text)
310    } else {
311        serde_json::from_slice::<serde_json::Value>(response_body).unwrap_or_else(|_| {
312            serde_json::Value::String(format!("<{} bytes>", response_body.len()))
313        })
314    };
315
316    let mut headers: Vec<(String, String)> = Vec::new();
317    for (name, value) in response_headers {
318        let name_lc = name.as_str().to_lowercase();
319        if redact.iter().any(|r| r == &name_lc) {
320            continue;
321        }
322        if let Ok(v) = value.to_str() {
323            headers.push((name_lc, v.to_owned()));
324        }
325    }
326
327    RecordedExchange {
328        method: method.to_owned(),
329        path: path.to_owned(),
330        status,
331        request: request_value,
332        response: response_value,
333        headers,
334    }
335}
336
337fn error_response(status: http::StatusCode, message: &str) -> Response<Full<Bytes>> {
338    Response::builder()
339        .status(status)
340        .header("content-type", "application/json")
341        .body(Full::new(Bytes::from(format!(
342            r#"{{"type":"error","error":{{"type":"recorder_error","message":{message:?}}}}}"#
343        ))))
344        .expect("static response is well-formed")
345}