use std::{net, str::FromStr, sync::mpsc, thread, time};
#[cfg(feature = "cookie")]
use coo_kie::{Cookie, CookieJar};
use ntex_tls::TlsConfig;
use uuid::Uuid;
use crate::channel::bstream;
use crate::client::{Client, ClientRequest, ClientResponse, Connector};
#[cfg(feature = "ws")]
use crate::io::Filter;
use crate::io::{Io, IoConfig};
use crate::server::Server;
use crate::service::{ServiceFactory, cfg::SharedCfg};
#[cfg(feature = "ws")]
use crate::ws::{WsClient, WsConnection, error::WsClientError};
use crate::{rt::System, time::Millis, time::Seconds, time::sleep, util::Bytes};
use super::error::{HttpError, PayloadError};
use super::header::{self, HeaderMap, HeaderName, HeaderValue};
use super::payload::Payload;
use super::{Method, Request, Uri, Version};
#[derive(Debug)]
pub struct TestRequest(Option<Inner>);
#[derive(Debug)]
struct Inner {
version: Version,
method: Method,
uri: Uri,
headers: HeaderMap,
#[cfg(feature = "cookie")]
cookies: CookieJar,
payload: Option<Payload>,
}
impl Default for TestRequest {
fn default() -> TestRequest {
TestRequest(Some(Inner {
method: Method::GET,
uri: Uri::from_str("/").unwrap(),
version: Version::HTTP_11,
headers: HeaderMap::new(),
#[cfg(feature = "cookie")]
cookies: CookieJar::new(),
payload: None,
}))
}
}
impl TestRequest {
#[must_use]
pub fn with_uri(path: &str) -> TestRequest {
TestRequest::default().uri(path).take()
}
#[must_use]
pub fn with_header<K, V>(key: K, value: V) -> TestRequest
where
HeaderName: TryFrom<K>,
HeaderValue: TryFrom<V>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
{
TestRequest::default().header(key, value).take()
}
pub fn version(&mut self, ver: Version) -> &mut Self {
parts(&mut self.0).version = ver;
self
}
pub fn method(&mut self, meth: Method) -> &mut Self {
parts(&mut self.0).method = meth;
self
}
pub fn uri(&mut self, path: &str) -> &mut Self {
parts(&mut self.0).uri = Uri::from_str(path).unwrap();
self
}
pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
where
HeaderName: TryFrom<K>,
HeaderValue: TryFrom<V>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
{
if let Ok(key) = HeaderName::try_from(key)
&& let Ok(value) = HeaderValue::try_from(value)
{
parts(&mut self.0).headers.append(key, value);
return self;
}
panic!("Cannot create header");
}
#[cfg(feature = "cookie")]
pub fn cookie<C>(&mut self, cookie: C) -> &mut Self
where
C: Into<Cookie<'static>>,
{
parts(&mut self.0).cookies.add(cookie.into());
self
}
pub fn set_payload<B: Into<Bytes>>(&mut self, data: B) -> &mut Self {
let payload = bstream::empty(Some(data.into()));
parts(&mut self.0).payload = Some(payload.into());
self
}
#[must_use]
pub fn take(&mut self) -> TestRequest {
TestRequest(self.0.take())
}
#[must_use]
pub fn finish(&mut self) -> Request {
let inner = self.0.take().expect("cannot reuse test request builder");
let mut req = if let Some(pl) = inner.payload {
Request::with_payload(pl)
} else {
Request::with_payload(bstream::empty(None).into())
};
let head = req.head_mut();
head.uri = inner.uri;
head.method = inner.method;
head.version = inner.version;
head.headers = inner.headers;
if let Some(conn) = head.headers.get(header::CONNECTION)
&& let Ok(s) = conn.to_str()
&& s.to_lowercase().contains("upgrade")
{
head.set_upgrade();
}
#[cfg(feature = "cookie")]
{
use percent_encoding::percent_encode;
use std::fmt::Write as FmtWrite;
let mut cookie = String::new();
for c in inner.cookies.delta() {
let name = percent_encode(c.name().as_bytes(), super::helpers::USERINFO);
let value = percent_encode(c.value().as_bytes(), super::helpers::USERINFO);
let _ = write!(cookie, "; {name}={value}");
}
if !cookie.is_empty() {
head.headers.insert(
super::header::COOKIE,
HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(),
);
}
}
req
}
}
#[inline]
fn parts(parts: &mut Option<Inner>) -> &mut Inner {
parts.as_mut().expect("cannot reuse test request builder")
}
pub async fn server<F, R>(factory: F) -> TestServer
where
F: AsyncFn() -> R + Send + Clone + 'static,
R: ServiceFactory<Io, SharedCfg> + 'static,
{
server_with_config(
factory,
SharedCfg::new("HTTP-TEST-SRV")
.add(IoConfig::new())
.add(TlsConfig::new())
.add(ntex_h2::ServiceConfig::new()),
)
.await
}
pub async fn server_with_config<F, R, U>(factory: F, cfg: U) -> TestServer
where
F: AsyncFn() -> R + Send + Clone + 'static,
R: ServiceFactory<Io, SharedCfg> + 'static,
U: Into<SharedCfg>,
{
let sys = System::current().config();
let name = System::current().name().to_string();
let id = Uuid::now_v7();
let cfg = cfg.into();
let (tx, rx) = mpsc::channel();
log::debug!("Starting {name:?} http server {id:?}");
thread::spawn(move || {
let sys = System::with_config(&name, sys);
let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap();
let local_addr = tcp.local_addr().unwrap();
sys.run(move || {
let srv = crate::server::build()
.listen("test", tcp, async move |_| factory().await)?
.config("test", cfg)
.workers(1)
.disable_signals()
.run();
crate::rt::spawn(async move {
tx.send((System::current(), srv, local_addr)).unwrap();
});
Ok(())
})
});
let (system, server, addr) = rx.recv().unwrap();
sleep(Millis(25)).await;
TestServer::create(id, system, server, addr, Seconds(90), Millis(90_000)).await
}
#[derive(Debug)]
pub struct TestServer {
id: Uuid,
cfg: SharedCfg,
addr: net::SocketAddr,
client: Client,
system: System,
server: Server,
}
impl TestServer {
pub async fn create(
id: Uuid,
system: System,
server: Server,
addr: net::SocketAddr,
timeout: Seconds,
connect_timeout: Millis,
) -> Self {
let cfg: SharedCfg = SharedCfg::new("TEST-CLIENT")
.add(IoConfig::new().set_connect_timeout(connect_timeout))
.add(TlsConfig::new().set_handshake_timeout(timeout))
.add(
ntex_h2::ServiceConfig::new()
.set_max_header_list_size(256 * 1024)
.set_max_header_continuation_frames(96),
)
.into();
let client = Self::create_client(cfg.clone()).await;
TestServer {
id,
cfg,
addr,
client,
system,
server,
}
}
pub async fn set_client_timeout(
mut self,
timeout: Seconds,
connect_timeout: Millis,
) -> Self {
self.cfg = SharedCfg::new("TEST-CLIENT")
.add(IoConfig::new().set_connect_timeout(connect_timeout))
.add(TlsConfig::new().set_handshake_timeout(timeout))
.add(
ntex_h2::ServiceConfig::new()
.set_max_header_list_size(256 * 1024)
.set_max_header_continuation_frames(96),
)
.into();
self.client = Self::create_client(self.cfg.clone()).await;
self
}
async fn create_client(cfg: SharedCfg) -> Client {
let connector = {
#[cfg(feature = "openssl")]
{
use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_verify(SslVerifyMode::NONE);
let _ = builder
.set_alpn_protos(b"\x02h2\x08http/1.1")
.map_err(|e| log::error!("Cannot set alpn protocol: {e:?}"));
Connector::default().openssl(builder.build())
}
#[cfg(not(feature = "openssl"))]
{
Connector::default()
}
};
Client::builder()
.connector::<&str>(connector)
.build(cfg)
.await
.unwrap()
}
pub fn addr(&self) -> net::SocketAddr {
self.addr
}
pub fn url(&self, uri: &str) -> String {
if uri.starts_with('/') {
format!("http://localhost:{}{}", self.addr.port(), uri)
} else {
format!("http://localhost:{}/{}", self.addr.port(), uri)
}
}
pub fn surl(&self, uri: &str) -> String {
if uri.starts_with('/') {
format!("https://localhost:{}{}", self.addr.port(), uri)
} else {
format!("https://localhost:{}/{}", self.addr.port(), uri)
}
}
pub fn request<S: AsRef<str>>(&self, method: Method, path: S) -> ClientRequest {
self.client
.request(method, self.url(path.as_ref()).as_str())
}
pub fn srequest<S: AsRef<str>>(&self, method: Method, path: S) -> ClientRequest {
self.client
.request(method, self.surl(path.as_ref()).as_str())
}
pub async fn load_body(&self, response: ClientResponse) -> Result<Bytes, PayloadError> {
response.body().limit(10_485_760).await
}
#[cfg(feature = "ws")]
pub async fn ws(&self) -> Result<WsConnection<impl Filter>, WsClientError> {
self.ws_at("/").await
}
#[cfg(feature = "ws")]
pub async fn ws_at(
&self,
path: &str,
) -> Result<WsConnection<impl Filter>, WsClientError> {
WsClient::builder(self.url(path))
.address(self.addr)
.timeout(Seconds(30))
.build(self.cfg.clone())
.await
.unwrap()
.connect()
.await
}
#[cfg(all(feature = "openssl", feature = "ws"))]
pub async fn wss(
&self,
) -> Result<
WsConnection<crate::io::Layer<crate::connect::openssl::SslFilter>>,
WsClientError,
> {
self.wss_at("/").await
}
#[cfg(all(feature = "openssl", feature = "ws"))]
pub async fn wss_at(
&self,
path: &str,
) -> Result<
WsConnection<crate::io::Layer<crate::connect::openssl::SslFilter>>,
WsClientError,
> {
use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_verify(SslVerifyMode::NONE);
let _ = builder
.set_alpn_protos(b"\x08http/1.1")
.map_err(|e| log::error!("Cannot set alpn protocol: {e:?}"));
WsClient::builder(self.url(path))
.address(self.addr)
.timeout(Seconds(30))
.openssl(builder.build())
.take()
.build(self.cfg.clone())
.await
.unwrap()
.connect()
.await
}
pub async fn stop(self) {
self.server.stop(true).await;
}
}
impl Drop for TestServer {
fn drop(&mut self) {
log::debug!("Stopping test http server {:?}", self.id);
drop(self.server.stop(false));
thread::sleep(time::Duration::from_millis(75));
self.system.stop();
thread::sleep(time::Duration::from_millis(25));
}
}