use crate::{Runtime, RuntimeTrait, ServerConnector, TestTransport, runtime};
use async_channel::Sender;
use std::{
any::{Any, type_name},
fmt::{self, Debug, Formatter},
future::{Future, IntoFuture},
net::IpAddr,
pin::Pin,
str,
sync::Arc,
};
use trillium::{Handler, Info, KnownHeaderName};
use trillium_client::{Client, IntoUrl};
use trillium_http::{HeaderName, HeaderValues, Headers, HttpContext, Method, Status};
#[allow(clippy::test_attr_in_doctest, reason = "demonstrating test usage")]
#[derive(Clone, Debug)]
pub struct TestServer<H> {
client: Client,
peer_ip_sender: Sender<IpAddr>,
connector: ServerConnector<H>,
}
impl<H: Handler> TestServer<H> {
pub async fn new(handler: H) -> Self {
Self::new_with_runtime(handler, runtime()).await
}
async fn new_with_runtime(mut handler: H, rt: impl RuntimeTrait) -> Self {
let url = "http://trillium.test".into_url(None).unwrap();
let mut info = Info::from(HttpContext::default());
info.insert_shared_state(rt.clone());
info.insert_shared_state(Runtime::new(rt.clone()));
info.insert_shared_state(url.clone());
handler.init(&mut info).await;
let context: Arc<HttpContext> = Arc::new(info.into());
let mut connector = ServerConnector::new(handler)
.with_context(context.clone())
.with_runtime(rt);
let (peer_ip_sender, receive) = async_channel::unbounded();
connector.server_peer_ips_receiver = Some(receive);
let client = Client::new(connector.clone()).with_base(url);
Self {
client,
peer_ip_sender,
connector,
}
}
pub fn new_blocking(handler: H) -> Self {
let rt = runtime();
rt.clone().block_on(Self::new_with_runtime(handler, rt))
}
pub fn build<M: TryInto<Method>>(&self, method: M, path: &str) -> ConnTest
where
M::Error: Debug,
{
ConnTest {
inner: self.client.build_conn(method, path),
body: None,
peer_ip_sender: self.peer_ip_sender.clone(),
peer_ip: None,
runtime: self.connector.runtime().clone(),
}
}
pub fn shared_state<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.connector.context().shared_state().get()
}
#[track_caller]
pub fn assert_shared_state<T>(&self, expected: T) -> &Self
where
T: Send + Sync + Debug + PartialEq + 'static,
{
match self.shared_state::<T>() {
Some(actual) => assert_eq!(*actual, expected),
None => panic!(
"expected handler state of type {}, but none was found",
type_name::<T>()
),
};
self
}
pub fn assert_shared_state_with<T, F>(&self, f: F) -> &Self
where
T: Send + Sync + 'static,
F: FnOnce(&T),
{
match self.shared_state::<T>() {
Some(state) => f(state),
None => panic!(
"expected handler state of type {}, but none was found",
type_name::<T>()
),
};
self
}
pub fn handler(&self) -> &H {
self.connector.handler()
}
pub fn with_host(mut self, host: &str) -> Self {
self.set_host(host);
self
}
pub fn set_host(&mut self, host: &str) -> &mut Self {
let _ = self.client.base_mut().unwrap().set_host(Some(host));
self
}
pub fn with_base(mut self, base: impl IntoUrl) -> Self {
self.set_base(base);
self
}
pub fn set_base(&mut self, base: impl IntoUrl) -> &mut Self {
self.client
.set_base(base)
.expect("unable to build a base url");
self
}
pub fn get(&self, path: &str) -> ConnTest {
self.build(Method::Get, path)
}
pub fn post(&self, path: &str) -> ConnTest {
self.build(Method::Post, path)
}
pub fn put(&self, path: &str) -> ConnTest {
self.build(Method::Put, path)
}
pub fn delete(&self, path: &str) -> ConnTest {
self.build(Method::Delete, path)
}
pub fn patch(&self, path: &str) -> ConnTest {
self.build(Method::Patch, path)
}
}
pub struct ConnTest {
inner: trillium_client::Conn,
body: Option<Vec<u8>>,
peer_ip_sender: Sender<IpAddr>,
peer_ip: Option<IpAddr>,
runtime: Runtime,
}
impl Debug for ConnTest {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("ConnTest")
.field("inner", &self.inner)
.field("body", &self.body.as_deref().map(String::from_utf8_lossy))
.finish()
}
}
impl ConnTest {
pub fn with_request_header(
mut self,
name: impl Into<HeaderName<'static>>,
value: impl Into<HeaderValues>,
) -> Self {
self.inner.request_headers_mut().insert(name, value);
self
}
pub fn with_request_headers<HN, HV, I>(mut self, headers: I) -> Self
where
I: IntoIterator<Item = (HN, HV)> + Send,
HN: Into<HeaderName<'static>>,
HV: Into<HeaderValues>,
{
self.inner.request_headers_mut().extend(headers);
self
}
pub fn without_request_header(mut self, name: impl Into<HeaderName<'static>>) -> Self {
self.inner.request_headers_mut().remove(name);
self
}
pub fn with_body(mut self, body: impl Into<trillium_http::Body>) -> Self {
self.inner.set_request_body(body);
self
}
#[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
pub fn with_json_body(mut self, body: &impl serde::Serialize) -> Self {
self.inner
.request_headers_mut()
.try_insert(KnownHeaderName::ContentType, "application/json");
self.with_body(crate::to_json_string(body).unwrap())
}
pub fn with_peer_ip(mut self, peer_ip: impl Into<IpAddr>) -> Self {
self.peer_ip = Some(peer_ip.into());
self
}
pub fn block(self) -> Self {
self.runtime.clone().block_on(self.into_future())
}
}
impl ConnTest {
pub fn state<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.inner.state::<T>()
}
#[track_caller]
pub fn assert_state<T>(&self, expected: T) -> &Self
where
T: PartialEq + Debug + Send + Sync + 'static,
{
match self.state::<T>() {
Some(actual) => assert_eq!(*actual, expected),
None => panic!(
"expected handler state of type {}, but none was found",
type_name::<T>()
),
}
self
}
#[track_caller]
pub fn assert_no_state<T>(&self) -> &Self
where
T: Debug + Send + Sync + 'static,
{
if let Some(value) = self.state::<T>() {
panic!(
"expected no handler state of type {}, but found {:?}",
type_name::<T>(),
value
);
}
self
}
pub fn status(&self) -> Status {
self.inner
.status()
.expect("response not yet received — did you .await this ConnTest?")
}
pub fn body(&self) -> &str {
str::from_utf8(self.body_bytes()).expect("body was not utf-8")
}
pub fn body_bytes(&self) -> &[u8] {
self.body.as_deref().expect("body was not set")
}
pub fn response_headers(&self) -> &Headers {
self.inner.response_headers()
}
pub fn response_trailers(&self) -> Option<&Headers> {
self.inner.response_trailers()
}
pub fn request_trailers(&self) -> Option<&Headers> {
self.inner.request_trailers()
}
pub fn header<'a>(&self, name: impl Into<HeaderName<'a>>) -> Option<&str> {
self.inner.response_headers().get_str(name)
}
pub fn trailer<'a>(&self, name: impl Into<HeaderName<'a>>) -> Option<&str> {
self.inner
.response_trailers()
.and_then(|trailers| trailers.get_str(name))
}
#[track_caller]
pub fn assert_status(&self, status: impl TryInto<Status>) -> &Self {
let expected: Status = status
.try_into()
.ok()
.expect("expected a valid status code");
let actual = self.status();
assert_eq!(actual, expected, "expected status {expected}, got {actual}");
self
}
#[track_caller]
pub fn assert_ok(&self) -> &Self {
self.assert_status(200)
}
#[track_caller]
pub fn assert_body(&self, expected: &str) -> &Self {
assert_eq!(self.body().trim_end(), expected.trim_end());
self
}
#[track_caller]
pub fn assert_body_contains(&self, pattern: &str) -> &Self {
let body = self.body();
assert!(
body.contains(pattern),
"expected body to contain {pattern:?}, but body was:\n{body}"
);
self
}
#[track_caller]
pub fn assert_header<'a, HV, HN>(&self, name: HN, expected: HV) -> &Self
where
HeaderValues: PartialEq<HV>,
HV: Debug,
HN: Into<HeaderName<'a>>,
{
let name = name.into();
match self.inner.response_headers().get_values(name.clone()) {
Some(actual) => assert_eq!(*actual, expected, "for header {name:?}"),
None => panic!("header {name} not set"),
};
self
}
#[track_caller]
pub fn assert_headers<'a, I, HN, HV>(&self, headers: I) -> &Self
where
I: IntoIterator<Item = (HN, HV)> + Send,
HN: Into<HeaderName<'a>>,
HV: Debug,
HeaderValues: PartialEq<HV>,
{
for (name, expected) in headers {
self.assert_header(name, expected);
}
self
}
#[track_caller]
pub fn assert_no_header(&self, name: &str) -> &Self {
let actual = self.header(name);
assert!(
actual.is_none(),
"expected no header {name:?}, but found {actual:?}"
);
self
}
#[track_caller]
pub fn assert_header_with<'a, F>(&self, name: impl Into<HeaderName<'a>>, f: F) -> &Self
where
F: FnOnce(&HeaderValues),
{
let name = name.into();
match self.response_headers().get_values(name.clone()) {
Some(values) => f(values),
None => panic!("expected header {name:?}, but it was not found"),
}
self
}
#[track_caller]
pub fn assert_state_with<T, F>(&self, f: F) -> &Self
where
T: Send + Sync + 'static,
F: FnOnce(&T),
{
match self.state::<T>() {
Some(state) => f(state),
None => panic!(
"expected handler state of type {}, but none was found",
type_name::<T>()
),
};
self
}
#[track_caller]
pub fn assert_body_with<F>(&self, f: F) -> &Self
where
F: FnOnce(&str),
{
f(self.body());
self
}
#[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
#[track_caller]
pub fn assert_json_body_with<T, F>(&self, f: F) -> &Self
where
T: serde::de::DeserializeOwned,
F: FnOnce(&T),
{
let parsed: T =
crate::from_json_str(self.body()).expect("failed to parse response body as JSON");
f(&parsed);
self
}
#[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
#[track_caller]
pub fn assert_json_body<T>(&self, body: &T) -> &Self
where
T: serde::de::DeserializeOwned + PartialEq + Debug,
{
let parsed: T =
crate::from_json_str(self.body()).expect("failed to parse response body as JSON");
assert_eq!(&parsed, body);
self
}
#[track_caller]
pub fn assert_trailer<'a, HV, HN>(&self, name: HN, expected: HV) -> &Self
where
HeaderValues: PartialEq<HV>,
HV: Debug,
HN: Into<HeaderName<'a>>,
{
let name = name.into();
match self
.inner
.response_trailers()
.and_then(|trailers| trailers.get_values(name.clone()))
{
Some(actual) => assert_eq!(*actual, expected, "for trailer {name:?}"),
None => panic!("trailer {name} not set"),
};
self
}
#[track_caller]
pub fn assert_trailers<'a, I, HN, HV>(&self, trailers: I) -> &Self
where
I: IntoIterator<Item = (HN, HV)> + Send,
HN: Into<HeaderName<'a>>,
HV: Debug,
HeaderValues: PartialEq<HV>,
{
for (name, expected) in trailers {
self.assert_trailer(name, expected);
}
self
}
#[track_caller]
pub fn assert_no_trailer(&self, name: &str) -> &Self {
let actual = self.trailer(name);
assert!(
actual.is_none(),
"expected no trailer {name:?}, but found {actual:?}"
);
self
}
#[track_caller]
pub fn assert_trailer_with<'a, F>(&self, name: impl Into<HeaderName<'a>>, f: F) -> &Self
where
F: FnOnce(&HeaderValues),
{
let name = name.into();
match self
.response_trailers()
.and_then(|trailers| trailers.get_values(name.clone()))
{
Some(values) => f(values),
None => panic!("expected trailer {name:?}, but it was not found"),
}
self
}
}
impl IntoFuture for ConnTest {
type IntoFuture = Pin<Box<dyn Future<Output = ConnTest> + Send + 'static>>;
type Output = ConnTest;
fn into_future(mut self) -> Self::IntoFuture {
Box::pin(async move {
if let Some(peer_ip) = self.peer_ip.take() {
let _ = self.peer_ip_sender.send(peer_ip).await;
}
if let Some(host) = self
.inner
.request_headers()
.get_str(KnownHeaderName::Host)
.map(ToString::to_string)
{
let _ = self.inner.url_mut().set_host(Some(&host));
}
let inner = &mut self.inner;
inner.await.expect("request to virtual server failed");
let inner = &mut self.inner;
if let Some(transport) = inner.transport_mut() {
let state = std::mem::take(
&mut *((transport as &dyn Any)
.downcast_ref::<TestTransport>()
.unwrap()
.state()
.write()
.unwrap()),
);
*inner.as_mut() = state;
}
self.body = Some(
self.inner
.response_body()
.read_bytes()
.await
.expect("failed to read response body"),
);
self
})
}
}