use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Context;
use miden_node_db::DatabaseError;
use miden_node_proto::domain::account::NetworkAccountId;
use miden_node_proto::domain::mempool::MempoolEvent;
use miden_protocol::account::delta::AccountUpdateDetails;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{Semaphore, mpsc};
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use crate::actor::{AccountActor, AccountActorContext, AccountOrigin, ActorShutdownReason};
use crate::db::Db;
#[derive(Clone)]
struct ActorHandle {
event_tx: mpsc::Sender<Arc<MempoolEvent>>,
cancel_token: CancellationToken,
}
impl ActorHandle {
fn new(event_tx: mpsc::Sender<Arc<MempoolEvent>>, cancel_token: CancellationToken) -> Self {
Self { event_tx, cancel_token }
}
}
pub struct Coordinator {
actor_registry: HashMap<NetworkAccountId, ActorHandle>,
actor_join_set: JoinSet<ActorShutdownReason>,
semaphore: Arc<Semaphore>,
db: Db,
actor_channel_size: usize,
}
impl Coordinator {
pub fn new(max_inflight_transactions: usize, actor_channel_size: usize, db: Db) -> Self {
Self {
actor_registry: HashMap::new(),
actor_join_set: JoinSet::new(),
semaphore: Arc::new(Semaphore::new(max_inflight_transactions)),
db,
actor_channel_size,
}
}
#[tracing::instrument(name = "ntx.builder.spawn_actor", skip(self, origin, actor_context))]
pub async fn spawn_actor(
&mut self,
origin: AccountOrigin,
actor_context: &AccountActorContext,
) -> Result<(), SendError<Arc<MempoolEvent>>> {
let account_id = origin.id();
if let Some(handle) = self.actor_registry.remove(&account_id) {
tracing::error!(
account_id = %account_id,
"Account actor already exists"
);
handle.cancel_token.cancel();
}
let (event_tx, event_rx) = mpsc::channel(self.actor_channel_size);
let cancel_token = tokio_util::sync::CancellationToken::new();
let actor = AccountActor::new(origin, actor_context, event_rx, cancel_token.clone());
let handle = ActorHandle::new(event_tx, cancel_token);
let semaphore = self.semaphore.clone();
self.actor_join_set.spawn(Box::pin(actor.run(semaphore)));
self.actor_registry.insert(account_id, handle);
tracing::info!(account_id = %account_id, "Created actor for account prefix");
Ok(())
}
#[tracing::instrument(name = "ntx.coordinator.broadcast", skip_all, fields(
actor.count = self.actor_registry.len(),
event.kind = %event.kind()
))]
pub async fn broadcast(&mut self, event: Arc<MempoolEvent>) {
let mut failed_actors = Vec::new();
for (account_id, handle) in &self.actor_registry {
if let Err(err) = Self::send(handle, event.clone()).await {
tracing::error!(
account_id = %account_id,
error = %err,
"Failed to send event to actor"
);
failed_actors.push(*account_id);
}
}
for account_id in failed_actors {
let handle =
self.actor_registry.remove(&account_id).expect("actor found in send loop above");
handle.cancel_token.cancel();
}
}
pub async fn next(&mut self) -> anyhow::Result<()> {
let actor_result = self.actor_join_set.join_next().await;
match actor_result {
Some(Ok(shutdown_reason)) => match shutdown_reason {
ActorShutdownReason::Cancelled(account_id) => {
tracing::info!(account_id = %account_id, "Account actor cancelled");
Ok(())
},
ActorShutdownReason::EventChannelClosed => {
anyhow::bail!("event channel closed");
},
ActorShutdownReason::SemaphoreFailed(err) => Err(err).context("semaphore failed"),
},
Some(Err(err)) => {
tracing::error!(err = %err, "actor task failed");
Ok(())
},
None => {
std::future::pending().await
},
}
}
pub async fn send_targeted(
&mut self,
event: &Arc<MempoolEvent>,
) -> Result<(), SendError<Arc<MempoolEvent>>> {
let mut target_actors = HashMap::new();
if let MempoolEvent::TransactionAdded { network_notes, account_delta, .. } = event.as_ref()
{
if let Some(AccountUpdateDetails::Delta(delta)) = account_delta {
let account_id = delta.id();
if account_id.is_network() {
let network_account_id =
account_id.try_into().expect("account is network account");
if let Some(actor) = self.actor_registry.get(&network_account_id) {
target_actors.insert(network_account_id, actor);
}
}
}
for note in network_notes {
let account = note.target_account_id();
let account = NetworkAccountId::try_from(account)
.expect("network note target account should be a network account");
if let Some(actor) = self.actor_registry.get(&account) {
target_actors.insert(account, actor);
}
}
}
for actor in target_actors.values() {
Self::send(actor, event.clone()).await?;
}
Ok(())
}
pub async fn write_event(
&self,
event: &MempoolEvent,
) -> Result<Vec<NetworkAccountId>, DatabaseError> {
match event {
MempoolEvent::TransactionAdded {
id,
nullifiers,
network_notes,
account_delta,
} => {
self.db
.handle_transaction_added(
*id,
account_delta.clone(),
network_notes.clone(),
nullifiers.clone(),
)
.await?;
Ok(Vec::new())
},
MempoolEvent::BlockCommitted { header, txs } => {
self.db
.handle_block_committed(
txs.clone(),
header.block_num(),
header.as_ref().clone(),
)
.await?;
Ok(Vec::new())
},
MempoolEvent::TransactionsReverted(tx_ids) => {
self.db.handle_transactions_reverted(tx_ids.iter().copied().collect()).await
},
}
}
pub fn cancel_actor(&mut self, account_id: &NetworkAccountId) {
if let Some(handle) = self.actor_registry.remove(account_id) {
handle.cancel_token.cancel();
}
}
async fn send(
handle: &ActorHandle,
event: Arc<MempoolEvent>,
) -> Result<(), SendError<Arc<MempoolEvent>>> {
handle.event_tx.send(event).await
}
}