use std::collections::VecDeque;
use std::sync::{Arc, RwLock};
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct CommunicationConfig {
pub queue_size: usize,
pub timeout: Option<Duration>,
pub heartbeat_interval: Duration,
pub properties: HashMap<String, String>,
}
use std::collections::HashMap;
impl Default for CommunicationConfig {
fn default() -> Self {
Self {
queue_size: 1000,
timeout: Some(Duration::from_secs(30)),
heartbeat_interval: Duration::from_secs(5),
properties: HashMap::new(),
}
}
}
impl CommunicationConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_queue_size(mut self, size: usize) -> Self {
self.queue_size = size;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
self.heartbeat_interval = interval;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageType {
Request,
Response,
Broadcast,
Heartbeat,
Barrier,
DataExchange,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageStatus {
Pending,
Sent,
Received,
Processed,
Failed(String),
}
pub type MessageId = u64;
pub type NodeId = usize;
#[derive(Debug, Clone)]
pub struct Message {
pub id: MessageId,
pub from: NodeId,
pub to: Option<NodeId>,
pub message_type: MessageType,
pub payload: MessagePayload,
pub timestamp: u64,
pub status: MessageStatus,
}
impl Message {
pub fn request(from: NodeId, to: NodeId, payload: MessagePayload) -> Self {
Self {
id: generate_message_id(),
from,
to: Some(to),
message_type: MessageType::Request,
payload,
timestamp: current_timestamp_ms(),
status: MessageStatus::Pending,
}
}
pub fn broadcast(from: NodeId, payload: MessagePayload) -> Self {
Self {
id: generate_message_id(),
from,
to: None,
message_type: MessageType::Broadcast,
payload,
timestamp: current_timestamp_ms(),
status: MessageStatus::Pending,
}
}
pub fn response(from: NodeId, to: NodeId, payload: MessagePayload) -> Self {
Self {
id: generate_message_id(),
from,
to: Some(to),
message_type: MessageType::Response,
payload,
timestamp: current_timestamp_ms(),
status: MessageStatus::Pending,
}
}
pub fn heartbeat(from: NodeId) -> Self {
Self {
id: generate_message_id(),
from,
to: None,
message_type: MessageType::Heartbeat,
payload: MessagePayload::Heartbeat,
timestamp: current_timestamp_ms(),
status: MessageStatus::Pending,
}
}
}
#[derive(Debug, Clone)]
pub enum MessagePayload {
Text(String),
Binary(Vec<u8>),
Json(String),
NodeValues(Vec<(usize, f64)>),
BoundaryValues(HashMap<usize, f64>),
Barrier {
barrier_id: usize,
participant_count: usize,
},
Heartbeat,
Custom(String),
}
impl PartialEq for MessagePayload {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Text(a), Self::Text(b)) => a == b,
(Self::Json(a), Self::Json(b)) => a == b,
(Self::Custom(a), Self::Custom(b)) => a == b,
(Self::Heartbeat, Self::Heartbeat) => true,
(Self::Barrier { barrier_id: a, .. }, Self::Barrier { barrier_id: b, .. }) => a == b,
_ => false,
}
}
}
impl MessagePayload {
pub fn text(text: impl Into<String>) -> Self {
Self::Text(text.into())
}
pub fn json(json: impl Into<String>) -> Self {
Self::Json(json.into())
}
pub fn node_values(values: Vec<(usize, f64)>) -> Self {
Self::NodeValues(values)
}
}
pub trait Channel: Send + Sync {
fn send(&self, message: Message) -> Result<(), String>;
fn recv(&self, timeout: Option<Duration>) -> Option<Message>;
fn broadcast(&self, from: NodeId, payload: MessagePayload) -> Result<usize, String>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct InMemoryChannel {
queue: Arc<RwLock<VecDeque<Message>>>,
_node_id: NodeId,
}
impl InMemoryChannel {
pub fn new(node_id: NodeId) -> Self {
Self {
queue: Arc::new(RwLock::new(VecDeque::new())),
_node_id: node_id,
}
}
pub fn shared(node_id: NodeId, queue: Arc<RwLock<VecDeque<Message>>>) -> Self {
Self {
queue,
_node_id: node_id,
}
}
}
impl Channel for InMemoryChannel {
fn send(&self, message: Message) -> Result<(), String> {
let mut queue = self.queue.write().map_err(|e| e.to_string())?;
queue.push_back(message);
Ok(())
}
fn recv(&self, _timeout: Option<Duration>) -> Option<Message> {
let mut queue = self.queue.write().ok()?;
queue.pop_front()
}
fn broadcast(&self, from: NodeId, payload: MessagePayload) -> Result<usize, String> {
let message = Message::broadcast(from, payload);
self.send(message)?;
Ok(1)
}
fn len(&self) -> usize {
self.queue.read().map(|q| q.len()).unwrap_or(0)
}
}
pub struct MessageRouter {
channels: HashMap<NodeId, Arc<dyn Channel>>,
broadcast_channel: Arc<RwLock<VecDeque<Message>>>,
}
impl MessageRouter {
pub fn new() -> Self {
Self {
channels: HashMap::new(),
broadcast_channel: Arc::new(RwLock::new(VecDeque::new())),
}
}
pub fn register_channel(&mut self, node_id: NodeId, channel: Arc<dyn Channel>) {
self.channels.insert(node_id, channel);
}
pub fn send_to(&self, to: NodeId, message: Message) -> Result<(), String> {
if let Some(channel) = self.channels.get(&to) {
channel.send(message)
} else {
Err(format!("Node {} not found", to))
}
}
pub fn broadcast(&self, from: NodeId, payload: MessagePayload) -> Result<usize, String> {
let message = Message::broadcast(from, payload.clone());
let mut count = 0;
self.broadcast_channel
.write()
.map_err(|e| e.to_string())?
.push_back(message);
for (node_id, channel) in &self.channels {
if *node_id != from {
let msg = Message::request(from, *node_id, payload.clone());
if channel.send(msg).is_ok() {
count += 1;
}
}
}
Ok(count)
}
pub fn get_broadcast(&self) -> Option<Message> {
self.broadcast_channel
.write()
.ok()
.and_then(|mut q| q.pop_front())
}
}
impl Default for MessageRouter {
fn default() -> Self {
Self::new()
}
}
fn generate_message_id() -> MessageId {
use std::sync::atomic::{AtomicU64, Ordering};
static MESSAGE_COUNTER: AtomicU64 = AtomicU64::new(1);
MESSAGE_COUNTER.fetch_add(1, Ordering::Relaxed)
}
fn current_timestamp_ms() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
#[test]
fn test_communication_config() {
use std::time::Duration;
let config = CommunicationConfig::new()
.with_queue_size(2000)
.with_timeout(Duration::from_secs(60))
.with_heartbeat_interval(Duration::from_secs(10));
assert_eq!(config.queue_size, 2000);
assert_eq!(config.timeout, Some(Duration::from_secs(60)));
}
#[test]
fn test_message_creation() {
let msg = Message::request(0, 1, MessagePayload::text("hello"));
assert_eq!(msg.from, 0);
assert_eq!(msg.to, Some(1));
assert_eq!(msg.message_type, MessageType::Request);
assert_eq!(msg.status, MessageStatus::Pending);
}
#[test]
fn test_broadcast_message() {
let msg = Message::broadcast(0, MessagePayload::text("broadcast"));
assert_eq!(msg.from, 0);
assert_eq!(msg.to, None);
assert_eq!(msg.message_type, MessageType::Broadcast);
}
#[test]
fn test_heartbeat_message() {
let msg = Message::heartbeat(0);
assert_eq!(msg.from, 0);
assert_eq!(msg.message_type, MessageType::Heartbeat);
assert!(matches!(msg.payload, MessagePayload::Heartbeat));
}
#[test]
fn test_in_memory_channel() {
let channel = InMemoryChannel::new(0);
assert!(channel.is_empty());
let msg = Message::request(0, 1, MessagePayload::text("test"));
assert!(channel.send(msg.clone()).is_ok());
assert_eq!(channel.len(), 1);
assert!(!channel.is_empty());
let received = channel.recv(None);
assert!(received.is_some());
assert_eq!(received.unwrap().payload, msg.payload);
assert!(channel.is_empty());
}
#[test]
fn test_message_router() {
let mut router = MessageRouter::new();
let shared_queue = Arc::new(RwLock::new(VecDeque::new()));
let channel1 = InMemoryChannel::shared(1, shared_queue.clone());
let channel2 = InMemoryChannel::shared(2, shared_queue.clone());
router.register_channel(1, Arc::new(channel1));
router.register_channel(2, Arc::new(channel2));
let payload = MessagePayload::text("broadcast test");
let count = router.broadcast(0, payload.clone());
assert!(count.is_ok());
assert_eq!(count.unwrap(), 2); }
}