use std::{
collections::HashMap,
path::PathBuf,
fmt,
io,
};
use async_std::{
prelude::*,
sync::RwLock,
fs,
};
use serde::{Serialize, Deserialize};
use surf::{
http::{self, Method, Version},
middleware::{Middleware, Next},
Client,
Request, Response,
StatusCode,
Url,
};
use once_cell::sync::OnceCell;
type Session = (Vec<VcrRequest>, Vec<VcrResponse>);
static CASSETTES: OnceCell<RwLock<HashMap<PathBuf, RwLock::<Option<Session>>>>>
= OnceCell::new();
type RequestModifier = dyn Fn(&mut VcrRequest) + Send + Sync + 'static;
type ResponseModifier = dyn Fn(&mut VcrResponse) + Send + Sync + 'static;
pub struct VcrMiddleware {
mode: VcrMode,
file: PathBuf,
modify_request: Option<Box<RequestModifier>>,
modify_response: Option<Box<ResponseModifier>>,
}
#[surf::utils::async_trait]
impl Middleware for VcrMiddleware {
async fn handle(&self, mut req: Request, client: Client, next: Next<'_>)
-> surf::Result<Response> {
let mut request = VcrRequest::from_request(&mut req).await?;
if let Some(ref modifier) = self.modify_request {
modifier(&mut request);
}
match self.mode {
VcrMode::Record => {
let mut res = next.run(req, client).await?;
let mut response = VcrResponse::try_from_response(&mut res)
.await?;
if let Some(ref modifier) = self.modify_response {
modifier(&mut response);
}
let doc = serde_yaml::to_string(
&(
SerdeWrapper::Request(request),
SerdeWrapper::Response(response)
)
)?;
let recorders = CASSETTES.get().unwrap().read().await;
let lock = recorders[&self.file].write().await;
let mut file = fs::OpenOptions::new()
.create(true)
.append(true)
.open(&self.file).await?;
file.write_all(doc.as_bytes()).await?;
drop(lock);
Ok(res)
},
VcrMode::Replay => {
let cassettes = CASSETTES.get().unwrap().read().await;
let sessions = &cassettes[&self.file].read().await;
let (requests, responses) = sessions.as_ref()
.expect(&format!("Missing session: {:?}", self.file));
match requests.iter().position(|x| x == &request) {
Some(pos) => Ok(Response::from(&responses[pos])),
None => Err(surf::Error::new(
StatusCode::NotFound,
VcrError::Lookup(Request::from(request))
)),
}
}
}
}
}
impl VcrMiddleware {
pub async fn new<P>(mode: VcrMode, recording: P) -> Result<Self, VcrError>
where P: Into<PathBuf>,
{
let recording = recording.into();
if mode == VcrMode::Replay {
let _ = CASSETTES.set(RwLock::new(HashMap::new()));
let mut cassettes = CASSETTES.get().unwrap().write().await;
let recording_exists = cassettes.contains_key(&recording)
&& cassettes[&recording].read().await.is_some();
if ! recording_exists {
let mut requests = vec![];
let mut responses = vec![];
let replays = fs::read_to_string(&recording).await?;
for replay in replays.split("\n---\n") {
let (request, response) = serde_yaml::from_str(replay)?;
let req = match request {
SerdeWrapper::Request(r) => r,
_ => panic!("Invalid request"),
};
let resp = match response {
SerdeWrapper::Response(r) => r,
_ => panic!("Invalid response"),
};
requests.push(req);
responses.push(resp);
}
cassettes.insert(
recording.clone(),
RwLock::new(Some((requests, responses)))
);
}
} else { let _ = CASSETTES.set(RwLock::new(HashMap::new()));
let mut recorders = CASSETTES.get().unwrap().write().await;
recorders.insert(recording.clone(), RwLock::new(None));
}
Ok(Self {
mode,
file: recording,
modify_request: None,
modify_response: None
})
}
pub fn with_modify_request<F>(mut self, modifier: F) -> Self
where F: Fn(&mut VcrRequest) + Send + Sync + 'static
{
self.modify_request.replace(Box::new(modifier));
self
}
pub fn with_modify_response<F>(mut self, modifier: F) -> Self
where F: Fn(&mut VcrResponse) + Send + Sync + 'static
{
self.modify_response.replace(Box::new(modifier));
self
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Body {
Bytes(Vec<u8>),
Str(String),
}
impl From<&[u8]> for Body {
fn from(bytes: &[u8]) -> Self {
match std::str::from_utf8(&bytes) {
Ok(s) => Body::Str(s.to_owned()),
Err(_) => Body::Bytes(bytes.to_vec()),
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub enum VcrMode {
Record,
Replay,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct VcrRequest {
pub method: Method,
pub url: Url,
pub headers: HashMap<String, Vec<String>>,
pub body: Body,
}
impl VcrRequest {
async fn from_request(req: &mut Request) -> surf::Result<VcrRequest> {
let headers = {
let mut headers = HashMap::new();
for header in req.header_names() {
let values = req.header(header).iter()
.map(|v| v.as_str().to_string())
.collect::<Vec<String>>();
headers.insert(header.to_string(), values);
}
headers
};
let orig_body = req.take_body().into_bytes().await?;
let body = Body::from(orig_body.as_slice());
req.set_body(orig_body.as_slice());
Ok(Self {
method: req.method(),
url: req.url().to_owned(),
headers,
body,
})
}
}
impl From<VcrRequest> for Request {
fn from(req: VcrRequest) -> Request {
let mut request = http::Request::new(req.method, req.url);
for name in req.headers.keys() {
let values = &req.headers[name];
for value in values.iter() {
request.append_header(name.as_str(), value);
}
}
match &req.body {
Body::Bytes(b) => request.set_body(b.as_slice()),
Body::Str(s) => request.set_body(s.as_str()),
}
Request::from(request)
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct VcrResponse {
pub status: StatusCode,
pub version: Option<Version>,
pub headers: HashMap<String, Vec<String>>,
pub body: Body,
}
impl VcrResponse {
async fn try_from_response(resp: &mut Response)
-> surf::Result<VcrResponse> {
let headers = {
let mut headers = HashMap::new();
for hdr in resp.header_names() {
let values = resp.header(hdr).iter()
.map(|v| v.as_str().to_string())
.collect::<Vec<String>>();
headers.insert(hdr.to_string(), values);
}
headers
};
let orig_body = resp.body_bytes().await?;
let body = Body::from(orig_body.as_slice());
resp.set_body(orig_body.as_slice());
Ok(Self {
status: resp.status(),
version: resp.version(),
headers,
body,
})
}
}
impl From<&VcrResponse> for Response {
fn from(resp: &VcrResponse) -> Response {
let mut response = http::Response::new(resp.status);
response.set_version(resp.version);
for name in resp.headers.keys() {
let values = &resp.headers[name];
for value in values.iter() {
response.append_header(name.as_str(), value);
}
}
match &resp.body {
Body::Bytes(b) => response.set_body(b.as_slice()),
Body::Str(s) => response.set_body(s.as_str()),
}
Response::from(response)
}
}
#[derive(Debug, Deserialize, Serialize)]
enum SerdeWrapper {
Request(VcrRequest),
Response(VcrResponse),
}
#[derive(Debug)]
pub enum VcrError {
File(io::Error),
Parse(serde_yaml::Error),
Lookup(surf::Request),
}
impl std::error::Error for VcrError {}
impl fmt::Display for VcrError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::File(e) => e.fmt(f),
Self::Parse(e) => e.fmt(f),
Self::Lookup(req) =>
write!(f, "Request not found at {}: {:#?}", req.url(), req),
}
}
}
impl From<io::Error> for VcrError {
fn from(e: io::Error) -> Self { Self::File(e) }
}
impl From<serde_yaml::Error> for VcrError {
fn from(e: serde_yaml::Error) -> Self { Self::Parse(e) }
}
#[cfg(test)]
mod tests {
use super::*;
#[async_std::test]
async fn read_recording_from_disk() -> Result<(), VcrError> {
let vcr = VcrMiddleware::new(
VcrMode::Replay,
"test-sessions/simple.yml"
).await?;
let mut req_headers = HashMap::new();
req_headers.insert(
"X-some-header".to_owned(),
vec!["hello".to_owned()]
);
let req = VcrRequest {
method: Method::Get,
url: Url::parse("https://example.com").unwrap(),
headers: req_headers,
body: Body::Str("My Request".to_owned()),
};
let mut res_headers = HashMap::new();
res_headers.insert(
"X-some-header".to_owned(),
vec!["goodbye".to_owned()]
);
let res = VcrResponse {
status: StatusCode::Ok,
version: None,
headers: res_headers,
body: Body::Str("A Response".to_owned()),
};
let cassettes = CASSETTES.get().unwrap().read().await;
let sessions = &cassettes[&vcr.file].read().await;
let (requests, responses) = sessions.as_ref().unwrap();
assert_eq!(req, requests[0]);
assert_eq!(res, responses[0]);
Ok(())
}
#[async_std::test]
async fn replay_recorded_communications() -> Result<(), VcrError> {
let vcr = VcrMiddleware::new(
VcrMode::Replay,
"test-sessions/simple.yml"
).await?
.with_modify_request(|res| {
*res.headers.get_mut("secret-header").unwrap() =
vec![String::from("(secret)")];
});
let client = surf::Client::new().with(vcr);
let req = surf::get("https://example.com")
.header("X-some-header", "another hello")
.header("secret-header", "sensitive data")
.build();
let mut res = client.send(req).await.unwrap();
let mut res_headers = HashMap::new();
res_headers.insert(
"x-some-header".to_owned(),
vec!["another goodbye".to_owned()]
);
res_headers.insert(
"content-type".to_owned(),
vec!["text/plain;charset=utf-8".to_owned()]
);
res_headers.insert(
"date".to_owned(),
vec!["Fri, 28 May 2021 00:44:58 GMT".to_owned()]
);
let expected = VcrResponse {
status: StatusCode::Ok,
version: None,
headers: res_headers,
body: Body::Str("A Response".to_owned()),
};
assert_eq!(
VcrResponse::try_from_response(&mut res).await.unwrap(),
expected
);
Ok(())
}
#[async_std::test]
async fn record_communication_in_write_mode() -> Result<(), VcrError> {
let path = "test-sessions/record-test.yml";
let _ = async_std::fs::remove_file("test-sessions/record-test.yml")
.await;
fn hide_session_key(req: &mut VcrRequest) {
req.headers.entry("session-key".into())
.and_modify(|val| *val = vec!["(some key)".into()]);
}
fn hide_cookie(res: &mut VcrResponse) {
res.headers.entry("Set-Cookie".into())
.and_modify(|val| *val = vec!["(erased)".into()]);
}
let outer = VcrMiddleware::new(
VcrMode::Replay,
"test-sessions/simple.yml",
).await?;
let vcr = VcrMiddleware::new(VcrMode::Record, path).await?
.with_modify_request(hide_session_key)
.with_modify_response(hide_cookie);
let client = surf::Client::new()
.with(vcr)
.with(outer);
let req = surf::get("https://example.com")
.header("X-some-header", "another hello")
.header("Content-Type", "application/octet-stream")
.header("session-key", "00112233445566778899AABBCCDDEEFF")
.build();
let mut expected_res = client.send(req).await.unwrap();
let client = surf::Client::new()
.with(VcrMiddleware::new(VcrMode::Replay, path).await?);
let req = surf::get("https://example.com")
.header("X-some-header", "another hello")
.header("Content-Type", "application/octet-stream")
.header("session-key", "(some key)")
.build();
let mut res = client.send(req).await.unwrap();
let res = VcrResponse::try_from_response(&mut res).await.unwrap();
assert_eq!(
res,
VcrResponse::try_from_response(&mut expected_res).await.unwrap()
);
let cookies = &res.headers["set-cookie"];
assert!(! cookies.contains(&"cookie2=val2; Expires=date2".into()));
Ok(())
}
}