use crate::response::{Body, Chunked};
use crate::{Mock, Request, SERVER_ADDRESS_INTERNAL};
use std::fmt::Display;
use std::io;
use std::io::Write;
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::mpsc;
use std::sync::Mutex;
use std::thread;
impl Mock {
fn method_matches(&self, request: &Request) -> bool {
self.method == request.method
}
fn path_matches(&self, request: &Request) -> bool {
self.path.matches_value(&request.path)
}
fn headers_match(&self, request: &Request) -> bool {
self.headers.iter().all(|&(ref field, ref expected)| {
expected.matches_values(&request.find_header_values(field))
})
}
fn body_matches(&self, request: &Request) -> bool {
let raw_body = &request.body;
let safe_body = &String::from_utf8_lossy(raw_body);
self.body.matches_value(safe_body) || self.body.matches_binary_value(raw_body)
}
#[allow(clippy::missing_const_for_fn)]
fn is_missing_hits(&self) -> bool {
match (self.expected_hits_at_least, self.expected_hits_at_most) {
(Some(_at_least), Some(at_most)) => self.hits < at_most,
(Some(at_least), None) => self.hits < at_least,
(None, Some(at_most)) => self.hits < at_most,
(None, None) => self.hits < 1,
}
}
}
impl<'a> PartialEq<Request> for &'a mut Mock {
fn eq(&self, other: &Request) -> bool {
self.method_matches(other)
&& self.path_matches(other)
&& self.headers_match(other)
&& self.body_matches(other)
}
}
pub struct State {
pub listening_addr: Option<SocketAddr>,
pub mocks: Vec<Mock>,
pub unmatched_requests: Vec<Request>,
}
impl State {
#[allow(clippy::missing_const_for_fn)]
fn new() -> Self {
Self {
listening_addr: None,
mocks: Vec::new(),
unmatched_requests: Vec::new(),
}
}
}
lazy_static! {
pub static ref STATE: Mutex<State> = Mutex::new(State::new());
}
pub fn address() -> SocketAddr {
try_start();
let state = STATE.lock().map(|state| state.listening_addr);
state
.expect("state lock")
.expect("server should be listening")
}
pub fn url() -> String {
format!("http://{}", address())
}
pub fn try_start() {
let mut state = STATE.lock().unwrap();
if state.listening_addr.is_some() {
return;
}
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let res = TcpListener::bind(SERVER_ADDRESS_INTERNAL).or_else(|err| {
warn!("{}", err);
TcpListener::bind("127.0.0.1:0")
});
let (listener, addr) = match res {
Ok(listener) => {
let addr = listener.local_addr().unwrap();
tx.send(Some(addr)).unwrap();
(listener, addr)
}
Err(err) => {
error!("{}", err);
tx.send(None).unwrap();
return;
}
};
debug!("Server is listening at {}", addr);
for stream in listener.incoming() {
if let Ok(stream) = stream {
let request = Request::from(&stream);
debug!("Request received: {}", request);
if request.is_ok() {
handle_request(request, stream);
} else {
let message = request
.error()
.map_or("Could not parse the request.", |err| err.as_str());
debug!("Could not parse request because: {}", message);
respond_with_error(stream, request.version, message);
}
} else {
debug!("Could not read from stream");
}
}
});
state.listening_addr = rx.recv().ok().and_then(|addr| addr);
}
fn handle_request(request: Request, stream: TcpStream) {
handle_match_mock(request, stream);
}
fn handle_match_mock(request: Request, stream: TcpStream) {
let mut state = STATE.lock().unwrap();
let mut matchings_mocks = state
.mocks
.iter_mut()
.filter(|mock| mock == &request)
.collect::<Vec<_>>();
let maybe_missing_hits = matchings_mocks.iter_mut().find(|m| m.is_missing_hits());
let mock = match maybe_missing_hits {
Some(m) => Some(m),
None => matchings_mocks.last_mut(),
};
if let Some(mock) = mock {
debug!("Mock found");
mock.hits += 1;
respond_with_mock(stream, request.version, mock, request.is_head());
} else {
debug!("Mock not found");
respond_with_mock_not_found(stream, request.version);
state.unmatched_requests.push(request);
}
}
fn respond(
stream: TcpStream,
version: (u8, u8),
status: impl Display,
headers: Option<&Vec<(String, String)>>,
body: Option<&str>,
) {
let body = body.map(|s| Body::Bytes(s.as_bytes().to_owned()));
if let Err(e) = respond_bytes(stream, version, status, headers, body.as_ref()) {
eprintln!("warning: Mock response write error: {}", e);
}
}
fn respond_bytes(
mut stream: TcpStream,
version: (u8, u8),
status: impl Display,
headers: Option<&Vec<(String, String)>>,
body: Option<&Body>,
) -> io::Result<()> {
let mut response = Vec::from(format!("HTTP/{}.{} {}\r\n", version.0, version.1, status));
let mut has_content_length_header = false;
if let Some(headers) = headers {
for &(ref key, ref value) in headers {
response.extend(key.as_bytes());
response.extend(b": ");
response.extend(value.as_bytes());
response.extend(b"\r\n");
}
has_content_length_header = headers.iter().any(|(key, _)| key == "content-length");
}
match body {
Some(Body::Bytes(bytes)) => {
if !has_content_length_header {
response.extend(format!("content-length: {}\r\n", bytes.len()).as_bytes());
}
}
Some(Body::Fn(_)) => {
response.extend(b"transfer-encoding: chunked\r\n");
}
None => {}
};
response.extend(b"\r\n");
stream.write_all(&response)?;
match body {
Some(Body::Bytes(bytes)) => {
stream.write_all(bytes)?;
}
Some(Body::Fn(cb)) => {
let mut chunked = Chunked::new(&mut stream);
cb(&mut chunked)?;
chunked.finish()?;
}
None => {}
};
stream.flush()
}
fn respond_with_mock(stream: TcpStream, version: (u8, u8), mock: &Mock, skip_body: bool) {
let body = if skip_body {
None
} else {
Some(&mock.response.body)
};
if let Err(e) = respond_bytes(
stream,
version,
&mock.response.status,
Some(&mock.response.headers),
body,
) {
eprintln!("warning: Mock response write error: {}", e);
}
}
fn respond_with_mock_not_found(stream: TcpStream, version: (u8, u8)) {
respond(
stream,
version,
"501 Mock Not Found",
Some(&vec![("content-length".into(), "0".into())]),
None,
);
}
fn respond_with_error(stream: TcpStream, version: (u8, u8), message: &str) {
respond(stream, version, "422 Mock Error", None, Some(message));
}