use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use parking_lot::Mutex;
use rustc_hash::{FxHashMap, FxHashSet};
use serde::{Deserialize, Serialize};
use crate::cluster::discovery::NodeId;
#[cfg(feature = "cluster")]
use crate::cluster::discovery::{NodeInfo, NodeState};
#[cfg(feature = "cluster")]
use tokio::sync::watch;
pub const ANNOUNCEMENT_KEY: &str = "control:barrier";
pub const ACK_KEY: &str = "control:barrier-ack";
#[cfg(feature = "cluster")]
pub const BARRIER_ADDR_KEY: &str = "barrier:addr";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Phase {
Prepare,
Aligned,
Commit,
Abort,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BarrierAnnouncement {
pub epoch: u64,
pub checkpoint_id: u64,
pub phase: Phase,
pub flags: u64,
#[serde(default)]
pub min_watermark_ms: Option<i64>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BarrierAck {
pub epoch: u64,
pub ok: bool,
pub error: Option<String>,
#[serde(default)]
pub local_watermark_ms: Option<i64>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QuorumOutcome {
Reached {
acks: Vec<NodeId>,
min_follower_watermark_ms: Option<i64>,
},
TimedOut {
got: Vec<NodeId>,
missing: Vec<NodeId>,
},
Failed {
failures: Vec<(NodeId, String)>,
},
}
#[async_trait]
pub trait ClusterKv: Send + Sync + 'static {
async fn write(&self, key: &str, value: String);
async fn read_from(&self, who: NodeId, key: &str) -> Option<String>;
async fn scan(&self, key: &str) -> Vec<(NodeId, String)>;
async fn scan_prefix(&self, prefix: &str) -> Vec<(NodeId, String, String)>;
fn supports_subscription_routing(&self) -> bool {
true
}
}
#[derive(Debug)]
pub struct InMemoryKv {
local_id: NodeId,
state: Mutex<FxHashMap<(NodeId, String), String>>,
}
impl InMemoryKv {
#[must_use]
pub fn new(local_id: NodeId) -> Self {
Self {
local_id,
state: Mutex::new(FxHashMap::default()),
}
}
pub fn seed(&self, peer: NodeId, key: &str, value: String) {
self.state.lock().insert((peer, key.to_string()), value);
}
}
#[async_trait]
impl ClusterKv for InMemoryKv {
async fn write(&self, key: &str, value: String) {
self.state
.lock()
.insert((self.local_id, key.to_string()), value);
}
async fn read_from(&self, who: NodeId, key: &str) -> Option<String> {
self.state.lock().get(&(who, key.to_string())).cloned()
}
async fn scan(&self, key: &str) -> Vec<(NodeId, String)> {
self.state
.lock()
.iter()
.filter(|((_, k), _)| k == key)
.map(|((n, _), v)| (*n, v.clone()))
.collect()
}
async fn scan_prefix(&self, prefix: &str) -> Vec<(NodeId, String, String)> {
self.state
.lock()
.iter()
.filter(|((_, k), _)| k.starts_with(prefix))
.map(|((n, k), v)| (*n, k.clone(), v.clone()))
.collect()
}
}
#[cfg(feature = "cluster")]
#[allow(
clippy::doc_markdown,
clippy::default_trait_access,
clippy::missing_const_for_fn,
clippy::must_use_candidate,
clippy::too_many_lines,
missing_docs
)]
pub(crate) mod barrier_v1 {
tonic::include_proto!("laminar.barrier.v1");
}
#[cfg(feature = "cluster")]
type BarrierFlavor = crossfire::mpsc::Array<BarrierAnnouncement>;
#[cfg(feature = "cluster")]
type BarrierClientPool = Arc<
parking_lot::Mutex<
FxHashMap<
NodeId,
barrier_v1::barrier_sync_client::BarrierSyncClient<tonic::transport::Channel>,
>,
>,
>;
#[cfg(feature = "cluster")]
struct GrpcState {
latest_rx: watch::Receiver<Option<BarrierAnnouncement>>,
#[allow(dead_code)]
incoming_tx: crossfire::MAsyncTx<BarrierFlavor>,
pending_acks: Arc<parking_lot::Mutex<FxHashMap<u64, tokio::sync::oneshot::Sender<BarrierAck>>>>,
completed_acks: Arc<parking_lot::Mutex<FxHashMap<u64, BarrierAck>>>,
clients: BarrierClientPool,
server_handle: Arc<parking_lot::Mutex<Option<tokio::task::JoinHandle<()>>>>,
relay_handle: Arc<parking_lot::Mutex<Option<tokio::task::JoinHandle<()>>>>,
advertise_addr: String,
}
#[cfg(feature = "cluster")]
type ActiveLeaderState = Option<(NodeId, watch::Receiver<Vec<NodeInfo>>)>;
#[cfg(feature = "cluster")]
struct GrpcBarrierServer {
kv: Arc<dyn ClusterKv>,
incoming_tx: crossfire::MAsyncTx<BarrierFlavor>,
pending_acks: Arc<parking_lot::Mutex<FxHashMap<u64, tokio::sync::oneshot::Sender<BarrierAck>>>>,
completed_acks: Arc<parking_lot::Mutex<FxHashMap<u64, BarrierAck>>>,
leader_election: Arc<parking_lot::Mutex<ActiveLeaderState>>,
}
#[cfg(feature = "cluster")]
impl GrpcBarrierServer {
async fn validate_leader(
&self,
metadata: &tonic::metadata::MetadataMap,
) -> Result<(), tonic::Status> {
let leader_id_str = metadata
.get("x-leader-id")
.ok_or_else(|| tonic::Status::permission_denied("Missing leader identity"))?
.to_str()
.map_err(|_| tonic::Status::permission_denied("Invalid leader identity"))?;
let leader_id_u64 = leader_id_str
.parse::<u64>()
.map_err(|_| tonic::Status::permission_denied("Invalid leader identity"))?;
let sender_leader_id = NodeId(leader_id_u64);
let election_state = self.leader_election.lock().clone();
let observed_leader = if let Some((instance_id, members_rx)) = election_state {
let members = members_rx.borrow();
let mut ids: Vec<NodeId> = members
.iter()
.filter(|m| matches!(m.state, NodeState::Active))
.map(|m| m.id)
.collect();
ids.push(instance_id);
super::leader_of(&ids)
} else {
let live_nodes: Vec<NodeId> = self
.kv
.scan(BARRIER_ADDR_KEY)
.await
.into_iter()
.map(|(id, _)| id)
.collect();
super::leader_of(&live_nodes)
};
if Some(sender_leader_id) != observed_leader {
return Err(tonic::Status::permission_denied(
"Sender is not the observed leader",
));
}
Ok(())
}
}
#[cfg(feature = "cluster")]
#[tonic::async_trait]
impl barrier_v1::barrier_sync_server::BarrierSync for GrpcBarrierServer {
async fn prepare(
&self,
request: tonic::Request<barrier_v1::PrepareRequest>,
) -> Result<tonic::Response<barrier_v1::Ack>, tonic::Status> {
let validation_res = self.validate_leader(request.metadata()).await;
let req = request.into_inner();
{
let mut completed = self.completed_acks.lock();
if let Some(ack) = completed.remove(&req.epoch) {
validation_res?;
return Ok(tonic::Response::new(barrier_v1::Ack {
epoch: ack.epoch,
ok: ack.ok,
error: ack.error,
local_watermark_ms: ack.local_watermark_ms,
}));
}
}
let (tx, rx) = tokio::sync::oneshot::channel::<BarrierAck>();
{
let mut guard = self.pending_acks.lock();
guard.insert(req.epoch, tx);
}
if let Err(status) = validation_res {
let mut guard = self.pending_acks.lock();
guard.remove(&req.epoch);
return Err(status);
}
let ann = BarrierAnnouncement {
epoch: req.epoch,
checkpoint_id: req.checkpoint_id,
phase: Phase::Prepare,
flags: req.flags,
min_watermark_ms: None,
};
if self.incoming_tx.send(ann).await.is_err() {
let mut guard = self.pending_acks.lock();
guard.remove(&req.epoch);
return Err(tonic::Status::aborted("Follower coordinator shutdown"));
}
match tokio::time::timeout(Duration::from_secs(30), rx).await {
Ok(Ok(ack)) => Ok(tonic::Response::new(barrier_v1::Ack {
epoch: ack.epoch,
ok: ack.ok,
error: ack.error,
local_watermark_ms: ack.local_watermark_ms,
})),
Ok(Err(_)) => Err(tonic::Status::internal("Ack sender dropped")),
Err(_) => {
let mut guard = self.pending_acks.lock();
guard.remove(&req.epoch);
Err(tonic::Status::deadline_exceeded(
"Follower checkpoint prepare timed out",
))
}
}
}
async fn aligned(
&self,
request: tonic::Request<barrier_v1::AlignedRequest>,
) -> Result<tonic::Response<barrier_v1::Ack>, tonic::Status> {
self.validate_leader(request.metadata()).await?;
let req = request.into_inner();
let ann = BarrierAnnouncement {
epoch: req.epoch,
checkpoint_id: req.checkpoint_id,
phase: Phase::Aligned,
flags: req.flags,
min_watermark_ms: req.min_watermark_ms,
};
if self.incoming_tx.send(ann).await.is_err() {
return Err(tonic::Status::aborted("Follower coordinator shutdown"));
}
Ok(tonic::Response::new(barrier_v1::Ack {
epoch: req.epoch,
ok: true,
error: None,
local_watermark_ms: None,
}))
}
async fn commit(
&self,
request: tonic::Request<barrier_v1::CommitRequest>,
) -> Result<tonic::Response<barrier_v1::Ack>, tonic::Status> {
self.validate_leader(request.metadata()).await?;
let req = request.into_inner();
{
let mut completed = self.completed_acks.lock();
completed.remove(&req.epoch);
completed.retain(|&epoch, _| epoch >= req.epoch);
}
let ann = BarrierAnnouncement {
epoch: req.epoch,
checkpoint_id: req.checkpoint_id,
phase: Phase::Commit,
flags: req.flags,
min_watermark_ms: req.min_watermark_ms,
};
if self.incoming_tx.send(ann).await.is_err() {
return Err(tonic::Status::aborted("Follower coordinator shutdown"));
}
Ok(tonic::Response::new(barrier_v1::Ack {
epoch: req.epoch,
ok: true,
error: None,
local_watermark_ms: None,
}))
}
async fn abort(
&self,
request: tonic::Request<barrier_v1::AbortRequest>,
) -> Result<tonic::Response<barrier_v1::Ack>, tonic::Status> {
self.validate_leader(request.metadata()).await?;
let req = request.into_inner();
{
let mut completed = self.completed_acks.lock();
completed.remove(&req.epoch);
completed.retain(|&epoch, _| epoch >= req.epoch);
}
let ann = BarrierAnnouncement {
epoch: req.epoch,
checkpoint_id: req.checkpoint_id,
phase: Phase::Abort,
flags: req.flags,
min_watermark_ms: None,
};
if self.incoming_tx.send(ann).await.is_err() {
return Err(tonic::Status::aborted("Follower coordinator shutdown"));
}
Ok(tonic::Response::new(barrier_v1::Ack {
epoch: req.epoch,
ok: true,
error: None,
local_watermark_ms: None,
}))
}
}
#[cfg(feature = "cluster")]
async fn get_barrier_client(
peer: NodeId,
pool: &BarrierClientPool,
kv: &Arc<dyn ClusterKv>,
) -> Option<barrier_v1::barrier_sync_client::BarrierSyncClient<tonic::transport::Channel>> {
if let Some(client) = pool.lock().get(&peer) {
return Some(client.clone());
}
let addr_str = kv.read_from(peer, BARRIER_ADDR_KEY).await?;
let endpoint = super::tls::client_endpoint(&addr_str).ok()?;
let channel = endpoint.connect_lazy();
let client = barrier_v1::barrier_sync_client::BarrierSyncClient::new(channel);
pool.lock().insert(peer, client.clone());
Some(client)
}
#[cfg(feature = "cluster")]
fn stamp_leader_id<T>(req: &mut tonic::Request<T>, local_id: Option<NodeId>) {
if let Some(lid) = local_id {
if let Ok(val) = lid.0.to_string().parse() {
req.metadata_mut().insert("x-leader-id", val);
}
}
}
#[cfg(feature = "cluster")]
async fn send_phase_rpc(
peer: NodeId,
clients_pool: BarrierClientPool,
kv: Arc<dyn ClusterKv>,
ann: BarrierAnnouncement,
local_id: Option<NodeId>,
) -> Result<(), String> {
let mut client = get_barrier_client(peer, &clients_pool, &kv)
.await
.ok_or_else(|| format!("failed to get client for peer {}", peer.0))?;
let result = match ann.phase {
Phase::Aligned => {
let mut req = tonic::Request::new(barrier_v1::AlignedRequest {
epoch: ann.epoch,
checkpoint_id: ann.checkpoint_id,
flags: ann.flags,
min_watermark_ms: ann.min_watermark_ms,
});
stamp_leader_id(&mut req, local_id);
client
.aligned(req)
.await
.map(|_| ())
.map_err(|e| ("aligned", e))
}
Phase::Commit => {
let mut req = tonic::Request::new(barrier_v1::CommitRequest {
epoch: ann.epoch,
checkpoint_id: ann.checkpoint_id,
flags: ann.flags,
min_watermark_ms: ann.min_watermark_ms,
});
stamp_leader_id(&mut req, local_id);
client
.commit(req)
.await
.map(|_| ())
.map_err(|e| ("commit", e))
}
Phase::Abort => {
let mut req = tonic::Request::new(barrier_v1::AbortRequest {
epoch: ann.epoch,
checkpoint_id: ann.checkpoint_id,
flags: ann.flags,
});
stamp_leader_id(&mut req, local_id);
client
.abort(req)
.await
.map(|_| ())
.map_err(|e| ("abort", e))
}
Phase::Prepare => Ok(()),
};
result.map_err(|(rpc, e)| {
clients_pool.lock().remove(&peer);
format!("{rpc} RPC to peer {} failed: {e}", peer.0)
})
}
#[cfg(feature = "cluster")]
enum PeerFailure {
Unreachable,
Nack(String),
}
pub struct BarrierCoordinator {
kv: Arc<dyn ClusterKv>,
#[cfg(feature = "cluster")]
grpc: Arc<parking_lot::Mutex<Option<Arc<GrpcState>>>>,
#[cfg(feature = "cluster")]
leader_election: Arc<parking_lot::Mutex<ActiveLeaderState>>,
}
impl std::fmt::Debug for BarrierCoordinator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BarrierCoordinator").finish_non_exhaustive()
}
}
impl Drop for BarrierCoordinator {
fn drop(&mut self) {
#[cfg(feature = "cluster")]
{
let grpc_opt = self.grpc.lock().take();
if let Some(state) = grpc_opt {
let handle_opt = state.server_handle.lock().take();
if let Some(handle) = handle_opt {
handle.abort();
}
let relay_opt = state.relay_handle.lock().take();
if let Some(handle) = relay_opt {
handle.abort();
}
}
}
}
}
impl BarrierCoordinator {
#[must_use]
pub fn new(kv: Arc<dyn ClusterKv>) -> Self {
Self {
kv,
#[cfg(feature = "cluster")]
grpc: Arc::new(parking_lot::Mutex::new(None)),
#[cfg(feature = "cluster")]
leader_election: Arc::new(parking_lot::Mutex::new(None)),
}
}
#[cfg(feature = "cluster")]
pub fn set_leader_election(
&mut self,
instance_id: NodeId,
members_rx: watch::Receiver<Vec<NodeInfo>>,
) {
*self.leader_election.lock() = Some((instance_id, members_rx));
}
#[cfg(feature = "cluster")]
async fn local_node_id(&self) -> Option<NodeId> {
let grpc_opt = self.grpc.lock().clone();
let state = grpc_opt?;
let local_addr_str = state.advertise_addr.clone();
for (node_id, addr) in self.kv.scan(BARRIER_ADDR_KEY).await {
if addr == local_addr_str {
return Some(node_id);
}
}
None
}
#[cfg(feature = "cluster")]
pub async fn start_server(
&self,
bind_addr: std::net::SocketAddr,
advertise_host: Option<String>,
query_handler: super::QueryHandlerSlot,
) -> Result<std::net::SocketAddr, String> {
use super::query::query_service_server;
use barrier_v1::barrier_sync_server::BarrierSyncServer;
use std::net::TcpListener;
use tonic::transport::Server;
let listener = TcpListener::bind(bind_addr).map_err(|e| e.to_string())?;
let local_addr = listener.local_addr().map_err(|e| e.to_string())?;
listener.set_nonblocking(true).map_err(|e| e.to_string())?;
let tokio_listener =
tokio::net::TcpListener::from_std(listener).map_err(|e| e.to_string())?;
let (incoming_tx, incoming_rx) = crossfire::mpsc::bounded_async::<BarrierAnnouncement>(128);
let pending_acks = Arc::new(parking_lot::Mutex::new(FxHashMap::default()));
let completed_acks = Arc::new(parking_lot::Mutex::new(FxHashMap::default()));
let clients = Arc::new(parking_lot::Mutex::new(FxHashMap::default()));
let server_impl = GrpcBarrierServer {
kv: Arc::clone(&self.kv),
incoming_tx: incoming_tx.clone(),
pending_acks: Arc::clone(&pending_acks),
completed_acks: Arc::clone(&completed_acks),
leader_election: Arc::clone(&self.leader_election),
};
let query_svc = query_service_server(query_handler);
let mut builder = Server::builder();
if let Some(tls) = super::tls::server_tls() {
builder = builder
.tls_config(tls.clone())
.map_err(|e| format!("cluster control-plane TLS config: {e}"))?;
}
let router = builder
.add_service(BarrierSyncServer::new(server_impl))
.add_service(query_svc);
let server_task = tokio::spawn(async move {
let incoming_stream = tokio_stream::wrappers::TcpListenerStream::new(tokio_listener);
let _ = router.serve_with_incoming(incoming_stream).await;
});
let advertise_addr = if let Some(ref host) = advertise_host {
format!("{host}:{}", local_addr.port())
} else if local_addr.ip().is_unspecified() {
let hostname = gethostname::gethostname();
let hostname = hostname.to_string_lossy();
if hostname.is_empty() {
local_addr.to_string()
} else {
format!("{hostname}:{}", local_addr.port())
}
} else {
local_addr.to_string()
};
let (latest_tx, latest_rx) = watch::channel::<Option<BarrierAnnouncement>>(None);
let relay_task = tokio::spawn(async move {
while let Ok(ann) = incoming_rx.recv().await {
let _ = latest_tx.send(Some(ann));
}
});
let grpc_state = Arc::new(GrpcState {
latest_rx,
incoming_tx,
pending_acks,
completed_acks,
clients,
server_handle: Arc::new(parking_lot::Mutex::new(Some(server_task))),
relay_handle: Arc::new(parking_lot::Mutex::new(Some(relay_task))),
advertise_addr: advertise_addr.clone(),
});
*self.grpc.lock() = Some(grpc_state);
self.kv.write(BARRIER_ADDR_KEY, advertise_addr).await;
Ok(local_addr)
}
pub async fn announce(&self, ann: &BarrierAnnouncement) -> Result<(), String> {
#[cfg(feature = "cluster")]
{
let grpc_opt = self.grpc.lock().clone();
if let Some(state) = grpc_opt {
let local_id = self.local_node_id().await;
if ann.phase == Phase::Prepare {
} else {
let mut expected = Vec::new();
for (node_id, addr) in self.kv.scan(BARRIER_ADDR_KEY).await {
if addr == state.advertise_addr {
continue;
}
expected.push(node_id);
}
let mut futures = Vec::new();
for peer in expected {
let clients_pool = Arc::clone(&state.clients);
let kv = Arc::clone(&self.kv);
let ann_clone = ann.clone();
futures.push(send_phase_rpc(peer, clients_pool, kv, ann_clone, local_id));
}
let results = futures::future::join_all(futures).await;
for res in results {
match res {
Ok(()) => {}
Err(e) if ann.phase == Phase::Aligned => {
tracing::warn!(
epoch = ann.epoch,
error = %e,
"aligned announcement RPC failed; peer resumes on Commit"
);
}
Err(e) => return Err(e),
}
}
}
let json = serde_json::to_string(ann).map_err(|e| e.to_string())?;
self.kv.write(ANNOUNCEMENT_KEY, json).await;
return Ok(());
}
}
let json = serde_json::to_string(ann).map_err(|e| e.to_string())?;
self.kv.write(ANNOUNCEMENT_KEY, json).await;
Ok(())
}
#[cfg(feature = "cluster")]
#[must_use]
pub fn announcement_watch(&self) -> Option<watch::Receiver<Option<BarrierAnnouncement>>> {
self.grpc.lock().as_ref().map(|s| s.latest_rx.clone())
}
pub async fn observe(&self, leader: NodeId) -> Result<Option<BarrierAnnouncement>, String> {
#[cfg(feature = "cluster")]
let grpc_latest: Option<BarrierAnnouncement> = {
let grpc_opt = self.grpc.lock().clone();
grpc_opt.and_then(|state| state.latest_rx.borrow().clone())
};
#[cfg(not(feature = "cluster"))]
let grpc_latest: Option<BarrierAnnouncement> = None;
let kv_latest: Option<BarrierAnnouncement> =
match self.kv.read_from(leader, ANNOUNCEMENT_KEY).await {
Some(json) => match serde_json::from_str(&json) {
Ok(a) => Some(a),
Err(e) if grpc_latest.is_some() => {
tracing::warn!(error = %e, "corrupt gossip announcement; using gRPC value");
None
}
Err(e) => return Err(e.to_string()),
},
None => None,
};
Ok(match (grpc_latest, kv_latest) {
(Some(g), Some(k)) => {
if k.epoch > g.epoch {
Some(k)
} else {
Some(g)
}
}
(Some(g), None) => Some(g),
(None, k) => k,
})
}
pub async fn ack(&self, ack: &BarrierAck) -> Result<(), String> {
#[cfg(feature = "cluster")]
{
let grpc_opt = self.grpc.lock().clone();
if let Some(state) = grpc_opt {
{
let mut completed = state.completed_acks.lock();
completed.insert(ack.epoch, ack.clone());
}
let tx_opt = {
let mut guard = state.pending_acks.lock();
guard.remove(&ack.epoch)
};
if let Some(tx) = tx_opt {
let _ = tx.send(ack.clone());
}
return Ok(());
}
}
let json = serde_json::to_string(ack).map_err(|e| e.to_string())?;
self.kv.write(ACK_KEY, json).await;
Ok(())
}
#[allow(clippy::too_many_lines)]
pub async fn wait_for_quorum(
&self,
epoch: u64,
expected: &[NodeId],
deadline: Duration,
) -> QuorumOutcome {
#[cfg(feature = "cluster")]
{
let grpc_opt = self.grpc.lock().clone();
if let Some(state) = grpc_opt {
let checkpoint_id =
match self
.kv
.scan(ANNOUNCEMENT_KEY)
.await
.into_iter()
.find(|(_, json)| {
serde_json::from_str::<BarrierAnnouncement>(json)
.is_ok_and(|a| a.epoch == epoch)
}) {
Some((_, json)) => serde_json::from_str::<BarrierAnnouncement>(&json)
.map_or(0, |a| a.checkpoint_id),
None => 0,
};
let local_id = self.local_node_id().await;
let mut futures = Vec::new();
for &peer in expected {
let clients_pool = Arc::clone(&state.clients);
let kv = Arc::clone(&self.kv);
futures.push(async move {
let client_opt = get_barrier_client(peer, &clients_pool, &kv).await;
let Some(mut client) = client_opt else {
return Err((peer, PeerFailure::Unreachable));
};
let mut req = tonic::Request::new(barrier_v1::PrepareRequest {
epoch,
checkpoint_id,
flags: 0,
});
stamp_leader_id(&mut req, local_id);
match tokio::time::timeout(deadline, client.prepare(req)).await {
Ok(Ok(response)) => {
let ack = response.into_inner();
if ack.ok {
Ok((peer, ack.local_watermark_ms))
} else {
Err((
peer,
PeerFailure::Nack(ack.error.unwrap_or_else(|| {
"Unknown prepare failure".to_string()
})),
))
}
}
Ok(Err(status)) => {
clients_pool.lock().remove(&peer);
match status.code() {
tonic::Code::Unavailable
| tonic::Code::DeadlineExceeded
| tonic::Code::Cancelled
| tonic::Code::Aborted => Err((peer, PeerFailure::Unreachable)),
_ => Err((peer, PeerFailure::Nack(status.to_string()))),
}
}
Err(_) => {
clients_pool.lock().remove(&peer);
Err((peer, PeerFailure::Unreachable))
}
}
});
}
let results = futures::future::join_all(futures).await;
let mut successful = Vec::new();
let mut failures = Vec::new();
let mut min_follower_wm: Option<i64> = None;
let mut timed_out = Vec::new();
for res in results {
match res {
Ok((peer, wm)) => {
successful.push(peer);
if let Some(w) = wm {
min_follower_wm = Some(match min_follower_wm {
Some(cur) => cur.min(w),
None => w,
});
}
}
Err((peer, PeerFailure::Unreachable)) => timed_out.push(peer),
Err((peer, PeerFailure::Nack(msg))) => failures.push((peer, msg)),
}
}
if !failures.is_empty() {
return QuorumOutcome::Failed { failures };
}
if !timed_out.is_empty() || successful.len() < expected.len() {
let got = successful;
let mut missing = timed_out;
for &peer in expected {
if !got.contains(&peer) && !missing.contains(&peer) {
missing.push(peer);
}
}
return QuorumOutcome::TimedOut { got, missing };
}
return QuorumOutcome::Reached {
acks: successful,
min_follower_watermark_ms: min_follower_wm,
};
}
}
let start = Instant::now();
let expected_set: FxHashSet<NodeId> = expected.iter().copied().collect();
let mut successful: Vec<NodeId> = Vec::new();
let mut failures: Vec<(NodeId, String)> = Vec::new();
let mut min_follower_wm: Option<i64>;
loop {
successful.clear();
failures.clear();
min_follower_wm = None;
for (from, json) in self.kv.scan(ACK_KEY).await {
if !expected_set.contains(&from) {
continue;
}
let Ok(ack) = serde_json::from_str::<BarrierAck>(&json) else {
continue;
};
if ack.epoch != epoch {
continue;
}
if ack.ok {
successful.push(from);
if let Some(wm) = ack.local_watermark_ms {
min_follower_wm = Some(match min_follower_wm {
Some(cur) => cur.min(wm),
None => wm,
});
}
} else {
failures.push((from, ack.error.unwrap_or_default()));
}
}
if !failures.is_empty() {
return QuorumOutcome::Failed { failures };
}
if successful.len() == expected.len() {
return QuorumOutcome::Reached {
acks: successful,
min_follower_watermark_ms: min_follower_wm,
};
}
if start.elapsed() >= deadline {
let got: FxHashSet<NodeId> = successful.iter().copied().collect();
let missing: Vec<NodeId> = expected
.iter()
.copied()
.filter(|n| !got.contains(n))
.collect();
return QuorumOutcome::TimedOut {
got: successful,
missing,
};
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn kv(id: NodeId) -> Arc<InMemoryKv> {
Arc::new(InMemoryKv::new(id))
}
#[cfg(all(test, feature = "cluster"))]
mod grpc_tests {
use super::*;
use std::net::SocketAddr;
async fn wait_observe(
coord: &BarrierCoordinator,
leader: NodeId,
phase: Phase,
) -> BarrierAnnouncement {
for _ in 0..100 {
if let Some(ann) = coord.observe(leader).await.unwrap() {
if ann.phase == phase {
return ann;
}
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
panic!("timed out waiting for {phase:?} announcement from leader {leader:?}");
}
#[tokio::test]
async fn test_grpc_barrier_flow() {
let leader_kv = kv(NodeId(1));
let follower_kv = kv(NodeId(2));
let leader_coord = BarrierCoordinator::new(leader_kv.clone());
let follower_coord = BarrierCoordinator::new(follower_kv.clone());
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let slot = || Arc::new(parking_lot::RwLock::new(None));
let leader_addr = leader_coord.start_server(addr, None, slot()).await.unwrap();
let bound_addr = follower_coord
.start_server(addr, None, slot())
.await
.unwrap();
leader_kv.seed(NodeId(2), BARRIER_ADDR_KEY, bound_addr.to_string());
follower_kv.seed(NodeId(1), BARRIER_ADDR_KEY, leader_addr.to_string());
let (aligned_seen_tx, aligned_seen_rx) = tokio::sync::oneshot::channel::<()>();
let follower_task = tokio::spawn(async move {
let ann = wait_observe(&follower_coord, NodeId(1), Phase::Prepare).await;
assert_eq!(ann.epoch, 1);
assert_eq!(ann.checkpoint_id, 42);
follower_coord
.ack(&BarrierAck {
epoch: 1,
ok: true,
error: None,
local_watermark_ms: Some(100),
})
.await
.unwrap();
let aligned_ann = wait_observe(&follower_coord, NodeId(1), Phase::Aligned).await;
assert_eq!(aligned_ann.epoch, 1);
assert_eq!(aligned_ann.min_watermark_ms, Some(100));
aligned_seen_tx.send(()).unwrap();
let commit_ann = wait_observe(&follower_coord, NodeId(1), Phase::Commit).await;
assert_eq!(commit_ann.min_watermark_ms, Some(100));
});
leader_coord
.announce(&BarrierAnnouncement {
epoch: 1,
checkpoint_id: 42,
phase: Phase::Prepare,
flags: 0,
min_watermark_ms: None,
})
.await
.unwrap();
let outcome = leader_coord
.wait_for_quorum(1, &[NodeId(2)], Duration::from_secs(5))
.await;
match outcome {
QuorumOutcome::Reached {
acks,
min_follower_watermark_ms,
} => {
assert_eq!(acks, vec![NodeId(2)]);
assert_eq!(min_follower_watermark_ms, Some(100));
leader_coord
.announce(&BarrierAnnouncement {
epoch: 1,
checkpoint_id: 42,
phase: Phase::Aligned,
flags: 0,
min_watermark_ms: min_follower_watermark_ms,
})
.await
.unwrap();
aligned_seen_rx.await.unwrap();
leader_coord
.announce(&BarrierAnnouncement {
epoch: 1,
checkpoint_id: 42,
phase: Phase::Commit,
flags: 0,
min_watermark_ms: min_follower_watermark_ms,
})
.await
.unwrap();
}
other => panic!("expected Reached, got {other:?}"),
}
follower_task.await.unwrap();
}
}
#[cfg(feature = "cluster")]
#[tokio::test]
async fn observe_merges_grpc_and_gossip_by_epoch() {
let leader_kv = kv(NodeId(1));
let follower_kv = kv(NodeId(2));
let leader_coord = BarrierCoordinator::new(leader_kv.clone());
let follower_coord = BarrierCoordinator::new(follower_kv.clone());
let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
let slot = || Arc::new(parking_lot::RwLock::new(None));
let leader_addr = leader_coord.start_server(addr, None, slot()).await.unwrap();
let bound_addr = follower_coord
.start_server(addr, None, slot())
.await
.unwrap();
leader_kv.seed(NodeId(2), BARRIER_ADDR_KEY, bound_addr.to_string());
follower_kv.seed(NodeId(1), BARRIER_ADDR_KEY, leader_addr.to_string());
leader_coord
.announce(&BarrierAnnouncement {
epoch: 5,
checkpoint_id: 9,
phase: Phase::Abort,
flags: 0,
min_watermark_ms: None,
})
.await
.unwrap();
for _ in 0..100 {
if let Some(ann) = follower_coord.observe(NodeId(1)).await.unwrap() {
if ann.phase == Phase::Abort {
break;
}
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
let next = serde_json::to_string(&BarrierAnnouncement {
epoch: 6,
checkpoint_id: 10,
phase: Phase::Prepare,
flags: 0,
min_watermark_ms: None,
})
.unwrap();
follower_kv.seed(NodeId(1), ANNOUNCEMENT_KEY, next);
let got = follower_coord.observe(NodeId(1)).await.unwrap().unwrap();
assert_eq!(got.epoch, 6);
assert_eq!(got.phase, Phase::Prepare);
let stale = serde_json::to_string(&BarrierAnnouncement {
epoch: 5,
checkpoint_id: 9,
phase: Phase::Prepare,
flags: 0,
min_watermark_ms: None,
})
.unwrap();
follower_kv.seed(NodeId(1), ANNOUNCEMENT_KEY, stale);
let got = follower_coord.observe(NodeId(1)).await.unwrap().unwrap();
assert_eq!(
got.phase,
Phase::Abort,
"lagging gossip must not mask the fresher gRPC announcement",
);
}
#[tokio::test]
async fn leader_announces_follower_observes() {
let leader_kv = kv(NodeId(1));
let coord = BarrierCoordinator::new(leader_kv.clone());
coord
.announce(&BarrierAnnouncement {
epoch: 5,
checkpoint_id: 42,
phase: Phase::Prepare,
flags: 0,
min_watermark_ms: None,
})
.await
.unwrap();
let got = coord.observe(NodeId(1)).await.unwrap().unwrap();
assert_eq!(got.epoch, 5);
assert_eq!(got.checkpoint_id, 42);
}
#[tokio::test]
async fn observe_returns_none_when_leader_silent() {
let k = kv(NodeId(1));
let coord = BarrierCoordinator::new(k);
assert!(coord.observe(NodeId(1)).await.unwrap().is_none());
}
#[tokio::test]
async fn quorum_reached_when_all_ack_success() {
let k = kv(NodeId(1));
let ack_json = serde_json::to_string(&BarrierAck {
epoch: 7,
ok: true,
error: None,
local_watermark_ms: None,
})
.unwrap();
k.seed(NodeId(2), ACK_KEY, ack_json.clone());
k.seed(NodeId(3), ACK_KEY, ack_json);
let coord = BarrierCoordinator::new(k);
let outcome = coord
.wait_for_quorum(7, &[NodeId(2), NodeId(3)], Duration::from_millis(200))
.await;
match outcome {
QuorumOutcome::Reached {
mut acks,
min_follower_watermark_ms,
} => {
acks.sort_by_key(|n| n.0);
assert_eq!(acks, vec![NodeId(2), NodeId(3)]);
assert_eq!(
min_follower_watermark_ms, None,
"no follower reported a watermark — min is None"
);
}
other => panic!("expected Reached, got {other:?}"),
}
}
#[tokio::test]
async fn quorum_timeout_when_follower_silent() {
let k = kv(NodeId(1));
let ack_json = serde_json::to_string(&BarrierAck {
epoch: 8,
ok: true,
error: None,
local_watermark_ms: None,
})
.unwrap();
k.seed(NodeId(2), ACK_KEY, ack_json);
let coord = BarrierCoordinator::new(k);
let outcome = coord
.wait_for_quorum(8, &[NodeId(2), NodeId(3)], Duration::from_millis(150))
.await;
match outcome {
QuorumOutcome::TimedOut { got, missing } => {
assert_eq!(got, vec![NodeId(2)]);
assert_eq!(missing, vec![NodeId(3)]);
}
other => panic!("expected TimedOut, got {other:?}"),
}
}
#[tokio::test]
async fn quorum_fails_fast_on_reported_error() {
let k = kv(NodeId(1));
let good = serde_json::to_string(&BarrierAck {
epoch: 9,
ok: true,
error: None,
local_watermark_ms: None,
})
.unwrap();
let bad = serde_json::to_string(&BarrierAck {
epoch: 9,
ok: false,
error: Some("state snapshot failed: disk full".into()),
local_watermark_ms: None,
})
.unwrap();
k.seed(NodeId(2), ACK_KEY, good);
k.seed(NodeId(3), ACK_KEY, bad);
let coord = BarrierCoordinator::new(k);
let outcome = coord
.wait_for_quorum(9, &[NodeId(2), NodeId(3)], Duration::from_secs(2))
.await;
match outcome {
QuorumOutcome::Failed { failures } => {
assert_eq!(failures.len(), 1);
assert_eq!(failures[0].0, NodeId(3));
assert!(failures[0].1.contains("disk full"));
}
other => panic!("expected Failed, got {other:?}"),
}
}
#[tokio::test]
async fn wrong_epoch_ack_is_ignored() {
let k = kv(NodeId(1));
let stale = serde_json::to_string(&BarrierAck {
epoch: 9,
ok: true,
error: None,
local_watermark_ms: None,
})
.unwrap();
k.seed(NodeId(2), ACK_KEY, stale);
let coord = BarrierCoordinator::new(k);
let outcome = coord
.wait_for_quorum(10, &[NodeId(2)], Duration::from_millis(100))
.await;
assert!(
matches!(outcome, QuorumOutcome::TimedOut { .. }),
"stale-epoch ack must not satisfy quorum"
);
}
}