use super::comm::{CommunicationChannel, ConnectionManager, Message};
use super::process::{Communicator, ProcessError};
use crate::error::NumRs2Error;
use oxicode::{Decode, Encode};
use scirs2_core::ndarray::Array1;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, VecDeque};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{mpsc, Mutex, RwLock};
#[derive(Error, Debug)]
pub enum CommunicationError {
#[error("Process error: {0}")]
Process(#[from] ProcessError),
#[error("Serialization error: {0}")]
Serialization(String),
#[error("Deserialization error: {0}")]
Deserialization(String),
#[error("Compression error: {0}")]
Compression(String),
#[error("Decompression error: {0}")]
Decompression(String),
#[error("Channel error: {0}")]
Channel(String),
#[error("Timeout: operation exceeded {0}ms")]
Timeout(u64),
#[error("Invalid rank {rank}, communicator size is {size}")]
InvalidRank { rank: usize, size: usize },
#[error("Message size mismatch: expected {expected}, got {actual}")]
SizeMismatch { expected: usize, actual: usize },
#[error("Network error: {0}")]
Network(String),
}
impl From<CommunicationError> for NumRs2Error {
fn from(err: CommunicationError) -> Self {
NumRs2Error::DistributedComputing(err.to_string())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Encode, Decode)]
pub enum MessagePriority {
Low = 0,
Normal = 1,
High = 2,
Urgent = 3,
}
#[derive(Debug, Clone, Encode, Decode)]
pub enum CompressionStrategy {
None,
TopK { k: usize },
RandomK { k: usize },
Quantization { bits: u8 },
Threshold { threshold: f64 },
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct TensorMessage<T>
where
T: Clone + Encode + Decode,
{
pub data: Vec<T>,
pub shape: Vec<usize>,
pub compression: CompressionStrategy,
pub priority: MessagePriority,
pub sequence: u64,
pub sender: usize,
pub tag: u32,
pub indices: Option<Vec<usize>>,
}
impl<T> TensorMessage<T>
where
T: Clone + Encode + Decode,
{
pub fn new(data: Vec<T>, compression: CompressionStrategy, priority: MessagePriority) -> Self {
Self {
shape: vec![data.len()],
data,
compression,
priority,
sequence: 0,
sender: 0,
tag: 0,
indices: None,
}
}
pub fn with_shape(
data: Vec<T>,
shape: Vec<usize>,
compression: CompressionStrategy,
priority: MessagePriority,
) -> Self {
Self {
data,
shape,
compression,
priority,
sequence: 0,
sender: 0,
tag: 0,
indices: None,
}
}
pub fn with_sequence(mut self, sequence: u64) -> Self {
self.sequence = sequence;
self
}
pub fn with_sender(mut self, sender: usize) -> Self {
self.sender = sender;
self
}
pub fn with_tag(mut self, tag: u32) -> Self {
self.tag = tag;
self
}
pub fn size_bytes(&self) -> usize {
self.data.len() * std::mem::size_of::<T>()
}
}
#[derive(Debug)]
struct PrioritizedMessage<T> {
message: T,
priority: MessagePriority,
sequence: u64,
}
impl<T> PartialEq for PrioritizedMessage<T> {
fn eq(&self, other: &Self) -> bool {
self.priority == other.priority && self.sequence == other.sequence
}
}
impl<T> Eq for PrioritizedMessage<T> {}
impl<T> PartialOrd for PrioritizedMessage<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T> Ord for PrioritizedMessage<T> {
fn cmp(&self, other: &Self) -> Ordering {
match self.priority.cmp(&other.priority) {
Ordering::Equal => other.sequence.cmp(&self.sequence), ord => ord,
}
}
}
type SendQueues = Arc<RwLock<HashMap<usize, BinaryHeap<PrioritizedMessage<Vec<u8>>>>>>;
pub struct AsyncCommunicator {
communicator: Arc<Communicator>,
send_queues: SendQueues,
recv_buffers: Arc<RwLock<HashMap<usize, VecDeque<Vec<u8>>>>>,
sequence_counter: Arc<Mutex<u64>>,
channels: Arc<Mutex<HashMap<usize, mpsc::Sender<Vec<u8>>>>>,
}
impl AsyncCommunicator {
pub fn new(communicator: Arc<Communicator>) -> Result<Self, CommunicationError> {
Ok(Self {
communicator,
send_queues: Arc::new(RwLock::new(HashMap::new())),
recv_buffers: Arc::new(RwLock::new(HashMap::new())),
sequence_counter: Arc::new(Mutex::new(0)),
channels: Arc::new(Mutex::new(HashMap::new())),
})
}
async fn next_sequence(&self) -> u64 {
let mut counter = self.sequence_counter.lock().await;
let seq = *counter;
*counter += 1;
seq
}
pub async fn isend<T>(
&self,
message: TensorMessage<T>,
dest: usize,
) -> Result<(), CommunicationError>
where
T: Clone + Encode + Decode,
{
if dest >= self.communicator.size() {
return Err(CommunicationError::InvalidRank {
rank: dest,
size: self.communicator.size(),
});
}
let data = oxicode::encode_to_vec(&message).map_err(|e| {
CommunicationError::Serialization(format!("Failed to serialize message: {}", e))
})?;
let prioritized = PrioritizedMessage {
message: data,
priority: message.priority,
sequence: message.sequence,
};
let mut queues = self.send_queues.write().await;
queues
.entry(dest)
.or_insert_with(BinaryHeap::new)
.push(prioritized);
Ok(())
}
pub async fn irecv<T>(&self, source: usize) -> Result<TensorMessage<T>, CommunicationError>
where
T: Clone + Encode + Decode,
{
if source >= self.communicator.size() {
return Err(CommunicationError::InvalidRank {
rank: source,
size: self.communicator.size(),
});
}
let mut buffers = self.recv_buffers.write().await;
let buffer = buffers.entry(source).or_insert_with(VecDeque::new);
let data = buffer
.pop_front()
.ok_or_else(|| CommunicationError::Channel("No data available".to_string()))?;
let (message, _) = oxicode::decode_from_slice(&data).map_err(|e| {
CommunicationError::Deserialization(format!("Failed to deserialize message: {}", e))
})?;
Ok(message)
}
pub async fn send<T>(
&self,
message: TensorMessage<T>,
dest: usize,
) -> Result<(), CommunicationError>
where
T: Clone + Encode + Decode,
{
self.isend(message, dest).await?;
self.flush_send_queue(dest).await?;
Ok(())
}
pub async fn recv<T>(&self, source: usize) -> Result<TensorMessage<T>, CommunicationError>
where
T: Clone + Encode + Decode,
{
self.irecv(source).await
}
async fn flush_send_queue(&self, dest: usize) -> Result<(), CommunicationError> {
let mut queues = self.send_queues.write().await;
if let Some(queue) = queues.get_mut(&dest) {
while let Some(prioritized) = queue.pop() {
let _ = prioritized.message;
}
}
Ok(())
}
pub fn rank(&self) -> usize {
self.communicator.rank()
}
pub fn size(&self) -> usize {
self.communicator.size()
}
}
pub struct PipelinedCommunicator {
base: AsyncCommunicator,
depth: usize,
active_stages: Arc<Mutex<VecDeque<PipelineStage>>>,
}
#[derive(Debug)]
struct PipelineStage {
operation_id: u64,
status: PipelineStatus,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PipelineStatus {
Pending,
InProgress,
Completed,
}
impl PipelinedCommunicator {
pub fn new(communicator: Arc<Communicator>, depth: usize) -> Result<Self, CommunicationError> {
Ok(Self {
base: AsyncCommunicator::new(communicator)?,
depth,
active_stages: Arc::new(Mutex::new(VecDeque::new())),
})
}
pub async fn pipeline_send<T>(
&self,
message: TensorMessage<T>,
dest: usize,
) -> Result<u64, CommunicationError>
where
T: Clone + Encode + Decode,
{
self.wait_for_pipeline_slot().await?;
let op_id = self.base.next_sequence().await;
let mut stages = self.active_stages.lock().await;
stages.push_back(PipelineStage {
operation_id: op_id,
status: PipelineStatus::Pending,
});
self.base.isend(message, dest).await?;
Ok(op_id)
}
pub async fn wait_operation(&self, op_id: u64) -> Result<(), CommunicationError> {
let mut stages = self.active_stages.lock().await;
let pos = stages.iter().position(|s| s.operation_id == op_id);
if let Some(pos) = pos {
stages.remove(pos);
}
Ok(())
}
pub async fn wait_all(&self) -> Result<(), CommunicationError> {
let mut stages = self.active_stages.lock().await;
stages.clear();
Ok(())
}
async fn wait_for_pipeline_slot(&self) -> Result<(), CommunicationError> {
loop {
let mut stages = self.active_stages.lock().await;
if stages.len() < self.depth {
return Ok(());
}
while let Some(stage) = stages.front() {
if stage.status == PipelineStatus::Completed {
stages.pop_front();
} else {
break;
}
}
drop(stages);
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
}
pub fn depth(&self) -> usize {
self.depth
}
pub async fn active_count(&self) -> usize {
self.active_stages.lock().await.len()
}
}
pub fn compress_tensor<T>(
data: &[T],
strategy: &CompressionStrategy,
) -> Result<(Vec<T>, Option<Vec<usize>>), CommunicationError>
where
T: Clone + PartialOrd + Default,
{
match strategy {
CompressionStrategy::None => Ok((data.to_vec(), None)),
CompressionStrategy::TopK { k } => {
if *k >= data.len() {
return Ok((data.to_vec(), None));
}
let mut indexed: Vec<(usize, T)> = data
.iter()
.enumerate()
.map(|(i, v)| (i, v.clone()))
.collect();
if *k < indexed.len() {
indexed.select_nth_unstable_by(*k, |a, b| {
b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)
});
}
let mut values = Vec::with_capacity(*k);
let mut indices = Vec::with_capacity(*k);
for (idx, val) in indexed.iter().take(*k) {
indices.push(*idx);
values.push(val.clone());
}
Ok((values, Some(indices)))
}
CompressionStrategy::RandomK { k } => {
if *k >= data.len() {
return Ok((data.to_vec(), None));
}
let step = data.len() / k;
let values: Vec<T> = (0..*k).map(|i| data[i * step].clone()).collect();
let indices: Vec<usize> = (0..*k).map(|i| i * step).collect();
Ok((values, Some(indices)))
}
CompressionStrategy::Quantization { .. } => {
Ok((data.to_vec(), None))
}
CompressionStrategy::Threshold { .. } => {
Ok((data.to_vec(), None))
}
}
}
pub fn decompress_tensor<T>(
compressed: &[T],
indices: Option<&[usize]>,
original_size: usize,
) -> Result<Vec<T>, CommunicationError>
where
T: Clone + Default,
{
match indices {
None => Ok(compressed.to_vec()),
Some(idx) => {
let mut result = vec![T::default(); original_size];
for (i, &pos) in idx.iter().enumerate() {
if pos < original_size && i < compressed.len() {
result[pos] = compressed[i].clone();
}
}
Ok(result)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_message_creation() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let msg = TensorMessage::new(
data.clone(),
CompressionStrategy::None,
MessagePriority::Normal,
);
assert_eq!(msg.data, data);
assert_eq!(msg.shape, vec![4]);
assert_eq!(msg.priority, MessagePriority::Normal);
}
#[test]
fn test_tensor_message_with_shape() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![2, 3];
let msg = TensorMessage::with_shape(
data.clone(),
shape.clone(),
CompressionStrategy::None,
MessagePriority::High,
);
assert_eq!(msg.data, data);
assert_eq!(msg.shape, shape);
assert_eq!(msg.priority, MessagePriority::High);
}
#[test]
fn test_message_priority_ordering() {
assert!(MessagePriority::Urgent > MessagePriority::High);
assert!(MessagePriority::High > MessagePriority::Normal);
assert!(MessagePriority::Normal > MessagePriority::Low);
}
#[test]
fn test_compression_none() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let result = compress_tensor(&data, &CompressionStrategy::None);
assert!(result.is_ok());
let (compressed, indices) = result.expect("compression failed");
assert_eq!(compressed, data);
assert!(indices.is_none());
}
#[test]
fn test_compression_topk() {
let data = vec![5.0, 1.0, 8.0, 3.0, 9.0, 2.0];
let result = compress_tensor(&data, &CompressionStrategy::TopK { k: 3 });
assert!(result.is_ok());
let (compressed, indices) = result.expect("compression failed");
assert_eq!(compressed.len(), 3);
assert!(indices.is_some());
assert_eq!(indices.as_ref().expect("indices missing").len(), 3);
}
#[test]
fn test_compression_randomk() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let result = compress_tensor(&data, &CompressionStrategy::RandomK { k: 3 });
assert!(result.is_ok());
let (compressed, indices) = result.expect("compression failed");
assert_eq!(compressed.len(), 3);
assert!(indices.is_some());
}
#[test]
fn test_decompress_none() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let result = decompress_tensor(&data, None, data.len());
assert!(result.is_ok());
let decompressed = result.expect("decompression failed");
assert_eq!(decompressed, data);
}
#[test]
fn test_decompress_with_indices() {
let compressed = vec![5.0, 8.0, 9.0];
let indices = vec![0, 2, 4];
let original_size = 6;
let result = decompress_tensor(&compressed, Some(&indices), original_size);
assert!(result.is_ok());
let decompressed = result.expect("decompression failed");
assert_eq!(decompressed.len(), original_size);
assert_eq!(decompressed[0], 5.0);
assert_eq!(decompressed[2], 8.0);
assert_eq!(decompressed[4], 9.0);
}
#[test]
fn test_tensor_message_serialization() {
let data = vec![1.0_f32, 2.0, 3.0];
let msg = TensorMessage::new(data, CompressionStrategy::None, MessagePriority::Normal);
let serialized = oxicode::encode_to_vec(&msg);
assert!(serialized.is_ok());
let bytes = serialized.expect("serialization failed");
let deserialized: Result<(TensorMessage<f32>, usize), _> =
oxicode::decode_from_slice(&bytes);
assert!(deserialized.is_ok());
}
#[test]
fn test_compression_strategy_serialization() {
let strategies = vec![
CompressionStrategy::None,
CompressionStrategy::TopK { k: 10 },
CompressionStrategy::RandomK { k: 5 },
CompressionStrategy::Quantization { bits: 8 },
CompressionStrategy::Threshold { threshold: 0.01 },
];
for strategy in strategies {
let serialized = oxicode::encode_to_vec(&strategy);
assert!(serialized.is_ok());
let bytes = serialized.expect("serialization failed");
let deserialized: Result<(CompressionStrategy, usize), _> =
oxicode::decode_from_slice(&bytes);
assert!(deserialized.is_ok());
}
}
#[test]
fn test_message_priority_serialization() {
let priorities = vec![
MessagePriority::Low,
MessagePriority::Normal,
MessagePriority::High,
MessagePriority::Urgent,
];
for priority in priorities {
let serialized = oxicode::encode_to_vec(&priority);
assert!(serialized.is_ok());
let bytes = serialized.expect("serialization failed");
let deserialized: Result<(MessagePriority, usize), _> =
oxicode::decode_from_slice(&bytes);
assert!(deserialized.is_ok());
}
}
#[test]
fn test_tensor_message_size() {
let data = vec![1.0_f64; 1000];
let msg = TensorMessage::new(data, CompressionStrategy::None, MessagePriority::Normal);
let size = msg.size_bytes();
assert_eq!(size, 1000 * std::mem::size_of::<f64>());
}
#[test]
fn test_compression_topk_full_data() {
let data = vec![1.0, 2.0, 3.0];
let result = compress_tensor(&data, &CompressionStrategy::TopK { k: 10 });
assert!(result.is_ok());
let (compressed, indices) = result.expect("compression failed");
assert_eq!(compressed, data);
assert!(indices.is_none());
}
#[test]
fn test_compression_randomk_full_data() {
let data = vec![1.0, 2.0, 3.0];
let result = compress_tensor(&data, &CompressionStrategy::RandomK { k: 10 });
assert!(result.is_ok());
let (compressed, indices) = result.expect("compression failed");
assert_eq!(compressed, data);
assert!(indices.is_none());
}
}