use crate::common::IntegrateFloat;
use crate::distributed::types::{
AckStatus, BoundaryConditions, BoundaryData, ChunkId, ChunkResult, DistributedError,
DistributedMessage, DistributedResult, JobId, NodeCapabilities, NodeId, NodeStatus, WorkChunk,
};
use scirs2_core::ndarray::Array1;
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
pub struct MessageChannel<F: IntegrateFloat> {
outbox: Mutex<VecDeque<(NodeId, DistributedMessage<F>)>>,
inbox: Mutex<VecDeque<(NodeId, DistributedMessage<F>)>>,
next_message_id: AtomicU64,
pending_acks: Mutex<HashMap<u64, (Instant, NodeId)>>,
ack_timeout: Duration,
inbox_cv: Condvar,
inbox_mutex: Mutex<()>,
}
impl<F: IntegrateFloat> MessageChannel<F> {
pub fn new(ack_timeout: Duration) -> Self {
Self {
outbox: Mutex::new(VecDeque::new()),
inbox: Mutex::new(VecDeque::new()),
next_message_id: AtomicU64::new(1),
pending_acks: Mutex::new(HashMap::new()),
ack_timeout,
inbox_cv: Condvar::new(),
inbox_mutex: Mutex::new(()),
}
}
pub fn generate_message_id(&self) -> u64 {
self.next_message_id.fetch_add(1, Ordering::SeqCst)
}
pub fn send(&self, target: NodeId, message: DistributedMessage<F>) -> DistributedResult<u64> {
let message_id = self.generate_message_id();
match self.outbox.lock() {
Ok(mut outbox) => {
outbox.push_back((target, message));
}
Err(_) => {
return Err(DistributedError::CommunicationError(
"Failed to acquire outbox lock".to_string(),
))
}
}
match self.pending_acks.lock() {
Ok(mut pending) => {
pending.insert(message_id, (Instant::now(), target));
}
Err(_) => {
return Err(DistributedError::CommunicationError(
"Failed to track acknowledgment".to_string(),
))
}
}
Ok(message_id)
}
pub fn receive(&self, timeout: Duration) -> Option<(NodeId, DistributedMessage<F>)> {
let deadline = Instant::now() + timeout;
loop {
if let Ok(mut inbox) = self.inbox.lock() {
if let Some(msg) = inbox.pop_front() {
return Some(msg);
}
}
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return None;
}
if let Ok(guard) = self.inbox_mutex.lock() {
let _ = self.inbox_cv.wait_timeout(guard, remaining);
}
}
}
pub fn try_receive(&self) -> Option<(NodeId, DistributedMessage<F>)> {
match self.inbox.lock() {
Ok(mut inbox) => inbox.pop_front(),
Err(_) => None,
}
}
pub fn deliver(&self, source: NodeId, message: DistributedMessage<F>) -> DistributedResult<()> {
match self.inbox.lock() {
Ok(mut inbox) => {
inbox.push_back((source, message));
self.inbox_cv.notify_one();
Ok(())
}
Err(_) => Err(DistributedError::CommunicationError(
"Failed to acquire inbox lock".to_string(),
)),
}
}
pub fn process_ack(&self, message_id: u64, status: AckStatus) -> DistributedResult<()> {
match self.pending_acks.lock() {
Ok(mut pending) => {
if pending.remove(&message_id).is_some() {
if status == AckStatus::Error {
return Err(DistributedError::CommunicationError(
"Message processing failed at remote node".to_string(),
));
}
Ok(())
} else {
Ok(())
}
}
Err(_) => Err(DistributedError::CommunicationError(
"Failed to process acknowledgment".to_string(),
)),
}
}
pub fn check_timeouts(&self) -> Vec<(u64, NodeId)> {
match self.pending_acks.lock() {
Ok(mut pending) => {
let now = Instant::now();
let timed_out: Vec<_> = pending
.iter()
.filter(|(_, (sent_at, _))| now.duration_since(*sent_at) > self.ack_timeout)
.map(|(id, (_, node))| (*id, *node))
.collect();
for (id, _) in &timed_out {
pending.remove(id);
}
timed_out
}
Err(_) => Vec::new(),
}
}
pub fn outbox_size(&self) -> usize {
self.outbox.lock().map(|o| o.len()).unwrap_or(0)
}
pub fn inbox_size(&self) -> usize {
self.inbox.lock().map(|i| i.len()).unwrap_or(0)
}
pub fn drain_outbox(&self) -> Vec<(NodeId, DistributedMessage<F>)> {
match self.outbox.lock() {
Ok(mut outbox) => outbox.drain(..).collect(),
Err(_) => Vec::new(),
}
}
}
pub struct BoundaryExchanger<F: IntegrateFloat> {
received_boundaries: RwLock<HashMap<(ChunkId, ChunkId), BoundaryData<F>>>,
pending_requests: Mutex<HashMap<(ChunkId, ChunkId), Instant>>,
timeout: Duration,
}
impl<F: IntegrateFloat> BoundaryExchanger<F> {
pub fn new(timeout: Duration) -> Self {
Self {
received_boundaries: RwLock::new(HashMap::new()),
pending_requests: Mutex::new(HashMap::new()),
timeout,
}
}
pub fn request_boundary(
&self,
target_chunk: ChunkId,
source_chunk: ChunkId,
) -> DistributedResult<()> {
match self.pending_requests.lock() {
Ok(mut pending) => {
pending.insert((target_chunk, source_chunk), Instant::now());
Ok(())
}
Err(_) => Err(DistributedError::CommunicationError(
"Failed to register boundary request".to_string(),
)),
}
}
pub fn receive_boundary(
&self,
target_chunk: ChunkId,
source_chunk: ChunkId,
data: BoundaryData<F>,
) -> DistributedResult<()> {
match self.received_boundaries.write() {
Ok(mut boundaries) => {
boundaries.insert((target_chunk, source_chunk), data);
if let Ok(mut pending) = self.pending_requests.lock() {
pending.remove(&(target_chunk, source_chunk));
}
Ok(())
}
Err(_) => Err(DistributedError::CommunicationError(
"Failed to store boundary data".to_string(),
)),
}
}
pub fn get_boundary(
&self,
target_chunk: ChunkId,
source_chunk: ChunkId,
) -> Option<BoundaryData<F>> {
match self.received_boundaries.read() {
Ok(boundaries) => boundaries.get(&(target_chunk, source_chunk)).cloned(),
Err(_) => None,
}
}
pub fn build_boundary_conditions(
&self,
chunk_id: ChunkId,
left_neighbor: Option<ChunkId>,
right_neighbor: Option<ChunkId>,
) -> BoundaryConditions<F> {
let mut bc = BoundaryConditions::default();
if let Some(left_id) = left_neighbor {
bc.left_boundary = self.get_boundary(chunk_id, left_id);
}
if let Some(right_id) = right_neighbor {
bc.right_boundary = self.get_boundary(chunk_id, right_id);
}
bc
}
pub fn check_timeouts(&self) -> Vec<(ChunkId, ChunkId)> {
match self.pending_requests.lock() {
Ok(mut pending) => {
let now = Instant::now();
let timed_out: Vec<_> = pending
.iter()
.filter(|(_, sent_at)| now.duration_since(**sent_at) > self.timeout)
.map(|(key, _)| *key)
.collect();
for key in &timed_out {
pending.remove(key);
}
timed_out
}
Err(_) => Vec::new(),
}
}
pub fn clear(&self) {
if let Ok(mut boundaries) = self.received_boundaries.write() {
boundaries.clear();
}
if let Ok(mut pending) = self.pending_requests.lock() {
pending.clear();
}
}
}
pub struct SyncBarrier {
barrier_id: AtomicU64,
expected_count: usize,
state: Mutex<BarrierState>,
cv: Condvar,
}
struct BarrierState {
current_id: u64,
arrived: Vec<NodeId>,
released: bool,
}
impl SyncBarrier {
pub fn new(expected_count: usize) -> Self {
Self {
barrier_id: AtomicU64::new(1),
expected_count,
state: Mutex::new(BarrierState {
current_id: 1,
arrived: Vec::new(),
released: false,
}),
cv: Condvar::new(),
}
}
pub fn new_barrier(&self) -> u64 {
let new_id = self.barrier_id.fetch_add(1, Ordering::SeqCst);
if let Ok(mut state) = self.state.lock() {
state.current_id = new_id;
state.arrived.clear();
state.released = false;
}
new_id
}
pub fn arrive(&self, barrier_id: u64, node_id: NodeId) -> DistributedResult<()> {
let mut state = self
.state
.lock()
.map_err(|_| DistributedError::SyncError("Failed to acquire barrier lock".into()))?;
if state.current_id != barrier_id {
return Err(DistributedError::SyncError(format!(
"Barrier ID mismatch: expected {}, got {}",
state.current_id, barrier_id
)));
}
if !state.arrived.contains(&node_id) {
state.arrived.push(node_id);
}
if state.arrived.len() >= self.expected_count {
state.released = true;
self.cv.notify_all();
}
Ok(())
}
pub fn wait(&self, barrier_id: u64, timeout: Duration) -> DistributedResult<()> {
let deadline = Instant::now() + timeout;
let mut state = self
.state
.lock()
.map_err(|_| DistributedError::SyncError("Failed to acquire barrier lock".into()))?;
while !state.released && state.current_id == barrier_id {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return Err(DistributedError::SyncError(
"Barrier wait timeout".to_string(),
));
}
let (new_state, result) = self.cv.wait_timeout(state, remaining).map_err(|_| {
DistributedError::SyncError("Failed to wait on barrier".to_string())
})?;
state = new_state;
if result.timed_out() && !state.released {
return Err(DistributedError::SyncError(
"Barrier wait timeout".to_string(),
));
}
}
Ok(())
}
pub fn is_complete(&self, barrier_id: u64) -> bool {
match self.state.lock() {
Ok(state) => state.current_id == barrier_id && state.released,
Err(_) => false,
}
}
pub fn arrived_count(&self) -> usize {
match self.state.lock() {
Ok(state) => state.arrived.len(),
Err(_) => 0,
}
}
}
pub struct Communicator<F: IntegrateFloat> {
local_node_id: NodeId,
channel: Arc<MessageChannel<F>>,
boundary_exchanger: Arc<BoundaryExchanger<F>>,
barriers: RwLock<HashMap<u64, Arc<SyncBarrier>>>,
peers: RwLock<Vec<NodeId>>,
}
impl<F: IntegrateFloat> Communicator<F> {
pub fn new(
local_node_id: NodeId,
channel: Arc<MessageChannel<F>>,
boundary_exchanger: Arc<BoundaryExchanger<F>>,
) -> Self {
Self {
local_node_id,
channel,
boundary_exchanger,
barriers: RwLock::new(HashMap::new()),
peers: RwLock::new(Vec::new()),
}
}
pub fn local_id(&self) -> NodeId {
self.local_node_id
}
pub fn add_peer(&self, node_id: NodeId) -> DistributedResult<()> {
match self.peers.write() {
Ok(mut peers) => {
if !peers.contains(&node_id) {
peers.push(node_id);
}
Ok(())
}
Err(_) => Err(DistributedError::CommunicationError(
"Failed to add peer".to_string(),
)),
}
}
pub fn remove_peer(&self, node_id: NodeId) -> DistributedResult<()> {
match self.peers.write() {
Ok(mut peers) => {
peers.retain(|&id| id != node_id);
Ok(())
}
Err(_) => Err(DistributedError::CommunicationError(
"Failed to remove peer".to_string(),
)),
}
}
pub fn get_peers(&self) -> Vec<NodeId> {
match self.peers.read() {
Ok(peers) => peers.clone(),
Err(_) => Vec::new(),
}
}
pub fn send_work(
&self,
target: NodeId,
chunk: WorkChunk<F>,
deadline: Option<Duration>,
) -> DistributedResult<u64> {
let message = DistributedMessage::WorkAssignment { chunk, deadline };
self.channel.send(target, message)
}
pub fn send_result(&self, target: NodeId, result: ChunkResult<F>) -> DistributedResult<u64> {
let message = DistributedMessage::WorkResult { result };
self.channel.send(target, message)
}
pub fn send_boundary(
&self,
target: NodeId,
source_chunk: ChunkId,
target_chunk: ChunkId,
boundary_data: BoundaryData<F>,
) -> DistributedResult<u64> {
let message = DistributedMessage::BoundaryExchange {
source_chunk,
target_chunk,
boundary_data,
};
self.channel.send(target, message)
}
pub fn broadcast(&self, message: DistributedMessage<F>) -> DistributedResult<Vec<u64>> {
let peers = self.get_peers();
let mut message_ids = Vec::with_capacity(peers.len());
for peer in peers {
let id = self.channel.send(peer, message.clone())?;
message_ids.push(id);
}
Ok(message_ids)
}
pub fn create_barrier(&self, expected_count: usize) -> DistributedResult<u64> {
let barrier = Arc::new(SyncBarrier::new(expected_count));
let barrier_id = barrier.new_barrier();
match self.barriers.write() {
Ok(mut barriers) => {
barriers.insert(barrier_id, barrier);
Ok(barrier_id)
}
Err(_) => Err(DistributedError::SyncError(
"Failed to create barrier".to_string(),
)),
}
}
pub fn barrier(&self, barrier_id: u64, timeout: Duration) -> DistributedResult<()> {
let barrier = {
match self.barriers.read() {
Ok(barriers) => barriers.get(&barrier_id).cloned(),
Err(_) => None,
}
};
let barrier = barrier.ok_or_else(|| {
DistributedError::SyncError(format!("Barrier {} not found", barrier_id))
})?;
barrier.arrive(barrier_id, self.local_node_id)?;
let message = DistributedMessage::SyncBarrier {
barrier_id,
node_id: self.local_node_id,
};
let _ = self.broadcast(message);
barrier.wait(barrier_id, timeout)
}
pub fn process_barrier_message(
&self,
barrier_id: u64,
node_id: NodeId,
) -> DistributedResult<()> {
match self.barriers.read() {
Ok(barriers) => {
if let Some(barrier) = barriers.get(&barrier_id) {
barrier.arrive(barrier_id, node_id)?;
}
Ok(())
}
Err(_) => Err(DistributedError::SyncError(
"Failed to process barrier message".to_string(),
)),
}
}
pub fn receive_boundary(
&self,
target_chunk: ChunkId,
source_chunk: ChunkId,
data: BoundaryData<F>,
) -> DistributedResult<()> {
self.boundary_exchanger
.receive_boundary(target_chunk, source_chunk, data)
}
pub fn get_boundary_conditions(
&self,
chunk_id: ChunkId,
left_neighbor: Option<ChunkId>,
right_neighbor: Option<ChunkId>,
) -> BoundaryConditions<F> {
self.boundary_exchanger
.build_boundary_conditions(chunk_id, left_neighbor, right_neighbor)
}
}
pub fn serialize_boundary_data<F: IntegrateFloat>(data: &BoundaryData<F>) -> Vec<u8> {
let mut bytes = Vec::new();
let time_f64 = data.time.to_f64().unwrap_or(0.0);
bytes.extend_from_slice(&time_f64.to_le_bytes());
let state_len = data.state.len() as u64;
bytes.extend_from_slice(&state_len.to_le_bytes());
for val in data.state.iter() {
let val_f64 = val.to_f64().unwrap_or(0.0);
bytes.extend_from_slice(&val_f64.to_le_bytes());
}
bytes.extend_from_slice(&data.source_chunk.0.to_le_bytes());
bytes
}
pub fn deserialize_boundary_data<F: IntegrateFloat>(
bytes: &[u8],
) -> DistributedResult<BoundaryData<F>> {
if bytes.len() < 16 {
return Err(DistributedError::CommunicationError(
"Insufficient data for boundary deserialization".to_string(),
));
}
let mut offset = 0;
let time_bytes: [u8; 8] = bytes[offset..offset + 8]
.try_into()
.map_err(|_| DistributedError::CommunicationError("Invalid time bytes".to_string()))?;
let time_f64 = f64::from_le_bytes(time_bytes);
let time = F::from(time_f64).ok_or_else(|| {
DistributedError::CommunicationError("Failed to convert time".to_string())
})?;
offset += 8;
let len_bytes: [u8; 8] = bytes[offset..offset + 8]
.try_into()
.map_err(|_| DistributedError::CommunicationError("Invalid length bytes".to_string()))?;
let state_len = u64::from_le_bytes(len_bytes) as usize;
offset += 8;
if bytes.len() < offset + state_len * 8 + 8 {
return Err(DistributedError::CommunicationError(
"Insufficient data for state values".to_string(),
));
}
let mut state = Array1::zeros(state_len);
for i in 0..state_len {
let val_bytes: [u8; 8] = bytes[offset..offset + 8]
.try_into()
.map_err(|_| DistributedError::CommunicationError("Invalid value bytes".to_string()))?;
let val_f64 = f64::from_le_bytes(val_bytes);
state[i] = F::from(val_f64).ok_or_else(|| {
DistributedError::CommunicationError("Failed to convert value".to_string())
})?;
offset += 8;
}
let chunk_bytes: [u8; 8] = bytes[offset..offset + 8]
.try_into()
.map_err(|_| DistributedError::CommunicationError("Invalid chunk ID bytes".to_string()))?;
let source_chunk = ChunkId(u64::from_le_bytes(chunk_bytes));
Ok(BoundaryData {
time,
state,
derivative: None,
source_chunk,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_channel() {
let channel: MessageChannel<f64> = MessageChannel::new(Duration::from_secs(5));
let node_id = NodeId::new(1);
let message = DistributedMessage::Heartbeat {
node_id,
status: NodeStatus::Available,
timestamp: 12345,
};
let msg_id = channel.send(node_id, message.clone());
assert!(msg_id.is_ok());
channel.deliver(node_id, message).expect("Delivery failed");
let received = channel.try_receive();
assert!(received.is_some());
}
#[test]
fn test_boundary_exchanger() {
let exchanger: BoundaryExchanger<f64> = BoundaryExchanger::new(Duration::from_secs(5));
let target = ChunkId::new(1);
let source = ChunkId::new(0);
let data = BoundaryData {
time: 1.0,
state: Array1::from_vec(vec![1.0, 2.0, 3.0]),
derivative: None,
source_chunk: source,
};
exchanger
.receive_boundary(target, source, data)
.expect("Failed to receive boundary");
let retrieved = exchanger.get_boundary(target, source);
assert!(retrieved.is_some());
assert_eq!(retrieved.map(|b| b.time), Some(1.0));
}
#[test]
fn test_sync_barrier() {
let barrier = SyncBarrier::new(2);
let barrier_id = barrier.new_barrier();
barrier
.arrive(barrier_id, NodeId::new(1))
.expect("Failed to arrive");
assert!(!barrier.is_complete(barrier_id));
barrier
.arrive(barrier_id, NodeId::new(2))
.expect("Failed to arrive");
assert!(barrier.is_complete(barrier_id));
}
#[test]
fn test_boundary_serialization() {
let data = BoundaryData {
time: 1.5,
state: Array1::from_vec(vec![1.0, 2.0, 3.0]),
derivative: None,
source_chunk: ChunkId::new(42),
};
let bytes = serialize_boundary_data(&data);
let deserialized: BoundaryData<f64> =
deserialize_boundary_data(&bytes).expect("Deserialization failed");
assert!((deserialized.time - data.time).abs() < 1e-10);
assert_eq!(deserialized.state.len(), data.state.len());
assert_eq!(deserialized.source_chunk.0, data.source_chunk.0);
}
}