use std::convert::TryFrom as _;
use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use futures_util::future::{BoxFuture, FutureExt};
use http::Uri;
use hyper_util::client::legacy::connect::{Connected, Connection};
use hyper_util::rt::TokioIo;
use log::info;
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio::time::Sleep;
use tokio_rustls::client::TlsStream;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateSec1KeyDer, ServerName};
use tokio_rustls::rustls::{self, ClientConfig, RootCertStore, ServerConfig};
use tokio_rustls::TlsConnector;
use tower_service::Service;
use crate::handler::NewHandler;
use crate::test::async_test::{AsyncTestClient, AsyncTestServerInner};
use crate::test::{self, TestClient, TestServerData};
use crate::tls::rustls_wrap;
fn server_config() -> ServerConfig {
let cert = CertificateDer::from_slice(include_bytes!("tls_cert.der"));
let key = PrivateSec1KeyDer::from(&include_bytes!("tls_key.der")[..]);
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert], key.into())
.expect("Unable to create TLS server config")
}
#[derive(Clone)]
pub struct TestServer {
data: Arc<TestServerData>,
}
impl test::Server for TestServer {
fn run_future<F, O>(&self, future: F) -> O
where
F: Future<Output = O>,
{
self.data.run_future(future)
}
fn request_expiry(&self) -> Sleep {
self.data.request_expiry()
}
}
impl TestServer {
pub fn new<NH: NewHandler + 'static>(new_handler: NH) -> anyhow::Result<TestServer> {
TestServer::with_timeout(new_handler, 10)
}
pub fn with_timeout<NH: NewHandler + 'static>(
new_handler: NH,
timeout: u64,
) -> anyhow::Result<TestServer> {
let cfg = server_config();
let data = TestServerData::new(new_handler, timeout, rustls_wrap(cfg))?;
Ok(TestServer {
data: Arc::new(data),
})
}
pub fn client(&self) -> TestClient<Self, TestConnect> {
self.data.client(self)
}
pub fn spawn<F>(&self, future: F)
where
F: Future<Output = ()> + Send + 'static,
{
self.data.spawn(future)
}
}
#[derive(Clone)]
pub struct AsyncTestServer {
inner: Arc<AsyncTestServerInner>,
}
impl AsyncTestServer {
pub async fn new<NH: NewHandler + 'static>(new_handler: NH) -> anyhow::Result<AsyncTestServer> {
AsyncTestServer::new_with_timeout(new_handler, Duration::from_secs(10)).await
}
pub async fn new_with_timeout<NH: NewHandler + 'static>(
new_handler: NH,
timeout: Duration,
) -> anyhow::Result<AsyncTestServer> {
let cfg = server_config();
let inner = AsyncTestServerInner::new(new_handler, timeout, rustls_wrap(cfg)).await?;
Ok(AsyncTestServer {
inner: Arc::new(inner),
})
}
pub fn client(&self) -> AsyncTestClient<crate::tls::test::TestConnect> {
self.inner.client()
}
}
#[allow(missing_docs)]
#[pin_project]
pub struct TlsConnectionStream<IO>(#[pin] TlsStream<IO>);
impl<IO: AsyncRead + AsyncWrite + Connection + Unpin> Connection for TlsConnectionStream<IO> {
fn connected(&self) -> Connected {
let (tcp, tls) = self.0.get_ref();
if tls.alpn_protocol() == Some(b"h2") {
tcp.connected().negotiated_h2()
} else {
tcp.connected()
}
}
}
impl<IO> AsyncRead for TlsConnectionStream<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), io::Error>> {
self.project().0.poll_read(cx, buf)
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsConnectionStream<IO> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().0.poll_write(cx, buf)
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().0.poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().0.poll_shutdown(cx)
}
}
#[derive(Clone)]
pub struct TestConnect {
pub(crate) addr: SocketAddr,
pub(crate) config: Arc<rustls::ClientConfig>,
}
impl Service<Uri> for TestConnect {
type Response = TokioIo<TlsConnectionStream<TcpStream>>;
type Error = tokio::io::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Ok(()).into()
}
fn call(&mut self, req: Uri) -> Self::Future {
let tls = TlsConnector::from(self.config.clone());
let address = self.addr;
async move {
match TcpStream::connect(address).await {
Ok(stream) => {
let domain = ServerName::try_from(req.host().unwrap())
.unwrap()
.to_owned();
match tls.connect(domain, stream).await {
Ok(tls_stream) => {
info!("Client TcpStream connected: {:?}", tls_stream);
Ok(TokioIo::new(TlsConnectionStream(tls_stream)))
}
Err(error) => {
info!("TLS TestClient error: {:?}", error);
Err(error)
}
}
}
Err(error) => Err(error),
}
}
.boxed()
}
}
impl From<SocketAddr> for TestConnect {
fn from(addr: SocketAddr) -> Self {
let mut root_store = RootCertStore::empty();
let ca_cert = CertificateDer::from_slice(include_bytes!("tls_ca_cert.der"));
root_store.add(ca_cert).unwrap();
let cfg = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
Self {
addr,
config: Arc::new(cfg),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test::helper::TestHandler;
use crate::test::{self, async_test, Server};
use tokio::sync::oneshot;
#[test]
fn test_server_serves_requests() {
test::common_tests::serves_requests(TestServer::new, TestServer::client)
}
#[test]
fn test_server_times_out() {
test::common_tests::times_out(TestServer::with_timeout, TestServer::client)
}
#[test]
fn test_server_async_echo() {
test::common_tests::async_echo(TestServer::new, TestServer::client)
}
#[test]
fn test_server_supports_multiple_servers() {
test::common_tests::supports_multiple_servers(TestServer::new, TestServer::client)
}
#[test]
fn test_server_spawns_and_runs_futures() {
let server = TestServer::new(TestHandler::default()).unwrap();
let (sender, spawn_receiver) = oneshot::channel();
let (spawn_sender, run_receiver) = oneshot::channel();
sender.send(1).unwrap();
server.spawn(async move {
assert_eq!(1, spawn_receiver.await.unwrap());
spawn_sender.send(42).unwrap();
});
assert_eq!(42, server.run_future(run_receiver).unwrap());
}
#[test]
fn test_server_adds_client_address_to_state() {
test::common_tests::adds_client_address_to_state(TestServer::new, TestServer::client);
}
#[tokio::test]
async fn async_test_server_serves_requests() {
async_test::common_tests::serves_requests(AsyncTestServer::new, AsyncTestServer::client)
.await;
}
#[tokio::test]
async fn async_test_server_times_out() {
async_test::common_tests::times_out(
AsyncTestServer::new_with_timeout,
AsyncTestServer::client,
)
.await;
}
#[tokio::test]
async fn async_test_server_echo() {
async_test::common_tests::echo(AsyncTestServer::new, AsyncTestServer::client).await;
}
#[tokio::test]
async fn async_test_server_supports_multiple_servers() {
async_test::common_tests::supports_multiple_servers(
AsyncTestServer::new,
AsyncTestServer::client,
)
.await;
}
#[tokio::test]
async fn async_test_server_adds_client_address_to_state() {
async_test::common_tests::adds_client_address_to_state(
AsyncTestServer::new,
AsyncTestServer::client,
)
.await;
}
}