use std::{
cell::RefCell,
collections::HashMap,
fmt::{Debug, Display},
marker::PhantomData,
ops::Deref,
rc::Rc,
sync::Arc,
};
use bytes::Bytes;
use derive_more::derive::{Display, From};
use http::StatusCode;
use reqwest::{Client, RequestBuilder};
use tokio::sync::Mutex;
use url::Url;
use crate::net::{Request, Response};
use super::{Error, Result};
type Plans = Vec<Plan>;
#[derive(Debug)]
struct Plan {
match_request: Vec<MatchRequest>,
response: reqwest::Response,
}
impl Plan {
#[tracing::instrument]
fn matches(&self, request: &Request) -> bool {
let url = request.url();
let is_match =
self.match_request.iter().all(|criteria| match criteria {
MatchRequest::Body(body) => {
request.body().and_then(reqwest::Body::as_bytes) == Some(body)
}
MatchRequest::Header { name, value } => {
request
.headers()
.iter()
.any(|(request_header_name, request_header_value)| {
let Ok(request_header_value) = request_header_value.to_str() else {
return false;
};
request_header_name.as_str() == name && request_header_value == value
})
}
MatchRequest::Method(method) => request.method() == http::Method::from(method),
MatchRequest::Scheme(scheme) => url.scheme() == scheme,
MatchRequest::Host(host) => url.host_str() == Some(host),
MatchRequest::Path(path) => url.path() == path,
MatchRequest::Fragment(fragment) => url.fragment() == Some(fragment),
MatchRequest::Query { name, value } => url.query_pairs().into_iter().any(
|(request_query_name, request_query_value)| {
request_query_name.as_ref() == name && request_query_value.as_ref() == value
},
),
});
tracing::debug!(?is_match);
is_match
}
}
impl Display for Plan {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for m in &self.match_request {
write!(f, "{m} ")?;
}
writeln!(f, "=> {:?}", self.response)
}
}
#[derive(Debug, Clone, Default)]
pub struct Net {
plans: Option<Arc<Mutex<RefCell<Plans>>>>,
}
impl Net {
pub(super) const fn new() -> Self {
Self { plans: None }
}
}
impl Net {
#[must_use]
pub fn client(&self) -> Client {
Default::default()
}
#[tracing::instrument(skip_all)]
pub async fn send(&self, request: impl Into<RequestBuilder>) -> Result<Response> {
let request: RequestBuilder = request.into();
tracing::debug!(?request);
let Some(plans) = &self.plans else {
tracing::debug!("no plans - sending real request");
return request.send().await.map_err(Error::from);
};
tracing::debug!("build request");
let request = request.build()?;
let index = plans
.lock()
.await
.deref()
.borrow()
.iter()
.position(|plan| plan.matches(&request));
match index {
Some(i) => {
let plan = plans.lock().await.borrow_mut().remove(i);
let response = plan.response;
if response.status().is_success() {
tracing::debug!(?request, "matched success response");
Ok(response)
} else {
tracing::debug!(?request, "matched error response");
Err(crate::net::Error::ResponseError { response })
}
}
None => {
tracing::warn!(?request, "unexpected mock request");
Err(Error::UnexpectedMockRequest { request })
}
}
}
#[must_use]
pub fn delete(&self, url: impl Into<String>) -> ReqBuilder {
ReqBuilder::new(self, NetMethod::Delete, url)
}
#[must_use]
pub fn get(&self, url: impl Into<String>) -> ReqBuilder {
ReqBuilder::new(self, NetMethod::Get, url)
}
#[must_use]
pub fn head(&self, url: impl Into<String>) -> ReqBuilder {
ReqBuilder::new(self, NetMethod::Head, url)
}
#[must_use]
pub fn patch(&self, url: impl Into<String>) -> ReqBuilder {
ReqBuilder::new(self, NetMethod::Patch, url)
}
#[must_use]
pub fn post(&self, url: impl Into<String>) -> ReqBuilder {
ReqBuilder::new(self, NetMethod::Post, url)
}
#[must_use]
pub fn put(&self, url: impl Into<String>) -> ReqBuilder {
ReqBuilder::new(self, NetMethod::Put, url)
}
}
impl MockNet {
pub async fn try_from(net: Net) -> std::result::Result<Self, super::Error> {
match &net.plans {
Some(plans) => Ok(MockNet {
plans: Rc::new(RefCell::new(plans.lock().await.take())),
}),
None => Err(super::Error::NetIsNotAMock),
}
}
}
#[derive(Debug, Clone, Display, PartialEq, Eq)]
pub enum NetMethod {
Delete,
Get,
Head,
Patch,
Post,
Put,
}
impl From<&NetMethod> for http::Method {
fn from(value: &NetMethod) -> Self {
match value {
NetMethod::Delete => http::Method::DELETE,
NetMethod::Get => http::Method::GET,
NetMethod::Head => http::Method::HEAD,
NetMethod::Patch => http::Method::PATCH,
NetMethod::Post => http::Method::POST,
NetMethod::Put => http::Method::PUT,
}
}
}
pub struct ReqBuilder<'net> {
net: &'net Net,
url: String,
method: NetMethod,
headers: Vec<(String, String)>,
query: Vec<(String, String)>,
body: Option<Bytes>,
}
impl<'net> ReqBuilder<'net> {
#[must_use]
fn new(net: &'net Net, method: NetMethod, url: impl Into<String>) -> Self {
Self {
net,
url: url.into(),
method,
headers: vec![],
query: vec![],
body: None,
}
}
#[must_use]
pub fn with(self, f: impl FnOnce(Self) -> Self) -> Self {
f(self)
}
#[must_use]
pub fn with_option<T>(self, option: Option<T>) -> WithOption<'net, T> {
WithOption {
req_builder: self,
option,
}
}
#[must_use]
pub fn with_result<T, E>(self, result: std::result::Result<T, E>) -> WithResult<'net, T, E> {
WithResult {
req_builder: self,
result,
}
}
#[tracing::instrument(skip_all)]
pub async fn send(self) -> Result<Response> {
let client = self.net.client();
let mut url = self.url;
tracing::trace!(?url);
if !self.query.is_empty() {
url.push('?');
for (i, (name, value)) in self.query.into_iter().enumerate() {
tracing::trace!(?name, ?value, "query parameters");
if i > 0 {
url.push('&');
}
url.push_str(&name);
url.push('=');
url.push_str(&value);
}
tracing::trace!(?url, "with query parameters");
}
let mut req = match self.method {
NetMethod::Delete => client.delete(url),
NetMethod::Get => client.get(url),
NetMethod::Head => client.head(url),
NetMethod::Patch => client.patch(url),
NetMethod::Post => client.post(url),
NetMethod::Put => client.put(url),
};
for (name, value) in self.headers.into_iter() {
req = req.header(name, value);
}
if let Some(bytes) = self.body {
req = req.body(bytes);
}
tracing::debug!(?req);
self.net.send(req).await
}
#[must_use]
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.push((name.into(), value.into()));
self
}
#[must_use]
pub fn headers(mut self, headers: HashMap<String, String>) -> Self {
self.headers.extend(headers);
self
}
#[must_use]
pub fn body(mut self, bytes: impl Into<Bytes>) -> Self {
self.body = Some(bytes.into());
tracing::trace!(body = ?self.body);
self
}
#[must_use]
pub fn user_agent(self, user_agent: impl Into<String>) -> Self {
self.header(http::header::USER_AGENT.to_string(), user_agent)
}
#[must_use]
pub fn basic_auth(
self,
username: impl Into<String>,
password: Option<impl Into<String>>,
) -> Self {
let value = basic_auth_header_value(username, password);
self.header(http::header::AUTHORIZATION.to_string(), value)
}
#[must_use]
pub fn bearer_auth<T>(self, token: T) -> Self
where
T: std::fmt::Display,
{
self.header(
http::header::AUTHORIZATION.to_string(),
format!("Bearer {token}"),
)
}
#[must_use]
pub fn query(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.query.push((key.into(), value.into()));
self
}
}
fn basic_auth_header_value(
username: impl Into<String>,
password: Option<impl Into<String>>,
) -> String {
let username = username.into();
let password = password.map(|p| p.into());
let value = {
use base64::prelude::BASE64_STANDARD;
use base64::write::EncoderWriter;
use std::io::Write;
let mut buf = b"Basic ".to_vec();
{
let mut encoder = EncoderWriter::new(&mut buf, &BASE64_STANDARD);
let _ = write!(encoder, "{username}:");
if let Some(password) = password {
let _ = write!(encoder, "{password}");
}
}
String::from_utf8(buf).expect("should always be valid utf8")
};
value
}
#[derive(Debug, Clone, Default)]
pub struct MockNet {
plans: Rc<RefCell<Plans>>,
}
impl MockNet {
pub fn client(&self) -> Client {
Default::default()
}
#[must_use]
pub fn on(&self) -> WhenRequest<WhenBuildRequest> {
WhenRequest::new(self)
}
fn _when(&self, plan: Plan) {
self.plans.borrow_mut().push(plan);
}
pub fn reset(&self) {
tracing::debug!("reset plans");
self.plans.take();
}
}
impl From<MockNet> for Net {
fn from(mock_net: MockNet) -> Self {
Self {
plans: Some(Arc::new(Mutex::new(RefCell::new(mock_net.plans.take())))),
}
}
}
impl Drop for MockNet {
#[cfg_attr(test, mutants::skip)]
#[tracing::instrument]
fn drop(&mut self) {
if std::thread::panicking() {
return;
}
let unused = self.plans.take();
if !unused.is_empty() {
log_unused_plans(&unused);
assert!(
unused.is_empty(),
"{} expected requests were not made",
unused.len()
);
}
}
}
impl Drop for Net {
#[cfg_attr(test, mutants::skip)]
#[tracing::instrument]
fn drop(&mut self) {
if std::thread::panicking() {
return;
}
if let Some(plans) = &self.plans {
let unused = plans.try_lock().expect("lock plans").take();
if !unused.is_empty() {
log_unused_plans(&unused);
assert!(
unused.is_empty(),
"{} expected requests were not made",
unused.len()
);
}
}
}
}
#[cfg_attr(test, mutants::skip)]
fn log_unused_plans(unused: &[Plan]) {
if !unused.is_empty() {
eprintln!(
"Net::drop(): {} expected requests were not made:\n{}",
unused.len(),
unused
.iter()
.map(|p| format!(" - {}", p))
.collect::<Vec<_>>()
.join("\n")
);
}
}
impl Net {
#[cfg_attr(test, mutants::skip)]
pub fn assert_no_unused_plans(&self) {
if let Some(plans) = &self.plans {
let unused = plans.try_lock().expect("lock plans").take();
if !unused.is_empty() {
log_unused_plans(&unused);
assert!(
unused.is_empty(),
"{} expected requests were not made",
unused.len()
);
}
}
}
}
impl MockNet {
#[cfg_attr(test, mutants::skip)]
pub fn assert_no_unused_plans(&self) {
let unused = self.plans.take();
if !unused.is_empty() {
log_unused_plans(&unused);
assert!(
unused.is_empty(),
"{} expected requests were not made",
unused.len()
);
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MatchRequest {
Body(bytes::Bytes),
Fragment(String),
Header { name: String, value: String },
Host(String),
Method(NetMethod),
Path(String),
Query { name: String, value: String },
Scheme(String),
}
impl Display for MatchRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Body(body) => write!(f, "Body: {body:?}"),
Self::Fragment(fragment) => write!(f, "#{fragment}"),
Self::Header { name, value } => write!(f, "({name}: {value})"),
Self::Host(host) => write!(f, "@{host}"),
Self::Method(method) => write!(f, "{method}"),
Self::Path(path) => write!(f, "/{path}"),
Self::Query { name, value } => write!(f, "?{name}={value})"),
Self::Scheme(scheme) => write!(f, "{scheme}://"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RespondWith {
Status(StatusCode),
Header { name: String, value: String },
Body(bytes::Bytes),
}
#[derive(Clone, Debug, Display, From)]
pub enum MockError {
#[display("url parse: {}", 0)]
UrlParse(#[from] url::ParseError),
}
impl std::error::Error for MockError {}
pub trait WhenState {}
#[derive(Debug)]
pub struct WhenBuildRequest;
impl WhenState for WhenBuildRequest {}
#[derive(Debug)]
pub struct WhenBuildResponse;
impl WhenState for WhenBuildResponse {}
#[derive(Debug, Clone)]
pub struct WhenRequest<'net, State>
where
State: WhenState,
{
_state: PhantomData<State>,
net: &'net MockNet,
match_on: Vec<MatchRequest>,
respond_with: Vec<RespondWith>,
error: Option<MockError>,
}
impl<'net> WhenRequest<'net, WhenBuildRequest> {
fn new(net: &'net MockNet) -> Self {
Self {
_state: PhantomData,
net,
match_on: vec![],
respond_with: vec![],
error: None,
}
}
#[must_use]
pub fn get(self, url: impl Into<String>) -> Self {
self._url(NetMethod::Get, url)
}
#[must_use]
pub fn post(self, url: impl Into<String>) -> Self {
self._url(NetMethod::Post, url)
}
#[must_use]
pub fn put(self, url: impl Into<String>) -> Self {
self._url(NetMethod::Put, url)
}
#[must_use]
pub fn delete(self, url: impl Into<String>) -> Self {
self._url(NetMethod::Delete, url)
}
#[must_use]
pub fn head(self, url: impl Into<String>) -> Self {
self._url(NetMethod::Head, url)
}
#[must_use]
pub fn patch(self, url: impl Into<String>) -> Self {
self._url(NetMethod::Patch, url)
}
#[tracing::instrument(skip_all)]
fn _url(mut self, method: NetMethod, url: impl Into<String>) -> Self {
self.match_on.push(MatchRequest::Method(method));
match Url::parse(&url.into()) {
Ok(url) => {
self.match_on
.push(MatchRequest::Scheme(url.scheme().into()));
if url.has_host() {
if let Some(host) = url.host_str() {
self.match_on.push(MatchRequest::Host(host.into()));
}
}
self.match_on.push(MatchRequest::Path(url.path().into()));
if let Some(fragment) = url.fragment() {
self.match_on.push(MatchRequest::Fragment(fragment.into()));
}
url.query_pairs().into_iter().for_each(|(key, value)| {
self.match_on.push(MatchRequest::Query {
name: key.into(),
value: value.into(),
})
});
}
Err(err) => {
self.error.replace(err.into());
}
}
tracing::debug!(match_on = ?self.match_on, error = ?self.error);
self
}
#[must_use]
pub fn query(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
let name = name.into();
let value = value.into();
self.match_on.push(MatchRequest::Query { name, value });
self
}
#[must_use]
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.match_on.push(MatchRequest::Header {
name: name.into(),
value: value.into(),
});
self
}
#[must_use]
pub fn headers(mut self, headers: HashMap<String, String>) -> Self {
for (name, value) in headers {
self.match_on.push(MatchRequest::Header { name, value });
}
self
}
#[must_use]
pub fn basic_auth(
self,
username: impl Into<String>,
password: Option<impl Into<String>>,
) -> Self {
let value = basic_auth_header_value(username, password);
self.header(http::header::AUTHORIZATION.to_string(), value)
}
#[must_use]
pub fn bearer_auth(self, token: impl Into<String>) -> Self {
self.header(
http::header::AUTHORIZATION.to_string(),
format!("Bearer {}", token.into()),
)
}
#[must_use]
pub fn user_agent(self, agent: impl Into<String>) -> Self {
self.header(http::header::USER_AGENT.to_string(), agent)
}
#[must_use]
pub fn body(mut self, body: impl Into<bytes::Bytes>) -> Self {
self.match_on.push(MatchRequest::Body(body.into()));
self
}
#[must_use]
pub fn respond(self, status: StatusCode) -> WhenRequest<'net, WhenBuildResponse> {
WhenRequest::<WhenBuildResponse> {
_state: PhantomData,
net: self.net,
match_on: self.match_on,
respond_with: vec![RespondWith::Status(status)],
error: self.error,
}
}
}
impl WhenRequest<'_, WhenBuildResponse> {
#[must_use]
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
let name = name.into();
let value = value.into();
self.respond_with.push(RespondWith::Header { name, value });
self
}
#[must_use]
pub fn headers(mut self, headers: impl Into<HashMap<String, String>>) -> Self {
let h: HashMap<String, String> = headers.into();
for (name, value) in h.into_iter() {
self.respond_with.push(RespondWith::Header { name, value });
}
self
}
pub fn body(mut self, body: impl Into<bytes::Bytes>) -> Result<()> {
self.respond_with.push(RespondWith::Body(body.into()));
self.mock()
}
pub fn mock(self) -> Result<()> {
if let Some(error) = self.error {
return Err(crate::net::Error::InvalidMock(error));
}
let mut builder = http::response::Builder::default();
let mut response_body = None;
for part in self.respond_with {
builder = match part {
RespondWith::Status(status) => builder.status(status),
RespondWith::Header { name, value } => builder.header(name, value),
RespondWith::Body(body) => {
response_body.replace(body);
builder
}
}
}
let body = response_body.unwrap_or_default();
let response = builder.body(body)?;
self.net._when(Plan {
match_request: self.match_on,
response: response.into(),
});
Ok(())
}
}
pub struct WithOption<'net, T> {
req_builder: ReqBuilder<'net>,
option: Option<T>,
}
impl<'net, T> WithOption<'net, T> {
pub fn some(
self,
f_some: impl FnOnce(ReqBuilder<'net>, T) -> ReqBuilder<'net>,
) -> ReqBuilder<'net> {
match self.option {
Some(value) => f_some(self.req_builder, value),
None => self.req_builder,
}
}
pub fn none(
self,
f_none: impl FnOnce(ReqBuilder<'net>) -> ReqBuilder<'net>,
) -> ReqBuilder<'net> {
match self.option {
None => f_none(self.req_builder),
Some(_) => self.req_builder,
}
}
pub fn either(
self,
f_some: impl FnOnce(ReqBuilder<'net>, T) -> ReqBuilder<'net>,
f_none: impl FnOnce(ReqBuilder<'net>) -> ReqBuilder<'net>,
) -> ReqBuilder<'net> {
match self.option {
Some(value) => f_some(self.req_builder, value),
None => f_none(self.req_builder),
}
}
}
pub struct WithResult<'net, T, E> {
req_builder: ReqBuilder<'net>,
result: std::result::Result<T, E>,
}
impl<'net, T, E> WithResult<'net, T, E> {
pub fn ok(
self,
f_ok: impl FnOnce(ReqBuilder<'net>, T) -> ReqBuilder<'net>,
) -> ReqBuilder<'net> {
match self.result {
Ok(ok) => f_ok(self.req_builder, ok),
Err(_) => self.req_builder,
}
}
pub fn err(
self,
f_err: impl FnOnce(ReqBuilder<'net>, E) -> ReqBuilder<'net>,
) -> ReqBuilder<'net> {
match self.result {
Err(err) => f_err(self.req_builder, err),
Ok(_) => self.req_builder,
}
}
pub fn either(
self,
f_ok: impl FnOnce(ReqBuilder<'net>, T) -> ReqBuilder<'net>,
f_err: impl FnOnce(ReqBuilder<'net>, E) -> ReqBuilder<'net>,
) -> ReqBuilder<'net> {
match self.result {
Ok(ok) => f_ok(self.req_builder, ok),
Err(err) => f_err(self.req_builder, err),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn is_normal<T: Sized + Send + Sync + Unpin>() {}
#[test]
fn normal_types() {
is_normal::<Net>();
is_normal::<MatchRequest>();
is_normal::<Plan>();
}
#[test]
fn plan_display() {
let plan = Plan {
match_request: vec![
MatchRequest::Method(NetMethod::Put),
MatchRequest::Header {
name: "alpha".into(),
value: "1".into(),
},
MatchRequest::Body("req body".into()),
],
response: http::response::Builder::default()
.status(204)
.header("foo", "bar")
.header("baz", "buck")
.body("contents")
.expect("body")
.into(),
};
let result = plan.to_string();
let expected = [
"Put",
"(alpha: 1)",
"Body: b\"req body\"",
"=>",
"Response {",
"url: \"http://no.url.provided.local/\",",
"status: 204,",
"headers: {\"foo\": \"bar\", \"baz\": \"buck\"}",
"}\n",
]
.join(" ");
assert_eq!(result, expected);
}
}