use std::{net, str::FromStr, sync::mpsc, thread};
#[cfg(feature = "cookie")]
use coo_kie::{Cookie, CookieJar};
use crate::channel::bstream;
#[cfg(feature = "ws")]
use crate::io::Filter;
use crate::io::Io;
use crate::server::Server;
use crate::service::ServiceFactory;
#[cfg(feature = "ws")]
use crate::ws::{error::WsClientError, WsClient, WsConnection};
use crate::{rt::System, time::sleep, time::Millis, time::Seconds, util::Bytes};
use super::client::{Client, ClientRequest, ClientResponse, Connector};
use super::error::{HttpError, PayloadError};
use super::header::{self, HeaderMap, HeaderName, HeaderValue};
use super::payload::Payload;
use super::{Method, Request, Uri, Version};
#[derive(Debug)]
pub struct TestRequest(Option<Inner>);
#[derive(Debug)]
struct Inner {
version: Version,
method: Method,
uri: Uri,
headers: HeaderMap,
#[cfg(feature = "cookie")]
cookies: CookieJar,
payload: Option<Payload>,
}
impl Default for TestRequest {
fn default() -> TestRequest {
TestRequest(Some(Inner {
method: Method::GET,
uri: Uri::from_str("/").unwrap(),
version: Version::HTTP_11,
headers: HeaderMap::new(),
#[cfg(feature = "cookie")]
cookies: CookieJar::new(),
payload: None,
}))
}
}
impl TestRequest {
pub fn with_uri(path: &str) -> TestRequest {
TestRequest::default().uri(path).take()
}
pub fn with_header<K, V>(key: K, value: V) -> TestRequest
where
HeaderName: TryFrom<K>,
HeaderValue: TryFrom<V>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
{
TestRequest::default().header(key, value).take()
}
pub fn version(&mut self, ver: Version) -> &mut Self {
parts(&mut self.0).version = ver;
self
}
pub fn method(&mut self, meth: Method) -> &mut Self {
parts(&mut self.0).method = meth;
self
}
pub fn uri(&mut self, path: &str) -> &mut Self {
parts(&mut self.0).uri = Uri::from_str(path).unwrap();
self
}
pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
where
HeaderName: TryFrom<K>,
HeaderValue: TryFrom<V>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
{
if let Ok(key) = HeaderName::try_from(key) {
if let Ok(value) = HeaderValue::try_from(value) {
parts(&mut self.0).headers.append(key, value);
return self;
}
}
panic!("Cannot create header");
}
#[cfg(feature = "cookie")]
pub fn cookie<C>(&mut self, cookie: C) -> &mut Self
where
C: Into<Cookie<'static>>,
{
parts(&mut self.0).cookies.add(cookie.into());
self
}
pub fn set_payload<B: Into<Bytes>>(&mut self, data: B) -> &mut Self {
let payload = bstream::empty(Some(data.into()));
parts(&mut self.0).payload = Some(payload.into());
self
}
pub fn take(&mut self) -> TestRequest {
TestRequest(self.0.take())
}
pub fn finish(&mut self) -> Request {
let inner = self.0.take().expect("cannot reuse test request builder");
let mut req = if let Some(pl) = inner.payload {
Request::with_payload(pl)
} else {
Request::with_payload(bstream::empty(None).into())
};
let head = req.head_mut();
head.uri = inner.uri;
head.method = inner.method;
head.version = inner.version;
head.headers = inner.headers;
if let Some(conn) = head.headers.get(header::CONNECTION) {
if let Ok(s) = conn.to_str() {
if s.to_lowercase().contains("upgrade") {
head.set_upgrade()
}
}
}
#[cfg(feature = "cookie")]
{
use percent_encoding::percent_encode;
use std::fmt::Write as FmtWrite;
let mut cookie = String::new();
for c in inner.cookies.delta() {
let name = percent_encode(c.name().as_bytes(), super::helpers::USERINFO);
let value = percent_encode(c.value().as_bytes(), super::helpers::USERINFO);
let _ = write!(cookie, "; {name}={value}");
}
if !cookie.is_empty() {
head.headers.insert(
super::header::COOKIE,
HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(),
);
}
}
req
}
}
#[inline]
fn parts(parts: &mut Option<Inner>) -> &mut Inner {
parts.as_mut().expect("cannot reuse test request builder")
}
pub fn server<F, R>(factory: F) -> TestServer
where
F: Fn() -> R + Send + Clone + 'static,
R: ServiceFactory<Io> + 'static,
{
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let sys = System::new("test-server");
let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap();
let local_addr = tcp.local_addr().unwrap();
let system = sys.system();
sys.run(move || {
let srv = crate::server::build()
.listen("test", tcp, move |_| factory())?
.set_tag("test", "HTTP-TEST-SRV")
.workers(1)
.disable_signals()
.run();
crate::rt::spawn(async move {
sleep(Millis(125)).await;
tx.send((system, srv, local_addr)).unwrap();
});
Ok(())
})
});
let (system, server, addr) = rx.recv().unwrap();
TestServer {
addr,
system,
server,
client: Client::build().finish(),
}
.set_client_timeout(Seconds(90), Millis(90_000))
}
#[derive(Debug)]
pub struct TestServer {
addr: net::SocketAddr,
client: Client,
system: System,
server: Server,
}
impl TestServer {
pub fn set_client_timeout(mut self, timeout: Seconds, connect_timeout: Millis) -> Self {
let client = {
let connector = {
#[cfg(feature = "openssl")]
{
use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_verify(SslVerifyMode::NONE);
let _ = builder
.set_alpn_protos(b"\x02h2\x08http/1.1")
.map_err(|e| log::error!("Cannot set alpn protocol: {e:?}"));
Connector::default()
.timeout(connect_timeout)
.openssl(builder.build())
.configure_http2(|cfg| {
cfg.max_header_list_size(256 * 1024);
cfg.max_header_continuation_frames(96);
})
.finish()
}
#[cfg(not(feature = "openssl"))]
{
Connector::default().timeout(connect_timeout).finish()
}
};
Client::build()
.timeout(timeout)
.connector(connector)
.finish()
};
self.client = client;
self
}
pub fn addr(&self) -> net::SocketAddr {
self.addr
}
pub fn url(&self, uri: &str) -> String {
if uri.starts_with('/') {
format!("http://localhost:{}{}", self.addr.port(), uri)
} else {
format!("http://localhost:{}/{}", self.addr.port(), uri)
}
}
pub fn surl(&self, uri: &str) -> String {
if uri.starts_with('/') {
format!("https://localhost:{}{}", self.addr.port(), uri)
} else {
format!("https://localhost:{}/{}", self.addr.port(), uri)
}
}
pub fn request<S: AsRef<str>>(&self, method: Method, path: S) -> ClientRequest {
self.client
.request(method, self.url(path.as_ref()).as_str())
}
pub fn srequest<S: AsRef<str>>(&self, method: Method, path: S) -> ClientRequest {
self.client
.request(method, self.surl(path.as_ref()).as_str())
}
pub async fn load_body(
&mut self,
mut response: ClientResponse,
) -> Result<Bytes, PayloadError> {
response.body().limit(10_485_760).await
}
#[cfg(feature = "ws")]
pub async fn ws(&mut self) -> Result<WsConnection<impl Filter>, WsClientError> {
self.ws_at("/").await
}
#[cfg(feature = "ws")]
pub async fn ws_at(
&mut self,
path: &str,
) -> Result<WsConnection<impl Filter>, WsClientError> {
WsClient::build(self.url(path))
.address(self.addr)
.timeout(Seconds(30))
.finish()
.unwrap()
.connect()
.await
}
#[cfg(all(feature = "openssl", feature = "ws"))]
pub async fn wss(
&mut self,
) -> Result<
WsConnection<crate::io::Layer<crate::connect::openssl::SslFilter>>,
WsClientError,
> {
self.wss_at("/").await
}
#[cfg(all(feature = "openssl", feature = "ws"))]
pub async fn wss_at(
&mut self,
path: &str,
) -> Result<
WsConnection<crate::io::Layer<crate::connect::openssl::SslFilter>>,
WsClientError,
> {
use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_verify(SslVerifyMode::NONE);
let _ = builder
.set_alpn_protos(b"\x08http/1.1")
.map_err(|e| log::error!("Cannot set alpn protocol: {e:?}"));
WsClient::build(self.url(path))
.address(self.addr)
.timeout(Seconds(30))
.openssl(builder.build())
.take()
.finish()
.unwrap()
.connect()
.await
}
pub async fn stop(&self) {
self.server.stop(true).await;
self.system.stop();
}
}
impl Drop for TestServer {
fn drop(&mut self) {
self.system.stop();
}
}