use std::sync::Arc;
use std::time::{Duration, Instant};
use futures::future::{join_all, try_join_all};
use std::net::SocketAddr;
use crate::compute::{DaemonRuntime, MeshDaemon as ComputeMeshDaemon};
use crate::groups::{ReplicaGroup, ReplicaGroupConfig};
use crate::mesh::{Mesh, MeshBuilder};
use crate::meshos::{
EntityKeypair, LoggingDispatcher, MeshDaemon, MeshOsConfig, MeshOsDaemonHandle,
MeshOsDaemonSdk, MigrationSnapshotSource, NodeId, OrchestratorMigrationSnapshotSource,
RuntimeShutdownError,
};
use super::probes::install_mesh_probes;
const HARNESS_PSK: [u8; 32] = *b"ai2070-cluster-harness-testing.x";
#[derive(Clone, Debug)]
pub struct ClusterConfig {
pub psk: [u8; 32],
pub handshake_timeout: Duration,
pub mesh_session_stable_timeout: Duration,
pub meshos_snapshot_stable_timeout: Duration,
pub poll_interval: Duration,
pub meshos_tick_interval: Duration,
#[cfg(feature = "deck")]
pub verifier: Option<Arc<crate::deck::AdminVerifier>>,
}
impl Default for ClusterConfig {
fn default() -> Self {
Self {
psk: HARNESS_PSK,
handshake_timeout: Duration::from_secs(5),
mesh_session_stable_timeout: Duration::from_secs(2),
meshos_snapshot_stable_timeout: Duration::from_secs(3),
poll_interval: Duration::from_millis(25),
meshos_tick_interval: Duration::from_millis(100),
#[cfg(feature = "deck")]
verifier: None,
}
}
}
pub struct ClusterNode {
pub(crate) mesh: Arc<Mesh>,
pub(crate) sdk: Option<MeshOsDaemonSdk>,
pub(crate) daemon_runtime: Option<DaemonRuntime>,
pub(crate) local_addr: SocketAddr,
pub(crate) node_id: NodeId,
pub(crate) public_key: [u8; 32],
}
impl ClusterNode {
pub fn mesh(&self) -> &Arc<Mesh> {
&self.mesh
}
pub fn sdk(&self) -> Option<&MeshOsDaemonSdk> {
self.sdk.as_ref()
}
pub fn daemon_runtime(&self) -> Option<&DaemonRuntime> {
self.daemon_runtime.as_ref()
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn node_id(&self) -> NodeId {
self.node_id
}
pub fn public_key(&self) -> [u8; 32] {
self.public_key
}
}
pub struct NodeDaemonHandle {
pub node_index: usize,
pub node_id: NodeId,
pub daemon_id: u64,
pub handle: MeshOsDaemonHandle,
}
impl NodeDaemonHandle {
pub async fn graceful_shutdown(self, grace: Duration) -> Result<(), crate::meshos::SdkError> {
self.handle.graceful_shutdown(grace).await
}
}
#[derive(Clone, Copy, Debug)]
pub struct ClusterHealth {
pub total_nodes: usize,
pub meshes_with_full_peers: usize,
pub runtimes_with_full_peers: usize,
}
impl ClusterHealth {
pub fn fully_converged(&self) -> bool {
self.meshes_with_full_peers == self.total_nodes
&& self.runtimes_with_full_peers == self.total_nodes
}
}
#[derive(Debug, thiserror::Error)]
pub enum ClusterError {
#[error("cluster build: {0}")]
Invariant(String),
#[error("mesh build failed: {0}")]
MeshBuild(String),
#[error("handshake failed between node[{from}] and node[{to}]: {reason}")]
Handshake {
from: usize,
to: usize,
reason: String,
},
#[error("timed out waiting for {what} after {budget_ms}ms")]
Timeout { what: String, budget_ms: u64 },
#[error("shutdown failed: {0}")]
Shutdown(String),
#[error("daemon spawn on node[{node_index}] failed: {reason}")]
Spawn { node_index: usize, reason: String },
}
pub struct ClusterHarness {
nodes: Vec<ClusterNode>,
shutdown_called: bool,
}
impl ClusterHarness {
pub async fn new(n: usize) -> Result<Self, ClusterError> {
Self::with_config(n, ClusterConfig::default()).await
}
pub async fn with_config(n: usize, cfg: ClusterConfig) -> Result<Self, ClusterError> {
if n == 0 {
return Err(ClusterError::Invariant("n must be > 0".into()));
}
let mesh_futures = (0..n).map(|_| async {
let builder = MeshBuilder::new("127.0.0.1:0", &cfg.psk)
.map_err(|e| ClusterError::MeshBuild(e.to_string()))?;
builder
.build()
.await
.map(Arc::new)
.map_err(|e| ClusterError::MeshBuild(e.to_string()))
});
let meshes: Vec<Arc<Mesh>> = try_join_all(mesh_futures).await?;
let identities: Vec<(NodeId, [u8; 32], SocketAddr)> = meshes
.iter()
.map(|m| (m.node_id(), *m.public_key(), m.local_addr()))
.collect();
let handshake_budget = cfg.handshake_timeout;
for i in 0..n {
for j in (i + 1)..n {
let mesh_i = Arc::clone(&meshes[i]);
let mesh_j = Arc::clone(&meshes[j]);
let (id_i, _, _addr_i) = identities[i];
let (id_j, pubkey_j, addr_j) = identities[j];
let accept = async move { mesh_j.accept(id_i).await };
let connect = async move {
tokio::time::sleep(Duration::from_millis(10)).await;
let peer_addr = format!("{addr_j}");
mesh_i.connect(&peer_addr, &pubkey_j, id_j).await
};
let result =
tokio::time::timeout(
handshake_budget,
async move { tokio::join!(accept, connect) },
)
.await;
match result {
Err(_) => {
return Err(ClusterError::Handshake {
from: i,
to: j,
reason: format!("timed out after {}ms", handshake_budget.as_millis()),
});
}
Ok((accept_res, connect_res)) => {
if let Err(e) = accept_res {
return Err(ClusterError::Handshake {
from: i,
to: j,
reason: format!("accept: {e}"),
});
}
if let Err(e) = connect_res {
return Err(ClusterError::Handshake {
from: i,
to: j,
reason: format!("connect: {e}"),
});
}
}
}
}
}
for m in &meshes {
m.start();
}
for m in meshes.iter() {
m.announce_capabilities(crate::capabilities::CapabilitySet::new())
.await
.map_err(|e| ClusterError::MeshBuild(format!("announce_capabilities: {e}")))?;
}
let dispatcher = Arc::new(LoggingDispatcher::new());
let expected_peers: Arc<Vec<NodeId>> =
Arc::new(identities.iter().map(|(id, _, _)| *id).collect());
let mut nodes = Vec::with_capacity(n);
for (i, mesh) in meshes.iter().enumerate() {
let (node_id, public_key, local_addr) = identities[i];
let daemon_runtime = DaemonRuntime::new(Arc::clone(mesh));
daemon_runtime.start().await.map_err(|e| {
ClusterError::MeshBuild(format!("daemon_runtime.start() on node[{i}]: {e}"))
})?;
let migration_source: Arc<dyn MigrationSnapshotSource> =
Arc::new(OrchestratorMigrationSnapshotSource::new(
daemon_runtime.migration_orchestrator_arc(),
));
let mut mesh_cfg = MeshOsConfig::default();
mesh_cfg.this_node = node_id;
mesh_cfg.tick_interval = cfg.meshos_tick_interval;
let sdk = MeshOsDaemonSdk::start_with_verifier_and_migration_source(
mesh_cfg,
Arc::clone(&dispatcher) as Arc<LoggingDispatcher>,
#[cfg(feature = "deck")]
cfg.verifier.clone(),
#[cfg(not(feature = "deck"))]
None,
Some(migration_source),
);
install_mesh_probes(sdk.runtime(), Arc::clone(mesh), Arc::clone(&expected_peers));
nodes.push(ClusterNode {
mesh: Arc::clone(mesh),
sdk: Some(sdk),
daemon_runtime: Some(daemon_runtime),
local_addr,
node_id,
public_key,
});
}
wait_for(
"mesh session table",
cfg.mesh_session_stable_timeout,
cfg.poll_interval,
|| nodes.iter().all(|n| n.mesh.peer_count() == nodes.len() - 1),
)
.await?;
let expected_remote = nodes.len() - 1;
wait_for(
"meshos snapshot.peers fold",
cfg.meshos_snapshot_stable_timeout,
cfg.poll_interval,
|| {
nodes.iter().all(|n| {
n.sdk
.as_ref()
.map(|sdk| sdk.runtime().snapshot().peers.len() == expected_remote)
.unwrap_or(false)
})
},
)
.await?;
Ok(Self {
nodes,
shutdown_called: false,
})
}
pub fn nodes(&self) -> &[ClusterNode] {
&self.nodes
}
pub fn nth(&self, i: usize) -> &ClusterNode {
&self.nodes[i]
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn health(&self) -> ClusterHealth {
let total = self.nodes.len();
let expected_remote = total.saturating_sub(1);
let meshes_with_full_peers = self
.nodes
.iter()
.filter(|n| n.mesh.peer_count() == expected_remote)
.count();
let runtimes_with_full_peers = self
.nodes
.iter()
.filter(|n| {
n.sdk
.as_ref()
.map(|sdk| sdk.runtime().snapshot().peers.len() == expected_remote)
.unwrap_or(false)
})
.count();
ClusterHealth {
total_nodes: total,
meshes_with_full_peers,
runtimes_with_full_peers,
}
}
pub async fn spawn_per_node<D, F>(
&self,
factory: F,
) -> Result<Vec<NodeDaemonHandle>, ClusterError>
where
D: MeshDaemon + 'static,
F: Fn() -> D,
{
self.spawn_where(factory, |_| true).await
}
pub async fn spawn_where<D, F, P>(
&self,
factory: F,
predicate: P,
) -> Result<Vec<NodeDaemonHandle>, ClusterError>
where
D: MeshDaemon + 'static,
F: Fn() -> D,
P: Fn(&ClusterNode) -> bool,
{
let rollback_grace = Duration::from_millis(200);
let mut spawned: Vec<NodeDaemonHandle> = Vec::new();
for (i, node) in self.nodes.iter().enumerate() {
if !predicate(node) {
continue;
}
let sdk = match node.sdk.as_ref() {
Some(sdk) => sdk,
None => {
rollback(spawned, rollback_grace).await;
return Err(ClusterError::Spawn {
node_index: i,
reason: "node sdk already shut down".into(),
});
}
};
let daemon = Box::new(factory());
let keypair = EntityKeypair::generate();
let daemon_id = keypair.origin_hash();
match sdk.register_daemon(daemon, keypair) {
Ok(handle) => spawned.push(NodeDaemonHandle {
node_index: i,
node_id: node.node_id,
daemon_id,
handle,
}),
Err(e) => {
let reason = e.to_string();
rollback(spawned, rollback_grace).await;
return Err(ClusterError::Spawn {
node_index: i,
reason,
});
}
}
}
Ok(spawned)
}
pub async fn shutdown(mut self) -> Result<(), ClusterError> {
if self.shutdown_called {
return Ok(());
}
self.shutdown_called = true;
let runtimes: Vec<DaemonRuntime> = self
.nodes
.iter_mut()
.filter_map(|n| n.daemon_runtime.take())
.collect();
for rt in &runtimes {
let _ = rt.shutdown().await;
}
let sdks: Vec<MeshOsDaemonSdk> =
self.nodes.iter_mut().filter_map(|n| n.sdk.take()).collect();
let results = join_all(
sdks.into_iter()
.map(|sdk| async move { sdk.shutdown().await }),
)
.await;
for r in results {
r.map_err(|e: RuntimeShutdownError| ClusterError::Shutdown(format!("{e:?}")))?;
}
Ok(())
}
pub fn spawn_replica_group<D, F>(
&self,
anchor_index: usize,
kind: &str,
factory: F,
config: ReplicaGroupConfig,
) -> Result<ReplicaGroup, ClusterError>
where
D: ComputeMeshDaemon + 'static,
F: Fn() -> D + Send + Sync + 'static,
{
let anchor = self
.nodes
.get(anchor_index)
.ok_or_else(|| ClusterError::Spawn {
node_index: anchor_index,
reason: "anchor_index out of range".into(),
})?;
let rt = anchor
.daemon_runtime
.as_ref()
.ok_or_else(|| ClusterError::Spawn {
node_index: anchor_index,
reason: "anchor node has no daemon runtime".into(),
})?;
rt.register_factory(kind, move || Box::new(factory()))
.map_err(|e| ClusterError::Spawn {
node_index: anchor_index,
reason: format!("register_factory({kind}): {e:?}"),
})?;
ReplicaGroup::spawn(rt, kind, config).map_err(|e| ClusterError::Spawn {
node_index: anchor_index,
reason: format!("ReplicaGroup::spawn({kind}): {e:?}"),
})
}
}
impl Drop for ClusterHarness {
fn drop(&mut self) {
if self.shutdown_called {
return;
}
eprintln!(
"[net-sdk testing] ClusterHarness dropped without \
explicit shutdown — relying on Drop impls. \
Awaiting `harness.shutdown().await` is the clean path."
);
}
}
async fn rollback(handles: Vec<NodeDaemonHandle>, grace: Duration) {
for h in handles.into_iter().rev() {
let _ = h.handle.graceful_shutdown(grace).await;
}
}
async fn wait_for<F: FnMut() -> bool>(
what: &'static str,
budget: Duration,
poll_interval: Duration,
mut cond: F,
) -> Result<(), ClusterError> {
let start = Instant::now();
loop {
if cond() {
return Ok(());
}
if start.elapsed() >= budget {
return Err(ClusterError::Timeout {
what: what.to_string(),
budget_ms: budget.as_millis() as u64,
});
}
tokio::time::sleep(poll_interval).await;
}
}