claude_api_test/
recorder.rs1use std::convert::Infallible;
27use std::path::PathBuf;
28use std::sync::Arc;
29
30use bytes::Bytes;
31use http::HeaderMap;
32use http_body_util::{BodyExt, Full};
33use hyper::body::Incoming;
34use hyper::service::service_fn;
35use hyper::{Request, Response};
36use hyper_util::rt::TokioIo;
37use tokio::io::AsyncWriteExt;
38use tokio::net::TcpListener;
39use tokio::sync::{Mutex, oneshot};
40use tokio::task::JoinHandle;
41
42use crate::RecordedExchange;
43
44pub const DEFAULT_REDACT_HEADERS: &[&str] = &["x-api-key", "authorization"];
48
49#[derive(Debug, Clone)]
51pub struct RecorderConfig {
52 pub upstream: String,
55 pub cassette_path: PathBuf,
58 pub redact_headers: Vec<String>,
63}
64
65impl Default for RecorderConfig {
66 fn default() -> Self {
67 Self {
68 upstream: "https://api.anthropic.com".into(),
69 cassette_path: PathBuf::from("./cassette.jsonl"),
70 redact_headers: DEFAULT_REDACT_HEADERS
71 .iter()
72 .map(|s| (*s).to_owned())
73 .collect(),
74 }
75 }
76}
77
78pub struct Recorder {
81 url: String,
82 shutdown: Option<oneshot::Sender<()>>,
83 handle: Option<JoinHandle<()>>,
84}
85
86impl Recorder {
87 pub async fn start(config: RecorderConfig) -> std::io::Result<Self> {
93 let upstream = config.upstream.trim_end_matches('/').to_owned();
94 let listener = TcpListener::bind(("127.0.0.1", 0)).await?;
95 let local_addr = listener.local_addr()?;
96 let url = format!("http://{local_addr}");
97
98 let file = tokio::fs::OpenOptions::new()
99 .create(true)
100 .write(true)
101 .truncate(true)
102 .open(&config.cassette_path)
103 .await?;
104 let writer = Arc::new(Mutex::new(file));
105
106 let forwarder = reqwest::Client::builder()
110 .build()
111 .map_err(std::io::Error::other)?;
112
113 let redact: Arc<Vec<String>> = Arc::new(
114 config
115 .redact_headers
116 .iter()
117 .map(|s| s.to_lowercase())
118 .collect(),
119 );
120
121 let (tx, rx) = oneshot::channel::<()>();
122
123 let handle = tokio::spawn(async move {
124 tokio::pin!(rx);
125 loop {
126 tokio::select! {
127 _ = &mut rx => break,
128 accept = listener.accept() => {
129 let Ok((stream, _peer)) = accept else { continue };
130 let upstream = upstream.clone();
131 let writer = Arc::clone(&writer);
132 let forwarder = forwarder.clone();
133 let redact = Arc::clone(&redact);
134 tokio::spawn(async move {
135 let io = TokioIo::new(stream);
136 let svc = service_fn(move |req| {
137 let upstream = upstream.clone();
138 let writer = Arc::clone(&writer);
139 let forwarder = forwarder.clone();
140 let redact = Arc::clone(&redact);
141 async move {
142 handle_request(req, &upstream, &forwarder, writer, redact)
143 .await
144 }
145 });
146 let _ = hyper::server::conn::http1::Builder::new()
147 .serve_connection(io, svc)
148 .await;
149 });
150 }
151 }
152 }
153 });
154
155 Ok(Self {
156 url,
157 shutdown: Some(tx),
158 handle: Some(handle),
159 })
160 }
161
162 #[must_use]
165 pub fn url(&self) -> &str {
166 &self.url
167 }
168
169 pub async fn shutdown(mut self) -> std::io::Result<()> {
172 if let Some(tx) = self.shutdown.take() {
173 let _ = tx.send(());
174 }
175 if let Some(handle) = self.handle.take() {
176 let _ = handle.await;
177 }
178 Ok(())
179 }
180}
181
182impl Drop for Recorder {
183 fn drop(&mut self) {
184 if let Some(tx) = self.shutdown.take() {
185 let _ = tx.send(());
186 }
187 }
188}
189
190async fn handle_request(
191 req: Request<Incoming>,
192 upstream: &str,
193 forwarder: &reqwest::Client,
194 writer: Arc<Mutex<tokio::fs::File>>,
195 redact: Arc<Vec<String>>,
196) -> Result<Response<Full<Bytes>>, Infallible> {
197 let method = req.method().clone();
198 let path_and_query = req
199 .uri()
200 .path_and_query()
201 .map_or_else(|| req.uri().path().to_owned(), ToString::to_string);
202 let path_only = req.uri().path().to_owned();
203 let headers = req.headers().clone();
204
205 let body_bytes = match req.into_body().collect().await {
206 Ok(b) => b.to_bytes(),
207 Err(_) => {
208 return Ok(error_response(
209 http::StatusCode::BAD_GATEWAY,
210 "recorder: failed to read request body",
211 ));
212 }
213 };
214
215 let url = format!("{upstream}{path_and_query}");
217 let mut fwd = forwarder.request(method.clone(), &url);
218 for (name, value) in &headers {
219 if matches!(name.as_str(), "host" | "content-length") {
221 continue;
222 }
223 fwd = fwd.header(name, value);
224 }
225 if !body_bytes.is_empty() {
226 fwd = fwd.body(body_bytes.to_vec());
227 }
228 let upstream_resp = match fwd.send().await {
229 Ok(r) => r,
230 Err(e) => {
231 return Ok(error_response(
232 http::StatusCode::BAD_GATEWAY,
233 &format!("recorder: upstream request failed: {e}"),
234 ));
235 }
236 };
237 let status = upstream_resp.status();
238 let upstream_headers = upstream_resp.headers().clone();
239 let resp_bytes = upstream_resp.bytes().await.unwrap_or_default();
240
241 let exchange = build_exchange(
243 method.as_str(),
244 &path_only,
245 status.as_u16(),
246 &body_bytes,
247 &upstream_headers,
248 &resp_bytes,
249 &redact,
250 );
251 let _ = &headers;
256 if let Ok(line) = serde_json::to_string(&exchange) {
257 let mut guard = writer.lock().await;
258 let _ = guard.write_all(line.as_bytes()).await;
259 let _ = guard.write_all(b"\n").await;
260 let _ = guard.flush().await;
261 }
262
263 let mut builder = Response::builder().status(status);
265 for (name, value) in &upstream_headers {
266 builder = builder.header(name, value);
267 }
268 let response = builder
269 .body(Full::new(resp_bytes))
270 .unwrap_or_else(|_| error_response(http::StatusCode::BAD_GATEWAY, "recorder: build error"));
271 Ok(response)
272}
273
274fn build_exchange(
275 method: &str,
276 path: &str,
277 status: u16,
278 request_body: &[u8],
279 response_headers: &HeaderMap,
280 response_body: &[u8],
281 redact: &[String],
282) -> RecordedExchange {
283 let request_value = if request_body.is_empty() {
288 None
289 } else {
290 Some(
291 serde_json::from_slice::<serde_json::Value>(request_body).unwrap_or_else(|_| {
292 serde_json::Value::String(format!("<{} bytes>", request_body.len()))
293 }),
294 )
295 };
296
297 let is_sse = response_headers
302 .get(http::header::CONTENT_TYPE)
303 .and_then(|v| v.to_str().ok())
304 .is_some_and(|ct| ct.contains("text/event-stream"));
305
306 let response_value = if is_sse {
307 let text = String::from_utf8_lossy(response_body).into_owned();
309 serde_json::Value::String(text)
310 } else {
311 serde_json::from_slice::<serde_json::Value>(response_body).unwrap_or_else(|_| {
312 serde_json::Value::String(format!("<{} bytes>", response_body.len()))
313 })
314 };
315
316 let mut headers: Vec<(String, String)> = Vec::new();
317 for (name, value) in response_headers {
318 let name_lc = name.as_str().to_lowercase();
319 if redact.iter().any(|r| r == &name_lc) {
320 continue;
321 }
322 if let Ok(v) = value.to_str() {
323 headers.push((name_lc, v.to_owned()));
324 }
325 }
326
327 RecordedExchange {
328 method: method.to_owned(),
329 path: path.to_owned(),
330 status,
331 request: request_value,
332 response: response_value,
333 headers,
334 }
335}
336
337fn error_response(status: http::StatusCode, message: &str) -> Response<Full<Bytes>> {
338 Response::builder()
339 .status(status)
340 .header("content-type", "application/json")
341 .body(Full::new(Bytes::from(format!(
342 r#"{{"type":"error","error":{{"type":"recorder_error","message":{message:?}}}}}"#
343 ))))
344 .expect("static response is well-formed")
345}