use std::error::Error as StdError;
use std::fmt;
use std::net::SocketAddr;
use std::thread;
use bytes::Bytes;
use futures::{
future,
sync::{mpsc, oneshot},
Future, Sink, Stream,
};
use http::{
header::{HeaderName, HeaderValue},
HttpTryFrom, Response,
};
use serde::Serialize;
use serde_json;
use tokio::runtime::{Builder as RtBuilder, Runtime};
use filter::Filter;
use reject::Reject;
use reply::{Reply, ReplySealed};
use route::{self, Route};
use Request;
use self::inner::OneOrTuple;
pub fn request() -> RequestBuilder {
RequestBuilder {
remote_addr: None,
req: Request::default(),
}
}
pub fn ws() -> WsBuilder {
WsBuilder { req: request() }
}
#[must_use = "RequestBuilder does nothing on its own"]
#[derive(Debug)]
pub struct RequestBuilder {
remote_addr: Option<SocketAddr>,
req: Request,
}
#[must_use = "WsBuilder does nothing on its own"]
#[derive(Debug)]
pub struct WsBuilder {
req: RequestBuilder,
}
pub struct WsClient {
tx: mpsc::UnboundedSender<::ws::Message>,
rx: ::futures::stream::Wait<mpsc::UnboundedReceiver<Result<::ws::Message, ::Error>>>,
}
#[derive(Debug)]
pub struct WsError {
cause: Box<StdError + Send + Sync>,
}
impl RequestBuilder {
pub fn method(mut self, method: &str) -> Self {
*self.req.method_mut() = method.parse().expect("valid method");
self
}
pub fn path(mut self, p: &str) -> Self {
let uri = p.parse().expect("test request path invalid");
*self.req.uri_mut() = uri;
self
}
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: HttpTryFrom<K>,
HeaderValue: HttpTryFrom<V>,
{
let name: HeaderName = HttpTryFrom::try_from(key)
.map_err(|_| ())
.expect("invalid header name");
let value = HttpTryFrom::try_from(value)
.map_err(|_| ())
.expect("invalid header value");
self.req.headers_mut().insert(name, value);
self
}
pub fn body(mut self, body: impl AsRef<[u8]>) -> Self {
let body = body.as_ref().to_vec();
*self.req.body_mut() = body.into();
self
}
pub fn json(mut self, val: &impl Serialize) -> Self {
let vec = serde_json::to_vec(val).expect("json() must serialize to JSON");
*self.req.body_mut() = vec.into();
self
}
pub fn filter<F>(self, f: &F) -> Result<<F::Extract as OneOrTuple>::Output, F::Error>
where
F: Filter,
F::Future: Send + 'static,
F::Extract: OneOrTuple + Send + 'static,
F::Error: Send + 'static,
{
self.apply_filter(f).map(|ex| ex.one_or_tuple())
}
pub fn matches<F>(self, f: &F) -> bool
where
F: Filter,
F::Future: Send + 'static,
F::Extract: Send + 'static,
F::Error: Send + 'static,
{
self.apply_filter(f).is_ok()
}
pub fn reply<F>(self, f: &F) -> Response<Bytes>
where
F: Filter + 'static,
F::Extract: Reply + Send,
F::Error: Reject + Send,
{
assert!(!route::is_set(), "nested test filter calls");
let route = Route::new(self.req, self.remote_addr);
let mut fut = route::set(&route, move || f.filter())
.map(|rep| rep.into_response())
.or_else(|rej| {
debug!("rejected: {:?}", rej);
Ok(rej.into_response())
})
.and_then(|res| {
let (parts, body) = res.into_parts();
body.concat2()
.map(|chunk| Response::from_parts(parts, chunk.into()))
});
let fut = future::poll_fn(move || route::set(&route, || fut.poll()));
block_on(fut).expect("reply shouldn't fail")
}
fn apply_filter<F>(self, f: &F) -> Result<F::Extract, F::Error>
where
F: Filter,
F::Future: Send + 'static,
F::Extract: Send + 'static,
F::Error: Send + 'static,
{
assert!(!route::is_set(), "nested test filter calls");
let route = Route::new(self.req, self.remote_addr);
let mut fut = route::set(&route, move || f.filter());
let fut = future::poll_fn(move || route::set(&route, || fut.poll()));
block_on(fut)
}
}
impl WsBuilder {
pub fn path(self, p: &str) -> Self {
WsBuilder {
req: self.req.path(p),
}
}
pub fn header<K, V>(self, key: K, value: V) -> Self
where
HeaderName: HttpTryFrom<K>,
HeaderValue: HttpTryFrom<V>,
{
WsBuilder {
req: self.req.header(key, value),
}
}
pub fn handshake<F>(self, f: F) -> Result<WsClient, WsError>
where
F: Filter + Send + Sync + 'static,
F::Extract: Reply + Send,
F::Error: Reject + Send,
{
let (upgraded_tx, upgraded_rx) = oneshot::channel();
let (wr_tx, wr_rx) = mpsc::unbounded();
let (rd_tx, rd_rx) = mpsc::unbounded();
let test_thread = ::std::thread::current();
let test_name = test_thread.name().unwrap_or("<unknown>");
thread::Builder::new()
.name(test_name.into())
.spawn(move || {
use tungstenite::protocol;
let (addr, srv) = ::serve(f).bind_ephemeral(([127, 0, 0, 1], 0));
let srv = srv.map_err(|err| panic!("server error: {:?}", err));
let mut req = self
.req
.header("connection", "upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
.req;
let uri = format!("http://{}{}", addr, req.uri().path())
.parse()
.expect("addr + path is valid URI");
*req.uri_mut() = uri;
let mut rt = new_rt();
rt.spawn(srv);
let upgrade = ::hyper::Client::builder()
.build(AddrConnect(addr))
.request(req)
.and_then(|res| res.into_body().on_upgrade());
let upgraded = match rt.block_on(upgrade) {
Ok(up) => {
let _ = upgraded_tx.send(Ok(()));
up
}
Err(err) => {
let _ = upgraded_tx.send(Err(err));
return;
}
};
let io = protocol::WebSocket::from_raw_socket(
upgraded,
protocol::Role::Client,
Default::default(),
);
let (tx, rx) = ::ws::WebSocket::new(io).split();
let write = wr_rx
.map_err(|()| {
unreachable!("mpsc::Receiver doesn't error");
})
.forward(tx.sink_map_err(|_| ()))
.map(|_| ());
let read = rx
.then(|result| Ok(result))
.forward(rd_tx.sink_map_err(|_| ()))
.map(|_| ());
rt.block_on(write.join(read)).expect("websocket forward");
})
.expect("websocket handshake thread");
match upgraded_rx.wait() {
Ok(Ok(())) => Ok(WsClient {
tx: wr_tx,
rx: rd_rx.wait(),
}),
Ok(Err(err)) => Err(WsError::new(err)),
Err(_canceled) => panic!("websocket handshake thread panicked"),
}
}
}
impl WsClient {
pub fn send_text(&mut self, text: impl Into<String>) {
self.send(::ws::Message::text(text));
}
pub fn send(&mut self, msg: ::ws::Message) {
self.tx.unbounded_send(msg).unwrap();
}
pub fn recv(&mut self) -> Result<::filters::ws::Message, WsError> {
self.rx
.next()
.map(|unbounded_result| {
unbounded_result
.map(|result| result.map_err(WsError::new))
.unwrap_or_else(|_| {
unreachable!("mpsc Receiver never errors");
})
})
.unwrap_or_else(|| {
Err(WsError::new("closed"))
})
}
pub fn recv_closed(&mut self) -> Result<(), WsError> {
self.rx
.next()
.map(|unbounded_result| {
unbounded_result.unwrap_or_else(|_| {
unreachable!("mpsc Receiver never errors");
})
})
.map(|result| match result {
Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))),
Err(err) => Err(WsError::new(err)),
})
.unwrap_or_else(|| {
Ok(())
})
}
}
impl fmt::Debug for WsClient {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("WsClient").finish()
}
}
impl WsError {
fn new<E: Into<Box<StdError + Send + Sync>>>(cause: E) -> Self {
WsError {
cause: cause.into(),
}
}
}
impl fmt::Display for WsError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "websocket error: {}", self.cause)
}
}
impl StdError for WsError {
fn description(&self) -> &str {
"websocket error"
}
}
struct AddrConnect(SocketAddr);
impl ::hyper::client::connect::Connect for AddrConnect {
type Transport = ::tokio::net::tcp::TcpStream;
type Error = ::std::io::Error;
type Future = ::futures::future::Map<
::tokio::net::tcp::ConnectFuture,
fn(Self::Transport) -> (Self::Transport, ::hyper::client::connect::Connected),
>;
fn connect(&self, _: ::hyper::client::connect::Destination) -> Self::Future {
::tokio::net::tcp::TcpStream::connect(&self.0)
.map(|sock| (sock, ::hyper::client::connect::Connected::new()))
}
}
fn new_rt() -> Runtime {
let test_thread = ::std::thread::current();
let test_name = test_thread.name().unwrap_or("<unknown>");
let rt_name_prefix = format!("test {}; warp-test-runtime-", test_name);
RtBuilder::new()
.core_threads(1)
.blocking_threads(1)
.name_prefix(rt_name_prefix)
.build()
.expect("new rt")
}
fn block_on<F>(fut: F) -> Result<F::Item, F::Error>
where
F: Future + Send + 'static,
F::Item: Send + 'static,
F::Error: Send + 'static,
{
let mut rt = new_rt();
rt.block_on(fut)
}
mod inner {
pub trait OneOrTuple {
type Output;
fn one_or_tuple(self) -> Self::Output;
}
impl OneOrTuple for () {
type Output = ();
fn one_or_tuple(self) -> Self::Output {
()
}
}
macro_rules! one_or_tuple {
($type1:ident) => {
impl<$type1> OneOrTuple for ($type1,) {
type Output = $type1;
fn one_or_tuple(self) -> Self::Output {
self.0
}
}
};
($type1:ident, $( $type:ident ),*) => {
one_or_tuple!($( $type ),*);
impl<$type1, $($type),*> OneOrTuple for ($type1, $($type),*) {
type Output = Self;
fn one_or_tuple(self) -> Self::Output {
self
}
}
}
}
one_or_tuple! {
T1,
T2,
T3,
T4,
T5,
T6,
T7,
T8,
T9,
T10,
T11,
T12,
T13,
T14,
T15,
T16
}
}