use crate::error::{LinalgError, LinalgResult};
use std::collections::HashMap;
use std::sync::{Arc, Condvar, Mutex};
use std::time::{Duration, Instant};
pub struct DistributedCoordinator {
node_rank: usize,
num_nodes: usize,
sync_state: Arc<Mutex<CoordinationState>>,
communicator: Option<Arc<super::communication::DistributedCommunicator>>,
}
impl DistributedCoordinator {
pub fn new(config: &super::DistributedConfig) -> LinalgResult<Self> {
let sync_state = Arc::new(Mutex::new(CoordinationState::new(_config.num_nodes)));
Ok(Self {
node_rank: config.node_rank,
num_nodes: config.num_nodes,
sync_state,
communicator: None,
})
}
pub fn set_communicator(&mut self, communicator: Arc<super::communication::DistributedCommunicator>) {
self.communicator = Some(communicator);
}
pub fn barrier(&self) -> LinalgResult<()> {
self.barrier_with_timeout(Duration::from_secs(30))
}
pub fn barrier_with_timeout(&self, timeout: Duration) -> LinalgResult<()> {
let start_time = Instant::now();
if let Some(ref comm) = self.communicator {
comm.barrier()
} else {
self.simulate_barrier(timeout)
}
}
pub fn create_distributed_lock(&self, lockname: &str) -> LinalgResult<DistributedLock> {
DistributedLock::new(lock_name.to_string(), self.node_rank, self.num_nodes)
}
pub fn consensus<T>(&self, proposal: T) -> LinalgResult<T>
where
T: Clone + PartialEq + Send + Sync + 'static,
{
if self.node_rank == 0 {
Ok(proposal)
} else {
Ok(proposal) }
}
pub fn checkpoint(&self, checkpointid: u64) -> LinalgResult<()> {
let mut state = self.sync_state.lock().expect("Operation failed");
state.checkpoints.insert(self.node_rank, checkpoint_id);
if state.checkpoints.len() == self.num_nodes {
let min_checkpoint = *state.checkpoints.values().min().unwrap_or(&0);
if min_checkpoint >= checkpoint_id {
state.checkpoints.clear();
return Ok(());
}
}
drop(state);
std::thread::sleep(Duration::from_millis(10));
self.checkpoint(checkpoint_id)
}
pub fn coordinate_reduction(&self, operation: ReductionOperation) -> LinalgResult<ReductionCoordination> {
ReductionCoordination::new(operation, self.node_rank, self.num_nodes)
}
pub fn handle_node_failure(&self, failednode: usize) -> LinalgResult<RecoveryPlan> {
let mut state = self.sync_state.lock().expect("Operation failed");
state.failed_nodes.insert(failed_node);
let remaining_nodes: Vec<usize> = (0..self.num_nodes)
.filter(|&n| !state.failed_nodes.contains(&n))
.collect();
let recovery_plan = RecoveryPlan {
failed_node,
remaining_nodes: remaining_nodes.clone(),
redistribution_required: true,
estimated_recovery_time: Duration::from_secs(30),
};
state.active_nodes = remaining_nodes.len();
Ok(recovery_plan)
}
pub fn get_stats(&self) -> CoordinationStats {
let state = self.sync_state.lock().expect("Operation failed");
CoordinationStats {
active_nodes: state.active_nodes,
failed_nodes: state.failed_nodes.len(),
checkpoint_count: state.checkpoint_count,
barrier_count: state.barrier_count,
total_sync_time: state.total_sync_time,
}
}
fn simulate_barrier(&self, timeout: Duration) -> LinalgResult<()> {
let start_time = Instant::now();
let mut state = self.sync_state.lock().expect("Operation failed");
state.barrier_participants.insert(self.node_rank);
state.barrier_count += 1;
while state.barrier_participants.len() < self.num_nodes - state.failed_nodes.len() {
if start_time.elapsed() > timeout {
return Err(LinalgError::TimeoutError(
"Barrier timeout waiting for nodes".to_string()
));
}
drop(state);
std::thread::sleep(Duration::from_millis(1));
state = self.sync_state.lock().expect("Operation failed");
}
state.barrier_participants.clear();
state.total_sync_time += start_time.elapsed();
Ok(())
}
}
#[derive(Debug)]
struct CoordinationState {
total_nodes: usize,
active_nodes: usize,
failed_nodes: std::collections::HashSet<usize>,
checkpoints: HashMap<usize, u64>,
barrier_participants: std::collections::HashSet<usize>,
checkpoint_count: usize,
barrier_count: usize,
total_sync_time: Duration,
}
impl CoordinationState {
fn new(_totalnodes: usize) -> Self {
Self {
total_nodes,
active_nodes: total_nodes,
failed_nodes: std::collections::HashSet::new(),
checkpoints: HashMap::new(),
barrier_participants: std::collections::HashSet::new(),
checkpoint_count: 0,
barrier_count: 0,
total_sync_time: Duration::default(),
}
}
}
pub struct DistributedLock {
name: String,
owner: Option<usize>,
node_rank: usize,
num_nodes: usize,
state: Arc<Mutex<LockState>>,
}
impl DistributedLock {
pub fn new(_name: String, node_rank: usize, numnodes: usize) -> LinalgResult<Self> {
Ok(Self {
name,
owner: None,
node_rank,
num_nodes,
state: Arc::new(Mutex::new(LockState::Unlocked)),
})
}
pub fn acquire(&mut self) -> LinalgResult<()> {
self.acquire_with_timeout(Duration::from_secs(30))
}
pub fn acquire_with_timeout(&mut self, timeout: Duration) -> LinalgResult<()> {
let start_time = Instant::now();
loop {
let mut state = self.state.lock().expect("Operation failed");
match *state {
LockState::Unlocked => {
*state = LockState::Locked(self.node_rank);
self.owner = Some(self.node_rank);
return Ok(());
}
LockState::Locked(current_owner) if current_owner == self.node_rank => {
return Ok(());
}
LockState::Locked(_) => {
if start_time.elapsed() > timeout {
return Err(LinalgError::TimeoutError(
format!("Failed to acquire lock '{}' within timeout", self.name)
));
}
drop(state);
std::thread::sleep(Duration::from_millis(10));
}
}
}
}
pub fn release(&mut self) -> LinalgResult<()> {
let mut state = self.state.lock().expect("Operation failed");
match *state {
LockState::Locked(owner) if owner == self.node_rank => {
*state = LockState::Unlocked;
self.owner = None;
Ok(())
}
LockState::Locked(_) => {
Err(LinalgError::InvalidOperation(
"Cannot release lock owned by another node".to_string()
))
}
LockState::Unlocked => {
Ok(())
}
}
}
pub fn is_owned(&self) -> bool {
self.owner == Some(self.node_rank)
}
}
#[derive(Debug, Clone, Copy)]
enum LockState {
Unlocked,
Locked(usize),
}
pub struct SynchronizationBarrier {
expected_nodes: usize,
arrived_nodes: Arc<Mutex<std::collections::HashSet<usize>>>,
condition: Arc<Condvar>,
barrier_id: u64,
}
impl SynchronizationBarrier {
pub fn new(_expected_nodes: usize, barrierid: u64) -> Self {
Self {
expected_nodes,
arrived_nodes: Arc::new(Mutex::new(std::collections::HashSet::new())),
condition: Arc::new(Condvar::new()),
barrier_id,
}
}
pub fn wait(&self, noderank: usize) -> LinalgResult<()> {
self.wait_timeout(node_rank, Duration::from_secs(60))
}
pub fn wait_timeout(&self, noderank: usize, timeout: Duration) -> LinalgResult<()> {
let mut arrived = self.arrived_nodes.lock().expect("Operation failed");
arrived.insert(node_rank);
if arrived.len() >= self.expected_nodes {
self.condition.notify_all();
return Ok(());
}
let (_guard, timeout_result) = self.condition
.wait_timeout(arrived, timeout)
.expect("Operation failed");
if timeout_result.timed_out() {
Err(LinalgError::TimeoutError(
format!("Barrier {} timeout waiting for nodes", self.barrier_id)
))
} else {
Ok(())
}
}
pub fn reset(&self) {
let mut arrived = self.arrived_nodes.lock().expect("Operation failed");
arrived.clear();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReductionOperation {
Sum,
Max,
Min,
Average,
And,
Or,
}
pub struct ReductionCoordination {
operation: ReductionOperation,
node_rank: usize,
num_nodes: usize,
tree: ReductionTree,
}
impl ReductionCoordination {
pub fn new(_operation: ReductionOperation, node_rank: usize, numnodes: usize) -> LinalgResult<Self> {
let tree = ReductionTree::new(num_nodes);
Ok(Self {
operation,
node_rank,
num_nodes,
tree,
})
}
pub fn get_receive_nodes(&self) -> Vec<usize> {
self.tree.get_children(self.node_rank)
}
pub fn get_send_node(&self) -> Option<usize> {
self.tree.get_parent(self.node_rank)
}
pub fn is_root(&self) -> bool {
self.node_rank == 0
}
}
struct ReductionTree {
num_nodes: usize,
}
impl ReductionTree {
fn new(_numnodes: usize) -> Self {
Self { _num_nodes }
}
fn get_parent(&self, node: usize) -> Option<usize> {
if node == 0 {
None
} else {
Some((node - 1) / 2)
}
}
fn get_children(&self, node: usize) -> Vec<usize> {
let mut children = Vec::new();
let left_child = 2 * node + 1;
let right_child = 2 * node + 2;
if left_child < self.num_nodes {
children.push(left_child);
}
if right_child < self.num_nodes {
children.push(right_child);
}
children
}
}
#[derive(Debug, Clone)]
pub struct RecoveryPlan {
pub failed_node: usize,
pub remaining_nodes: Vec<usize>,
pub redistribution_required: bool,
pub estimated_recovery_time: Duration,
}
impl RecoveryPlan {
pub fn execute(&self) -> LinalgResult<()> {
Ok(())
}
pub fn get_node_mapping(&self) -> HashMap<usize, usize> {
let mut mapping = HashMap::new();
for (new_rank, &old_rank) in self.remaining_nodes.iter().enumerate() {
mapping.insert(old_rank, new_rank);
}
mapping
}
}
#[derive(Debug, Clone, Default)]
pub struct CoordinationStats {
pub active_nodes: usize,
pub failed_nodes: usize,
pub checkpoint_count: usize,
pub barrier_count: usize,
pub total_sync_time: Duration,
}
impl CoordinationStats {
pub fn avg_sync_time(&self) -> Duration {
if self.barrier_count + self.checkpoint_count > 0 {
self.total_sync_time / (self.barrier_count + self.checkpoint_count) as u32
} else {
Duration::default()
}
}
pub fn availability_ratio(&self) -> f64 {
let total_nodes = self.active_nodes + self.failed_nodes;
if total_nodes > 0 {
self.active_nodes as f64 / total_nodes as f64
} else {
1.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distributed_coordinator() {
use super::super::DistributedConfig;
let config = DistributedConfig::default().with_num_nodes(3).with_node_rank(0);
let coordinator = DistributedCoordinator::new(&config).expect("Operation failed");
let stats = coordinator.get_stats();
assert_eq!(stats.active_nodes, 3);
assert_eq!(stats.failed_nodes, 0);
}
#[test]
fn test_distributed_lock() {
let mut lock = DistributedLock::new("test_lock".to_string(), 0, 2).expect("Operation failed");
assert!(lock.acquire().is_ok());
assert!(lock.is_owned());
assert!(lock.release().is_ok());
assert!(!lock.is_owned());
}
#[test]
fn test_synchronization_barrier() {
let barrier = SynchronizationBarrier::new(2, 1);
let start = Instant::now();
let result = barrier.wait_timeout(0, Duration::from_millis(100));
assert!(result.is_err());
assert!(start.elapsed() >= Duration::from_millis(90));
}
#[test]
fn test_reduction_coordination() {
let reduction = ReductionCoordination::new(ReductionOperation::Sum, 0, 4).expect("Operation failed");
assert!(reduction.is_root());
assert!(reduction.get_send_node().is_none());
assert_eq!(reduction.get_receive_nodes(), vec![1, 2]);
}
#[test]
fn test_recovery_plan() {
let plan = RecoveryPlan {
failed_node: 2,
remaining_nodes: vec![0, 1, 3],
redistribution_required: true,
estimated_recovery_time: Duration::from_secs(30),
};
let mapping = plan.get_node_mapping();
assert_eq!(mapping.len(), 3);
assert_eq!(mapping[&0], 0);
assert_eq!(mapping[&1], 1);
assert_eq!(mapping[&3], 2);
}
#[test]
fn test_reduction_tree() {
let tree = ReductionTree::new(7);
assert_eq!(tree.get_parent(0), None);
assert_eq!(tree.get_children(0), vec![1, 2]);
assert_eq!(tree.get_parent(1), Some(0));
assert_eq!(tree.get_children(1), vec![3, 4]);
assert_eq!(tree.get_parent(6), Some(2));
assert_eq!(tree.get_children(6), vec![]);
}
}