Skip to main content

reqwest_vcr/
lib.rs

1//! Record-and-replay middleware for reqwest http client.
2//!
3//! Inspired by [https://github.com/vcr/vcr](Ruby-VCR) and
4//! [https://git.sr.ht/~rjframe/surf-vcr](Surf-VCR) Rust client.
5//!
6//! # Examples
7//!
8//! To record the requests, initialize client like following
9//! ```rust
10//! # #[cfg(feature = "reqwest-0_12")]
11//! # extern crate reqwest_0_12 as reqwest;
12//! # #[cfg(feature = "reqwest-0_13")]
13//! # extern crate reqwest_0_13 as reqwest;
14//! # #[cfg(feature = "reqwest-0_12")]
15//! # extern crate reqwest_middleware_0_4 as reqwest_middleware;
16//! # #[cfg(feature = "reqwest-0_13")]
17//! # extern crate reqwest_middleware_0_5 as reqwest_middleware;
18//! use std::path::PathBuf;
19//! use reqwest::Client;
20//! use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
21//! use reqwest_vcr::{VCRMiddleware, VCRMode};
22//!
23//! let mut bundle = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
24//! bundle.push("tests/resources/replay.vcr.json");
25//!
26//! let middleware: VCRMiddleware = VCRMiddleware::try_from(bundle.clone())
27//!     .unwrap()
28//!     .with_mode(VCRMode::Record);
29//!
30//! let vcr_client: ClientWithMiddleware = ClientBuilder::new(reqwest::Client::new())
31//!     .with(middleware)
32//!     .build();
33//! ```
34//!
35//! To use recorded VCR cassette files, replace `.with_mode(VCRMode::Record)`
36//!  with `.with_mode(VCRMode::Replay)`
37
38#[cfg(all(feature = "reqwest-0_12", feature = "reqwest-0_13"))]
39compile_error!(
40    "Features `reqwest-0_12` and `reqwest-0_13` are mutually exclusive. Please enable only one."
41);
42
43#[cfg(not(any(feature = "reqwest-0_12", feature = "reqwest-0_13")))]
44compile_error!("Either `reqwest-0_12` or `reqwest-0_13` feature must be enabled.");
45
46#[cfg(feature = "reqwest-0_12")]
47extern crate reqwest_0_12 as reqwest;
48#[cfg(feature = "reqwest-0_12")]
49extern crate reqwest_middleware_0_4 as reqwest_middleware;
50
51#[cfg(feature = "reqwest-0_13")]
52extern crate reqwest_0_13 as reqwest;
53#[cfg(feature = "reqwest-0_13")]
54extern crate reqwest_middleware_0_5 as reqwest_middleware;
55
56#[cfg(feature = "compress")]
57use std::io::Read;
58#[cfg(feature = "compress")]
59use std::io::Write;
60use std::{fs, path::PathBuf, str::FromStr, sync::Mutex};
61
62use base64::{engine::general_purpose, Engine};
63use reqwest_middleware::Middleware;
64use vcr_cassette::{HttpInteraction, RecorderId};
65
66pub const VERSION: &str = env!("CARGO_PKG_VERSION");
67
68lazy_static::lazy_static! {
69    static ref RECORDER: RecorderId = format!("rVCR {VERSION}");
70    static ref BASE64: String = String::from("base64");
71}
72
73/// Pluggable VCR middleware for record-and-replay for reqwest items
74pub struct VCRMiddleware {
75    path: Option<PathBuf>,
76    storage: Mutex<vcr_cassette::Cassette>,
77    mode: VCRMode,
78    search: VCRReplaySearch,
79    skip: Mutex<usize>,
80    compress: bool,
81    rich_diff: bool,
82    modify_request: Option<Box<RequestModifier>>,
83    modify_response: Option<Box<ResponseModifier>>,
84}
85
86type RequestModifier = dyn Fn(&mut vcr_cassette::Request) + Send + Sync + 'static;
87type ResponseModifier = dyn Fn(&mut vcr_cassette::Response) + Send + Sync + 'static;
88
89/// VCR mode switcher
90#[derive(Eq, PartialEq, Clone)]
91pub enum VCRMode {
92    /// Record requests to the local VCR cassette files. Existing files will be overwritten
93    Record,
94    /// Replay requests using local files
95    Replay,
96}
97
98/// Skip requests
99#[derive(Eq, PartialEq)]
100pub enum VCRReplaySearch {
101    /// Skip requests which already have been found. Useful for
102    /// verifying use-cases with strict request order.
103    SkipFound,
104    /// Search through all requests every time
105    SearchAll,
106}
107
108pub type VCRError = &'static str;
109
110/// Implements boilerplate for converting between vcr_cassette
111/// and reqwest structures.
112///
113/// Carries methods to find response in a cassette, and to record
114/// an interaction.
115impl VCRMiddleware {
116    /// Adjust mode in the middleware and return it
117    pub fn with_mode(mut self, mode: VCRMode) -> Self {
118        self.mode = mode;
119        self
120    }
121
122    pub fn with_modify_request<F>(mut self, modifier: F) -> Self
123    where
124        F: Fn(&mut vcr_cassette::Request) + Send + Sync + 'static,
125    {
126        self.modify_request.replace(Box::new(modifier));
127        self
128    }
129
130    pub fn with_modify_response<F>(mut self, modifier: F) -> Self
131    where
132        F: Fn(&mut vcr_cassette::Response) + Send + Sync + 'static,
133    {
134        self.modify_response.replace(Box::new(modifier));
135        self
136    }
137
138    /// Adjust search behavior for responses
139    pub fn with_search(mut self, search: VCRReplaySearch) -> Self {
140        self.search = search;
141        self
142    }
143
144    /// Adjust path in the middleware and return it
145    pub fn with_path(mut self, path: impl Into<PathBuf>) -> Self {
146        self.path = Some(path.into());
147        self
148    }
149
150    /// Adjust rich diff in the middleware and return it
151    pub fn with_rich_diff(mut self, rich_diff: bool) -> Self {
152        self.rich_diff = rich_diff;
153        self
154    }
155
156    /// Make VCR files to be compressed before creating
157    #[cfg(feature = "compress")]
158    pub fn compressed(mut self, compress: bool) -> Self {
159        self.compress = compress;
160        self
161    }
162
163    fn convert_version_to_vcr(&self, version: http::Version) -> vcr_cassette::Version {
164        if version == http::Version::HTTP_10 {
165            vcr_cassette::Version::Http1_0
166        } else if version == http::Version::HTTP_11 {
167            vcr_cassette::Version::Http1_1
168        } else if version == http::Version::HTTP_2 {
169            vcr_cassette::Version::Http2_0
170        } else {
171            panic!("rVCR only supports http 1.0, 1.1 and 2.0")
172        }
173    }
174
175    fn convert_version_from_vcr(&self, version: vcr_cassette::Version) -> http::Version {
176        match version {
177            vcr_cassette::Version::Http1_0 => http::Version::HTTP_10,
178            vcr_cassette::Version::Http1_1 => http::Version::HTTP_11,
179            vcr_cassette::Version::Http2_0 => http::Version::HTTP_2,
180            _ => {
181                panic!("rVCR only supports http 1.0, 1.1 and 2.0")
182            }
183        }
184    }
185
186    fn bytes_to_vcr_body(&self, body_bytes: &[u8]) -> vcr_cassette::Body {
187        // Try to parse UTF-8 string from the body;
188        // if it fails, body bytes are base64 encoded before saving
189
190        // FIXME: detecting support more encodings
191        match String::from_utf8(body_bytes.to_vec()) {
192            Ok(body_str) => vcr_cassette::Body::from_str(&body_str).unwrap(),
193            Err(e) => {
194                tracing::debug!("Can not deserialize utf-8 string: {e:?}");
195                let base64_str = general_purpose::STANDARD_NO_PAD.encode(body_bytes);
196                vcr_cassette::Body {
197                    string: base64_str,
198                    encoding: Some(BASE64.to_string()),
199                }
200            }
201        }
202    }
203
204    fn headers_to_vcr(&self, headers: &reqwest::header::HeaderMap) -> vcr_cassette::Headers {
205        let mut vcr_headers = vcr_cassette::Headers::new();
206        for (header_name, header_value) in headers {
207            let header_name_string = header_name.to_string();
208            let header_value_bytes = header_value.as_bytes();
209            let header_value = String::from_utf8(header_value_bytes.to_vec())
210                .unwrap_or_else(|_| panic!("Non utf header value for header named {header_name}; header values are supposed to be ASCII encoded"));
211            vcr_headers.insert(header_name_string, vec![header_value]);
212        }
213        vcr_headers
214    }
215
216    fn request_to_vcr(&self, req: reqwest::Request) -> vcr_cassette::Request {
217        let body = match req.body() {
218            Some(body) => match body.as_bytes() {
219                Some(body_bytes) => self.bytes_to_vcr_body(body_bytes),
220                None => vcr_cassette::Body::from_str("").unwrap(),
221            },
222            None => vcr_cassette::Body::from_str("").unwrap(),
223        };
224
225        let method_str = req.method().to_string().to_lowercase();
226
227        let method: vcr_cassette::Method = serde_json::from_str(&format!("\"{method_str}\""))
228            .unwrap_or_else(|_| panic!("Unknown HTTP method passed from reqwest: {method_str}"));
229
230        let headers = self.headers_to_vcr(req.headers());
231
232        let mut vcr_request = vcr_cassette::Request {
233            body,
234            method,
235            uri: req.url().to_owned(),
236            headers,
237        };
238
239        if let Some(ref modifier) = self.modify_request {
240            modifier(&mut vcr_request);
241        }
242
243        vcr_request
244    }
245
246    async fn response_to_vcr(&self, resp: reqwest::Response) -> vcr_cassette::Response {
247        let http_version = Some(self.convert_version_to_vcr(resp.version()));
248        let status_code = resp.status();
249        let headers = self.headers_to_vcr(resp.headers());
250        let response_text = resp.bytes().await.expect("Can not fetch response bytes");
251        let body = self.bytes_to_vcr_body(&response_text);
252
253        let status = vcr_cassette::Status {
254            code: status_code.as_u16(),
255            message: status_code
256                .canonical_reason()
257                .unwrap_or("Unknown")
258                .to_string(),
259        };
260
261        let mut vcr_response = vcr_cassette::Response {
262            body,
263            http_version,
264            status,
265            headers,
266        };
267
268        if let Some(ref modifier) = self.modify_response {
269            modifier(&mut vcr_response);
270        }
271
272        vcr_response
273    }
274
275    fn header_values_to_string(&self, header_values: Option<&Vec<String>>) -> String {
276        match header_values {
277            Some(values) => values.join(", "),
278            None => "<MISSING>".to_string(),
279        }
280    }
281
282    fn find_response_in_vcr(&self, req: vcr_cassette::Request) -> Option<vcr_cassette::Response> {
283        let cassette = self.storage.lock().unwrap();
284        let iteractions: Vec<&HttpInteraction> = match self.search {
285            VCRReplaySearch::SkipFound => {
286                let skip = *self.skip.lock().unwrap();
287                *self.skip.lock().unwrap() += 1;
288                cassette.http_interactions.iter().skip(skip).collect()
289            }
290            VCRReplaySearch::SearchAll => cassette.http_interactions.iter().collect(),
291        };
292
293        // we only want to log match failures if no match is found, so capture
294        // everything at the beginning and then output it all at once if none
295        // are found
296        let mut diff_log = if self.rich_diff {
297            Some(String::new())
298        } else {
299            None
300        };
301        for interaction in iteractions {
302            if interaction.request == req {
303                return Some(interaction.response.clone());
304            }
305            if let Some(diff) = diff_log.as_mut() {
306                diff.push_str(&format!(
307                    "Did not match {method:?} to {uri}:\n",
308                    method = interaction.request.method,
309                    uri = interaction.request.uri.as_str()
310                ));
311                if interaction.request.method != req.method {
312                    diff.push_str(&format!(
313                        "  Method differs: recorded {expected:?}, got {got:?}\n",
314                        expected = interaction.request.method,
315                        got = req.method
316                    ));
317                }
318                if interaction.request.uri != req.uri {
319                    diff.push_str("  URI differs:\n");
320                    diff.push_str(&format!(
321                        "    recorded: \"{}\"\n",
322                        interaction.request.uri.as_str()
323                    ));
324                    diff.push_str(&format!("    got:      \"{}\"\n", req.uri.as_str()));
325                }
326                if interaction.request.headers != req.headers {
327                    diff.push_str("  Headers differ:\n");
328                    for (recorded_header_name, recorded_header_values) in
329                        &interaction.request.headers
330                    {
331                        let expected = self.header_values_to_string(Some(recorded_header_values));
332                        let got =
333                            self.header_values_to_string(req.headers.get(recorded_header_name));
334                        if expected != got {
335                            diff.push_str(&format!("    {}:\n", recorded_header_name));
336                            diff.push_str(&format!("      recorded: \"{}\"\n", expected));
337                            diff.push_str(&format!("      got:      \"{}\"\n", got));
338                        }
339                    }
340                    for (got_header_name, got_header_values) in &req.headers {
341                        if !interaction.request.headers.contains_key(got_header_name) {
342                            let got = self.header_values_to_string(Some(got_header_values));
343                            diff.push_str(&format!("    {}:\n", got_header_name));
344                            diff.push_str("      recorded: <MISSING>\n");
345                            diff.push_str(&format!("      got:      \"{}\"\n", got));
346                        }
347                    }
348                }
349                if interaction.request.body != req.body {
350                    diff.push_str("  Body differs:\n");
351                    diff.push_str(&format!(
352                        "    recorded: \"{}\"\n",
353                        interaction.request.body.string
354                    ));
355                    diff.push_str(&format!("    got:      \"{}\"\n", req.body.string));
356                }
357                diff.push('\n');
358            }
359        }
360        if let Some(diff) = diff_log {
361            // tracing_test does not appear to capture multiline outputs for test
362            // assertion purposes, so we print each line out separately
363            for line in diff.split('\n') {
364                tracing::info!("{}", line);
365            }
366        }
367        None
368    }
369
370    fn vcr_to_response(&self, response: vcr_cassette::Response) -> reqwest::Response {
371        let code = response.status.code;
372        let mut builder = http::Response::builder().status(code);
373        for (header_name, header_values) in response.headers {
374            builder = builder.header(header_name, header_values.first().unwrap());
375        }
376        let http_version = self.convert_version_from_vcr(
377            response
378                .http_version
379                .unwrap_or(vcr_cassette::Version::Http1_1),
380        );
381        let builder = builder.version(http_version);
382
383        match response.body.encoding {
384            None => {
385                if !response.body.string.is_empty() {
386                    reqwest::Response::from(builder.body(response.body.string).unwrap())
387                } else {
388                    reqwest::Response::from(builder.body("".as_bytes()).unwrap())
389                }
390            }
391            Some(encoding) => {
392                if encoding == "base64" {
393                    let decoded = general_purpose::STANDARD_NO_PAD
394                        .decode(encoding)
395                        .expect("Invalid response body base64 can not be decoded");
396                    reqwest::Response::from(builder.body(decoded).unwrap())
397                } else {
398                    // FIXME: support more encodings
399                    panic!("Unsupported encoding: {encoding}");
400                }
401            }
402        }
403    }
404
405    fn record(&self, request: vcr_cassette::Request, response: vcr_cassette::Response) {
406        let mut cassette = self.storage.lock().unwrap();
407        cassette
408            .http_interactions
409            .push(vcr_cassette::HttpInteraction {
410                response,
411                request,
412                recorded_at: chrono::Utc::now().into(),
413            });
414    }
415}
416
417/// Reqwest middleware implementation
418///
419/// It receives request, converts it to internal VCR format,
420/// and saves data in the internal.
421#[async_trait::async_trait]
422impl Middleware for VCRMiddleware {
423    async fn handle(
424        &self,
425        req: reqwest::Request,
426        extensions: &mut http::Extensions,
427        next: reqwest_middleware::Next<'_>,
428    ) -> reqwest_middleware::Result<reqwest::Response> {
429        let vcr_request = self.request_to_vcr(req.try_clone().unwrap());
430
431        match self.mode {
432            VCRMode::Record => {
433                let response = next.run(req, extensions).await?;
434                let vcr_response = self.response_to_vcr(response).await;
435                let converted_response = self.vcr_to_response(vcr_response.clone());
436                self.record(vcr_request, vcr_response);
437                Ok(converted_response)
438            }
439            VCRMode::Replay => match self.find_response_in_vcr(vcr_request) {
440                None => {
441                    let message = format!(
442                        "Cannot find corresponding request in cassette {:?}",
443                        self.path,
444                    );
445                    Err(reqwest_middleware::Error::Middleware(anyhow::anyhow!(
446                        message
447                    )))
448                }
449                Some(response) => {
450                    let response = self.vcr_to_response(response);
451                    Ok(response)
452                }
453            },
454        }
455    }
456}
457
458/// Create middleware instance from Cassette
459impl From<vcr_cassette::Cassette> for VCRMiddleware {
460    fn from(cassette: vcr_cassette::Cassette) -> Self {
461        VCRMiddleware {
462            storage: Mutex::new(cassette),
463            mode: VCRMode::Replay,
464            path: None,
465            skip: Mutex::new(0),
466            search: VCRReplaySearch::SkipFound,
467            compress: false,
468            rich_diff: false,
469            modify_request: None,
470            modify_response: None,
471        }
472    }
473}
474
475/// Save cassette interactions after the run
476impl Drop for VCRMiddleware {
477    fn drop(&mut self) {
478        if self.mode == VCRMode::Record {
479            let path = self
480                .path
481                .clone()
482                .unwrap_or(format!(".rvcr-{}.vcr", chrono::Utc::now().timestamp()).into());
483            let cassette = self.storage.lock().unwrap();
484
485            let contents: String = serde_json::to_string_pretty(&*cassette).unwrap();
486
487            #[cfg(feature = "compress")]
488            if self.compress {
489                let file = std::fs::File::create(path.clone()).unwrap();
490
491                let mut zip = zip::ZipWriter::new(file);
492
493                let options = zip::write::FileOptions::default()
494                    .compression_method(zip::CompressionMethod::Bzip2)
495                    .compression_level(Some(9))
496                    .unix_permissions(0o644);
497                zip.start_file("test.vcr.json", options).unwrap();
498                zip.write_all(contents.as_bytes()).unwrap();
499                zip.finish().unwrap();
500            }
501
502            if !self.compress {
503                fs::write(path.clone(), contents.as_bytes())
504                    .unwrap_or_else(|_| panic!("Can not write cassette contents to {path:?}"));
505                tracing::info!("Written VCR cassette file at {path:?}");
506            }
507        }
508    }
509}
510
511/// Load VCR cassette for filesystem
512//
513/// For simplicity, support JSON format only for now
514impl TryFrom<PathBuf> for VCRMiddleware {
515    fn try_from(pb: PathBuf) -> Result<Self, Self::Error> {
516        let empty = vcr_cassette::Cassette {
517            http_interactions: vec![],
518            recorded_with: RECORDER.to_string(),
519        };
520
521        let mut mw = Self::from(empty);
522        mw.path = Some(pb.clone());
523        if !pb.exists() {
524            Ok(mw)
525        } else {
526            let content = fs::read(pb.clone()).map_err(|e| {
527                tracing::error!("Failed reading VCR cassette: {e}");
528                format!(
529                    "Failed to read VCR cassette from path {}",
530                    pb.to_str().unwrap()
531                )
532            })?;
533
534            #[cfg(feature = "compress")]
535            let content = {
536                let file = fs::File::open(mw.path.clone().unwrap()).unwrap();
537                match zip::ZipArchive::new(file) {
538                    Ok(mut archive) => {
539                        let mut content = content;
540                        content.clear();
541                        let contents = archive.by_name("test.vcr.json");
542                        let mut contents =
543                            contents.expect("test.vcr.json file is missing in zip archive");
544                        contents
545                            .read_to_end(&mut content)
546                            .expect("Can not read test.vcr.json from zip archive");
547                        content
548                    }
549                    Err(e) => {
550                        tracing::debug!("Failed to detect file as zip: {e:?}");
551                        content
552                    }
553                }
554            };
555
556            let cassette: vcr_cassette::Cassette =
557                serde_json::from_slice(&content).map_err(|e| {
558                    tracing::error!("Failed deserializing VCR cassette: {e}");
559                    format!(
560                        "Failed to deserialize VCR cassette from path {}",
561                        pb.to_str().unwrap()
562                    )
563                })?;
564
565            let mut mw = Self::from(cassette);
566            mw.path = Some(pb);
567            Ok(mw)
568        }
569    }
570
571    type Error = String;
572}