use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Context;
use indexmap::IndexMap;
use miden_node_proto::domain::account::NetworkAccountId;
use miden_node_proto::domain::mempool::MempoolEvent;
use miden_node_proto::domain::note::NetworkNote;
use miden_protocol::account::delta::AccountUpdateDetails;
use miden_protocol::transaction::TransactionId;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{Semaphore, mpsc};
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use crate::actor::{AccountActor, AccountActorContext, AccountOrigin, ActorShutdownReason};
#[derive(Clone)]
struct ActorHandle {
event_tx: mpsc::Sender<Arc<MempoolEvent>>,
cancel_token: CancellationToken,
}
impl ActorHandle {
fn new(event_tx: mpsc::Sender<Arc<MempoolEvent>>, cancel_token: CancellationToken) -> Self {
Self { event_tx, cancel_token }
}
}
pub struct Coordinator {
actor_registry: HashMap<NetworkAccountId, ActorHandle>,
actor_join_set: JoinSet<ActorShutdownReason>,
semaphore: Arc<Semaphore>,
predating_events: HashMap<NetworkAccountId, IndexMap<TransactionId, Arc<MempoolEvent>>>,
}
impl Coordinator {
const ACTOR_CHANNEL_SIZE: usize = 100;
pub fn new(max_inflight_transactions: usize) -> Self {
Self {
actor_registry: HashMap::new(),
actor_join_set: JoinSet::new(),
semaphore: Arc::new(Semaphore::new(max_inflight_transactions)),
predating_events: HashMap::new(),
}
}
#[tracing::instrument(name = "ntx.builder.spawn_actor", skip(self, origin, actor_context))]
pub async fn spawn_actor(
&mut self,
origin: AccountOrigin,
actor_context: &AccountActorContext,
) -> Result<(), SendError<Arc<MempoolEvent>>> {
let account_id = origin.id();
if let Some(handle) = self.actor_registry.remove(&account_id) {
tracing::error!("account actor already exists for account: {}", account_id);
handle.cancel_token.cancel();
}
let (event_tx, event_rx) = mpsc::channel(Self::ACTOR_CHANNEL_SIZE);
let cancel_token = tokio_util::sync::CancellationToken::new();
let actor = AccountActor::new(origin, actor_context, event_rx, cancel_token.clone());
let handle = ActorHandle::new(event_tx, cancel_token);
let semaphore = self.semaphore.clone();
self.actor_join_set.spawn(Box::pin(actor.run(semaphore)));
if let Some(predating_events) = self.predating_events.remove(&account_id) {
for event in predating_events.values() {
Self::send(&handle, event.clone()).await?;
}
}
self.actor_registry.insert(account_id, handle);
tracing::info!("created actor for account: {}", account_id);
Ok(())
}
pub async fn broadcast(&mut self, event: Arc<MempoolEvent>) {
tracing::debug!(
actor_count = self.actor_registry.len(),
"broadcasting event to all actors"
);
let mut failed_actors = Vec::new();
for (account_id, handle) in &self.actor_registry {
if let Err(err) = Self::send(handle, event.clone()).await {
tracing::error!("failed to send event to actor {}: {}", account_id, err);
failed_actors.push(*account_id);
}
}
for account_id in failed_actors {
let handle =
self.actor_registry.remove(&account_id).expect("actor found in send loop above");
handle.cancel_token.cancel();
}
}
pub async fn next(&mut self) -> anyhow::Result<()> {
let actor_result = self.actor_join_set.join_next().await;
match actor_result {
Some(Ok(shutdown_reason)) => match shutdown_reason {
ActorShutdownReason::Cancelled(account_id) => {
tracing::info!("account actor cancelled: {}", account_id);
Ok(())
},
ActorShutdownReason::AccountReverted(account_id) => {
tracing::info!("account reverted: {}", account_id);
self.actor_registry.remove(&account_id);
Ok(())
},
ActorShutdownReason::EventChannelClosed => {
anyhow::bail!("event channel closed");
},
ActorShutdownReason::SemaphoreFailed(err) => Err(err).context("semaphore failed"),
},
Some(Err(err)) => {
tracing::error!(err = %err, "actor task failed");
Ok(())
},
None => {
std::future::pending().await
},
}
}
pub async fn send_targeted(
&mut self,
event: &Arc<MempoolEvent>,
) -> Result<(), SendError<Arc<MempoolEvent>>> {
let mut target_actors = HashMap::new();
if let MempoolEvent::TransactionAdded { id, network_notes, account_delta, .. } =
event.as_ref()
{
if let Some(AccountUpdateDetails::Delta(delta)) = account_delta {
let account_id = delta.id();
if account_id.is_network() {
let network_account_id =
account_id.try_into().expect("account is network account");
if let Some(actor) = self.actor_registry.get(&network_account_id) {
target_actors.insert(network_account_id, actor);
}
}
}
for note in network_notes {
let NetworkNote::SingleTarget(note) = note;
let network_account_id = note.account_id();
if let Some(actor) = self.actor_registry.get(&network_account_id) {
target_actors.insert(network_account_id, actor);
} else {
self.predating_events
.entry(network_account_id)
.or_default()
.insert(*id, event.clone());
}
}
}
for actor in target_actors.values() {
Self::send(actor, event.clone()).await?;
}
Ok(())
}
pub fn drain_predating_events(&mut self, tx_id: &TransactionId) {
self.predating_events.retain(|_, account_events| {
account_events.shift_remove(tx_id);
!account_events.is_empty()
});
}
async fn send(
handle: &ActorHandle,
event: Arc<MempoolEvent>,
) -> Result<(), SendError<Arc<MempoolEvent>>> {
handle.event_tx.send(event).await
}
}