1#[cfg(all(feature = "reqwest-0_12", feature = "reqwest-0_13"))]
39compile_error!(
40 "Features `reqwest-0_12` and `reqwest-0_13` are mutually exclusive. Please enable only one."
41);
42
43#[cfg(not(any(feature = "reqwest-0_12", feature = "reqwest-0_13")))]
44compile_error!("Either `reqwest-0_12` or `reqwest-0_13` feature must be enabled.");
45
46#[cfg(feature = "reqwest-0_12")]
47extern crate reqwest_0_12 as reqwest;
48#[cfg(feature = "reqwest-0_12")]
49extern crate reqwest_middleware_0_4 as reqwest_middleware;
50
51#[cfg(feature = "reqwest-0_13")]
52extern crate reqwest_0_13 as reqwest;
53#[cfg(feature = "reqwest-0_13")]
54extern crate reqwest_middleware_0_5 as reqwest_middleware;
55
56#[cfg(feature = "compress")]
57use std::io::Read;
58#[cfg(feature = "compress")]
59use std::io::Write;
60use std::{fs, path::PathBuf, str::FromStr, sync::Mutex};
61
62use base64::{engine::general_purpose, Engine};
63use reqwest_middleware::Middleware;
64use vcr_cassette::{HttpInteraction, RecorderId};
65
66pub const VERSION: &str = env!("CARGO_PKG_VERSION");
67
68lazy_static::lazy_static! {
69 static ref RECORDER: RecorderId = format!("rVCR {VERSION}");
70 static ref BASE64: String = String::from("base64");
71}
72
73pub struct VCRMiddleware {
75 path: Option<PathBuf>,
76 storage: Mutex<vcr_cassette::Cassette>,
77 mode: VCRMode,
78 search: VCRReplaySearch,
79 skip: Mutex<usize>,
80 compress: bool,
81 rich_diff: bool,
82 modify_request: Option<Box<RequestModifier>>,
83 modify_response: Option<Box<ResponseModifier>>,
84}
85
86type RequestModifier = dyn Fn(&mut vcr_cassette::Request) + Send + Sync + 'static;
87type ResponseModifier = dyn Fn(&mut vcr_cassette::Response) + Send + Sync + 'static;
88
89#[derive(Eq, PartialEq, Clone)]
91pub enum VCRMode {
92 Record,
94 Replay,
96}
97
98#[derive(Eq, PartialEq)]
100pub enum VCRReplaySearch {
101 SkipFound,
104 SearchAll,
106}
107
108pub type VCRError = &'static str;
109
110impl VCRMiddleware {
116 pub fn with_mode(mut self, mode: VCRMode) -> Self {
118 self.mode = mode;
119 self
120 }
121
122 pub fn with_modify_request<F>(mut self, modifier: F) -> Self
123 where
124 F: Fn(&mut vcr_cassette::Request) + Send + Sync + 'static,
125 {
126 self.modify_request.replace(Box::new(modifier));
127 self
128 }
129
130 pub fn with_modify_response<F>(mut self, modifier: F) -> Self
131 where
132 F: Fn(&mut vcr_cassette::Response) + Send + Sync + 'static,
133 {
134 self.modify_response.replace(Box::new(modifier));
135 self
136 }
137
138 pub fn with_search(mut self, search: VCRReplaySearch) -> Self {
140 self.search = search;
141 self
142 }
143
144 pub fn with_path(mut self, path: impl Into<PathBuf>) -> Self {
146 self.path = Some(path.into());
147 self
148 }
149
150 pub fn with_rich_diff(mut self, rich_diff: bool) -> Self {
152 self.rich_diff = rich_diff;
153 self
154 }
155
156 #[cfg(feature = "compress")]
158 pub fn compressed(mut self, compress: bool) -> Self {
159 self.compress = compress;
160 self
161 }
162
163 fn convert_version_to_vcr(&self, version: http::Version) -> vcr_cassette::Version {
164 if version == http::Version::HTTP_10 {
165 vcr_cassette::Version::Http1_0
166 } else if version == http::Version::HTTP_11 {
167 vcr_cassette::Version::Http1_1
168 } else if version == http::Version::HTTP_2 {
169 vcr_cassette::Version::Http2_0
170 } else {
171 panic!("rVCR only supports http 1.0, 1.1 and 2.0")
172 }
173 }
174
175 fn convert_version_from_vcr(&self, version: vcr_cassette::Version) -> http::Version {
176 match version {
177 vcr_cassette::Version::Http1_0 => http::Version::HTTP_10,
178 vcr_cassette::Version::Http1_1 => http::Version::HTTP_11,
179 vcr_cassette::Version::Http2_0 => http::Version::HTTP_2,
180 _ => {
181 panic!("rVCR only supports http 1.0, 1.1 and 2.0")
182 }
183 }
184 }
185
186 fn bytes_to_vcr_body(&self, body_bytes: &[u8]) -> vcr_cassette::Body {
187 match String::from_utf8(body_bytes.to_vec()) {
192 Ok(body_str) => vcr_cassette::Body::from_str(&body_str).unwrap(),
193 Err(e) => {
194 tracing::debug!("Can not deserialize utf-8 string: {e:?}");
195 let base64_str = general_purpose::STANDARD_NO_PAD.encode(body_bytes);
196 vcr_cassette::Body {
197 string: base64_str,
198 encoding: Some(BASE64.to_string()),
199 }
200 }
201 }
202 }
203
204 fn headers_to_vcr(&self, headers: &reqwest::header::HeaderMap) -> vcr_cassette::Headers {
205 let mut vcr_headers = vcr_cassette::Headers::new();
206 for (header_name, header_value) in headers {
207 let header_name_string = header_name.to_string();
208 let header_value_bytes = header_value.as_bytes();
209 let header_value = String::from_utf8(header_value_bytes.to_vec())
210 .unwrap_or_else(|_| panic!("Non utf header value for header named {header_name}; header values are supposed to be ASCII encoded"));
211 vcr_headers.insert(header_name_string, vec![header_value]);
212 }
213 vcr_headers
214 }
215
216 fn request_to_vcr(&self, req: reqwest::Request) -> vcr_cassette::Request {
217 let body = match req.body() {
218 Some(body) => match body.as_bytes() {
219 Some(body_bytes) => self.bytes_to_vcr_body(body_bytes),
220 None => vcr_cassette::Body::from_str("").unwrap(),
221 },
222 None => vcr_cassette::Body::from_str("").unwrap(),
223 };
224
225 let method_str = req.method().to_string().to_lowercase();
226
227 let method: vcr_cassette::Method = serde_json::from_str(&format!("\"{method_str}\""))
228 .unwrap_or_else(|_| panic!("Unknown HTTP method passed from reqwest: {method_str}"));
229
230 let headers = self.headers_to_vcr(req.headers());
231
232 let mut vcr_request = vcr_cassette::Request {
233 body,
234 method,
235 uri: req.url().to_owned(),
236 headers,
237 };
238
239 if let Some(ref modifier) = self.modify_request {
240 modifier(&mut vcr_request);
241 }
242
243 vcr_request
244 }
245
246 async fn response_to_vcr(&self, resp: reqwest::Response) -> vcr_cassette::Response {
247 let http_version = Some(self.convert_version_to_vcr(resp.version()));
248 let status_code = resp.status();
249 let headers = self.headers_to_vcr(resp.headers());
250 let response_text = resp.bytes().await.expect("Can not fetch response bytes");
251 let body = self.bytes_to_vcr_body(&response_text);
252
253 let status = vcr_cassette::Status {
254 code: status_code.as_u16(),
255 message: status_code
256 .canonical_reason()
257 .unwrap_or("Unknown")
258 .to_string(),
259 };
260
261 let mut vcr_response = vcr_cassette::Response {
262 body,
263 http_version,
264 status,
265 headers,
266 };
267
268 if let Some(ref modifier) = self.modify_response {
269 modifier(&mut vcr_response);
270 }
271
272 vcr_response
273 }
274
275 fn header_values_to_string(&self, header_values: Option<&Vec<String>>) -> String {
276 match header_values {
277 Some(values) => values.join(", "),
278 None => "<MISSING>".to_string(),
279 }
280 }
281
282 fn find_response_in_vcr(&self, req: vcr_cassette::Request) -> Option<vcr_cassette::Response> {
283 let cassette = self.storage.lock().unwrap();
284 let iteractions: Vec<&HttpInteraction> = match self.search {
285 VCRReplaySearch::SkipFound => {
286 let skip = *self.skip.lock().unwrap();
287 *self.skip.lock().unwrap() += 1;
288 cassette.http_interactions.iter().skip(skip).collect()
289 }
290 VCRReplaySearch::SearchAll => cassette.http_interactions.iter().collect(),
291 };
292
293 let mut diff_log = if self.rich_diff {
297 Some(String::new())
298 } else {
299 None
300 };
301 for interaction in iteractions {
302 if interaction.request == req {
303 return Some(interaction.response.clone());
304 }
305 if let Some(diff) = diff_log.as_mut() {
306 diff.push_str(&format!(
307 "Did not match {method:?} to {uri}:\n",
308 method = interaction.request.method,
309 uri = interaction.request.uri.as_str()
310 ));
311 if interaction.request.method != req.method {
312 diff.push_str(&format!(
313 " Method differs: recorded {expected:?}, got {got:?}\n",
314 expected = interaction.request.method,
315 got = req.method
316 ));
317 }
318 if interaction.request.uri != req.uri {
319 diff.push_str(" URI differs:\n");
320 diff.push_str(&format!(
321 " recorded: \"{}\"\n",
322 interaction.request.uri.as_str()
323 ));
324 diff.push_str(&format!(" got: \"{}\"\n", req.uri.as_str()));
325 }
326 if interaction.request.headers != req.headers {
327 diff.push_str(" Headers differ:\n");
328 for (recorded_header_name, recorded_header_values) in
329 &interaction.request.headers
330 {
331 let expected = self.header_values_to_string(Some(recorded_header_values));
332 let got =
333 self.header_values_to_string(req.headers.get(recorded_header_name));
334 if expected != got {
335 diff.push_str(&format!(" {}:\n", recorded_header_name));
336 diff.push_str(&format!(" recorded: \"{}\"\n", expected));
337 diff.push_str(&format!(" got: \"{}\"\n", got));
338 }
339 }
340 for (got_header_name, got_header_values) in &req.headers {
341 if !interaction.request.headers.contains_key(got_header_name) {
342 let got = self.header_values_to_string(Some(got_header_values));
343 diff.push_str(&format!(" {}:\n", got_header_name));
344 diff.push_str(" recorded: <MISSING>\n");
345 diff.push_str(&format!(" got: \"{}\"\n", got));
346 }
347 }
348 }
349 if interaction.request.body != req.body {
350 diff.push_str(" Body differs:\n");
351 diff.push_str(&format!(
352 " recorded: \"{}\"\n",
353 interaction.request.body.string
354 ));
355 diff.push_str(&format!(" got: \"{}\"\n", req.body.string));
356 }
357 diff.push('\n');
358 }
359 }
360 if let Some(diff) = diff_log {
361 for line in diff.split('\n') {
364 tracing::info!("{}", line);
365 }
366 }
367 None
368 }
369
370 fn vcr_to_response(&self, response: vcr_cassette::Response) -> reqwest::Response {
371 let code = response.status.code;
372 let mut builder = http::Response::builder().status(code);
373 for (header_name, header_values) in response.headers {
374 builder = builder.header(header_name, header_values.first().unwrap());
375 }
376 let http_version = self.convert_version_from_vcr(
377 response
378 .http_version
379 .unwrap_or(vcr_cassette::Version::Http1_1),
380 );
381 let builder = builder.version(http_version);
382
383 match response.body.encoding {
384 None => {
385 if !response.body.string.is_empty() {
386 reqwest::Response::from(builder.body(response.body.string).unwrap())
387 } else {
388 reqwest::Response::from(builder.body("".as_bytes()).unwrap())
389 }
390 }
391 Some(encoding) => {
392 if encoding == "base64" {
393 let decoded = general_purpose::STANDARD_NO_PAD
394 .decode(encoding)
395 .expect("Invalid response body base64 can not be decoded");
396 reqwest::Response::from(builder.body(decoded).unwrap())
397 } else {
398 panic!("Unsupported encoding: {encoding}");
400 }
401 }
402 }
403 }
404
405 fn record(&self, request: vcr_cassette::Request, response: vcr_cassette::Response) {
406 let mut cassette = self.storage.lock().unwrap();
407 cassette
408 .http_interactions
409 .push(vcr_cassette::HttpInteraction {
410 response,
411 request,
412 recorded_at: chrono::Utc::now().into(),
413 });
414 }
415}
416
417#[async_trait::async_trait]
422impl Middleware for VCRMiddleware {
423 async fn handle(
424 &self,
425 req: reqwest::Request,
426 extensions: &mut http::Extensions,
427 next: reqwest_middleware::Next<'_>,
428 ) -> reqwest_middleware::Result<reqwest::Response> {
429 let vcr_request = self.request_to_vcr(req.try_clone().unwrap());
430
431 match self.mode {
432 VCRMode::Record => {
433 let response = next.run(req, extensions).await?;
434 let vcr_response = self.response_to_vcr(response).await;
435 let converted_response = self.vcr_to_response(vcr_response.clone());
436 self.record(vcr_request, vcr_response);
437 Ok(converted_response)
438 }
439 VCRMode::Replay => match self.find_response_in_vcr(vcr_request) {
440 None => {
441 let message = format!(
442 "Cannot find corresponding request in cassette {:?}",
443 self.path,
444 );
445 Err(reqwest_middleware::Error::Middleware(anyhow::anyhow!(
446 message
447 )))
448 }
449 Some(response) => {
450 let response = self.vcr_to_response(response);
451 Ok(response)
452 }
453 },
454 }
455 }
456}
457
458impl From<vcr_cassette::Cassette> for VCRMiddleware {
460 fn from(cassette: vcr_cassette::Cassette) -> Self {
461 VCRMiddleware {
462 storage: Mutex::new(cassette),
463 mode: VCRMode::Replay,
464 path: None,
465 skip: Mutex::new(0),
466 search: VCRReplaySearch::SkipFound,
467 compress: false,
468 rich_diff: false,
469 modify_request: None,
470 modify_response: None,
471 }
472 }
473}
474
475impl Drop for VCRMiddleware {
477 fn drop(&mut self) {
478 if self.mode == VCRMode::Record {
479 let path = self
480 .path
481 .clone()
482 .unwrap_or(format!(".rvcr-{}.vcr", chrono::Utc::now().timestamp()).into());
483 let cassette = self.storage.lock().unwrap();
484
485 let contents: String = serde_json::to_string_pretty(&*cassette).unwrap();
486
487 #[cfg(feature = "compress")]
488 if self.compress {
489 let file = std::fs::File::create(path.clone()).unwrap();
490
491 let mut zip = zip::ZipWriter::new(file);
492
493 let options = zip::write::FileOptions::default()
494 .compression_method(zip::CompressionMethod::Bzip2)
495 .compression_level(Some(9))
496 .unix_permissions(0o644);
497 zip.start_file("test.vcr.json", options).unwrap();
498 zip.write_all(contents.as_bytes()).unwrap();
499 zip.finish().unwrap();
500 }
501
502 if !self.compress {
503 fs::write(path.clone(), contents.as_bytes())
504 .unwrap_or_else(|_| panic!("Can not write cassette contents to {path:?}"));
505 tracing::info!("Written VCR cassette file at {path:?}");
506 }
507 }
508 }
509}
510
511impl TryFrom<PathBuf> for VCRMiddleware {
515 fn try_from(pb: PathBuf) -> Result<Self, Self::Error> {
516 let empty = vcr_cassette::Cassette {
517 http_interactions: vec![],
518 recorded_with: RECORDER.to_string(),
519 };
520
521 let mut mw = Self::from(empty);
522 mw.path = Some(pb.clone());
523 if !pb.exists() {
524 Ok(mw)
525 } else {
526 let content = fs::read(pb.clone()).map_err(|e| {
527 tracing::error!("Failed reading VCR cassette: {e}");
528 format!(
529 "Failed to read VCR cassette from path {}",
530 pb.to_str().unwrap()
531 )
532 })?;
533
534 #[cfg(feature = "compress")]
535 let content = {
536 let file = fs::File::open(mw.path.clone().unwrap()).unwrap();
537 match zip::ZipArchive::new(file) {
538 Ok(mut archive) => {
539 let mut content = content;
540 content.clear();
541 let contents = archive.by_name("test.vcr.json");
542 let mut contents =
543 contents.expect("test.vcr.json file is missing in zip archive");
544 contents
545 .read_to_end(&mut content)
546 .expect("Can not read test.vcr.json from zip archive");
547 content
548 }
549 Err(e) => {
550 tracing::debug!("Failed to detect file as zip: {e:?}");
551 content
552 }
553 }
554 };
555
556 let cassette: vcr_cassette::Cassette =
557 serde_json::from_slice(&content).map_err(|e| {
558 tracing::error!("Failed deserializing VCR cassette: {e}");
559 format!(
560 "Failed to deserialize VCR cassette from path {}",
561 pb.to_str().unwrap()
562 )
563 })?;
564
565 let mut mw = Self::from(cassette);
566 mw.path = Some(pb);
567 Ok(mw)
568 }
569 }
570
571 type Error = String;
572}