use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CommunicationBackend {
InMemory,
TCP,
MPI,
RDMA,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TcpMessageHeader {
source: usize,
destination: usize,
tag: u32,
datasize: usize,
sequence: u64,
timestamp: u64,
checksum: u64,
}
struct TcpConnectionPool {
connections: HashMap<usize, std::net::TcpStream>,
listener: Option<std::net::TcpListener>,
}
impl TcpConnectionPool {
fn new() -> Self {
Self {
connections: HashMap::new(),
listener: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MessageTag {
Data = 0,
Control = 1,
Sync = 2,
Status = 3,
MatMul = 4,
Decomp = 5,
Solve = 6,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageMetadata {
pub source: usize,
pub destination: usize,
pub tag: MessageTag,
pub size_bytes: usize,
pub sequence: u64,
pub timestamp: u64,
pub compressed: bool,
}
#[derive(Debug, Clone)]
pub struct Message<T> {
pub metadata: MessageMetadata,
pub data: T,
}
impl<T> Message<T> {
pub fn new(
source: usize,
destination: usize,
tag: MessageTag,
data: T,
sequence: u64,
) -> Self {
let metadata = MessageMetadata {
source,
destination,
tag,
size_bytes: std::mem::size_of::<T>(),
sequence,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64,
compressed: false,
};
Self { metadata, data }
}
}
pub struct DistributedCommunicator {
pub rank: usize,
pub size: usize,
backend: CommunicationBackend,
sequence_counter: Arc<Mutex<u64>>,
message_buffer: Arc<Mutex<HashMap<(usize, MessageTag), Vec<u8>>>>,
stats: Arc<Mutex<CommunicationStats>>,
tcp_pool: Arc<Mutex<TcpConnectionPool>>,
node_addresses: HashMap<usize, String>,
}
impl DistributedCommunicator {
pub fn new(config: &super::DistributedConfig) -> LinalgResult<Self> {
let mut node_addresses = HashMap::new();
for rank in 0.._config.num_nodes {
let port = 7000 + rank; node_addresses.insert(rank, format!("127.0.0.1:{}", port));
}
Ok(Self {
rank: config.node_rank,
size: config.num_nodes,
backend: config.backend,
sequence_counter: Arc::new(Mutex::new(0)),
message_buffer: Arc::new(Mutex::new(HashMap::new())),
stats: Arc::new(Mutex::new(CommunicationStats::default())),
tcp_pool: Arc::new(Mutex::new(TcpConnectionPool::new())),
node_addresses,
})
}
pub fn sendmatrix<T>(&self, matrix: &ArrayView2<T>, dest: usize, tag: MessageTag) -> LinalgResult<()>
where
T: Clone + Send + Sync + Serialize,
{
let start_time = Instant::now();
let serialized = self.serializematrix(matrix)?;
let sequence = self.next_sequence();
let message = Message::new(self.rank, dest, tag, serialized, sequence);
match self.backend {
CommunicationBackend::InMemory => {
self.send_in_memory(message)?;
},
CommunicationBackend::TCP => {
self.send_tcp(message)?;
},
CommunicationBackend::MPI => {
self.send_mpi(message)?;
},
CommunicationBackend::RDMA => {
self.send_rdma(message)?;
},
CommunicationBackend::Custom => {
return Err(LinalgError::NotImplemented("Custom backend not implemented".to_string()));
},
}
let elapsed = start_time.elapsed();
self.update_send_stats(serialized.len(), elapsed);
Ok(())
}
pub fn recvmatrix<T>(&self, source: usize, tag: MessageTag) -> LinalgResult<Array2<T>>
where
T: Clone + Send + Sync + for<'de> Deserialize<'de>,
{
let start_time = Instant::now();
let message = match self.backend {
CommunicationBackend::InMemory => {
self.recv_in_memory(source, tag)?
},
CommunicationBackend::TCP => {
self.recv_tcp(source, tag)?
},
CommunicationBackend::MPI => {
self.recv_mpi(source, tag)?
},
CommunicationBackend::RDMA => {
self.recv_rdma(source, tag)?
},
CommunicationBackend::Custom => {
return Err(LinalgError::NotImplemented("Custom backend not implemented".to_string()));
},
};
let matrix = self.deserializematrix(&message.data)?;
let elapsed = start_time.elapsed();
self.update_recv_stats(message.data.len(), elapsed);
Ok(matrix)
}
pub fn broadcastmatrix<T>(&self, matrix: &ArrayView2<T>) -> LinalgResult<()>
where
T: Clone + Send + Sync + Serialize,
{
if self.rank == 0 {
for dest in 1..self.size {
self.sendmatrix(matrix, dest, MessageTag::Data)?;
}
}
Ok(())
}
pub fn gather_matrices<T>(&self, localmatrix: &ArrayView2<T>) -> LinalgResult<Option<Vec<Array2<T>>>>
where
T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de>,
{
if self.rank == 0 {
let mut matrices = Vec::with_capacity(self.size);
matrices.push(localmatrix.to_owned());
for source in 1..self.size {
let matrix = self.recvmatrix(source, MessageTag::Data)?;
matrices.push(matrix);
}
Ok(Some(matrices))
} else {
self.sendmatrix(localmatrix, 0, MessageTag::Data)?;
Ok(None)
}
}
pub fn allreduce_sum<T>(&self, localmatrix: &ArrayView2<T>) -> LinalgResult<Array2<T>>
where
T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + scirs2_core::numeric::Zero + std::ops::Add<Output = T>,
{
if let Some(matrices) = self.gather_matrices(localmatrix)? {
let mut result = matrices[0].clone();
for matrix in matrices.iter().skip(1) {
result = &result + matrix;
}
self.broadcastmatrix(&result.view())?;
Ok(result)
} else {
self.recvmatrix(0, MessageTag::Data)
}
}
pub fn scattermatrix<T>(&self, matrix: Option<&ArrayView2<T>>) -> LinalgResult<Array2<T>>
where
T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de>,
{
if self.rank == 0 {
let matrix = matrix.ok_or_else(|| {
LinalgError::InvalidInput("Root node must provide matrix for scatter".to_string())
})?;
let (rows, cols) = matrix.dim();
let rows_per_node = rows / self.size;
let remainder = rows % self.size;
let mut start_row = 0;
for dest in 1..self.size {
let chunk_rows = if dest <= remainder { rows_per_node + 1 } else { rows_per_node };
let end_row = start_row + chunk_rows;
let chunk = matrix.slice(scirs2_core::ndarray::s![start_row..end_row, ..]);
self.sendmatrix(&chunk, dest, MessageTag::Data)?;
start_row = end_row;
}
let root_rows = if 0 < remainder { rows_per_node + 1 } else { rows_per_node };
Ok(matrix.slice(scirs2_core::ndarray::s![..root_rows, ..]).to_owned())
} else {
self.recvmatrix(0, MessageTag::Data)
}
}
pub fn barrier(&self) -> LinalgResult<()> {
if self.rank == 0 {
for source in 1..self.size {
let _: Array2<f64> = self.recvmatrix(source, MessageTag::Sync)?;
}
let dummy = Array2::<f64>::zeros((1, 1));
for dest in 1..self.size {
self.sendmatrix(&dummy.view(), dest, MessageTag::Sync)?;
}
} else {
let dummy = Array2::<f64>::zeros((1, 1));
self.sendmatrix(&dummy.view(), 0, MessageTag::Sync)?;
let _: Array2<f64> = self.recvmatrix(0, MessageTag::Sync)?;
}
Ok(())
}
pub fn finalize(&self) -> LinalgResult<()> {
match self.backend {
CommunicationBackend::InMemory => {
self.message_buffer.lock().expect("Operation failed").clear();
},
CommunicationBackend::MPI => {
}_ => {
},
}
Ok(())
}
pub fn get_stats(&self) -> CommunicationStats {
self.stats.lock().expect("Operation failed").clone()
}
fn next_sequence(&self) -> u64 {
let mut counter = self.sequence_counter.lock().expect("Operation failed");
*counter += 1;
*counter
}
fn serializematrix<T>(&self, matrix: &ArrayView2<T>) -> LinalgResult<Vec<u8>>
where
T: Serialize,
{
let cfg = oxicode::config::standard();
oxicode::serde::encode_to_vec(matrix, cfg).map_err(|e| {
LinalgError::SerializationError(format!("Failed to serialize matrix (oxicode): {}", e))
})
}
fn deserializematrix<T>(&self, data: &[u8]) -> LinalgResult<Array2<T>>
where
T: for<'de> Deserialize<'de>,
{
let cfg = oxicode::config::standard();
oxicode::serde::decode_owned_from_slice(data, cfg)
.map(|(m, _len)| m)
.map_err(|e| {
LinalgError::SerializationError(format!("Failed to deserialize matrix (oxicode): {}", e))
})
}
fn send_in_memory(&self, message: Message<Vec<u8>>) -> LinalgResult<()> {
let mut buffer = self.message_buffer.lock().expect("Operation failed");
let key = (message.metadata.source, message.metadata.tag);
buffer.insert(key, message.data);
Ok(())
}
fn recv_in_memory(&self, source: usize, tag: MessageTag) -> LinalgResult<Message<Vec<u8>>> {
let mut buffer = self.message_buffer.lock().expect("Operation failed");
let key = (source, tag);
if let Some(data) = buffer.remove(&key) {
let metadata = MessageMetadata {
source,
destination: self.rank,
tag,
size_bytes: data.len(),
sequence: 0,
timestamp: 0,
compressed: false,
};
Ok(Message { metadata, data })
} else {
Err(LinalgError::CommunicationError(format!(
"No message available from {} with tag {:?}",
source, tag
)))
}
}
fn send_tcp(&self, message: Message<Vec<u8>>) -> LinalgResult<()> {
use std::net::TcpStream;
use std::io::Write;
let dest_address = self.get_node_address(message.metadata.destination)?;
let mut stream = TcpStream::connect(&dest_address)
.map_err(|e| LinalgError::CommunicationError(
format!("Failed to connect to {}: {}", dest_address, e)
))?;
let header = TcpMessageHeader {
source: message.metadata.source,
destination: message.metadata.destination,
tag: message.metadata.tag as u32,
datasize: message.data.len(),
sequence: message.metadata.sequence,
timestamp: message.metadata.timestamp,
checksum: self.calculate_checksum(&message.data),
};
let cfg = oxicode::config::standard();
let header_bytes = oxicode::serde::encode_to_vec(&header, cfg)
.map_err(|e| LinalgError::SerializationError(format!("Header serialization failed (oxicode): {}", e)))?;
let headersize = header_bytes.len() as u32;
stream.write_all(&headersize.to_be_bytes())
.map_err(|e| LinalgError::CommunicationError(format!("Failed to send header size: {}", e)))?;
stream.write_all(&header_bytes)
.map_err(|e| LinalgError::CommunicationError(format!("Failed to send header: {}", e)))?;
const CHUNK_SIZE: usize = 64 * 1024; for chunk in message.data.chunks(CHUNK_SIZE) {
stream.write_all(chunk)
.map_err(|e| LinalgError::CommunicationError(format!("Failed to send data chunk: {}", e)))?;
}
stream.flush()
.map_err(|e| LinalgError::CommunicationError(format!("Failed to flush TCP stream: {}", e)))?;
Ok(())
}
fn recv_tcp(&self, source: usize, tag: MessageTag) -> LinalgResult<Message<Vec<u8>>> {
use std::net::{TcpListener, TcpStream};
use std::io::Read;
let listen_address = self.get_node_address(self.rank)?;
let listener = TcpListener::bind(&listen_address)
.map_err(|e| LinalgError::CommunicationError(
format!("Failed to bind to {}: {}", listen_address, e)
))?;
let (mut stream, remote_addr) = listener.accept()
.map_err(|e| LinalgError::CommunicationError(format!("Failed to accept connection: {}", e)))?;
let mut headersize_bytes = [0u8; 4];
stream.read_exact(&mut headersize_bytes)
.map_err(|e| LinalgError::CommunicationError(format!("Failed to read header size: {}", e)))?;
let headersize = u32::from_be_bytes(headersize_bytes) as usize;
let mut header_bytes = vec![0u8; headersize];
stream.read_exact(&mut header_bytes)
.map_err(|e| LinalgError::CommunicationError(format!("Failed to read header: {}", e)))?;
let cfg = oxicode::config::standard();
let (header, _len): (TcpMessageHeader, usize) = oxicode::serde::decode_owned_from_slice(&header_bytes, cfg)
.map_err(|e| LinalgError::SerializationError(format!("Header deserialization failed (oxicode): {}", e)))?;
if header.source != source {
return Err(LinalgError::CommunicationError(format!(
"Expected message from node {}, got from node {}",
source, header.source
)));
}
if header.tag != tag as u32 {
return Err(LinalgError::CommunicationError(format!(
"Expected message with tag {:?}, got tag {}",
tag, header.tag
)));
}
let mut data = vec![0u8; header.datasize];
let mut bytes_read = 0;
while bytes_read < header.datasize {
let chunksize = std::cmp::min(header.datasize - bytes_read, 64 * 1024);
let mut chunk = vec![0u8; chunksize];
stream.read_exact(&mut chunk)
.map_err(|e| LinalgError::CommunicationError(format!("Failed to read data chunk: {}", e)))?;
data[bytes_read..bytes_read + chunksize].copy_from_slice(&chunk);
bytes_read += chunksize;
}
let received_checksum = self.calculate_checksum(&data);
if received_checksum != header.checksum {
return Err(LinalgError::CommunicationError(format!(
"Checksum mismatch: expected {}, got {}",
header.checksum, received_checksum
)));
}
let metadata = MessageMetadata {
source: header.source,
destination: header.destination,
tag,
size_bytes: data.len(),
sequence: header.sequence,
timestamp: header.timestamp,
compressed: false,
};
Ok(Message { metadata, data })
}
fn send_mpi(&selfmessage: Message<Vec<u8>>) -> LinalgResult<()> {
Err(LinalgError::NotImplemented("MPI backend not implemented".to_string()))
}
fn recv_mpi(&self_source: usize, tag: MessageTag) -> LinalgResult<Message<Vec<u8>>> {
Err(LinalgError::NotImplemented("MPI backend not implemented".to_string()))
}
fn send_rdma(&selfmessage: Message<Vec<u8>>) -> LinalgResult<()> {
Err(LinalgError::NotImplemented("RDMA backend not implemented".to_string()))
}
fn recv_rdma(&self_source: usize, tag: MessageTag) -> LinalgResult<Message<Vec<u8>>> {
Err(LinalgError::NotImplemented("RDMA backend not implemented".to_string()))
}
fn update_send_stats(&self, bytes: usize, duration: Duration) {
let mut stats = self.stats.lock().expect("Operation failed");
stats.messages_sent += 1;
stats.bytes_sent += bytes;
stats.total_send_time += duration;
}
fn update_recv_stats(&self, bytes: usize, duration: Duration) {
let mut stats = self.stats.lock().expect("Operation failed");
stats.messages_received += 1;
stats.bytes_received += bytes;
stats.total_recv_time += duration;
}
fn get_node_address(&self, rank: usize) -> LinalgResult<String> {
self.node_addresses.get(&rank)
.cloned()
.ok_or_else(|| LinalgError::CommunicationError(
format!("No address found for node rank {}", rank)
))
}
fn calculate_checksum(&self, data: &[u8]) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
data.hash(&mut hasher);
hasher.finish()
}
fn compress_data(&self, data: &[u8]) -> LinalgResult<Vec<u8>> {
Ok(data.to_vec())
}
fn decompress_data(&self, compresseddata: &[u8]) -> LinalgResult<Vec<u8>> {
Ok(compressed_data.to_vec())
}
}
#[derive(Debug, Clone, Default)]
pub struct CommunicationStats {
pub messages_sent: usize,
pub messages_received: usize,
pub bytes_sent: usize,
pub bytes_received: usize,
pub total_send_time: Duration,
pub total_recv_time: Duration,
}
impl CommunicationStats {
pub fn avg_send_bandwidth(&self) -> f64 {
if self.total_send_time.as_secs_f64() > 0.0 {
self.bytes_sent as f64 / self.total_send_time.as_secs_f64()
} else {
0.0
}
}
pub fn avg_recv_bandwidth(&self) -> f64 {
if self.total_recv_time.as_secs_f64() > 0.0 {
self.bytes_received as f64 / self.total_recv_time.as_secs_f64()
} else {
0.0
}
}
pub fn efficiency_ratio(&self) -> f64 {
if self.messages_sent + self.messages_received > 0 {
let total_messages = self.messages_sent + self.messages_received;
let avg_time = (self.total_send_time + self.total_recv_time).as_secs_f64() / total_messages as f64;
1.0 / (1.0 + avg_time) } else {
1.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_message_creation() {
let data = vec![1, 2, 3, 4];
let message = Message::new(0, 1, MessageTag::Data, data.clone(), 1);
assert_eq!(message.metadata.source, 0);
assert_eq!(message.metadata.destination, 1);
assert_eq!(message.metadata.tag, MessageTag::Data);
assert_eq!(message.data, data);
}
#[test]
fn test_communication_stats() {
let mut stats = CommunicationStats::default();
stats.messages_sent = 10;
stats.bytes_sent = 1024;
stats.total_send_time = Duration::from_millis(100);
let bandwidth = stats.avg_send_bandwidth();
assert!(bandwidth > 0.0);
let efficiency = stats.efficiency_ratio();
assert!(efficiency > 0.0 && efficiency <= 1.0);
}
#[test]
fn test_in_memory_communication() {
use super::super::DistributedConfig;
let config = DistributedConfig::default()
.with_backend(CommunicationBackend::InMemory);
let comm = DistributedCommunicator::new(&config).expect("Operation failed");
let matrix = Array2::from_shape_fn((3, 3), |(i, j)| (i + j) as f64);
let serialized = comm.serializematrix(&matrix.view()).expect("Operation failed");
let deserialized: Array2<f64> = comm.deserializematrix(&serialized).expect("Operation failed");
assert_eq!(matrix, deserialized);
}
}