use super::{
engine::{ChannelPair, EngineDefinition, InitContext},
tracker::FinalizationUpdate,
};
use commonware_p2p::simulated::{Link, Oracle};
use commonware_runtime::{deterministic, Handle, Supervisor as _};
use commonware_utils::channel::mpsc;
use std::collections::{BTreeMap, HashSet};
use tracing::info;
pub struct Team<D: EngineDefinition> {
definition: D,
participants: Vec<D::PublicKey>,
handles: BTreeMap<D::PublicKey, Handle<()>>,
states: BTreeMap<D::PublicKey, D::State>,
restart_counts: BTreeMap<D::PublicKey, u32>,
}
impl<D: EngineDefinition> Team<D> {
pub const fn new(definition: D, participants: Vec<D::PublicKey>) -> Self {
Self {
definition,
participants,
handles: BTreeMap::new(),
states: BTreeMap::new(),
restart_counts: BTreeMap::new(),
}
}
pub async fn start_one(
&mut self,
ctx: &deterministic::Context,
oracle: &Oracle<D::PublicKey, deterministic::Context>,
pk: D::PublicKey,
monitor: mpsc::Sender<FinalizationUpdate<D::PublicKey>>,
) {
if let Some(handle) = self.handles.remove(&pk) {
handle.abort();
}
let restart_count = self.restart_counts.entry(pk.clone()).or_insert(0);
let index = self
.participants
.iter()
.position(|p| p == &pk)
.expect("participant not found");
let validator_ctx = ctx
.child("validator")
.with_attribute("index", index)
.with_attribute("restart", *restart_count);
*restart_count += 1;
let control = oracle.control(pk.clone());
let channel_specs = self.definition.channels();
let mut channels: Vec<ChannelPair<D::PublicKey>> = Vec::with_capacity(channel_specs.len());
for (channel_id, quota) in &channel_specs {
let pair = control
.register(*channel_id, *quota)
.await
.expect("channel registration failed");
channels.push(pair);
}
let (engine, state) = self
.definition
.init(InitContext {
context: validator_ctx,
index,
public_key: &pk,
oracle,
channels,
participants: &self.participants,
monitor,
})
.await;
let handle = D::start(engine);
self.handles.insert(pk.clone(), handle);
self.states.insert(pk, state);
}
pub async fn start(
&mut self,
ctx: &deterministic::Context,
oracle: &Oracle<D::PublicKey, deterministic::Context>,
link: Link,
monitor: mpsc::Sender<FinalizationUpdate<D::PublicKey>>,
delayed: &HashSet<D::PublicKey>,
) {
let participants = self.participants.clone();
for v1 in &participants {
for v2 in &participants {
if v1 == v2 {
continue;
}
oracle
.add_link(v1.clone(), v2.clone(), link.clone())
.await
.unwrap();
}
}
for pk in participants {
if delayed.contains(&pk) {
info!(target: "simulator", ?pk, "delayed participant");
continue;
}
self.start_one(ctx, oracle, pk, monitor.clone()).await;
}
}
pub fn crash(&mut self, pk: &D::PublicKey) -> bool {
self.handles.remove(pk).is_some_and(|handle| {
handle.abort();
info!(target: "simulator", ?pk, "crashed validator");
true
})
}
pub async fn restart(
&mut self,
ctx: &deterministic::Context,
oracle: &Oracle<D::PublicKey, deterministic::Context>,
pk: D::PublicKey,
monitor: mpsc::Sender<FinalizationUpdate<D::PublicKey>>,
) {
info!(target: "simulator", ?pk, "restarting validator");
self.start_one(ctx, oracle, pk, monitor).await;
}
pub fn active_states(&self) -> Vec<&D::State> {
self.handles
.keys()
.filter_map(|pk| self.states.get(pk))
.collect()
}
pub fn active_keys(&self) -> Vec<D::PublicKey> {
self.handles.keys().cloned().collect()
}
pub fn participants(&self) -> &[D::PublicKey] {
&self.participants
}
}