use std::net::{SocketAddr, UdpSocket};
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use chitchat::transport::{Socket, Transport, UdpTransport};
use object_store::{
path::Path as OsPath, CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload,
ObjectMeta, ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult,
};
use parking_lot::Mutex;
use rustc_hash::FxHashSet;
use tokio::sync::watch;
use super::control::{AssignmentSnapshotStore, ChitchatKv, ClusterController, ClusterKv};
use super::discovery::{
Discovery, GossipDiscovery, GossipDiscoveryConfig, NodeId, NodeInfo, NodeMetadata, NodeState,
};
pub struct NetworkRules {
dropped: Mutex<FxHashSet<(SocketAddr, SocketAddr)>>,
}
impl std::fmt::Debug for NetworkRules {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NetworkRules")
.field("drop_count", &self.dropped.lock().len())
.finish()
}
}
impl NetworkRules {
#[must_use]
pub fn new() -> Self {
Self {
dropped: Mutex::new(FxHashSet::default()),
}
}
pub fn partition(&self, side_a: &[SocketAddr], side_b: &[SocketAddr]) {
let mut set = self.dropped.lock();
for a in side_a {
for b in side_b {
set.insert((*a, *b));
set.insert((*b, *a));
}
}
}
pub fn drop_pair(&self, src: SocketAddr, dst: SocketAddr) {
self.dropped.lock().insert((src, dst));
}
pub fn heal(&self) {
self.dropped.lock().clear();
}
#[must_use]
pub fn is_dropped(&self, src: SocketAddr, dst: SocketAddr) -> bool {
self.dropped.lock().contains(&(src, dst))
}
}
impl Default for NetworkRules {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ObjectStoreFault {
None,
FailWrites,
FailReads,
FailAll,
}
impl ObjectStoreFault {
fn fails_writes(self) -> bool {
matches!(self, Self::FailWrites | Self::FailAll)
}
fn fails_reads(self) -> bool {
matches!(self, Self::FailReads | Self::FailAll)
}
}
pub struct FaultyObjectStore {
inner: Arc<dyn ObjectStore>,
fault: Mutex<ObjectStoreFault>,
}
impl std::fmt::Debug for FaultyObjectStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FaultyObjectStore")
.field("fault", &self.fault())
.finish_non_exhaustive()
}
}
impl std::fmt::Display for FaultyObjectStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "FaultyObjectStore({:?})", self.fault())
}
}
impl FaultyObjectStore {
#[must_use]
pub fn new(inner: Arc<dyn ObjectStore>) -> Self {
Self {
inner,
fault: Mutex::new(ObjectStoreFault::None),
}
}
#[must_use]
pub fn fault(&self) -> ObjectStoreFault {
*self.fault.lock()
}
pub fn set_fault(&self, mode: ObjectStoreFault) {
*self.fault.lock() = mode;
}
fn check_write(&self) -> object_store::Result<()> {
if self.fault().fails_writes() {
return Err(object_store::Error::Generic {
store: "FaultyObjectStore",
source: "injected write failure".into(),
});
}
Ok(())
}
fn check_read(&self, path: &OsPath) -> object_store::Result<()> {
if self.fault().fails_reads() {
return Err(object_store::Error::NotFound {
path: path.to_string(),
source: "injected read failure".into(),
});
}
Ok(())
}
}
#[async_trait]
impl ObjectStore for FaultyObjectStore {
async fn put_opts(
&self,
location: &OsPath,
payload: PutPayload,
opts: PutOptions,
) -> object_store::Result<PutResult> {
self.check_write()?;
self.inner.put_opts(location, payload, opts).await
}
async fn put_multipart_opts(
&self,
location: &OsPath,
opts: PutMultipartOptions,
) -> object_store::Result<Box<dyn MultipartUpload>> {
self.check_write()?;
self.inner.put_multipart_opts(location, opts).await
}
async fn get_opts(
&self,
location: &OsPath,
options: GetOptions,
) -> object_store::Result<GetResult> {
self.check_read(location)?;
self.inner.get_opts(location, options).await
}
fn delete_stream(
&self,
locations: futures::stream::BoxStream<'static, object_store::Result<OsPath>>,
) -> futures::stream::BoxStream<'static, object_store::Result<OsPath>> {
if self.fault().fails_writes() {
use futures::StreamExt;
locations
.map(|_| {
Err(object_store::Error::Generic {
store: "FaultyObjectStore",
source: "injected write failure (delete_stream)".into(),
})
})
.boxed()
} else {
self.inner.delete_stream(locations)
}
}
fn list(
&self,
prefix: Option<&OsPath>,
) -> futures::stream::BoxStream<'static, object_store::Result<ObjectMeta>> {
self.inner.list(prefix)
}
async fn list_with_delimiter(
&self,
prefix: Option<&OsPath>,
) -> object_store::Result<ListResult> {
self.inner.list_with_delimiter(prefix).await
}
async fn copy_opts(
&self,
from: &OsPath,
to: &OsPath,
options: CopyOptions,
) -> object_store::Result<()> {
self.check_write()?;
self.inner.copy_opts(from, to, options).await
}
}
pub struct PartitionableTransport {
rules: Arc<NetworkRules>,
inner: UdpTransport,
}
impl PartitionableTransport {
#[must_use]
pub fn new(rules: Arc<NetworkRules>) -> Self {
Self {
rules,
inner: UdpTransport,
}
}
}
#[async_trait]
impl Transport for PartitionableTransport {
async fn open(&self, listen_addr: SocketAddr) -> anyhow::Result<Box<dyn Socket>> {
let socket = self.inner.open(listen_addr).await?;
Ok(Box::new(PartitionableSocket {
my_addr: listen_addr,
rules: Arc::clone(&self.rules),
inner: socket,
}))
}
}
struct PartitionableSocket {
my_addr: SocketAddr,
rules: Arc<NetworkRules>,
inner: Box<dyn Socket>,
}
#[async_trait]
impl Socket for PartitionableSocket {
async fn send(&mut self, to: SocketAddr, msg: chitchat::ChitchatMessage) -> anyhow::Result<()> {
if self.rules.is_dropped(self.my_addr, to) {
return Ok(());
}
self.inner.send(to, msg).await
}
async fn recv(&mut self) -> anyhow::Result<(SocketAddr, chitchat::ChitchatMessage)> {
self.inner.recv().await
}
}
fn grab_port() -> u16 {
let sock = UdpSocket::bind("127.0.0.1:0").expect("bind 127.0.0.1:0");
let port = sock.local_addr().expect("local_addr").port();
drop(sock);
port
}
pub struct NodeHandle {
pub instance_id: NodeId,
pub gossip_addr: String,
pub controller: Arc<ClusterController>,
discovery: GossipDiscovery,
}
impl NodeHandle {
pub async fn kill(mut self) {
let left = NodeInfo {
state: NodeState::Left,
..current_info(&self)
};
let _ = self.discovery.announce(left).await;
tokio::time::sleep(Duration::from_millis(150)).await;
let _ = self.discovery.stop().await;
}
pub async fn crash(mut self) {
let _ = self.discovery.stop().await;
}
}
fn current_info(node: &NodeHandle) -> NodeInfo {
NodeInfo {
id: node.instance_id,
name: format!("minicluster-n{}", node.instance_id.0),
rpc_address: String::new(),
raft_address: String::new(),
state: NodeState::Active,
metadata: NodeMetadata {
cores: 1,
..NodeMetadata::default()
},
last_heartbeat_ms: 0,
}
}
pub struct MiniCluster {
pub nodes: Vec<NodeHandle>,
pub rules: Option<Arc<NetworkRules>>,
pub snapshot: Option<Arc<AssignmentSnapshotStore>>,
}
impl MiniCluster {
pub async fn spawn(n: usize) -> Self {
Self::spawn_inner(n, None, None).await
}
pub async fn spawn_partitionable(n: usize) -> Self {
let rules = Arc::new(NetworkRules::new());
Self::spawn_inner(n, Some(rules), None).await
}
pub async fn spawn_with_snapshot(n: usize, snapshot: Arc<AssignmentSnapshotStore>) -> Self {
Self::spawn_inner(n, None, Some(snapshot)).await
}
pub async fn join_node(&mut self, instance_id: NodeId) {
assert!(!self.nodes.is_empty(), "cannot join empty cluster");
let seeds: Vec<String> = self.nodes.iter().map(|n| n.gossip_addr.clone()).collect();
let port = grab_port();
let gossip_addr = format!("127.0.0.1:{port}");
let local_node = NodeInfo {
id: instance_id,
name: format!("minicluster-rejoin-{}", instance_id.0),
rpc_address: String::new(),
raft_address: String::new(),
state: NodeState::Active,
metadata: NodeMetadata {
cores: 1,
..NodeMetadata::default()
},
last_heartbeat_ms: 0,
};
let cfg = GossipDiscoveryConfig {
gossip_address: gossip_addr.clone(),
seed_nodes: seeds,
gossip_interval: Duration::from_millis(50),
phi_threshold: 3.0,
dead_node_grace_period: Duration::from_secs(1),
cluster_id: "minicluster".to_string(),
node_id: instance_id,
local_node,
advertise_host: None,
};
let mut discovery = GossipDiscovery::new(cfg);
match &self.rules {
Some(rules) => {
let transport = PartitionableTransport::new(Arc::clone(rules));
discovery
.start_with_transport(&transport)
.await
.expect("partitionable chitchat start on rejoin");
}
None => discovery.start().await.expect("chitchat start on rejoin"),
}
let handle = discovery
.chitchat_handle()
.expect("chitchat handle available after start");
let kv: Arc<dyn ClusterKv> = Arc::new(ChitchatKv::from_handle(handle));
let members_rx = discovery.membership_watch();
let controller = Arc::new(ClusterController::new(
instance_id,
kv,
self.snapshot.clone(),
members_rx,
));
self.nodes.push(NodeHandle {
instance_id,
gossip_addr,
controller,
discovery,
});
}
async fn spawn_inner(
n: usize,
rules: Option<Arc<NetworkRules>>,
snapshot: Option<Arc<AssignmentSnapshotStore>>,
) -> Self {
assert!(n >= 1, "MiniCluster needs at least one node");
let ports: Vec<u16> = (0..n).map(|_| grab_port()).collect();
let seed = format!("127.0.0.1:{}", ports[0]);
let transport = rules
.as_ref()
.map(|r| PartitionableTransport::new(Arc::clone(r)));
let mut nodes = Vec::with_capacity(n);
for (idx, port) in ports.iter().enumerate() {
let instance_id = NodeId((idx as u64) + 1); let gossip_addr = format!("127.0.0.1:{port}");
let local_node = NodeInfo {
id: instance_id,
name: format!("minicluster-n{idx}"),
rpc_address: String::new(),
raft_address: String::new(),
state: NodeState::Active,
metadata: NodeMetadata {
cores: 1,
..NodeMetadata::default()
},
last_heartbeat_ms: 0,
};
let seeds = if idx == 0 {
Vec::new()
} else {
vec![seed.clone()]
};
let cfg = GossipDiscoveryConfig {
gossip_address: gossip_addr.clone(),
seed_nodes: seeds,
gossip_interval: Duration::from_millis(50),
phi_threshold: 3.0,
dead_node_grace_period: Duration::from_secs(1),
cluster_id: "minicluster".to_string(),
node_id: instance_id,
local_node,
advertise_host: None,
};
let mut discovery = GossipDiscovery::new(cfg);
match &transport {
Some(t) => discovery
.start_with_transport(t)
.await
.expect("partitionable chitchat start"),
None => discovery.start().await.expect("chitchat start on loopback"),
}
let handle = discovery
.chitchat_handle()
.expect("chitchat handle available after start");
let kv: Arc<dyn ClusterKv> = Arc::new(ChitchatKv::from_handle(handle));
let members_rx: watch::Receiver<Vec<NodeInfo>> = discovery.membership_watch();
let controller = Arc::new(ClusterController::new(
instance_id,
kv,
snapshot.clone(),
members_rx,
));
nodes.push(NodeHandle {
instance_id,
gossip_addr,
controller,
discovery,
});
}
Self {
nodes,
rules,
snapshot,
}
}
#[must_use]
pub fn addrs(&self) -> Vec<SocketAddr> {
self.nodes
.iter()
.map(|n| n.gossip_addr.parse().expect("valid gossip_addr"))
.collect()
}
pub async fn wait_for_convergence(&self, deadline: Duration) -> Result<(), String> {
let start = Instant::now();
loop {
let mut all_converged = true;
let mut missing_summary = Vec::new();
for node in &self.nodes {
let peers = node
.discovery
.peers()
.await
.map_err(|e| format!("peers() failed on {}: {e}", node.instance_id.0))?;
let expected = self.nodes.len() - 1;
if peers.len() < expected {
all_converged = false;
missing_summary.push(format!(
"node {} sees {} peers (expected {})",
node.instance_id.0,
peers.len(),
expected
));
}
}
if all_converged {
return Ok(());
}
if start.elapsed() >= deadline {
return Err(format!(
"convergence timeout after {:?}: {}",
deadline,
missing_summary.join("; "),
));
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
pub async fn shutdown(mut self) {
for node in self.nodes.drain(..) {
node.kill().await;
}
}
}