use crate::cluster::{ClusterError, ClusterResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{RwLock, Mutex};
use tokio::time::interval;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkConfig {
pub listen_addr: String,
pub cluster_port: u16,
pub connection_timeout_ms: u64,
pub keep_alive_interval_secs: u64,
pub max_message_size: usize,
pub enable_tls: bool,
pub tls_cert_path: Option<String>,
pub tls_key_path: Option<String>,
pub tls_ca_path: Option<String>,
pub enable_compression: bool,
pub connection_pool_size: usize,
pub max_retry_attempts: u32,
pub retry_delay_ms: u64,
}
impl Default for NetworkConfig {
fn default() -> Self {
Self {
listen_addr: "0.0.0.0".to_string(),
cluster_port: 8081,
connection_timeout_ms: 30000,
keep_alive_interval_secs: 30,
max_message_size: 1048576, enable_tls: false,
tls_cert_path: None,
tls_key_path: None,
tls_ca_path: None,
enable_compression: true,
connection_pool_size: 10,
max_retry_attempts: 3,
retry_delay_ms: 1000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkMessage {
pub message_id: Uuid,
pub source: Uuid,
pub destination: Option<Uuid>,
pub message_type: MessageType,
pub payload: Vec<u8>,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub priority: MessagePriority,
pub ttl_secs: u64,
pub retry_count: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub enum MessageType {
Raft,
Discovery,
Replication,
Failover,
Management,
Heartbeat,
NetworkControl,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum MessagePriority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
#[derive(Debug, Clone)]
pub struct ConnectionInfo {
pub node_id: Uuid,
pub address: SocketAddr,
pub state: ConnectionState,
pub last_activity: Instant,
pub messages_sent: u64,
pub messages_received: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub latency_ms: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConnectionState {
Connecting,
Connected,
Disconnected,
Failed,
}
#[derive(Debug, thiserror::Error)]
pub enum NetworkError {
#[error("Connection failed: {0}")]
ConnectionFailed(String),
#[error("Message send failed: {0}")]
MessageSendFailed(String),
#[error("Message receive failed: {0}")]
MessageReceiveFailed(String),
#[error("Connection timeout")]
ConnectionTimeout,
#[error("Message too large: {0} bytes")]
MessageTooLarge(usize),
#[error("Invalid message format: {0}")]
InvalidMessageFormat(String),
#[error("TLS error: {0}")]
TlsError(String),
#[error("Network partition detected")]
NetworkPartition,
#[error("Address resolution failed: {0}")]
AddressResolutionFailed(String),
}
pub struct NetworkManager {
node_id: Uuid,
config: NetworkConfig,
connections: Arc<RwLock<HashMap<Uuid, ConnectionInfo>>>,
message_handlers: Arc<Mutex<HashMap<MessageType, Box<dyn MessageHandler + Send + Sync>>>>,
callbacks: Arc<Mutex<Vec<Box<dyn NetworkCallback + Send + Sync>>>>,
outgoing_queue: Arc<Mutex<Vec<NetworkMessage>>>,
incoming_queue: Arc<Mutex<Vec<NetworkMessage>>>,
statistics: Arc<RwLock<NetworkStatistics>>,
}
#[derive(Debug, Clone, Default)]
pub struct NetworkStatistics {
pub messages_sent: u64,
pub messages_received: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub active_connections: usize,
pub failed_connections: u64,
pub average_latency_ms: f64,
pub message_loss_rate: f64,
}
#[async_trait::async_trait]
pub trait NetworkCallback {
async fn on_connection_established(&self, node_id: Uuid, address: SocketAddr);
async fn on_connection_lost(&self, node_id: Uuid, reason: &str);
async fn on_message_sent(&self, message: &NetworkMessage, success: bool);
async fn on_message_received(&self, message: &NetworkMessage);
async fn on_network_partition(&self, affected_nodes: Vec<Uuid>);
}
#[async_trait::async_trait]
pub trait MessageHandler {
async fn handle_message(&self, message: NetworkMessage) -> ClusterResult<Option<NetworkMessage>>;
fn message_type(&self) -> MessageType;
}
impl NetworkManager {
pub fn new(node_id: Uuid, config: NetworkConfig) -> Self {
Self {
node_id,
config,
connections: Arc::new(RwLock::new(HashMap::new())),
message_handlers: Arc::new(Mutex::new(HashMap::new())),
callbacks: Arc::new(Mutex::new(Vec::new())),
outgoing_queue: Arc::new(Mutex::new(Vec::new())),
incoming_queue: Arc::new(Mutex::new(Vec::new())),
statistics: Arc::new(RwLock::new(NetworkStatistics::default())),
}
}
pub async fn start(&self) -> ClusterResult<()> {
info!("Starting network manager for node {}", self.node_id);
let manager = self.clone();
tokio::spawn(async move {
manager.message_processing_loop().await;
});
let manager = self.clone();
tokio::spawn(async move {
manager.connection_management_loop().await;
});
let manager = self.clone();
tokio::spawn(async move {
manager.statistics_collection_loop().await;
});
self.start_listener().await?;
Ok(())
}
async fn start_listener(&self) -> ClusterResult<()> {
let listen_addr = format!("{}:{}", self.config.listen_addr, self.config.cluster_port);
info!("Starting network listener on {}", listen_addr);
Ok(())
}
async fn message_processing_loop(&self) {
let mut interval = interval(Duration::from_millis(10));
loop {
interval.tick().await;
if let Err(e) = self.process_outgoing_messages().await {
error!("Failed to process outgoing messages: {}", e);
}
if let Err(e) = self.process_incoming_messages().await {
error!("Failed to process incoming messages: {}", e);
}
}
}
async fn connection_management_loop(&self) {
let mut interval = interval(Duration::from_secs(30));
loop {
interval.tick().await;
if let Err(e) = self.manage_connections().await {
error!("Connection management failed: {}", e);
}
}
}
async fn statistics_collection_loop(&self) {
let mut interval = interval(Duration::from_secs(60));
loop {
interval.tick().await;
if let Err(e) = self.update_statistics().await {
error!("Statistics collection failed: {}", e);
}
}
}
async fn process_outgoing_messages(&self) -> ClusterResult<()> {
let messages: Vec<_> = {
let mut queue = self.outgoing_queue.lock().await;
queue.drain(..).collect()
};
for message in messages {
if let Err(e) = self.send_message_internal(message).await {
error!("Failed to send message: {}", e);
}
}
Ok(())
}
async fn process_incoming_messages(&self) -> ClusterResult<()> {
let messages: Vec<_> = {
let mut queue = self.incoming_queue.lock().await;
queue.drain(..).collect()
};
for message in messages {
if let Err(e) = self.handle_message_internal(message).await {
error!("Failed to handle message: {}", e);
}
}
Ok(())
}
async fn manage_connections(&self) -> ClusterResult<()> {
let now = Instant::now();
let timeout_duration = Duration::from_secs(self.config.keep_alive_interval_secs);
let mut connections_to_remove = Vec::new();
{
let connections = self.connections.read().await;
for (node_id, conn_info) in connections.iter() {
if now.duration_since(conn_info.last_activity) > timeout_duration {
connections_to_remove.push(*node_id);
}
}
}
for node_id in connections_to_remove {
self.remove_connection(node_id, "timeout").await?;
}
{
let mut stats = self.statistics.write().await;
let connections = self.connections.read().await;
stats.active_connections = connections.len();
}
Ok(())
}
async fn update_statistics(&self) -> ClusterResult<()> {
let connections = self.connections.read().await;
if !connections.is_empty() {
let total_latency: u64 = connections.values()
.map(|conn| conn.latency_ms)
.sum();
let average_latency = total_latency as f64 / connections.len() as f64;
let mut stats = self.statistics.write().await;
stats.average_latency_ms = average_latency;
}
Ok(())
}
pub async fn send_message(&self, destination: Uuid, message_type: MessageType, payload: Vec<u8>, priority: MessagePriority) -> ClusterResult<()> {
let message = NetworkMessage {
message_id: Uuid::new_v4(),
source: self.node_id,
destination: Some(destination),
message_type,
payload,
timestamp: chrono::Utc::now(),
priority,
ttl_secs: 60, retry_count: 0,
};
{
let mut queue = self.outgoing_queue.lock().await;
queue.push(message);
}
Ok(())
}
pub async fn broadcast_message(&self, message_type: MessageType, payload: Vec<u8>, priority: MessagePriority) -> ClusterResult<()> {
let message = NetworkMessage {
message_id: Uuid::new_v4(),
source: self.node_id,
destination: None, message_type,
payload,
timestamp: chrono::Utc::now(),
priority,
ttl_secs: 60,
retry_count: 0,
};
{
let mut queue = self.outgoing_queue.lock().await;
queue.push(message);
}
Ok(())
}
async fn send_message_internal(&self, message: NetworkMessage) -> ClusterResult<()> {
let elapsed = chrono::Utc::now().signed_duration_since(message.timestamp);
if elapsed.num_seconds() > message.ttl_secs as i64 {
warn!("Message {} expired, dropping", message.message_id);
return Ok(());
}
match message.destination {
Some(destination) => {
self.send_to_node(destination, message).await?;
}
None => {
self.broadcast_to_all_nodes(message).await?;
}
}
Ok(())
}
async fn send_to_node(&self, destination: Uuid, message: NetworkMessage) -> ClusterResult<()> {
let connections = self.connections.read().await;
if let Some(conn_info) = connections.get(&destination) {
if conn_info.state == ConnectionState::Connected {
if let Err(e) = self.transmit_message(conn_info.address, &message).await {
error!("Failed to transmit message to {}: {}", destination, e);
{
let mut stats = self.statistics.write().await;
stats.failed_connections += 1;
}
return Err(e);
}
drop(connections);
self.update_connection_stats(destination, message.payload.len(), true).await;
{
let mut stats = self.statistics.write().await;
stats.messages_sent += 1;
stats.bytes_sent += message.payload.len() as u64;
}
let callbacks = self.callbacks.lock().await;
for callback in callbacks.iter() {
callback.on_message_sent(&message, true).await;
}
} else {
drop(connections);
if let Err(e) = self.establish_connection(destination).await {
error!("Failed to establish connection to {}: {}", destination, e);
return Err(e);
}
return Box::pin(self.send_to_node(destination, message)).await;
}
} else {
if let Err(e) = self.establish_connection(destination).await {
error!("Failed to establish connection to {}: {}", destination, e);
return Err(e);
}
return Box::pin(self.send_to_node(destination, message)).await;
}
Ok(())
}
async fn broadcast_to_all_nodes(&self, message: NetworkMessage) -> ClusterResult<()> {
let connections = self.connections.read().await;
let node_ids: Vec<Uuid> = connections.keys().cloned().collect();
drop(connections);
for node_id in node_ids {
if let Err(e) = self.send_to_node(node_id, message.clone()).await {
warn!("Failed to broadcast to node {}: {}", node_id, e);
}
}
Ok(())
}
async fn transmit_message(&self, address: SocketAddr, message: &NetworkMessage) -> ClusterResult<()> {
let serialized = serde_json::to_vec(message)
.map_err(|e| ClusterError::Network(NetworkError::InvalidMessageFormat(e.to_string())))?;
if serialized.len() > self.config.max_message_size {
return Err(ClusterError::Network(NetworkError::MessageTooLarge(serialized.len())));
}
debug!("Transmitting message {} to {}", message.message_id, address);
Ok(())
}
async fn handle_message_internal(&self, message: NetworkMessage) -> ClusterResult<()> {
{
let mut stats = self.statistics.write().await;
stats.messages_received += 1;
stats.bytes_received += message.payload.len() as u64;
}
let source = message.source;
self.update_connection_stats(source, message.payload.len(), false).await;
let callbacks = self.callbacks.lock().await;
for callback in callbacks.iter() {
callback.on_message_received(&message).await;
}
let handlers = self.message_handlers.lock().await;
if let Some(handler) = handlers.get(&message.message_type) {
if let Err(e) = handler.handle_message(message).await {
error!("Message handler failed: {}", e);
}
} else {
warn!("No handler for message type: {:?}", message.message_type);
}
Ok(())
}
async fn establish_connection(&self, node_id: Uuid) -> ClusterResult<()> {
info!("Establishing connection to node {}", node_id);
let address = format!("127.0.0.1:{}", self.config.cluster_port)
.parse::<SocketAddr>()
.map_err(|e| ClusterError::Network(NetworkError::AddressResolutionFailed(e.to_string())))?;
let connection_info = ConnectionInfo {
node_id,
address,
state: ConnectionState::Connected,
last_activity: Instant::now(),
messages_sent: 0,
messages_received: 0,
bytes_sent: 0,
bytes_received: 0,
latency_ms: 0,
};
{
let mut connections = self.connections.write().await;
connections.insert(node_id, connection_info);
}
let callbacks = self.callbacks.lock().await;
for callback in callbacks.iter() {
callback.on_connection_established(node_id, address).await;
}
Ok(())
}
async fn remove_connection(&self, node_id: Uuid, reason: &str) -> ClusterResult<()> {
info!("Removing connection to node {}: {}", node_id, reason);
let address = {
let mut connections = self.connections.write().await;
connections.remove(&node_id).map(|conn| conn.address)
};
if let Some(_address) = address {
let callbacks = self.callbacks.lock().await;
for callback in callbacks.iter() {
callback.on_connection_lost(node_id, reason).await;
}
}
Ok(())
}
async fn update_connection_stats(&self, node_id: Uuid, bytes: usize, is_sent: bool) {
let mut connections = self.connections.write().await;
if let Some(conn_info) = connections.get_mut(&node_id) {
conn_info.last_activity = Instant::now();
if is_sent {
conn_info.messages_sent += 1;
conn_info.bytes_sent += bytes as u64;
} else {
conn_info.messages_received += 1;
conn_info.bytes_received += bytes as u64;
}
}
}
pub async fn register_handler(&self, handler: Box<dyn MessageHandler + Send + Sync>) {
let mut handlers = self.message_handlers.lock().await;
handlers.insert(handler.message_type(), handler);
}
pub async fn add_callback(&self, callback: Box<dyn NetworkCallback + Send + Sync>) {
let mut callbacks = self.callbacks.lock().await;
callbacks.push(callback);
}
pub async fn get_statistics(&self) -> NetworkStatistics {
self.statistics.read().await.clone()
}
pub async fn get_connection_info(&self, node_id: Uuid) -> Option<ConnectionInfo> {
self.connections.read().await.get(&node_id).cloned()
}
pub async fn get_all_connections(&self) -> HashMap<Uuid, ConnectionInfo> {
self.connections.read().await.clone()
}
}
impl Clone for NetworkManager {
fn clone(&self) -> Self {
Self {
node_id: self.node_id,
config: self.config.clone(),
connections: Arc::clone(&self.connections),
message_handlers: Arc::clone(&self.message_handlers),
callbacks: Arc::clone(&self.callbacks),
outgoing_queue: Arc::clone(&self.outgoing_queue),
incoming_queue: Arc::clone(&self.incoming_queue),
statistics: Arc::clone(&self.statistics),
}
}
}