#![allow(clippy::unwrap_used, clippy::expect_used)]
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use tokio::net::TcpListener;
use tokio::sync::{oneshot, watch};
use tokio::task::JoinHandle;
use tokio::time::sleep;
use tokio_stream::wrappers::TcpListenerStream;
use tonic::service::Routes;
use tonic::transport::{Endpoint, Server as TonicServer};
use crate::{Server, ServerError, ServingState};
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
use crate::leader_hint::not_leader_status;
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
use tsoracle_proto::v1::{
GetTsRequest, GetTsResponse, LeaderHint,
tso_service_server::{TsoService, TsoServiceServer},
};
pub struct BootedServer {
pub addr: SocketAddr,
pub state_rx: watch::Receiver<ServingState>,
pub serve_handle: JoinHandle<Result<(), ServerError>>,
shutdown_tx: oneshot::Sender<()>,
}
impl BootedServer {
pub async fn shutdown(self) -> Result<(), ServerError> {
let _ = self.shutdown_tx.send(());
self.serve_handle
.await
.expect("server task panicked or was cancelled before shutdown")
}
}
pub async fn boot_server(server: Server) -> BootedServer {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind 127.0.0.1:0 for test server");
let addr = listener
.local_addr()
.expect("local_addr on freshly bound listener");
let state_rx = server.subscribe();
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let serve_handle = tokio::spawn(async move {
server
.serve_with_listener(listener, async {
let _ = shutdown_rx.await;
})
.await
});
BootedServer {
addr,
state_rx,
serve_handle,
shutdown_tx,
}
}
pub struct BootedRouter {
pub addr: SocketAddr,
pub serve_handle: JoinHandle<Result<(), tonic::transport::Error>>,
shutdown_tx: oneshot::Sender<()>,
}
impl BootedRouter {
pub async fn shutdown(self) -> Result<(), tonic::transport::Error> {
let _ = self.shutdown_tx.send(());
self.serve_handle
.await
.expect("router task panicked or was cancelled before shutdown")
}
pub fn abort(self) {
self.serve_handle.abort();
}
}
pub async fn boot_router(routes: Routes) -> BootedRouter {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind 127.0.0.1:0 for test router");
let addr = listener
.local_addr()
.expect("local_addr on freshly bound listener");
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let serve_handle = tokio::spawn(async move {
TonicServer::builder()
.add_routes(routes)
.serve_with_incoming_shutdown(TcpListenerStream::new(listener), async {
let _ = shutdown_rx.await;
})
.await
});
BootedRouter {
addr,
serve_handle,
shutdown_tx,
}
}
pub async fn wait_until<F>(rx: &mut watch::Receiver<ServingState>, predicate: F)
where
F: Fn(&ServingState) -> bool,
{
loop {
if predicate(&rx.borrow_and_update()) {
return;
}
rx.changed()
.await
.expect("state stream closed before reaching expected state");
}
}
pub async fn wait_until_serving(rx: &mut watch::Receiver<ServingState>) {
wait_until(rx, |s| matches!(s, ServingState::Serving)).await;
}
pub async fn wait_until_not_serving(rx: &mut watch::Receiver<ServingState>) {
wait_until(rx, |s| matches!(s, ServingState::NotServing { .. })).await;
}
pub async fn wait_for_grpc_handshake(
addr: SocketAddr,
budget: Duration,
) -> Result<(), tonic::transport::Error> {
let deadline = Instant::now() + budget;
let endpoint: Endpoint = format!("http://{addr}")
.parse()
.expect("constructed endpoint URI must parse");
let mut last_err: Option<tonic::transport::Error> = None;
loop {
match endpoint.connect().await {
Ok(channel) => {
drop(channel);
return Ok(());
}
Err(err) => {
if Instant::now() >= deadline {
return Err(last_err.unwrap_or(err));
}
last_err = Some(err);
sleep(Duration::from_millis(25)).await;
}
}
}
}
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
pub async fn wait_for_grpc_handshake_tls(
addr: SocketAddr,
tls_config: tonic::transport::ClientTlsConfig,
budget: Duration,
) -> Result<(), tonic::transport::Error> {
let deadline = Instant::now() + budget;
let endpoint: Endpoint = format!("https://{addr}")
.parse()
.expect("constructed endpoint URI must parse");
let endpoint = endpoint.tls_config(tls_config)?;
let mut last_err: Option<tonic::transport::Error> = None;
loop {
match endpoint.connect().await {
Ok(channel) => {
drop(channel);
return Ok(());
}
Err(err) => {
if Instant::now() >= deadline {
return Err(last_err.unwrap_or(err));
}
last_err = Some(err);
sleep(Duration::from_millis(25)).await;
}
}
}
}
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
struct FixedHintService {
hint_endpoint: String,
}
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
#[tonic::async_trait]
impl TsoService for FixedHintService {
async fn get_ts(
&self,
_request: tonic::Request<GetTsRequest>,
) -> Result<tonic::Response<GetTsResponse>, tonic::Status> {
Err(not_leader_status(
&crate::reporter::Reporter::for_tests(),
LeaderHint {
leader_endpoint: Some(self.hint_endpoint.clone()),
leader_epoch: None,
},
))
}
async fn get_current_max_safe(
&self,
_request: tonic::Request<tsoracle_proto::v1::GetCurrentMaxSafeRequest>,
) -> Result<tonic::Response<tsoracle_proto::v1::GetCurrentMaxSafeResponse>, tonic::Status> {
Ok(tonic::Response::new(
tsoracle_proto::v1::GetCurrentMaxSafeResponse::default(),
))
}
}
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
pub async fn boot_fixed_hint_server_tls(
hint_endpoint: String,
tls_config: tonic::transport::ServerTlsConfig,
) -> SocketAddr {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind 127.0.0.1:0 for fixed-hint TLS server");
let addr = listener
.local_addr()
.expect("local_addr for fixed-hint TLS server");
let server = TonicServer::builder()
.tls_config(tls_config)
.expect("fixed-hint server tls config")
.add_service(TsoServiceServer::new(FixedHintService { hint_endpoint }));
tokio::spawn(async move {
let _ = server
.serve_with_incoming(TcpListenerStream::new(listener))
.await;
});
addr
}