#![allow(clippy::collapsible_if)]
use super::types::{ShardId, ShardState};
use crate::core::hlc::HybridTimestamp;
use crate::core::id::{EdgeId, NodeId, TxId};
use std::collections::HashMap;
use std::fmt;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(missing_docs)]
pub enum NetworkError {
ConnectionFailed { shard_id: ShardId, reason: String },
Timeout {
shard_id: ShardId,
operation: String,
duration: Duration,
},
CircuitOpen {
shard_id: ShardId,
remaining: Duration,
},
ShardUnavailable(ShardId),
SerializationError(String),
ProtocolError(String),
PoolExhausted {
shard_id: ShardId,
max_connections: usize,
},
}
impl fmt::Display for NetworkError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
NetworkError::ConnectionFailed { shard_id, reason } => {
write!(f, "Connection to {} failed: {}", shard_id, reason)
}
NetworkError::Timeout {
shard_id,
operation,
duration,
} => {
write!(
f,
"Operation '{}' on {} timed out after {:?}",
operation, shard_id, duration
)
}
NetworkError::CircuitOpen {
shard_id,
remaining,
} => {
write!(
f,
"Circuit breaker open for {}, {} remaining",
shard_id,
remaining.as_secs()
)
}
NetworkError::ShardUnavailable(shard_id) => {
write!(f, "Shard {} is unavailable", shard_id)
}
NetworkError::SerializationError(msg) => {
write!(f, "Serialization error: {}", msg)
}
NetworkError::ProtocolError(msg) => {
write!(f, "Protocol error: {}", msg)
}
NetworkError::PoolExhausted {
shard_id,
max_connections,
} => {
write!(
f,
"Connection pool exhausted for {} (max: {})",
shard_id, max_connections
)
}
}
}
}
impl std::error::Error for NetworkError {}
pub type NetworkResult<T> = Result<T, NetworkError>;
#[derive(Debug, Clone)]
pub struct PrepareResponse {
pub ready: bool,
pub reason: Option<String>,
}
#[derive(Debug, Clone)]
pub struct CommitResponse {
pub success: bool,
}
#[derive(Debug, Clone)]
pub struct AbortResponse {
pub acknowledged: bool,
}
#[derive(Debug, Clone)]
pub struct NodeData {
pub id: NodeId,
pub label: String,
pub properties: Vec<u8>,
pub valid_from: u64,
pub valid_to: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct EdgeData {
pub id: EdgeId,
pub source: NodeId,
pub target: NodeId,
pub label: String,
pub properties: Vec<u8>,
pub valid_from: u64,
pub valid_to: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct MigrationBatch {
pub migration_id: u64,
pub batch_number: u64,
pub is_last: bool,
pub nodes: Vec<NodeData>,
pub edges: Vec<EdgeData>,
pub checksum: u64,
}
#[derive(Debug, Clone)]
pub struct MigrationResponse {
pub accepted: bool,
pub nodes_written: u64,
pub edges_written: u64,
pub error: Option<String>,
}
pub trait ShardClient: Send + Sync + fmt::Debug {
fn shard_id(&self) -> ShardId;
fn is_healthy(&self) -> bool;
fn prepare(
&self,
tx_id: TxId,
operations: &[u8],
timestamp: Option<HybridTimestamp>,
) -> NetworkResult<PrepareResponse>;
fn commit(
&self,
tx_id: TxId,
timestamp: Option<HybridTimestamp>,
) -> NetworkResult<CommitResponse>;
fn abort(&self, tx_id: TxId) -> NetworkResult<AbortResponse>;
fn query(&self, query_id: u64, query_data: &[u8]) -> NetworkResult<Vec<u8>>;
fn get_state(&self) -> NetworkResult<ShardState>;
fn receive_migration_batch(&self, batch: MigrationBatch) -> NetworkResult<MigrationResponse>;
fn extract_migration_batch(
&self,
migration_id: u64,
labels: &[String],
batch_size: usize,
offset: u64,
) -> NetworkResult<MigrationBatch>;
fn health_check(&self) -> NetworkResult<()>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: usize,
pub open_duration: Duration,
pub success_threshold: usize,
pub failure_window: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
open_duration: Duration::from_secs(30),
success_threshold: 3,
failure_window: Duration::from_secs(60),
}
}
}
#[derive(Debug)]
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: RwLock<CircuitState>,
failure_count: AtomicUsize,
success_count: AtomicUsize,
last_failure: RwLock<Option<Instant>>,
opened_at: RwLock<Option<Instant>>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: RwLock::new(CircuitState::Closed),
failure_count: AtomicUsize::new(0),
success_count: AtomicUsize::new(0),
last_failure: RwLock::new(None),
opened_at: RwLock::new(None),
}
}
pub fn state(&self) -> CircuitState {
self.maybe_transition();
self.state
.read()
.map(|s| *s)
.unwrap_or(CircuitState::Closed)
}
pub fn should_allow(&self) -> bool {
self.maybe_transition();
let state = self
.state
.read()
.map(|s| *s)
.unwrap_or(CircuitState::Closed);
matches!(state, CircuitState::Closed | CircuitState::HalfOpen)
}
pub fn record_success(&self) {
let state = match self.state.read() {
Ok(s) => *s,
Err(_) => return, };
match state {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::SeqCst);
}
CircuitState::HalfOpen => {
let successes = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
if successes >= self.config.success_threshold {
if let Ok(mut s) = self.state.write() {
*s = CircuitState::Closed;
}
self.failure_count.store(0, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
}
}
CircuitState::Open => {
}
}
}
pub fn record_failure(&self) {
let state = match self.state.read() {
Ok(s) => *s,
Err(_) => return, };
match state {
CircuitState::Closed => {
if let Ok(last) = self.last_failure.read() {
if let Some(last_time) = *last {
if last_time.elapsed() > self.config.failure_window {
self.failure_count.store(0, Ordering::SeqCst);
}
}
}
let failures = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
if let Ok(mut last) = self.last_failure.write() {
*last = Some(Instant::now());
}
if failures >= self.config.failure_threshold {
if let Ok(mut s) = self.state.write() {
*s = CircuitState::Open;
}
if let Ok(mut opened) = self.opened_at.write() {
*opened = Some(Instant::now());
}
}
}
CircuitState::HalfOpen => {
if let Ok(mut s) = self.state.write() {
*s = CircuitState::Open;
}
if let Ok(mut opened) = self.opened_at.write() {
*opened = Some(Instant::now());
}
self.success_count.store(0, Ordering::SeqCst);
}
CircuitState::Open => {
if let Ok(mut opened) = self.opened_at.write() {
*opened = Some(Instant::now());
}
}
}
}
fn maybe_transition(&self) {
let state = match self.state.read() {
Ok(s) => *s,
Err(_) => return, };
if state == CircuitState::Open {
if let Ok(opened) = self.opened_at.read() {
if let Some(opened_time) = *opened {
if opened_time.elapsed() >= self.config.open_duration {
if let Ok(mut s) = self.state.write() {
*s = CircuitState::HalfOpen;
}
self.success_count.store(0, Ordering::SeqCst);
}
}
}
}
}
pub fn remaining_open_time(&self) -> Option<Duration> {
let state = self.state.read().ok()?;
if *state != CircuitState::Open {
return None;
}
if let Ok(opened) = self.opened_at.read() {
if let Some(opened_time) = *opened {
let elapsed = opened_time.elapsed();
if elapsed < self.config.open_duration {
return Some(self.config.open_duration - elapsed);
}
}
}
None
}
pub fn reset(&self) {
if let Ok(mut s) = self.state.write() {
*s = CircuitState::Closed;
}
self.failure_count.store(0, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
}
}
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_connections: usize,
pub min_idle: usize,
pub connect_timeout: Duration,
pub idle_timeout: Duration,
pub max_lifetime: Duration,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_connections: 10,
min_idle: 2,
connect_timeout: Duration::from_secs(5),
idle_timeout: Duration::from_secs(300),
max_lifetime: Duration::from_secs(3600),
}
}
}
#[derive(Debug)]
pub struct ConnectionPool<C: ShardClient> {
shard_id: ShardId,
config: PoolConfig,
connections: RwLock<Vec<Arc<C>>>,
active_count: AtomicUsize,
circuit_breaker: CircuitBreaker,
total_requests: AtomicU64,
failed_requests: AtomicU64,
}
impl<C: ShardClient> ConnectionPool<C> {
pub fn new(
shard_id: ShardId,
config: PoolConfig,
circuit_config: CircuitBreakerConfig,
) -> Self {
Self {
shard_id,
config,
connections: RwLock::new(Vec::new()),
active_count: AtomicUsize::new(0),
circuit_breaker: CircuitBreaker::new(circuit_config),
total_requests: AtomicU64::new(0),
failed_requests: AtomicU64::new(0),
}
}
pub fn shard_id(&self) -> ShardId {
self.shard_id
}
pub fn is_available(&self) -> bool {
self.circuit_breaker.should_allow()
}
pub fn circuit_state(&self) -> CircuitState {
self.circuit_breaker.state()
}
pub fn get(&self) -> NetworkResult<Arc<C>> {
if !self.circuit_breaker.should_allow() {
return Err(NetworkError::CircuitOpen {
shard_id: self.shard_id,
remaining: self
.circuit_breaker
.remaining_open_time()
.unwrap_or_default(),
});
}
if let Ok(connections) = self.connections.read() {
for conn in connections.iter() {
if conn.is_healthy() {
self.active_count.fetch_add(1, Ordering::SeqCst);
return Ok(Arc::clone(conn));
}
}
}
let active = self.active_count.load(Ordering::SeqCst);
if active >= self.config.max_connections {
return Err(NetworkError::PoolExhausted {
shard_id: self.shard_id,
max_connections: self.config.max_connections,
});
}
Err(NetworkError::ShardUnavailable(self.shard_id))
}
pub fn release(&self, _conn: Arc<C>) {
self.active_count.fetch_sub(1, Ordering::SeqCst);
}
pub fn add(&self, conn: Arc<C>) {
if let Ok(mut connections) = self.connections.write() {
connections.push(conn);
}
}
pub fn record_success(&self) {
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.circuit_breaker.record_success();
}
pub fn record_failure(&self) {
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.failed_requests.fetch_add(1, Ordering::Relaxed);
self.circuit_breaker.record_failure();
}
pub fn stats(&self) -> PoolStats {
PoolStats {
shard_id: self.shard_id,
total_connections: self.connections.read().map(|c| c.len()).unwrap_or(0),
active_connections: self.active_count.load(Ordering::SeqCst),
circuit_state: self.circuit_breaker.state(),
total_requests: self.total_requests.load(Ordering::Relaxed),
failed_requests: self.failed_requests.load(Ordering::Relaxed),
}
}
pub fn reset_circuit(&self) {
self.circuit_breaker.reset();
}
}
#[derive(Debug, Clone)]
pub struct PoolStats {
pub shard_id: ShardId,
pub total_connections: usize,
pub active_connections: usize,
pub circuit_state: CircuitState,
pub total_requests: u64,
pub failed_requests: u64,
}
impl PoolStats {
pub fn success_rate(&self) -> f64 {
if self.total_requests == 0 {
1.0
} else {
1.0 - (self.failed_requests as f64 / self.total_requests as f64)
}
}
}
#[derive(Debug)]
pub struct MockShardClient {
shard_id: ShardId,
healthy: RwLock<bool>,
prepare_response: RwLock<Option<PrepareResponse>>,
commit_response: RwLock<Option<CommitResponse>>,
abort_response: RwLock<Option<AbortResponse>>,
query_response: RwLock<Vec<u8>>,
state: RwLock<ShardState>,
latency: RwLock<Duration>,
fail_next: RwLock<Option<NetworkError>>,
call_counts: RwLock<HashMap<String, usize>>,
}
impl MockShardClient {
pub fn new(shard_id: ShardId) -> Self {
Self {
shard_id,
healthy: RwLock::new(true),
prepare_response: RwLock::new(Some(PrepareResponse {
ready: true,
reason: None,
})),
commit_response: RwLock::new(Some(CommitResponse { success: true })),
abort_response: RwLock::new(Some(AbortResponse { acknowledged: true })),
query_response: RwLock::new(Vec::new()),
state: RwLock::new(ShardState::new(shard_id)),
latency: RwLock::new(Duration::from_micros(100)),
fail_next: RwLock::new(None),
call_counts: RwLock::new(HashMap::new()),
}
}
pub fn set_healthy(&self, healthy: bool) {
*self.healthy.write().unwrap() = healthy;
}
pub fn set_prepare_response(&self, response: PrepareResponse) {
*self.prepare_response.write().unwrap() = Some(response);
}
pub fn set_commit_response(&self, response: CommitResponse) {
*self.commit_response.write().unwrap() = Some(response);
}
pub fn set_abort_response(&self, response: AbortResponse) {
*self.abort_response.write().unwrap() = Some(response);
}
pub fn set_query_response(&self, response: Vec<u8>) {
*self.query_response.write().unwrap() = response;
}
pub fn set_state(&self, state: ShardState) {
*self.state.write().unwrap() = state;
}
pub fn set_latency(&self, latency: Duration) {
*self.latency.write().unwrap() = latency;
}
pub fn fail_next(&self, error: NetworkError) {
*self.fail_next.write().unwrap() = Some(error);
}
pub fn call_count(&self, method: &str) -> usize {
self.call_counts
.read()
.unwrap()
.get(method)
.copied()
.unwrap_or(0)
}
fn increment_call(&self, method: &str) {
let mut counts = self.call_counts.write().unwrap();
*counts.entry(method.to_string()).or_insert(0) += 1;
}
fn check_fail(&self) -> NetworkResult<()> {
let mut fail = self.fail_next.write().unwrap();
if let Some(err) = fail.take() {
return Err(err);
}
Ok(())
}
fn simulate_latency(&self) {
let latency = *self.latency.read().unwrap();
if latency > Duration::ZERO {
std::thread::sleep(latency);
}
}
}
impl ShardClient for MockShardClient {
fn shard_id(&self) -> ShardId {
self.shard_id
}
fn is_healthy(&self) -> bool {
*self.healthy.read().unwrap()
}
fn prepare(
&self,
_tx_id: TxId,
_operations: &[u8],
_timestamp: Option<HybridTimestamp>,
) -> NetworkResult<PrepareResponse> {
self.increment_call("prepare");
self.check_fail()?;
if !self.is_healthy() {
return Err(NetworkError::ShardUnavailable(self.shard_id));
}
self.simulate_latency();
self.prepare_response
.read()
.unwrap()
.clone()
.ok_or(NetworkError::ProtocolError("No response configured".into()))
}
fn commit(
&self,
_tx_id: TxId,
_timestamp: Option<HybridTimestamp>,
) -> NetworkResult<CommitResponse> {
self.increment_call("commit");
self.check_fail()?;
if !self.is_healthy() {
return Err(NetworkError::ShardUnavailable(self.shard_id));
}
self.simulate_latency();
self.commit_response
.read()
.unwrap()
.clone()
.ok_or(NetworkError::ProtocolError("No response configured".into()))
}
fn abort(&self, _tx_id: TxId) -> NetworkResult<AbortResponse> {
self.increment_call("abort");
self.check_fail()?;
if !self.is_healthy() {
return Err(NetworkError::ShardUnavailable(self.shard_id));
}
self.simulate_latency();
self.abort_response
.read()
.unwrap()
.clone()
.ok_or(NetworkError::ProtocolError("No response configured".into()))
}
fn query(&self, _query_id: u64, _query_data: &[u8]) -> NetworkResult<Vec<u8>> {
self.increment_call("query");
self.check_fail()?;
if !self.is_healthy() {
return Err(NetworkError::ShardUnavailable(self.shard_id));
}
self.simulate_latency();
Ok(self.query_response.read().unwrap().clone())
}
fn get_state(&self) -> NetworkResult<ShardState> {
self.increment_call("get_state");
self.check_fail()?;
if !self.is_healthy() {
return Err(NetworkError::ShardUnavailable(self.shard_id));
}
self.simulate_latency();
Ok(self.state.read().unwrap().clone())
}
fn receive_migration_batch(&self, batch: MigrationBatch) -> NetworkResult<MigrationResponse> {
self.increment_call("receive_migration_batch");
self.check_fail()?;
if !self.is_healthy() {
return Err(NetworkError::ShardUnavailable(self.shard_id));
}
self.simulate_latency();
Ok(MigrationResponse {
accepted: true,
nodes_written: batch.nodes.len() as u64,
edges_written: batch.edges.len() as u64,
error: None,
})
}
fn extract_migration_batch(
&self,
migration_id: u64,
_labels: &[String],
batch_size: usize,
offset: u64,
) -> NetworkResult<MigrationBatch> {
self.increment_call("extract_migration_batch");
self.check_fail()?;
if !self.is_healthy() {
return Err(NetworkError::ShardUnavailable(self.shard_id));
}
self.simulate_latency();
Ok(MigrationBatch {
migration_id,
batch_number: offset / batch_size as u64,
is_last: true,
nodes: Vec::new(),
edges: Vec::new(),
checksum: 0,
})
}
fn health_check(&self) -> NetworkResult<()> {
self.increment_call("health_check");
self.check_fail()?;
if !self.is_healthy() {
return Err(NetworkError::ShardUnavailable(self.shard_id));
}
self.simulate_latency();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_shard_id(id: u16) -> ShardId {
ShardId::new(id).unwrap()
}
#[test]
fn test_network_error_display() {
let err = NetworkError::ConnectionFailed {
shard_id: make_shard_id(0),
reason: "refused".to_string(),
};
assert!(format!("{}", err).contains("refused"));
let err = NetworkError::Timeout {
shard_id: make_shard_id(1),
operation: "query".to_string(),
duration: Duration::from_secs(30),
};
assert!(format!("{}", err).contains("query"));
assert!(format!("{}", err).contains("timed out"));
let err = NetworkError::CircuitOpen {
shard_id: make_shard_id(2),
remaining: Duration::from_secs(10),
};
assert!(format!("{}", err).contains("Circuit breaker"));
let err = NetworkError::PoolExhausted {
shard_id: make_shard_id(3),
max_connections: 10,
};
assert!(format!("{}", err).contains("exhausted"));
}
#[test]
fn test_circuit_breaker_initial_state() {
let cb = CircuitBreaker::new(CircuitBreakerConfig::default());
assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.should_allow());
}
#[test]
fn test_circuit_breaker_opens_on_failures() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
};
let cb = CircuitBreaker::new(config);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.should_allow());
}
#[test]
fn test_circuit_breaker_success_resets_count() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
};
let cb = CircuitBreaker::new(config);
cb.record_failure();
cb.record_failure();
cb.record_success();
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure(); assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn test_circuit_breaker_half_open_transition() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_millis(10),
success_threshold: 2,
..Default::default()
};
let cb = CircuitBreaker::new(config);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
std::thread::sleep(Duration::from_millis(20));
assert_eq!(cb.state(), CircuitState::HalfOpen);
assert!(cb.should_allow());
}
#[test]
fn test_circuit_breaker_closes_from_half_open() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_millis(10),
success_threshold: 2,
..Default::default()
};
let cb = CircuitBreaker::new(config);
cb.record_failure();
std::thread::sleep(Duration::from_millis(20));
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_failure_in_half_open() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_millis(10),
success_threshold: 2,
..Default::default()
};
let cb = CircuitBreaker::new(config);
cb.record_failure();
std::thread::sleep(Duration::from_millis(20));
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn test_circuit_breaker_reset() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
..Default::default()
};
let cb = CircuitBreaker::new(config);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
cb.reset();
assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.should_allow());
}
#[test]
fn test_circuit_breaker_remaining_time() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_secs(30),
..Default::default()
};
let cb = CircuitBreaker::new(config);
assert!(cb.remaining_open_time().is_none());
cb.record_failure();
let remaining = cb.remaining_open_time();
assert!(remaining.is_some());
assert!(remaining.unwrap() <= Duration::from_secs(30));
}
#[test]
fn test_pool_creation() {
let pool: ConnectionPool<MockShardClient> = ConnectionPool::new(
make_shard_id(0),
PoolConfig::default(),
CircuitBreakerConfig::default(),
);
assert_eq!(pool.shard_id(), make_shard_id(0));
assert!(pool.is_available());
assert_eq!(pool.circuit_state(), CircuitState::Closed);
}
#[test]
fn test_pool_add_and_get() {
let shard_id = make_shard_id(0);
let pool: ConnectionPool<MockShardClient> = ConnectionPool::new(
shard_id,
PoolConfig::default(),
CircuitBreakerConfig::default(),
);
let client = Arc::new(MockShardClient::new(shard_id));
pool.add(Arc::clone(&client));
let conn = pool.get().unwrap();
assert_eq!(conn.shard_id(), shard_id);
}
#[test]
fn test_pool_circuit_breaker_integration() {
let shard_id = make_shard_id(0);
let pool: ConnectionPool<MockShardClient> = ConnectionPool::new(
shard_id,
PoolConfig::default(),
CircuitBreakerConfig {
failure_threshold: 2,
..Default::default()
},
);
let client = Arc::new(MockShardClient::new(shard_id));
pool.add(client);
pool.record_failure();
assert!(pool.is_available());
pool.record_failure();
assert!(!pool.is_available());
let result = pool.get();
assert!(matches!(result, Err(NetworkError::CircuitOpen { .. })));
}
#[test]
fn test_pool_stats() {
let shard_id = make_shard_id(0);
let pool: ConnectionPool<MockShardClient> = ConnectionPool::new(
shard_id,
PoolConfig::default(),
CircuitBreakerConfig::default(),
);
let client = Arc::new(MockShardClient::new(shard_id));
pool.add(client);
pool.record_success();
pool.record_success();
pool.record_failure();
let stats = pool.stats();
assert_eq!(stats.total_requests, 3);
assert_eq!(stats.failed_requests, 1);
assert!((stats.success_rate() - 0.666).abs() < 0.01);
}
#[test]
fn test_mock_client_creation() {
let client = MockShardClient::new(make_shard_id(0));
assert!(client.is_healthy());
assert_eq!(client.shard_id(), make_shard_id(0));
}
#[test]
fn test_mock_client_prepare() {
let client = MockShardClient::new(make_shard_id(0));
let response = client.prepare(TxId::new(1), &[], None).unwrap();
assert!(response.ready);
assert_eq!(client.call_count("prepare"), 1);
}
#[test]
fn test_mock_client_commit() {
let client = MockShardClient::new(make_shard_id(0));
let response = client.commit(TxId::new(1), None).unwrap();
assert!(response.success);
assert_eq!(client.call_count("commit"), 1);
}
#[test]
fn test_mock_client_abort() {
let client = MockShardClient::new(make_shard_id(0));
let response = client.abort(TxId::new(1)).unwrap();
assert!(response.acknowledged);
assert_eq!(client.call_count("abort"), 1);
}
#[test]
fn test_mock_client_unhealthy() {
let client = MockShardClient::new(make_shard_id(0));
client.set_healthy(false);
let result = client.prepare(TxId::new(1), &[], None);
assert!(matches!(result, Err(NetworkError::ShardUnavailable(_))));
}
#[test]
fn test_mock_client_fail_next() {
let client = MockShardClient::new(make_shard_id(0));
client.fail_next(NetworkError::Timeout {
shard_id: make_shard_id(0),
operation: "test".to_string(),
duration: Duration::from_secs(30),
});
let result = client.prepare(TxId::new(1), &[], None);
assert!(matches!(result, Err(NetworkError::Timeout { .. })));
let result = client.prepare(TxId::new(2), &[], None);
assert!(result.is_ok());
}
#[test]
fn test_mock_client_custom_responses() {
let client = MockShardClient::new(make_shard_id(0));
client.set_prepare_response(PrepareResponse {
ready: false,
reason: Some("test".to_string()),
});
let response = client.prepare(TxId::new(1), &[], None).unwrap();
assert!(!response.ready);
assert_eq!(response.reason, Some("test".to_string()));
}
#[test]
fn test_mock_client_get_state() {
let client = MockShardClient::new(make_shard_id(0));
let mut state = ShardState::new(make_shard_id(0));
state.node_count = 1000;
client.set_state(state);
let retrieved = client.get_state().unwrap();
assert_eq!(retrieved.node_count, 1000);
}
#[test]
fn test_mock_client_migration() {
let client = MockShardClient::new(make_shard_id(0));
let batch = MigrationBatch {
migration_id: 1,
batch_number: 0,
is_last: false,
nodes: vec![NodeData {
id: NodeId::new(1).unwrap(),
label: "Person".to_string(),
properties: vec![],
valid_from: 0,
valid_to: None,
}],
edges: vec![],
checksum: 12345,
};
let response = client.receive_migration_batch(batch).unwrap();
assert!(response.accepted);
assert_eq!(response.nodes_written, 1);
}
#[test]
fn test_mock_client_extract_migration() {
let client = MockShardClient::new(make_shard_id(0));
let batch = client
.extract_migration_batch(1, &["Person".to_string()], 100, 0)
.unwrap();
assert!(batch.is_last);
assert!(batch.nodes.is_empty());
}
#[test]
fn test_mock_client_health_check() {
let client = MockShardClient::new(make_shard_id(0));
assert!(client.health_check().is_ok());
client.set_healthy(false);
assert!(client.health_check().is_err());
}
}