use std::convert::TryFrom;
use std::error::Error as StdError;
use std::fmt;
use std::future::Future;
use std::net::SocketAddr;
#[cfg(feature = "websocket")]
use std::pin::Pin;
#[cfg(feature = "websocket")]
use std::task::{self, Poll};
use bytes::Bytes;
#[cfg(feature = "websocket")]
use futures::StreamExt;
use futures::{future, FutureExt, TryFutureExt};
use http::{
header::{HeaderName, HeaderValue},
Response,
};
use serde::Serialize;
use serde_json;
#[cfg(feature = "websocket")]
use tokio::sync::{mpsc, oneshot};
use crate::filter::Filter;
use crate::reject::IsReject;
use crate::reply::Reply;
use crate::route::{self, Route};
use crate::Request;
use self::inner::OneOrTuple;
pub fn request() -> RequestBuilder {
RequestBuilder {
remote_addr: None,
req: Request::default(),
}
}
#[cfg(feature = "websocket")]
pub fn ws() -> WsBuilder {
WsBuilder { req: request() }
}
#[must_use = "RequestBuilder does nothing on its own"]
#[derive(Debug)]
pub struct RequestBuilder {
remote_addr: Option<SocketAddr>,
req: Request,
}
#[cfg(feature = "websocket")]
#[must_use = "WsBuilder does nothing on its own"]
#[derive(Debug)]
pub struct WsBuilder {
req: RequestBuilder,
}
#[cfg(feature = "websocket")]
pub struct WsClient {
tx: mpsc::UnboundedSender<crate::ws::Message>,
rx: mpsc::UnboundedReceiver<Result<crate::ws::Message, crate::error::Error>>,
}
#[derive(Debug)]
pub struct WsError {
cause: Box<dyn StdError + Send + Sync>,
}
impl RequestBuilder {
pub fn method(mut self, method: &str) -> Self {
*self.req.method_mut() = method.parse().expect("valid method");
self
}
pub fn path(mut self, p: &str) -> Self {
let uri = p.parse().expect("test request path invalid");
*self.req.uri_mut() = uri;
self
}
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
HeaderValue: TryFrom<V>,
{
let name: HeaderName = TryFrom::try_from(key)
.map_err(|_| ())
.expect("invalid header name");
let value = TryFrom::try_from(value)
.map_err(|_| ())
.expect("invalid header value");
self.req.headers_mut().insert(name, value);
self
}
pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
self.remote_addr = Some(addr);
self
}
pub fn extension<T>(mut self, ext: T) -> Self
where
T: Send + Sync + 'static,
{
self.req.extensions_mut().insert(ext);
self
}
pub fn body(mut self, body: impl AsRef<[u8]>) -> Self {
let body = body.as_ref().to_vec();
let len = body.len();
*self.req.body_mut() = body.into();
self.header("content-length", len.to_string())
}
pub fn json(mut self, val: &impl Serialize) -> Self {
let vec = serde_json::to_vec(val).expect("json() must serialize to JSON");
let len = vec.len();
*self.req.body_mut() = vec.into();
self.header("content-length", len.to_string())
.header("content-type", "application/json")
}
pub async fn filter<F>(self, f: &F) -> Result<<F::Extract as OneOrTuple>::Output, F::Error>
where
F: Filter,
F::Future: Send + 'static,
F::Extract: OneOrTuple + Send + 'static,
F::Error: Send + 'static,
{
self.apply_filter(f).await.map(|ex| ex.one_or_tuple())
}
pub async fn matches<F>(self, f: &F) -> bool
where
F: Filter,
F::Future: Send + 'static,
F::Extract: Send + 'static,
F::Error: Send + 'static,
{
self.apply_filter(f).await.is_ok()
}
pub async fn reply<F>(self, f: &F) -> Response<Bytes>
where
F: Filter + 'static,
F::Extract: Reply + Send,
F::Error: IsReject + Send,
{
assert!(!route::is_set(), "nested test filter calls");
let route = Route::new(self.req, self.remote_addr);
let mut fut = Box::pin(
route::set(&route, move || f.filter(crate::filter::Internal)).then(|result| {
let res = match result {
Ok(rep) => rep.into_response(),
Err(rej) => {
log::debug!("rejected: {:?}", rej);
rej.into_response()
}
};
let (parts, body) = res.into_parts();
hyper::body::to_bytes(body)
.map_ok(|chunk| Response::from_parts(parts, chunk.into()))
}),
);
let fut = future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx)));
fut.await.expect("reply shouldn't fail")
}
fn apply_filter<F>(self, f: &F) -> impl Future<Output = Result<F::Extract, F::Error>>
where
F: Filter,
F::Future: Send + 'static,
F::Extract: Send + 'static,
F::Error: Send + 'static,
{
assert!(!route::is_set(), "nested test filter calls");
let route = Route::new(self.req, self.remote_addr);
let mut fut = Box::pin(route::set(&route, move || {
f.filter(crate::filter::Internal)
}));
future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx)))
}
}
#[cfg(feature = "websocket")]
impl WsBuilder {
pub fn path(self, p: &str) -> Self {
WsBuilder {
req: self.req.path(p),
}
}
pub fn header<K, V>(self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
HeaderValue: TryFrom<V>,
{
WsBuilder {
req: self.req.header(key, value),
}
}
pub async fn handshake<F>(self, f: F) -> Result<WsClient, WsError>
where
F: Filter + Clone + Send + Sync + 'static,
F::Extract: Reply + Send,
F::Error: IsReject + Send,
{
let (upgraded_tx, upgraded_rx) = oneshot::channel();
let (wr_tx, wr_rx) = mpsc::unbounded_channel();
let (rd_tx, rd_rx) = mpsc::unbounded_channel();
tokio::spawn(async move {
use tokio_tungstenite::tungstenite::protocol;
let (addr, srv) = crate::serve(f).bind_ephemeral(([127, 0, 0, 1], 0));
let mut req = self
.req
.header("connection", "upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
.req;
let uri = format!("http://{}{}", addr, req.uri().path())
.parse()
.expect("addr + path is valid URI");
*req.uri_mut() = uri;
tokio::spawn(srv);
let upgrade = ::hyper::Client::builder()
.build(AddrConnect(addr))
.request(req)
.and_then(|res| res.into_body().on_upgrade());
let upgraded = match upgrade.await {
Ok(up) => {
let _ = upgraded_tx.send(Ok(()));
up
}
Err(err) => {
let _ = upgraded_tx.send(Err(err));
return;
}
};
let ws = crate::ws::WebSocket::from_raw_socket(
upgraded,
protocol::Role::Client,
Default::default(),
)
.await;
let (tx, rx) = ws.split();
let write = wr_rx.map(Ok).forward(tx).map(|_| ());
let read = rx
.take_while(|result| match result {
Err(_) => future::ready(false),
Ok(m) => future::ready(!m.is_close()),
})
.for_each(move |item| {
rd_tx.send(item).expect("ws receive error");
future::ready(())
});
future::join(write, read).await;
});
match upgraded_rx.await {
Ok(Ok(())) => Ok(WsClient {
tx: wr_tx,
rx: rd_rx,
}),
Ok(Err(err)) => Err(WsError::new(err)),
Err(_canceled) => panic!("websocket handshake thread panicked"),
}
}
}
#[cfg(feature = "websocket")]
impl WsClient {
pub async fn send_text(&mut self, text: impl Into<String>) {
self.send(crate::ws::Message::text(text)).await;
}
pub async fn send(&mut self, msg: crate::ws::Message) {
self.tx.send(msg).unwrap();
}
pub async fn recv(&mut self) -> Result<crate::filters::ws::Message, WsError> {
self.rx
.next()
.await
.map(|unbounded_result| unbounded_result.map_err(WsError::new))
.unwrap_or_else(|| {
Err(WsError::new("closed"))
})
}
pub async fn recv_closed(&mut self) -> Result<(), WsError> {
self.rx
.next()
.await
.map(|result| match result {
Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))),
Err(err) => Err(WsError::new(err)),
})
.unwrap_or_else(|| {
Ok(())
})
}
}
#[cfg(feature = "websocket")]
impl fmt::Debug for WsClient {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("WsClient").finish()
}
}
#[cfg(feature = "websocket")]
impl WsError {
fn new<E: Into<Box<dyn StdError + Send + Sync>>>(cause: E) -> Self {
WsError {
cause: cause.into(),
}
}
}
impl fmt::Display for WsError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "websocket error: {}", self.cause)
}
}
impl StdError for WsError {
fn description(&self) -> &str {
"websocket error"
}
}
#[cfg(feature = "websocket")]
#[derive(Clone)]
struct AddrConnect(SocketAddr);
#[cfg(feature = "websocket")]
impl tower_service::Service<::http::Uri> for AddrConnect {
type Response = ::tokio::net::TcpStream;
type Error = ::std::io::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: ::http::Uri) -> Self::Future {
Box::pin(tokio::net::TcpStream::connect(self.0))
}
}
mod inner {
pub trait OneOrTuple {
type Output;
fn one_or_tuple(self) -> Self::Output;
}
impl OneOrTuple for () {
type Output = ();
fn one_or_tuple(self) -> Self::Output {}
}
macro_rules! one_or_tuple {
($type1:ident) => {
impl<$type1> OneOrTuple for ($type1,) {
type Output = $type1;
fn one_or_tuple(self) -> Self::Output {
self.0
}
}
};
($type1:ident, $( $type:ident ),*) => {
one_or_tuple!($( $type ),*);
impl<$type1, $($type),*> OneOrTuple for ($type1, $($type),*) {
type Output = Self;
fn one_or_tuple(self) -> Self::Output {
self
}
}
}
}
one_or_tuple! {
T1,
T2,
T3,
T4,
T5,
T6,
T7,
T8,
T9,
T10,
T11,
T12,
T13,
T14,
T15,
T16
}
}