use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use miden_node_db::DatabaseError;
use miden_node_proto::domain::account::NetworkAccountId;
use miden_node_proto::domain::mempool::MempoolEvent;
use miden_protocol::account::delta::AccountUpdateDetails;
use tokio::sync::{Notify, Semaphore};
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use crate::actor::{AccountActor, AccountActorContext, AccountOrigin};
use crate::db::Db;
pub struct WriteEventResult {
pub accounts_to_notify: Vec<NetworkAccountId>,
}
#[derive(Clone)]
struct ActorHandle {
notify: Arc<Notify>,
cancel_token: CancellationToken,
}
impl ActorHandle {
fn new(notify: Arc<Notify>, cancel_token: CancellationToken) -> Self {
Self { notify, cancel_token }
}
}
pub struct Coordinator {
actor_registry: HashMap<NetworkAccountId, ActorHandle>,
actor_join_set: JoinSet<(NetworkAccountId, anyhow::Result<()>)>,
semaphore: Arc<Semaphore>,
db: Db,
crash_counts: HashMap<NetworkAccountId, usize>,
max_account_crashes: usize,
}
impl Coordinator {
pub fn new(max_inflight_transactions: usize, max_account_crashes: usize, db: Db) -> Self {
Self {
actor_registry: HashMap::new(),
actor_join_set: JoinSet::new(),
semaphore: Arc::new(Semaphore::new(max_inflight_transactions)),
db,
crash_counts: HashMap::new(),
max_account_crashes,
}
}
#[tracing::instrument(name = "ntx.builder.spawn_actor", skip(self, origin, actor_context))]
pub fn spawn_actor(&mut self, origin: AccountOrigin, actor_context: &AccountActorContext) {
let account_id = origin.id();
if let Some(&count) = self.crash_counts.get(&account_id) {
if count >= self.max_account_crashes {
tracing::warn!(
account.id = %account_id,
crash_count = count,
"Account deactivated due to repeated crashes, skipping actor spawn"
);
return;
}
}
if let Some(handle) = self.actor_registry.remove(&account_id) {
tracing::error!(
account_id = %account_id,
"Account actor already exists"
);
handle.cancel_token.cancel();
}
let notify = Arc::new(Notify::new());
let cancel_token = tokio_util::sync::CancellationToken::new();
let actor = AccountActor::new(origin, actor_context, notify.clone(), cancel_token.clone());
let handle = ActorHandle::new(notify, cancel_token);
let semaphore = self.semaphore.clone();
self.actor_join_set
.spawn(Box::pin(async move { (account_id, actor.run(semaphore).await) }));
self.actor_registry.insert(account_id, handle);
tracing::info!(account_id = %account_id, "Created actor for account prefix");
}
pub fn notify_accounts(&self, account_ids: &[NetworkAccountId]) {
for account_id in account_ids {
if let Some(handle) = self.actor_registry.get(account_id) {
handle.notify.notify_one();
}
}
}
pub async fn next(&mut self) -> anyhow::Result<Option<NetworkAccountId>> {
let actor_result = self.actor_join_set.join_next().await;
match actor_result {
Some(Ok((account_id, Ok(())))) => {
let should_respawn =
self.actor_registry.remove(&account_id).is_some_and(|handle| {
let notified = handle.notify.notified();
tokio::pin!(notified);
notified.enable()
});
Ok(should_respawn.then_some(account_id))
},
Some(Ok((account_id, Err(err)))) => {
let count = self.crash_counts.entry(account_id).or_insert(0);
*count += 1;
tracing::error!(
account.id = %account_id,
"Account actor crashed: {err:#}"
);
self.actor_registry.remove(&account_id);
Ok(None)
},
Some(Err(err)) => {
tracing::error!(err = %err, "actor task failed");
Ok(None)
},
None => {
std::future::pending().await
},
}
}
pub fn send_targeted(&self, event: &MempoolEvent) -> Vec<NetworkAccountId> {
let mut target_account_ids = HashSet::new();
let mut inactive_targets = Vec::new();
if let MempoolEvent::TransactionAdded { network_notes, account_delta, .. } = event {
if let Some(AccountUpdateDetails::Delta(delta)) = account_delta {
let account_id = delta.id();
if account_id.is_network() {
let network_account_id =
account_id.try_into().expect("account is network account");
if self.actor_registry.contains_key(&network_account_id) {
target_account_ids.insert(network_account_id);
}
}
}
for note in network_notes {
let account = note.target_account_id();
let account = NetworkAccountId::try_from(account)
.expect("network note target account should be a network account");
if self.actor_registry.contains_key(&account) {
target_account_ids.insert(account);
} else {
inactive_targets.push(account);
}
}
}
for account_id in &target_account_ids {
if let Some(handle) = self.actor_registry.get(account_id) {
handle.notify.notify_one();
}
}
inactive_targets
}
pub async fn write_event(
&self,
event: &MempoolEvent,
) -> Result<WriteEventResult, DatabaseError> {
match event {
MempoolEvent::TransactionAdded {
id,
nullifiers,
network_notes,
account_delta,
} => {
self.db
.handle_transaction_added(
*id,
account_delta.clone(),
network_notes.clone(),
nullifiers.clone(),
)
.await?;
Ok(WriteEventResult { accounts_to_notify: Vec::new() })
},
MempoolEvent::BlockCommitted { header, txs } => {
let affected_accounts = self
.db
.handle_block_committed(
txs.clone(),
header.block_num(),
header.as_ref().clone(),
)
.await?;
Ok(WriteEventResult { accounts_to_notify: affected_accounts })
},
MempoolEvent::TransactionsReverted(tx_ids) => {
let affected_accounts =
self.db.handle_transactions_reverted(tx_ids.iter().copied().collect()).await?;
Ok(WriteEventResult { accounts_to_notify: affected_accounts })
},
}
}
}
#[cfg(test)]
impl Coordinator {
pub async fn test() -> (Self, tempfile::TempDir) {
let (db, dir) = Db::test_setup().await;
(Self::new(4, 10, db), dir)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use miden_node_proto::domain::mempool::MempoolEvent;
use super::*;
use crate::actor::{AccountActorContext, AccountOrigin};
use crate::db::Db;
use crate::test_utils::*;
fn register_dummy_actor(coordinator: &mut Coordinator, account_id: NetworkAccountId) {
let notify = Arc::new(Notify::new());
let cancel_token = CancellationToken::new();
coordinator
.actor_registry
.insert(account_id, ActorHandle::new(notify, cancel_token));
}
#[tokio::test]
async fn send_targeted_returns_inactive_targets() {
let (mut coordinator, _dir) = Coordinator::test().await;
let active_id = mock_network_account_id();
let inactive_id = mock_network_account_id_seeded(42);
register_dummy_actor(&mut coordinator, active_id);
let note_active = mock_single_target_note(active_id, 10);
let note_inactive = mock_single_target_note(inactive_id, 20);
let event = MempoolEvent::TransactionAdded {
id: mock_tx_id(1),
nullifiers: vec![],
network_notes: vec![note_active, note_inactive],
account_delta: None,
};
let inactive_targets = coordinator.send_targeted(&event);
assert_eq!(inactive_targets.len(), 1);
assert_eq!(inactive_targets[0], inactive_id);
}
#[tokio::test]
async fn spawn_actor_skips_deactivated_account() {
let (db, _dir) = Db::test_setup().await;
let max_crashes = 3;
let mut coordinator = Coordinator::new(4, max_crashes, db.clone());
let actor_context = AccountActorContext::test(&db);
let account_id = mock_network_account_id();
coordinator.crash_counts.insert(account_id, max_crashes);
coordinator.spawn_actor(AccountOrigin::Store(account_id), &actor_context);
assert!(
!coordinator.actor_registry.contains_key(&account_id),
"Deactivated account should not have an actor in the registry"
);
}
#[tokio::test]
async fn spawn_actor_allows_below_threshold() {
let (db, _dir) = Db::test_setup().await;
let max_crashes = 3;
let mut coordinator = Coordinator::new(4, max_crashes, db.clone());
let actor_context = AccountActorContext::test(&db);
let account_id = mock_network_account_id();
coordinator.crash_counts.insert(account_id, max_crashes - 1);
coordinator.spawn_actor(AccountOrigin::Store(account_id), &actor_context);
assert!(
coordinator.actor_registry.contains_key(&account_id),
"Account below crash threshold should have an actor in the registry"
);
}
}