use std::convert::TryFrom;
use std::str::FromStr;
use std::sync::mpsc;
use std::{io, net, thread, time};
use bytes::Bytes;
use futures::Stream;
#[cfg(feature = "cookie")]
use coo_kie::{Cookie, CookieJar};
use crate::codec::{AsyncRead, AsyncWrite, Framed};
use crate::rt::{net::TcpStream, System};
use crate::server::{Server, StreamServiceFactory};
use super::client::error::WsClientError;
use super::client::{Client, ClientRequest, ClientResponse, Connector};
use super::error::{HttpError, PayloadError};
use super::header::{HeaderMap, HeaderName, IntoHeaderValue};
use super::payload::Payload;
use super::{Method, Request, Uri, Version};
pub struct TestRequest(Option<Inner>);
struct Inner {
version: Version,
method: Method,
uri: Uri,
headers: HeaderMap,
#[cfg(feature = "cookie")]
cookies: CookieJar,
payload: Option<Payload>,
}
impl Default for TestRequest {
fn default() -> TestRequest {
TestRequest(Some(Inner {
method: Method::GET,
uri: Uri::from_str("/").unwrap(),
version: Version::HTTP_11,
headers: HeaderMap::new(),
#[cfg(feature = "cookie")]
cookies: CookieJar::new(),
payload: None,
}))
}
}
impl TestRequest {
pub fn with_uri(path: &str) -> TestRequest {
TestRequest::default().uri(path).take()
}
pub fn with_header<K, V>(key: K, value: V) -> TestRequest
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
V: IntoHeaderValue,
{
TestRequest::default().header(key, value).take()
}
pub fn version(&mut self, ver: Version) -> &mut Self {
parts(&mut self.0).version = ver;
self
}
pub fn method(&mut self, meth: Method) -> &mut Self {
parts(&mut self.0).method = meth;
self
}
pub fn uri(&mut self, path: &str) -> &mut Self {
parts(&mut self.0).uri = Uri::from_str(path).unwrap();
self
}
pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
V: IntoHeaderValue,
{
if let Ok(key) = HeaderName::try_from(key) {
if let Ok(value) = value.try_into() {
parts(&mut self.0).headers.append(key, value);
return self;
}
}
panic!("Can not create header");
}
#[cfg(feature = "cookie")]
pub fn cookie<'a>(&mut self, cookie: Cookie<'a>) -> &mut Self {
parts(&mut self.0).cookies.add(cookie.into_owned());
self
}
pub fn set_payload<B: Into<Bytes>>(&mut self, data: B) -> &mut Self {
let mut payload = crate::http::h1::Payload::empty();
payload.unread_data(data.into());
parts(&mut self.0).payload = Some(payload.into());
self
}
pub fn take(&mut self) -> TestRequest {
TestRequest(self.0.take())
}
pub fn finish(&mut self) -> Request {
let inner = self.0.take().expect("cannot reuse test request builder");
let mut req = if let Some(pl) = inner.payload {
Request::with_payload(pl)
} else {
Request::with_payload(crate::http::h1::Payload::empty().into())
};
let head = req.head_mut();
head.uri = inner.uri;
head.method = inner.method;
head.version = inner.version;
head.headers = inner.headers;
#[cfg(feature = "cookie")]
{
use crate::http::header::HeaderValue;
use percent_encoding::percent_encode;
use std::fmt::Write as FmtWrite;
let mut cookie = String::new();
for c in inner.cookies.delta() {
let name = percent_encode(c.name().as_bytes(), super::helpers::USERINFO);
let value =
percent_encode(c.value().as_bytes(), super::helpers::USERINFO);
let _ = write!(&mut cookie, "; {}={}", name, value);
}
if !cookie.is_empty() {
head.headers.insert(
super::header::COOKIE,
HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(),
);
}
}
req
}
}
#[inline]
fn parts(parts: &mut Option<Inner>) -> &mut Inner {
parts.as_mut().expect("cannot reuse test request builder")
}
pub fn server<F: StreamServiceFactory<TcpStream>>(factory: F) -> TestServer {
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let mut sys = System::new("test-server");
let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap();
let local_addr = tcp.local_addr().unwrap();
sys.exec(|| {
Server::build()
.listen("test", tcp, factory)?
.workers(1)
.disable_signals()
.start();
Ok::<_, io::Error>(())
})?;
tx.send((System::current(), local_addr)).unwrap();
sys.run()
});
let (system, addr) = rx.recv().unwrap();
let client = {
let connector = {
#[cfg(feature = "openssl")]
{
use open_ssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_verify(SslVerifyMode::NONE);
let _ = builder
.set_alpn_protos(b"\x02h2\x08http/1.1")
.map_err(|e| log::error!("Can not set alpn protocol: {:?}", e));
Connector::default()
.timeout(time::Duration::from_millis(30000))
.openssl(builder.build())
.finish()
}
#[cfg(not(feature = "openssl"))]
{
Connector::default()
.timeout(time::Duration::from_millis(30000))
.finish()
}
};
Client::build()
.timeout(time::Duration::from_millis(30000))
.connector(connector)
.finish()
};
TestServer {
addr,
client,
system,
}
}
pub struct TestServer {
addr: net::SocketAddr,
client: Client,
system: System,
}
impl TestServer {
pub fn addr(&self) -> net::SocketAddr {
self.addr
}
pub fn url(&self, uri: &str) -> String {
if uri.starts_with('/') {
format!("http://localhost:{}{}", self.addr.port(), uri)
} else {
format!("http://localhost:{}/{}", self.addr.port(), uri)
}
}
pub fn surl(&self, uri: &str) -> String {
if uri.starts_with('/') {
format!("https://localhost:{}{}", self.addr.port(), uri)
} else {
format!("https://localhost:{}/{}", self.addr.port(), uri)
}
}
pub fn request<S: AsRef<str>>(&self, method: Method, path: S) -> ClientRequest {
self.client
.request(method, self.url(path.as_ref()).as_str())
}
pub fn srequest<S: AsRef<str>>(&self, method: Method, path: S) -> ClientRequest {
self.client
.request(method, self.surl(path.as_ref()).as_str())
}
pub async fn load_body<S>(
&mut self,
mut response: ClientResponse<S>,
) -> Result<Bytes, PayloadError>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin + 'static,
{
response.body().limit(10_485_760).await
}
pub async fn ws_at(
&mut self,
path: &str,
) -> Result<Framed<impl AsyncRead + AsyncWrite, crate::ws::Codec>, WsClientError>
{
let url = self.url(path);
let connect = self.client.ws(url).connect();
connect.await.map(|(_, framed)| framed)
}
pub async fn ws(
&mut self,
) -> Result<Framed<impl AsyncRead + AsyncWrite, crate::ws::Codec>, WsClientError>
{
self.ws_at("/").await
}
fn stop(&mut self) {
self.system.stop();
}
}
impl Drop for TestServer {
fn drop(&mut self) {
self.stop()
}
}