1#![cfg_attr(docsrs, feature(doc_cfg))]
56
57use std::path::Path;
58
59use serde::{Deserialize, Serialize};
60
61pub mod recorder;
62pub use recorder::{DEFAULT_REDACT_HEADERS, Recorder, RecorderConfig};
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66#[non_exhaustive]
67pub struct RecordedExchange {
68 pub method: String,
70 pub path: String,
72 pub status: u16,
74 #[serde(default)]
78 pub request: Option<serde_json::Value>,
79 pub response: serde_json::Value,
82 #[serde(default)]
85 pub headers: Vec<(String, String)>,
86}
87
88impl RecordedExchange {
89 #[must_use]
92 pub fn new(
93 method: impl Into<String>,
94 path: impl Into<String>,
95 status: u16,
96 response: serde_json::Value,
97 ) -> Self {
98 Self {
99 method: method.into(),
100 path: path.into(),
101 status,
102 request: None,
103 response,
104 headers: Vec::new(),
105 }
106 }
107
108 #[must_use]
110 pub fn with_request(mut self, body: serde_json::Value) -> Self {
111 self.request = Some(body);
112 self
113 }
114
115 #[must_use]
117 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
118 self.headers.push((name.into(), value.into()));
119 self
120 }
121}
122
123#[derive(Debug, Clone, Default)]
126pub struct Cassette {
127 exchanges: Vec<RecordedExchange>,
128 skip_request_match: bool,
129}
130
131impl Cassette {
132 #[must_use]
134 pub fn new() -> Self {
135 Self::default()
136 }
137
138 #[must_use]
141 pub fn from_exchanges(exchanges: Vec<RecordedExchange>) -> Self {
142 Self {
143 exchanges,
144 skip_request_match: false,
145 }
146 }
147
148 pub async fn from_path(path: impl AsRef<Path>) -> std::io::Result<Self> {
152 let text = tokio::fs::read_to_string(path).await?;
153 Self::parse_jsonl(&text).map_err(std::io::Error::other)
154 }
155
156 pub fn from_path_sync(path: impl AsRef<Path>) -> std::io::Result<Self> {
159 let text = std::fs::read_to_string(path)?;
160 Self::parse_jsonl(&text).map_err(std::io::Error::other)
161 }
162
163 pub fn parse_jsonl(jsonl: &str) -> serde_json::Result<Self> {
166 let mut exchanges = Vec::new();
167 for (line_no, line) in jsonl.lines().enumerate() {
168 let trimmed = line.trim();
169 if trimmed.is_empty() || trimmed.starts_with('#') {
170 continue;
171 }
172 let exchange: RecordedExchange = serde_json::from_str(trimmed).map_err(|e| {
173 let msg = format!("cassette parse failed at line {}: {}", line_no + 1, e);
174 serde::de::Error::custom(msg)
175 })?;
176 exchanges.push(exchange);
177 }
178 Ok(Self {
179 exchanges,
180 skip_request_match: false,
181 })
182 }
183
184 pub fn push(&mut self, exchange: RecordedExchange) -> &mut Self {
186 self.exchanges.push(exchange);
187 self
188 }
189
190 #[must_use]
194 pub fn skip_request_match(mut self) -> Self {
195 self.skip_request_match = true;
196 self
197 }
198
199 #[must_use]
201 pub fn exchanges(&self) -> &[RecordedExchange] {
202 &self.exchanges
203 }
204
205 #[must_use]
207 pub fn len(&self) -> usize {
208 self.exchanges.len()
209 }
210
211 #[must_use]
213 pub fn is_empty(&self) -> bool {
214 self.exchanges.is_empty()
215 }
216
217 pub fn to_jsonl(&self) -> serde_json::Result<String> {
219 let mut out = String::new();
220 for ex in &self.exchanges {
221 out.push_str(&serde_json::to_string(ex)?);
222 out.push('\n');
223 }
224 Ok(out)
225 }
226}
227
228pub async fn mount_cassette(server: &wiremock::MockServer, cassette: &Cassette) {
241 use wiremock::matchers::{body_json, method, path};
242 use wiremock::{Mock, ResponseTemplate};
243
244 for ex in &cassette.exchanges {
245 let is_sse = ex.headers.iter().any(|(k, v)| {
251 k.eq_ignore_ascii_case("content-type") && v.contains("text/event-stream")
252 });
253
254 let mut response = if is_sse {
255 let body = ex.response.as_str().unwrap_or("").as_bytes().to_owned();
262 ResponseTemplate::new(ex.status).set_body_raw(body, "text/event-stream")
263 } else {
264 ResponseTemplate::new(ex.status).set_body_json(ex.response.clone())
265 };
266
267 for (k, v) in &ex.headers {
272 if is_sse && k.eq_ignore_ascii_case("content-type") {
273 continue;
275 }
276 response = response.insert_header(k.as_str(), v.as_str());
277 }
278
279 let mock_builder = Mock::given(method(ex.method.as_str())).and(path(ex.path.as_str()));
280 let mock = match (&ex.request, cassette.skip_request_match) {
281 (Some(body), false) => mock_builder.and(body_json(body)).respond_with(response),
282 _ => mock_builder.respond_with(response),
283 };
284 mock.mount(server).await;
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use serde_json::json;
292
293 fn tiny_sse_corpus() -> &'static str {
295 concat!(
296 "event: message_start\n",
297 "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_sse\",\"type\":\"message\",",
298 "\"role\":\"assistant\",\"content\":[],\"model\":\"claude-haiku-4-5-20251001\",",
299 "\"usage\":{\"input_tokens\":3,\"output_tokens\":0}}}\n",
300 "\n",
301 "event: message_stop\n",
302 "data: {\"type\":\"message_stop\"}\n",
303 "\n",
304 )
305 }
306
307 #[test]
308 fn parse_jsonl_round_trips() {
309 let jsonl = r#"
310# leading comment, ignored
311{"method":"POST","path":"/v1/messages","status":200,"request":{"model":"x"},"response":{"id":"msg_1"}}
312{"method":"GET","path":"/v1/models","status":200,"request":null,"response":{"data":[]}}
313"#;
314 let c = Cassette::parse_jsonl(jsonl).unwrap();
315 assert_eq!(c.len(), 2);
316 assert_eq!(c.exchanges()[0].method, "POST");
317 assert_eq!(c.exchanges()[1].path, "/v1/models");
318
319 let serialized = c.to_jsonl().unwrap();
320 let again = Cassette::parse_jsonl(&serialized).unwrap();
321 assert_eq!(again.len(), 2);
322 }
323
324 #[test]
325 fn empty_cassette_is_empty() {
326 let c = Cassette::new();
327 assert!(c.is_empty());
328 assert_eq!(c.len(), 0);
329 }
330
331 #[test]
332 fn cassette_parse_error_includes_line_number() {
333 let jsonl = "not-json\n";
334 let err = Cassette::parse_jsonl(jsonl).unwrap_err();
335 assert!(format!("{err}").contains("line 1"));
336 }
337
338 #[test]
339 fn skip_request_match_flag_is_set() {
340 let c = Cassette::new().skip_request_match();
341 assert!(c.skip_request_match);
342 }
343
344 #[test]
345 fn from_exchanges_constructs_directly() {
346 let ex = RecordedExchange {
347 method: "POST".into(),
348 path: "/v1/x".into(),
349 status: 200,
350 request: Some(json!({"k": 1})),
351 response: json!({"ok": true}),
352 headers: vec![("request-id".into(), "req_1".into())],
353 };
354 let c = Cassette::from_exchanges(vec![ex]);
355 assert_eq!(c.len(), 1);
356 }
357
358 #[test]
365 fn sse_exchange_round_trips_through_jsonl() {
366 let sse = tiny_sse_corpus();
367 let ex = RecordedExchange {
368 method: "POST".into(),
369 path: "/v1/messages".into(),
370 status: 200,
371 request: Some(json!({"stream": true})),
372 response: json!(sse),
373 headers: vec![
374 ("content-type".into(), "text/event-stream".into()),
375 ("request-id".into(), "req_sse_1".into()),
376 ],
377 };
378
379 let cassette = Cassette::from_exchanges(vec![ex]);
380 let jsonl = cassette.to_jsonl().unwrap();
381 let again = Cassette::parse_jsonl(&jsonl).unwrap();
382
383 assert_eq!(again.len(), 1);
384 let entry = &again.exchanges()[0];
385 assert_eq!(entry.status, 200);
386 assert_eq!(entry.response.as_str().unwrap(), sse);
388 assert!(
390 entry
391 .headers
392 .iter()
393 .any(|(k, v)| k == "content-type" && v.contains("text/event-stream"))
394 );
395 }
396
397 #[tokio::test]
400 async fn mount_cassette_replays_sse_response() {
401 use claude_api::Client;
402 use claude_api::messages::CreateMessageRequest;
403 use claude_api::types::ModelId;
404
405 let sse = concat!(
406 "event: message_start\n",
407 "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_sse_replay\",\"type\":\"message\",",
408 "\"role\":\"assistant\",\"content\":[],\"model\":\"claude-haiku-4-5-20251001\",",
409 "\"usage\":{\"input_tokens\":3,\"output_tokens\":0}}}\n",
410 "\n",
411 "event: content_block_start\n",
412 "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n",
413 "\n",
414 "event: content_block_delta\n",
415 "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hi\"}}\n",
416 "\n",
417 "event: content_block_stop\n",
418 "data: {\"type\":\"content_block_stop\",\"index\":0}\n",
419 "\n",
420 "event: message_delta\n",
421 "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":3,\"output_tokens\":1}}\n",
422 "\n",
423 "event: message_stop\n",
424 "data: {\"type\":\"message_stop\"}\n",
425 "\n",
426 );
427
428 let cassette = Cassette::from_exchanges(vec![RecordedExchange {
429 method: "POST".into(),
430 path: "/v1/messages".into(),
431 status: 200,
432 request: None,
433 response: json!(sse),
434 headers: vec![("content-type".into(), "text/event-stream".into())],
435 }]);
436
437 let server = wiremock::MockServer::start().await;
438 mount_cassette(&server, &cassette).await;
439
440 let client = Client::builder()
441 .api_key("sk-ant-test")
442 .base_url(server.uri())
443 .build()
444 .unwrap();
445
446 let req = CreateMessageRequest::builder()
447 .model(ModelId::HAIKU_4_5)
448 .max_tokens(8)
449 .user("hi")
450 .build()
451 .unwrap();
452
453 let stream = client.messages().create_stream(req).await.unwrap();
454 let msg = stream.aggregate().await.unwrap();
455
456 assert_eq!(msg.id, "msg_sse_replay");
457 assert_eq!(
458 msg.stop_reason,
459 Some(claude_api::types::StopReason::EndTurn)
460 );
461 assert_eq!(msg.usage.output_tokens, 1);
462 }
463}