use std::{
io,
net::SocketAddr,
sync::{
Arc,
atomic::{AtomicBool, AtomicUsize, Ordering},
},
};
use trillium_client::{Client, ClientHandler, Conn, ConnExt, Status, Url};
use trillium_http::KnownHeaderName::ContentLength;
use trillium_server_common::Connector;
use trillium_testing::{ServerConnector, TestResult, harness, test};
#[derive(Debug, Default)]
struct Counter {
runs: AtomicUsize,
after_responses: AtomicUsize,
}
impl ClientHandler for Counter {
async fn run(&self, _conn: &mut Conn) -> trillium_client::Result<()> {
self.runs.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn after_response(&self, _conn: &mut Conn) -> trillium_client::Result<()> {
self.after_responses.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
#[derive(Debug)]
struct Halter;
impl ClientHandler for Halter {
async fn run(&self, conn: &mut Conn) -> trillium_client::Result<()> {
conn.set_status(Status::Ok).set_response_body("synthesized");
conn.response_headers_mut().insert(ContentLength, "11");
conn.halt();
Ok(())
}
}
#[derive(Debug, Default)]
struct OrderRecorder {
runs: std::sync::Mutex<Vec<&'static str>>,
after_responses: std::sync::Mutex<Vec<&'static str>>,
}
#[derive(Debug)]
struct Tagged {
tag: &'static str,
recorder: std::sync::Arc<OrderRecorder>,
}
impl ClientHandler for Tagged {
async fn run(&self, _conn: &mut Conn) -> trillium_client::Result<()> {
self.recorder.runs.lock().unwrap().push(self.tag);
Ok(())
}
async fn after_response(&self, _conn: &mut Conn) -> trillium_client::Result<()> {
self.recorder.after_responses.lock().unwrap().push(self.tag);
Ok(())
}
}
#[test(harness)]
async fn single_handler_runs_both_passes() -> TestResult {
let client = Client::new(ServerConnector::new(Status::Ok)).with_handler(Counter::default());
let _conn = client.get("http://example.com/").await?;
let counter = client
.downcast_handler::<Counter>()
.expect("handler installed");
assert_eq!(counter.runs.load(Ordering::SeqCst), 1);
assert_eq!(counter.after_responses.load(Ordering::SeqCst), 1);
Ok(())
}
#[test(harness)]
async fn handler_can_halt_and_synthesize_response() -> TestResult {
let client =
Client::new(ServerConnector::new(Status::InternalServerError)).with_handler(Halter);
let mut conn = client.get("http://synthetic.invalid/").await?;
assert_eq!(conn.status(), Some(Status::Ok));
assert_eq!(conn.response_body().read_string().await?, "synthesized");
Ok(())
}
#[test(harness)]
async fn tuple_after_response_runs_in_reverse_after_halt() -> TestResult {
let client = Client::new(ServerConnector::new(Status::InternalServerError))
.with_handler((Halter, Counter::default()));
let mut conn = client.get("http://synthetic.invalid/").await?;
assert_eq!(conn.status(), Some(Status::Ok));
assert_eq!(conn.response_body().read_string().await?, "synthesized");
let (_halter, counter) = client
.downcast_handler::<(Halter, Counter)>()
.expect("handler installed");
assert_eq!(counter.runs.load(Ordering::SeqCst), 0);
assert_eq!(counter.after_responses.load(Ordering::SeqCst), 1);
Ok(())
}
#[test(harness)]
async fn tuple_runs_forward_and_after_responses_in_reverse() -> TestResult {
let recorder = std::sync::Arc::new(OrderRecorder::default());
let a = Tagged {
tag: "A",
recorder: recorder.clone(),
};
let b = Tagged {
tag: "B",
recorder: recorder.clone(),
};
let c = Tagged {
tag: "C",
recorder: recorder.clone(),
};
let client = Client::new(ServerConnector::new(Status::Ok)).with_handler((a, b, c));
let _conn = client.get("http://example.com/").await?;
assert_eq!(*recorder.runs.lock().unwrap(), vec!["A", "B", "C"]);
assert_eq!(
*recorder.after_responses.lock().unwrap(),
vec!["C", "B", "A"]
);
Ok(())
}
#[test(harness)]
async fn unit_handler_is_default_and_no_op() -> TestResult {
let client = Client::new(ServerConnector::new(Status::Ok));
let conn = client.get("http://example.com/").await?;
assert_eq!(conn.status(), Some(Status::Ok));
Ok(())
}
#[test(harness)]
async fn downcast_handler_returns_none_for_wrong_type() -> TestResult {
let client = Client::new(ServerConnector::new(Status::Ok)).with_handler(Counter::default());
assert!(client.downcast_handler::<Halter>().is_none());
assert!(client.downcast_handler::<Counter>().is_some());
Ok(())
}
#[derive(Debug)]
struct FailingConnector {
inner: ServerConnector<Status>,
}
impl FailingConnector {
fn new() -> Self {
Self {
inner: ServerConnector::new(Status::Ok),
}
}
}
impl Connector for FailingConnector {
type Runtime = <ServerConnector<Status> as Connector>::Runtime;
type Transport = <ServerConnector<Status> as Connector>::Transport;
type Udp = <ServerConnector<Status> as Connector>::Udp;
async fn connect(&self, _url: &Url) -> io::Result<Self::Transport> {
Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
"test failure",
))
}
fn runtime(&self) -> Self::Runtime {
self.inner.runtime().clone()
}
async fn resolve(&self, host: &str, port: u16) -> io::Result<Vec<SocketAddr>> {
self.inner.resolve(host, port).await
}
}
#[derive(Debug, Default, Clone)]
struct ErrorObserver {
inner: Arc<ErrorObserverInner>,
}
#[derive(Debug, Default)]
struct ErrorObserverInner {
after_response_runs: AtomicUsize,
saw_error: AtomicBool,
}
impl ClientHandler for ErrorObserver {
async fn after_response(&self, conn: &mut Conn) -> trillium_client::Result<()> {
self.inner
.after_response_runs
.fetch_add(1, Ordering::SeqCst);
if conn.error().is_some() {
self.inner.saw_error.store(true, Ordering::SeqCst);
}
Ok(())
}
}
#[test(harness)]
async fn after_response_runs_on_transport_error() -> TestResult {
let observer = ErrorObserver::default();
let client = Client::new(FailingConnector::new()).with_handler(observer.clone());
let result = client.get("http://example.com/").await;
assert!(result.is_err(), "expected transport error, got {result:?}");
assert_eq!(
observer.inner.after_response_runs.load(Ordering::SeqCst),
1,
"after_response should run exactly once on transport failure"
);
assert!(
observer.inner.saw_error.load(Ordering::SeqCst),
"after_response should observe the stashed error"
);
Ok(())
}
#[derive(Debug)]
struct Recoverer;
impl ClientHandler for Recoverer {
async fn after_response(&self, conn: &mut Conn) -> trillium_client::Result<()> {
if conn.take_error().is_some() {
conn.set_status(Status::Ok).set_response_body("recovered");
}
Ok(())
}
}
#[test(harness)]
async fn after_response_can_recover_from_transport_error() -> TestResult {
let client = Client::new(FailingConnector::new()).with_handler(Recoverer);
let mut conn = client.get("http://example.com/").await?;
assert_eq!(conn.status(), Some(Status::Ok));
assert_eq!(conn.response_body().read_string().await?, "recovered");
Ok(())
}
#[derive(Debug, Default, Clone)]
struct ErroringFollowupQueuer {
after_response_runs: Arc<AtomicUsize>,
}
impl ClientHandler for ErroringFollowupQueuer {
async fn after_response(&self, conn: &mut Conn) -> trillium_client::Result<()> {
self.after_response_runs.fetch_add(1, Ordering::SeqCst);
if conn.error().is_some() {
let followup = conn.client().get("http://example.com/followup");
conn.set_followup(followup);
}
Ok(())
}
}
#[test(harness)]
async fn error_wins_over_queued_followup() -> TestResult {
let handler = ErroringFollowupQueuer::default();
let client = Client::new(FailingConnector::new()).with_handler(handler.clone());
let mut conn = client.get("http://example.com/");
let result = (&mut conn).await;
assert!(
result.is_err(),
"transport error should propagate when after_response leaves it stashed, got {result:?}"
);
assert_eq!(
handler.after_response_runs.load(Ordering::SeqCst),
1,
"after_response should run exactly once — the queued follow-up must not be picked up"
);
assert!(
conn.followup().is_none(),
"trampoline should clear the queued follow-up before propagating the error"
);
Ok(())
}