use {
crate::{
PeerId,
discovery::{Catalog, PeerEntryVersion, SignedPeerEntry},
groups::{
Bond,
Bonds,
Error,
Groups,
StateMachine,
Storage,
When,
bond::BondEvent,
config::GroupConfig,
error::AlreadyBonded,
raft::Raft,
state::{
AcceptRequest,
GroupHandle,
WorkerCommand,
WorkerRaftCommand,
WorkerState,
},
},
network::Cancelled,
primitives::{AsyncWorkQueue, ShortFmtExt},
},
core::{any::TypeId, future::poll_fn, pin::Pin},
futures::{Stream, StreamExt, stream::SelectAll},
im::ordmap::Entry,
iroh::protocol::AcceptError,
std::sync::Arc,
tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel},
tokio_stream::wrappers::UnboundedReceiverStream,
};
pub struct Worker<S, M>
where
S: Storage<M::Command>,
M: StateMachine,
{
state: Arc<WorkerState<M>>,
accepts: UnboundedReceiver<AcceptRequest>,
cmd_rx: UnboundedReceiver<WorkerCommand<M>>,
bond_events: BondEventsStream<M>,
raft: Raft<S, M>,
work_queue: AsyncWorkQueue,
last_local: PeerEntryVersion,
metrics_labels: [(&'static str, String); 2],
}
impl<S, M> Worker<S, M>
where
S: Storage<M::Command>,
M: StateMachine,
{
pub fn spawn(
groups: &Groups,
config: GroupConfig,
storage: S,
state_machine: M,
) -> Arc<GroupHandle> {
let (accepts_tx, accepts_rx) = unbounded_channel();
let (cmd_tx, cmd_rx) = unbounded_channel();
let worker_state = Arc::new(WorkerState {
config,
cmd_tx,
accepts: accepts_tx,
global_config: Arc::clone(&groups.config),
local: groups.local.clone(),
discovery: groups.discovery.clone(),
bonds: Bonds::default(),
cancel: groups.local.termination().child_token(),
when: When::new(groups.local.id()),
types: (TypeId::of::<M>(), TypeId::of::<S>()),
});
let worker_instance = Self {
cmd_rx,
accepts: accepts_rx,
state: Arc::clone(&worker_state),
bond_events: SelectAll::default(),
work_queue: AsyncWorkQueue::default(),
raft: Raft::new(Arc::clone(&worker_state), storage, state_machine),
last_local: groups.discovery.me().update_version(),
metrics_labels: [
("network", worker_state.network_id().short().to_string()),
("group", worker_state.group_id().short().to_string()),
],
};
tokio::spawn(worker_instance.run());
Arc::new(GroupHandle::new(worker_state))
}
}
impl<S, M> Worker<S, M>
where
S: Storage<M::Command>,
M: StateMachine,
{
async fn run(mut self) {
self.on_init();
let mut catalog = self.state.discovery.catalog_watch();
catalog.mark_changed();
loop {
tokio::select! {
() = self.state.cancel.cancelled() => {
self.on_terminated();
break;
}
_ = self.work_queue.next() => { }
() = poll_fn(|cx| self.raft.poll(cx)) => { }
_ = catalog.changed() => {
let catalog = catalog.borrow_and_update().clone();
self.on_catalog_update(catalog);
}
Some((event, peer_id)) = self.bond_events.next() => {
self.on_bond_event(event, peer_id);
}
Some(request) = self.accepts.recv() => {
self.accept_bond(request);
}
Some(command) = self.cmd_rx.recv() => {
self.on_worker_command(command);
}
}
}
}
}
impl<S, M> Worker<S, M>
where
S: Storage<M::Command>,
M: StateMachine,
{
fn on_init(&self) {
let group_id = *self.state.group_id();
self
.state
.discovery
.update_local_entry(move |entry| entry.add_groups(group_id));
tracing::info!(
group = %group_id.short(),
network = %self.state.network_id().short(),
"joining",
);
}
fn on_terminated(&self) {
tracing::debug!(
group = %self.state.group_id().short(),
network = %self.state.network_id().short(),
"leaving",
);
}
#[expect(clippy::needless_pass_by_value)]
fn on_catalog_update(&mut self, snapshot: Catalog) {
let new_peers_in_group = snapshot.signed_peers().filter(|peer| {
peer.groups().contains(self.state.group_id())
&& self.state.config.authorize_peer(peer).is_ok()
});
for peer in new_peers_in_group {
self.create_bond(peer.clone());
}
let me = self.state.discovery.me();
if me.update_version() > self.last_local {
self.last_local = me.update_version();
self.state.bonds.notify_local_info_update(&me);
}
}
fn on_bond_event(&mut self, event: BondEvent<M>, peer_id: PeerId) {
match event {
BondEvent::Terminated(reason) => {
self.state.bonds.update_with(|active| {
if let Some(bond) = active.remove(&peer_id)
&& reason != AlreadyBonded
{
metrics::gauge!("mosaik.groups.bonds.active", &self.metrics_labels)
.set(active.len() as f64);
tracing::debug!(
id = %bond.id().short(),
group = %self.state.group_id().short(),
peer = %peer_id.short(),
network = %self.state.network_id().short(),
reason = %reason,
"bond terminated",
);
}
});
}
BondEvent::Connected => {
self.on_bond_formed(peer_id);
}
BondEvent::Raft(message) => {
self.raft.receive_protocol_message(message, peer_id);
}
}
}
fn on_bond_formed(&self, peer_id: PeerId) {
let Some(bond) = self.state.bonds.get(&peer_id) else {
return;
};
tracing::debug!(
id = %bond.id().short(),
peer = %peer_id.short(),
group = %self.state.group_id().short(),
network = %self.state.network_id().short(),
"bond established",
);
let catalog = self.state.discovery.catalog();
let Some(peer_entry) = catalog.get_signed(&peer_id).cloned() else {
tracing::warn!(
network = %self.state.network_id().short(),
peer = %peer_id.short(),
group = %self.state.group_id().short(),
"peer entry not found in catalog after bond formed",
);
return;
};
self.state.bonds.notify_bond_formed(&peer_entry);
}
fn on_worker_command(&mut self, command: WorkerCommand<M>) {
match command {
WorkerCommand::Connect(peer_entry) => {
self.create_bond(*peer_entry);
}
WorkerCommand::Subscribe(events_rx, peer_id) => {
self.bond_events.push(Box::pin(
UnboundedReceiverStream::new(events_rx)
.map(move |event| (event, peer_id)),
));
}
WorkerCommand::Raft(command) => {
self.on_raft_command(command);
}
}
}
fn on_raft_command(&mut self, command: WorkerRaftCommand<M>) {
match command {
WorkerRaftCommand::Feed(cmd, result_tx) => {
let cmd_fut = self.raft.feed(cmd);
self.work_queue.enqueue(async move {
let result = cmd_fut.await;
let _ = result_tx.send(result);
});
}
WorkerRaftCommand::Query(query, consistency, result_tx) => {
let query_fut = self.raft.query(query, consistency);
self.work_queue.enqueue(async move {
let result = query_fut.await;
let _ = result_tx.send(result);
});
}
}
}
}
impl<S, M> Worker<S, M>
where
S: Storage<M::Command>,
M: StateMachine,
{
fn create_bond(&self, peer: SignedPeerEntry) {
if *peer.id() == self.state.local_id() {
return;
}
if self.state.bonds.contains_peer(peer.id()) {
return;
}
let peer_id = *peer.id();
let state = Arc::clone(&self.state);
let labels = self.metrics_labels.clone();
let fut = async move {
match Bond::create(Arc::clone(&state), peer).await {
Ok((handle, events)) => {
state.bonds.update_with(|active| {
match active.entry(peer_id) {
Entry::Vacant(place) => {
if state
.cmd_tx
.send(WorkerCommand::Subscribe(events, peer_id))
.is_ok()
{
place.insert(handle);
metrics::gauge!("mosaik.groups.bonds.active", &labels)
.set(active.len() as f64);
}
}
Entry::Occupied(_) => {
tokio::spawn(handle.close(AlreadyBonded));
}
}
});
}
Err(reason) => {
if !matches!(reason, Error::AlreadyBonded(_)) {
tracing::trace!(
error = %reason,
network = %state.local.network_id().short(),
peer = %peer_id.short(),
group = %state.group_id().short(),
"bonding failed",
);
}
}
}
};
self.work_queue.enqueue(fut);
}
fn accept_bond(&self, request: AcceptRequest) {
let AcceptRequest {
link,
peer,
handshake,
result_tx,
} = request;
let peer_id = link.remote_id();
assert_eq!(peer.id(), &peer_id);
if self.state.bonds.contains_peer(&peer_id) {
tokio::spawn(link.close(AlreadyBonded));
let _ = result_tx.send(Err(AcceptError::from_err(AlreadyBonded)));
return;
}
let state = Arc::clone(&self.state);
let labels = self.metrics_labels.clone();
let fut = async move {
match Bond::accept(Arc::clone(&state), link, peer, handshake).await {
Ok((handle, events)) => {
state.bonds.update_with(|active| {
match active.entry(peer_id) {
Entry::Vacant(place) => {
if state
.cmd_tx
.send(WorkerCommand::Subscribe(events, peer_id))
.is_ok()
{
place.insert(handle);
metrics::gauge!("mosaik.groups.bonds.active", &labels)
.set(active.len() as f64);
let _ = result_tx.send(Ok(()));
} else {
let _ = result_tx.send(Err(AcceptError::from_err(Cancelled)));
tokio::spawn(handle.close(Cancelled));
}
}
Entry::Occupied(_) => {
tokio::spawn(handle.close(AlreadyBonded));
let _ =
result_tx.send(Err(AcceptError::from_err(AlreadyBonded)));
}
}
});
}
Err(reason) => {
let _ = result_tx.send(Err(AcceptError::from_err(reason)));
}
}
};
self.work_queue.enqueue(fut);
}
}
type BondEventsStream<M: StateMachine> = SelectAll<
Pin<Box<dyn Stream<Item = (BondEvent<M>, PeerId)> + Send + Sync + 'static>>,
>;