#![allow(clippy::unwrap_used, clippy::expect_used)]
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use tokio::net::TcpListener;
use tokio::sync::{oneshot, watch};
use tokio::task::JoinHandle;
use tokio::time::sleep;
use tokio_stream::wrappers::TcpListenerStream;
use tonic::service::Routes;
use tonic::transport::{Endpoint, Server as TonicServer};
use crate::{Server, ServerError, ServingState};
pub struct BootedServer {
pub addr: SocketAddr,
pub state_rx: watch::Receiver<ServingState>,
pub serve_handle: JoinHandle<Result<(), ServerError>>,
shutdown_tx: oneshot::Sender<()>,
}
impl BootedServer {
pub async fn shutdown(self) -> Result<(), ServerError> {
let _ = self.shutdown_tx.send(());
self.serve_handle
.await
.expect("server task panicked or was cancelled before shutdown")
}
}
pub async fn boot_server(server: Server) -> BootedServer {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind 127.0.0.1:0 for test server");
let addr = listener
.local_addr()
.expect("local_addr on freshly bound listener");
let state_rx = server.state_rx.clone();
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let serve_handle = tokio::spawn(async move {
server
.serve_with_listener(listener, async {
let _ = shutdown_rx.await;
})
.await
});
BootedServer {
addr,
state_rx,
serve_handle,
shutdown_tx,
}
}
pub struct BootedRouter {
pub addr: SocketAddr,
pub serve_handle: JoinHandle<Result<(), tonic::transport::Error>>,
shutdown_tx: oneshot::Sender<()>,
}
impl BootedRouter {
pub async fn shutdown(self) -> Result<(), tonic::transport::Error> {
let _ = self.shutdown_tx.send(());
self.serve_handle
.await
.expect("router task panicked or was cancelled before shutdown")
}
pub fn abort(self) {
self.serve_handle.abort();
}
}
pub async fn boot_router(routes: Routes) -> BootedRouter {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind 127.0.0.1:0 for test router");
let addr = listener
.local_addr()
.expect("local_addr on freshly bound listener");
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let serve_handle = tokio::spawn(async move {
TonicServer::builder()
.add_routes(routes)
.serve_with_incoming_shutdown(TcpListenerStream::new(listener), async {
let _ = shutdown_rx.await;
})
.await
});
BootedRouter {
addr,
serve_handle,
shutdown_tx,
}
}
pub async fn wait_until<F>(rx: &mut watch::Receiver<ServingState>, predicate: F)
where
F: Fn(&ServingState) -> bool,
{
loop {
if predicate(&rx.borrow_and_update()) {
return;
}
rx.changed()
.await
.expect("state stream closed before reaching expected state");
}
}
pub async fn wait_until_serving(rx: &mut watch::Receiver<ServingState>) {
wait_until(rx, |s| matches!(s, ServingState::Serving)).await;
}
pub async fn wait_until_not_serving(rx: &mut watch::Receiver<ServingState>) {
wait_until(rx, |s| matches!(s, ServingState::NotServing { .. })).await;
}
pub async fn wait_for_grpc_handshake(
addr: SocketAddr,
budget: Duration,
) -> Result<(), tonic::transport::Error> {
let deadline = Instant::now() + budget;
let endpoint: Endpoint = format!("http://{addr}")
.parse()
.expect("constructed endpoint URI must parse");
let mut last_err: Option<tonic::transport::Error> = None;
loop {
match endpoint.connect().await {
Ok(channel) => {
drop(channel);
return Ok(());
}
Err(err) => {
if Instant::now() >= deadline {
return Err(last_err.unwrap_or(err));
}
last_err = Some(err);
sleep(Duration::from_millis(25)).await;
}
}
}
}
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
pub async fn wait_for_grpc_handshake_tls(
addr: SocketAddr,
tls_config: tonic::transport::ClientTlsConfig,
budget: Duration,
) -> Result<(), tonic::transport::Error> {
let deadline = Instant::now() + budget;
let endpoint: Endpoint = format!("https://{addr}")
.parse()
.expect("constructed endpoint URI must parse");
let endpoint = endpoint.tls_config(tls_config)?;
let mut last_err: Option<tonic::transport::Error> = None;
loop {
match endpoint.connect().await {
Ok(channel) => {
drop(channel);
return Ok(());
}
Err(err) => {
if Instant::now() >= deadline {
return Err(last_err.unwrap_or(err));
}
last_err = Some(err);
sleep(Duration::from_millis(25)).await;
}
}
}
}