use std::collections::HashMap;
use std::future::poll_fn;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll, Waker};
use std::time::{Duration, Instant};
use magnetar_proto::{HealthProbe, ServiceUrlProvider};
use moonpool_core::{NetworkProvider, Providers, TaskProvider, TimeProvider};
use parking_lot::Mutex;
use tokio::sync::Notify;
#[allow(
missing_debug_implementations,
reason = "manual Debug impl below, intentionally excludes the probe trait object"
)]
pub struct AutoClusterFailover<P: Providers> {
urls: Arc<Vec<String>>,
probe: Arc<dyn HealthProbe>,
active: Arc<AtomicUsize>,
_providers: std::marker::PhantomData<fn() -> P>,
}
impl<P: Providers> Clone for AutoClusterFailover<P> {
fn clone(&self) -> Self {
Self {
urls: self.urls.clone(),
probe: self.probe.clone(),
active: self.active.clone(),
_providers: std::marker::PhantomData,
}
}
}
impl<P: Providers> std::fmt::Debug for AutoClusterFailover<P> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AutoClusterFailover")
.field("urls", &self.urls)
.field("probe", &self.probe)
.field("active_index", &self.active.load(Ordering::Relaxed))
.finish()
}
}
impl<P: Providers> AutoClusterFailover<P> {
#[must_use]
pub fn new(urls: Vec<String>, probe: Arc<dyn HealthProbe>) -> Self {
Self::new_with_probe(urls, probe)
}
#[must_use]
pub fn new_with_probe(urls: Vec<String>, probe: Arc<dyn HealthProbe>) -> Self {
assert!(
!urls.is_empty(),
"AutoClusterFailover requires at least one URL"
);
Self {
urls: Arc::new(urls),
probe,
active: Arc::new(AtomicUsize::new(0)),
_providers: std::marker::PhantomData,
}
}
pub fn start(&self, providers: &P, interval: Duration) -> FailoverProbeHandle<P> {
let urls = self.urls.clone();
let probe = self.probe.clone();
let active = self.active.clone();
let time = providers.time().clone();
let clock_anchor = Instant::now();
let time_for_deadline = time.clone();
let now_instant = move || {
clock_anchor
.checked_add(time_for_deadline.now())
.unwrap_or(clock_anchor)
};
let stop = Arc::new(Notify::new());
let stop_for_task = stop.clone();
let join = providers.task().spawn_task(
"magnetar-moonpool-auto-cluster-failover",
async move {
loop {
tokio::select! {
biased;
() = stop_for_task.notified() => return,
slept = time.sleep(interval) => {
if slept.is_err() {
return;
}
}
}
let deadline = now_instant() + interval;
let mut new_active: Option<usize> = None;
for (idx, url) in urls.iter().enumerate() {
let healthy = tokio::select! {
biased;
() = stop_for_task.notified() => return,
healthy = poll_fn(|cx| probe.poll_probe(url, deadline, cx)) => healthy,
};
if healthy {
new_active = Some(idx);
break;
}
}
if let Some(idx) = new_active {
let prev = active.swap(idx, Ordering::Relaxed);
if prev != idx {
tracing::info!(
from_index = prev,
to_index = idx,
to_url = %urls[idx],
"AutoClusterFailover (moonpool): switching active URL",
);
}
}
}
},
);
FailoverProbeHandle { stop, _join: join }
}
#[must_use]
pub fn active_index(&self) -> usize {
self.active.load(Ordering::Relaxed)
}
}
pub struct FailoverProbeHandle<P: Providers> {
stop: Arc<Notify>,
_join: <P::Task as TaskProvider>::JoinHandle,
}
impl<P: Providers> FailoverProbeHandle<P> {
pub fn abort(&self) {
self.stop.notify_one();
}
}
impl<P: Providers> Drop for FailoverProbeHandle<P> {
fn drop(&mut self) {
self.stop.notify_one();
}
}
impl<P: Providers> std::fmt::Debug for FailoverProbeHandle<P> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FailoverProbeHandle")
.finish_non_exhaustive()
}
}
impl<P: Providers> ServiceUrlProvider for AutoClusterFailover<P> {
fn get_service_url(&self) -> String {
let idx = self.active_index();
self.urls
.get(idx)
.or_else(|| self.urls.first())
.cloned()
.unwrap_or_default()
}
}
pub struct MoonpoolHealthProbe<P: Providers> {
inflight: Arc<Mutex<HashMap<String, Arc<Mutex<ProbeSlot>>>>>,
providers: P,
}
impl<P: Providers> Default for MoonpoolHealthProbe<P>
where
P: Default,
{
fn default() -> Self {
Self::new(P::default())
}
}
impl<P: Providers> std::fmt::Debug for MoonpoolHealthProbe<P> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let inflight_count = self.inflight.lock().len();
f.debug_struct("MoonpoolHealthProbe")
.field("inflight", &inflight_count)
.finish_non_exhaustive()
}
}
impl<P: Providers> MoonpoolHealthProbe<P> {
#[must_use]
pub fn new(providers: P) -> Self {
Self {
inflight: Arc::new(Mutex::new(HashMap::new())),
providers,
}
}
fn authority(endpoint: &str) -> Option<String> {
let stripped = endpoint
.strip_prefix("pulsar+ssl://")
.or_else(|| endpoint.strip_prefix("pulsar://"))
.unwrap_or(endpoint);
let auth = stripped.split('/').next().unwrap_or(stripped);
if auth.is_empty() {
None
} else {
Some(auth.to_owned())
}
}
fn spawn_probe(&self, endpoint: &str, deadline: Instant, slot: Arc<Mutex<ProbeSlot>>) {
let endpoint_owned = endpoint.to_owned();
let providers = self.providers.clone();
let inflight = self.inflight.clone();
let _detached =
providers
.task()
.clone()
.spawn_task("magnetar-moonpool-health-probe", async move {
let verdict = run_probe::<P>(&providers, &endpoint_owned, deadline).await;
let waker_opt = {
let mut g = slot.lock();
g.verdict = Some(verdict);
g.waker.take()
};
if let Some(w) = waker_opt {
w.wake();
}
let _ = inflight; });
}
}
#[derive(Default)]
struct ProbeSlot {
verdict: Option<bool>,
waker: Option<Waker>,
spawned: bool,
}
async fn run_probe<P: Providers>(providers: &P, endpoint: &str, deadline: Instant) -> bool {
let Some(authority) = MoonpoolHealthProbe::<P>::authority(endpoint) else {
tracing::debug!(endpoint = %endpoint, "MoonpoolHealthProbe: cannot parse endpoint");
return false;
};
let clock_anchor = Instant::now();
let now = clock_anchor
.checked_add(providers.time().now())
.unwrap_or(clock_anchor);
let dur = deadline.saturating_duration_since(now);
let connect_fut = async {
match providers.network().connect(&authority).await {
Ok(stream) => {
let _stream: <P::Network as NetworkProvider>::TcpStream = stream;
true
}
Err(e) => {
tracing::debug!(
authority = %authority,
error = %e,
"MoonpoolHealthProbe: connect failed",
);
false
}
}
};
if let Ok(verdict) = providers.time().timeout(dur, connect_fut).await {
verdict
} else {
tracing::debug!(
authority = %authority,
"MoonpoolHealthProbe: connect timed out",
);
false
}
}
impl<P: Providers + Send + Sync> HealthProbe for MoonpoolHealthProbe<P> {
fn poll_probe(&self, endpoint: &str, deadline: Instant, cx: &mut Context<'_>) -> Poll<bool> {
let slot = {
let mut map = self.inflight.lock();
map.entry(endpoint.to_owned())
.or_insert_with(|| Arc::new(Mutex::new(ProbeSlot::default())))
.clone()
};
let verdict_opt = {
let mut g = slot.lock();
if let Some(v) = g.verdict.take() {
Some(v)
} else {
g.waker = Some(cx.waker().clone());
None
}
};
if let Some(v) = verdict_opt {
let mut map = self.inflight.lock();
if let Some(current) = map.get(endpoint) {
if Arc::ptr_eq(current, &slot) {
map.remove(endpoint);
}
}
return Poll::Ready(v);
}
let needs_spawn = {
let mut g = slot.lock();
if g.verdict.is_some() {
false
} else if g.spawned {
false
} else {
g.spawned = true;
true
}
};
if needs_spawn {
self.spawn_probe(endpoint, deadline, slot.clone());
let claimed = {
let mut g = slot.lock();
g.verdict.take()
};
if let Some(v) = claimed {
let mut map = self.inflight.lock();
if let Some(current) = map.get(endpoint) {
if Arc::ptr_eq(current, &slot) {
map.remove(endpoint);
}
}
return Poll::Ready(v);
}
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use magnetar_proto::{HealthProbe, ServiceUrlProvider};
use moonpool_core::TokioProviders;
use super::{AutoClusterFailover, MoonpoolHealthProbe};
#[derive(Debug)]
struct ConstProbe(bool);
impl HealthProbe for ConstProbe {
fn poll_probe(
&self,
_endpoint: &str,
_deadline: Instant,
_cx: &mut Context<'_>,
) -> Poll<bool> {
Poll::Ready(self.0)
}
}
#[test]
fn empty_url_list_panics() {
let r = std::panic::catch_unwind(|| {
AutoClusterFailover::<TokioProviders>::new(vec![], Arc::new(ConstProbe(true)))
});
assert!(r.is_err());
}
#[test]
fn new_with_probe_is_an_alias_of_new() {
let a = AutoClusterFailover::<TokioProviders>::new(
vec!["pulsar://a:6650".into(), "pulsar://b:6650".into()],
Arc::new(ConstProbe(true)),
);
let b = AutoClusterFailover::<TokioProviders>::new_with_probe(
vec!["pulsar://a:6650".into(), "pulsar://b:6650".into()],
Arc::new(ConstProbe(true)),
);
assert_eq!(a.active_index(), b.active_index());
assert_eq!(a.get_service_url(), b.get_service_url());
}
#[test]
fn initial_active_is_primary() {
let f = AutoClusterFailover::<TokioProviders>::new(
vec!["pulsar://a:6650".into(), "pulsar://b:6650".into()],
Arc::new(ConstProbe(true)),
);
assert_eq!(f.active_index(), 0);
assert_eq!(f.get_service_url(), "pulsar://a:6650");
}
#[tokio::test(flavor = "current_thread")]
async fn failover_switches_on_unhealthy_primary() {
#[derive(Debug)]
struct Flipping {
primary_healthy: AtomicUsize,
}
impl HealthProbe for Flipping {
fn poll_probe(
&self,
endpoint: &str,
_deadline: Instant,
_cx: &mut Context<'_>,
) -> Poll<bool> {
let healthy = if endpoint.contains("primary") {
self.primary_healthy.load(Ordering::SeqCst) != 0
} else {
true
};
Poll::Ready(healthy)
}
}
let local = tokio::task::LocalSet::new();
local
.run_until(async {
let providers = TokioProviders::new();
let probe = Arc::new(Flipping {
primary_healthy: AtomicUsize::new(1),
});
let f = AutoClusterFailover::<TokioProviders>::new(
vec![
"pulsar://primary:6650".into(),
"pulsar://standby:6650".into(),
],
probe.clone(),
);
let tick = Duration::from_millis(40);
let handle = f.start(&providers, tick);
tokio::time::sleep(tick + Duration::from_millis(10)).await;
assert_eq!(f.active_index(), 0);
probe.primary_healthy.store(0, Ordering::SeqCst);
tokio::time::sleep(tick + Duration::from_millis(10)).await;
assert_eq!(f.active_index(), 1);
assert_eq!(f.get_service_url(), "pulsar://standby:6650");
probe.primary_healthy.store(1, Ordering::SeqCst);
tokio::time::sleep(tick + Duration::from_millis(10)).await;
assert_eq!(f.active_index(), 0);
handle.abort();
})
.await;
}
#[test]
fn moonpool_probe_authority_strips_pulsar_scheme() {
assert_eq!(
MoonpoolHealthProbe::<TokioProviders>::authority("pulsar://broker.local:6650"),
Some("broker.local:6650".to_owned()),
);
assert_eq!(
MoonpoolHealthProbe::<TokioProviders>::authority("pulsar+ssl://broker.local:6651"),
Some("broker.local:6651".to_owned()),
);
}
#[test]
fn moonpool_probe_authority_passes_through_bare_host_port() {
assert_eq!(
MoonpoolHealthProbe::<TokioProviders>::authority("127.0.0.1:6650"),
Some("127.0.0.1:6650".to_owned()),
);
}
#[test]
fn moonpool_probe_authority_trims_trailing_path() {
assert_eq!(
MoonpoolHealthProbe::<TokioProviders>::authority("pulsar://broker.local:6650/admin/v2"),
Some("broker.local:6650".to_owned()),
);
}
#[test]
fn moonpool_probe_authority_rejects_empty_input() {
assert_eq!(MoonpoolHealthProbe::<TokioProviders>::authority(""), None,);
}
#[tokio::test(flavor = "current_thread")]
async fn moonpool_probe_reports_healthy_for_live_listener() {
let local = tokio::task::LocalSet::new();
local
.run_until(async {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
let addr = listener.local_addr().expect("local_addr");
let accept = tokio::spawn(async move {
let _ = listener.accept().await;
});
let probe = MoonpoolHealthProbe::new(TokioProviders::new());
let endpoint = format!("pulsar://{addr}");
let deadline = Instant::now() + Duration::from_secs(2);
let verdict =
std::future::poll_fn(|cx| probe.poll_probe(&endpoint, deadline, cx)).await;
assert!(verdict, "live listener must read healthy");
accept.abort();
})
.await;
}
#[tokio::test(flavor = "current_thread")]
async fn moonpool_probe_reports_unhealthy_for_closed_port() {
let local = tokio::task::LocalSet::new();
local
.run_until(async {
let probe_port = {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
listener.local_addr().expect("local_addr").port()
};
let probe = MoonpoolHealthProbe::new(TokioProviders::new());
let endpoint = format!("127.0.0.1:{probe_port}");
let deadline = Instant::now() + Duration::from_secs(2);
let verdict =
std::future::poll_fn(|cx| probe.poll_probe(&endpoint, deadline, cx)).await;
assert!(!verdict, "closed port must read unhealthy");
})
.await;
}
}