use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use beamr::atom::{Atom, AtomTable};
use beamr::distribution::connection::ConnectionManager;
use beamr::distribution::control::{encode_pg_update_frame, encode_send_frame};
use beamr::distribution::pg::{PgRegistry, PgUpdate, RemoteMember};
use beamr::native::ProcessContext;
use beamr::term::Term;
use crate::cluster::discovery::ClusterResolver;
use liminal::channel::{ClusterObserver, encode_envelope};
use liminal::envelope::Envelope;
#[derive(Clone)]
pub struct ClusterSync {
inner: Arc<SyncInner>,
}
struct SyncInner {
pg: Arc<PgRegistry>,
atoms: Arc<AtomTable>,
connections: ConnectionManager,
local_node: Atom,
_resolver: Arc<ClusterResolver>,
local: Mutex<HashMap<Atom, Vec<u64>>>,
}
impl std::fmt::Debug for ClusterSync {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("ClusterSync")
.field("local_node", &self.inner.local_node)
.finish_non_exhaustive()
}
}
impl ClusterSync {
#[must_use]
pub fn new(
pg: Arc<PgRegistry>,
atoms: Arc<AtomTable>,
connections: ConnectionManager,
local_node: Atom,
resolver: Arc<ClusterResolver>,
) -> Self {
Self {
inner: Arc::new(SyncInner {
pg,
atoms,
connections,
local_node,
_resolver: resolver,
local: Mutex::new(HashMap::new()),
}),
}
}
fn scope(&self) -> Atom {
self.inner.pg.default_scope()
}
fn group(&self, channel: &str) -> Atom {
self.inner.atoms.intern(channel)
}
#[must_use]
pub fn remote_targets(&self, channel: &str) -> Vec<RemoteMember> {
let group = self.group(channel);
self.inner.pg.remote_members(self.scope(), group)
}
fn record_local(&self, group: Atom, pid: u64) {
let mut local = self.lock_local();
let pids = local.entry(group).or_default();
if !pids.contains(&pid) {
pids.push(pid);
}
drop(local);
}
fn forget_local(&self, group: Atom, pid: u64) {
let mut local = self.lock_local();
if let Some(pids) = local.get_mut(&group) {
pids.retain(|candidate| *candidate != pid);
if pids.is_empty() {
local.remove(&group);
}
}
}
fn local_memberships(&self) -> Vec<(Atom, u64)> {
let local = self.lock_local();
local
.iter()
.flat_map(|(group, pids)| pids.iter().map(move |pid| (*group, *pid)))
.collect()
}
fn send_to_member(&self, member: RemoteMember, frame_bytes: &[u8]) {
let Some(to_pid) = Term::try_pid(member.pid_number) else {
tracing::warn!(
pid_number = member.pid_number,
"remote member pid out of immediate range; skipping cross-node delivery"
);
return;
};
let mut context = ProcessContext::new();
let Ok(payload) = context.alloc_binary(frame_bytes) else {
tracing::warn!("failed to allocate cross-node envelope payload");
return;
};
let Ok(frame) = encode_send_frame(
Term::atom(beamr::atom::Atom::OK),
to_pid,
payload,
&self.inner.atoms,
) else {
tracing::warn!("failed to encode cross-node send frame");
return;
};
self.write_frame(member.node, &frame);
}
fn write_frame(&self, node: Atom, frame: &[u8]) {
let Some(connection) = self.inner.connections.get_connection(node) else {
return;
};
write_raw_blocking(&connection, frame);
}
fn backfill_member(&self, node: Atom, group: Atom, pid: u64) {
let update = PgUpdate::Join {
scope: self.scope(),
group,
pid,
};
if let Ok(frame) = encode_pg_update_frame(update, self.inner.local_node, &self.inner.atoms)
{
self.write_frame(node, &frame);
} else {
tracing::warn!("failed to encode cluster backfill frame");
}
}
fn lock_local(&self) -> std::sync::MutexGuard<'_, HashMap<Atom, Vec<u64>>> {
self.inner
.local
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}
impl ClusterObserver for ClusterSync {
fn on_subscribe(&self, channel: &str, subscriber_pid: u64) {
let group = self.group(channel);
self.inner.pg.join(self.scope(), group, subscriber_pid);
self.record_local(group, subscriber_pid);
tracing::debug!(
channel = %channel,
pid = subscriber_pid,
"advertised local subscription to cluster"
);
}
fn on_unsubscribe(&self, channel: &str, subscriber_pid: u64) {
let group = self.group(channel);
self.inner.pg.leave(self.scope(), group, subscriber_pid);
self.forget_local(group, subscriber_pid);
tracing::debug!(
channel = %channel,
pid = subscriber_pid,
"withdrew local subscription from cluster"
);
}
fn on_publish(&self, channel: &str, envelope: &Envelope) {
let targets = self.remote_targets(channel);
if targets.is_empty() {
return;
}
let frame_bytes = encode_envelope(envelope);
for member in targets {
self.send_to_member(member, &frame_bytes);
}
}
}
impl ClusterSync {
pub fn on_peer_join(&self, node: Atom) {
for (group, pid) in self.local_memberships() {
self.backfill_member(node, group, pid);
}
}
pub fn on_peer_leave(&self, node: Atom) {
let name = self
.inner
.atoms
.resolve(node)
.map_or_else(|| format!("<atom {node:?}>"), str::to_owned);
tracing::info!(
peer = %name,
"peer departed; its remote subscriptions were purged by beamr"
);
}
}
fn write_raw_blocking(
connection: &Arc<beamr::distribution::connection::DistConnection>,
frame: &[u8],
) {
let connection = Arc::clone(connection);
let frame = frame.to_vec();
let write = async move {
let _ = connection.write_raw(&frame).await;
};
if let Ok(handle) = tokio::runtime::Handle::try_current() {
if matches!(
handle.runtime_flavor(),
tokio::runtime::RuntimeFlavor::MultiThread
) {
tokio::task::block_in_place(|| handle.block_on(write));
return;
}
}
match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(runtime) => runtime.block_on(write),
Err(error) => tracing::warn!(error = %error, "failed to build cluster send runtime"),
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::ClusterSync;
use crate::cluster::discovery::ClusterResolver;
use beamr::atom::AtomTable;
use beamr::distribution::connection::ConnectionManager;
use beamr::distribution::pg::PgRegistry;
use beamr::distribution::resolver::StaticResolver;
use liminal::channel::ClusterObserver;
use std::collections::HashMap;
use std::sync::Arc;
fn sync_fixture() -> (ClusterSync, Arc<PgRegistry>, Arc<AtomTable>) {
let atoms = Arc::new(AtomTable::with_common_atoms());
let pg = Arc::new(PgRegistry::new(&atoms));
let connections = ConnectionManager::new(
Arc::clone(&atoms),
Arc::new(StaticResolver::new(HashMap::new())),
"test-cookie",
"local@127.0.0.1",
1,
);
let local_node = atoms.intern("local@127.0.0.1");
let resolver = Arc::new(ClusterResolver::new());
let sync = ClusterSync::new(
Arc::clone(&pg),
Arc::clone(&atoms),
connections,
local_node,
resolver,
);
(sync, pg, atoms)
}
#[test]
fn subscribe_joins_the_channel_pg_group() {
let (sync, pg, atoms) = sync_fixture();
sync.on_subscribe("orders", 42);
let group = atoms.intern("orders");
assert_eq!(pg.local_members(pg.default_scope(), group), vec![42]);
}
#[test]
fn unsubscribe_leaves_the_channel_pg_group() {
let (sync, pg, atoms) = sync_fixture();
sync.on_subscribe("orders", 42);
sync.on_unsubscribe("orders", 42);
let group = atoms.intern("orders");
assert!(pg.local_members(pg.default_scope(), group).is_empty());
}
#[test]
fn local_memberships_track_subscriptions_for_backfill() {
let (sync, _pg, _atoms) = sync_fixture();
sync.on_subscribe("orders", 1);
sync.on_subscribe("orders", 2);
sync.on_subscribe("events", 3);
let mut memberships = sync.local_memberships();
memberships.sort_by_key(|(group, pid)| (*group, *pid));
assert_eq!(memberships.len(), 3);
sync.on_unsubscribe("events", 3);
let remaining = sync.local_memberships();
assert_eq!(remaining.len(), 2);
assert!(remaining.iter().all(|(_, pid)| *pid == 1 || *pid == 2));
}
#[test]
fn remote_targets_empty_without_remote_members() {
let (sync, _pg, _atoms) = sync_fixture();
sync.on_subscribe("orders", 1);
assert!(sync.remote_targets("orders").is_empty());
}
#[test]
fn remote_targets_reflect_applied_remote_joins() {
let (sync, pg, atoms) = sync_fixture();
let group = atoms.intern("orders");
let remote_node = atoms.intern("node-b@127.0.0.1");
pg.apply_remote_join(pg.default_scope(), group, remote_node, 99, 0);
let targets = sync.remote_targets("orders");
assert_eq!(targets.len(), 1);
assert_eq!(targets[0].node, remote_node);
assert_eq!(targets[0].pid_number, 99);
}
}