use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
use futures_util::future::select_all;
use tokio::sync::{broadcast, oneshot};
#[derive(Debug)]
pub enum GateError {
Timeout {
not_ready: Vec<String>,
},
ServerFailed {
actor: String,
},
}
impl std::fmt::Display for GateError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Timeout { not_ready } => {
write!(f, "readiness timeout: not ready: {}", not_ready.join(", "))
}
Self::ServerFailed { actor } => {
write!(f, "server actor '{actor}' failed before becoming ready")
}
}
}
}
impl std::error::Error for GateError {}
pub struct ReadinessGate {
ready_rxs: Vec<(String, oneshot::Receiver<()>)>,
gate_tx: broadcast::Sender<()>,
}
impl ReadinessGate {
#[must_use]
#[allow(clippy::similar_names)]
pub fn new(server_actors: &[String]) -> (Self, Vec<(String, oneshot::Sender<()>)>) {
let (gate_tx, _) = broadcast::channel(1);
let mut receivers = Vec::with_capacity(server_actors.len());
let mut senders = Vec::with_capacity(server_actors.len());
for name in server_actors {
let (tx, rx) = oneshot::channel();
receivers.push((name.clone(), rx));
senders.push((name.clone(), tx));
}
let gate = Self {
ready_rxs: receivers,
gate_tx,
};
(gate, senders)
}
#[must_use]
pub fn subscribe(&self) -> broadcast::Receiver<()> {
self.gate_tx.subscribe()
}
pub async fn wait_all_ready(self, timeout: Duration) -> Result<(), GateError> {
let Self { ready_rxs, gate_tx } = self;
wait_all_receivers(ready_rxs, timeout).await?;
let _ = gate_tx.send(());
Ok(())
}
}
type ReadyWait =
Pin<Box<dyn Future<Output = (String, Result<(), oneshot::error::RecvError>)> + Send>>;
async fn wait_all_receivers(
ready_rxs: Vec<(String, oneshot::Receiver<()>)>,
timeout: Duration,
) -> Result<(), GateError> {
let deadline = tokio::time::Instant::now() + timeout;
let mut not_ready: Vec<String> = ready_rxs.iter().map(|(name, _)| name.clone()).collect();
let mut waits: Vec<ReadyWait> = ready_rxs
.into_iter()
.map(|(name, rx)| Box::pin(async move { (name, rx.await) }) as ReadyWait)
.collect();
while !waits.is_empty() {
let selected = tokio::time::timeout_at(deadline, select_all(waits)).await;
match selected {
Ok(((name, Ok(())), _idx, remaining)) => {
waits = remaining;
not_ready.retain(|n| n != &name);
}
Ok(((name, Err(_)), _idx, _remaining)) => {
return Err(GateError::ServerFailed { actor: name });
}
Err(_elapsed) => {
not_ready.sort();
return Err(GateError::Timeout { not_ready });
}
}
}
Ok(())
}
impl std::fmt::Debug for ReadinessGate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReadinessGate")
.field("pending_servers", &self.ready_rxs.len())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn all_servers_ready_opens_gate() {
let servers = vec!["server1".to_string(), "server2".to_string()];
let (gate, ready_txs) = ReadinessGate::new(&servers);
let mut client_rx = gate.subscribe();
for (_, tx) in ready_txs {
tx.send(()).unwrap();
}
gate.wait_all_ready(Duration::from_secs(5)).await.unwrap();
assert!(client_rx.recv().await.is_ok());
}
#[tokio::test]
async fn timeout_returns_error_with_server_names() {
let servers = vec!["server1".to_string(), "server2".to_string()];
let (gate, mut ready_txs) = ReadinessGate::new(&servers);
let (_, tx) = ready_txs.remove(0);
tx.send(()).unwrap();
let result = gate.wait_all_ready(Duration::from_millis(50)).await;
assert!(result.is_err());
if let Err(GateError::Timeout { not_ready }) = result {
assert_eq!(not_ready, vec!["server2".to_string()]);
} else {
panic!("Expected GateError::Timeout, got {result:?}");
}
}
#[tokio::test]
async fn zero_servers_no_gate_needed() {
let servers: Vec<String> = vec![];
let (gate, ready_txs) = ReadinessGate::new(&servers);
assert!(ready_txs.is_empty());
gate.wait_all_ready(Duration::from_secs(1)).await.unwrap();
}
#[tokio::test]
async fn dropped_sender_detected() {
let servers = vec!["server1".to_string()];
let (gate, ready_txs) = ReadinessGate::new(&servers);
drop(ready_txs);
let result = gate.wait_all_ready(Duration::from_secs(1)).await;
assert!(result.is_err());
if let Err(GateError::ServerFailed { actor }) = result {
assert_eq!(actor, "server1");
} else {
panic!("Expected GateError::ServerFailed, got {result:?}");
}
}
#[tokio::test]
async fn subscribe_before_ready() {
let servers = vec!["server1".to_string()];
let (gate, ready_txs) = ReadinessGate::new(&servers);
let mut client_rx = gate.subscribe();
let gate_handle =
tokio::spawn(async move { gate.wait_all_ready(Duration::from_secs(5)).await });
let (_, tx) = ready_txs.into_iter().next().unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
tx.send(()).unwrap();
gate_handle.await.unwrap().unwrap();
assert!(client_rx.recv().await.is_ok());
}
}