use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use crate::federated_learning_v2_backup::types::*;
use trustformers_core::{Result, CoreError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommunicationProtocolConfig {
pub protocol: CommunicationProtocol,
pub transport_security: TransportSecurityConfig,
pub compression: CompressionConfig,
pub bandwidth_management: BandwidthManagementConfig,
pub message_queue: MessageQueueConfig,
pub timeout_config: TimeoutConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransportSecurityConfig {
pub protocol: TransportSecurity,
pub certificate_validation: bool,
pub mutual_tls: bool,
pub cipher_suites: Vec<String>,
pub protocol_versions: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressionConfig {
pub algorithm: CompressionAlgorithm,
pub compression_level: u8,
pub min_size_for_compression: usize,
pub adaptive_compression: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BandwidthManagementConfig {
pub max_bandwidth_mbps: f64,
pub adaptation_strategy: BandwidthAdaptationStrategy,
pub congestion_control: bool,
pub qos_priority: QoSPriority,
pub rate_limiting: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QoSPriority {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageQueueConfig {
pub queue_type: MessageQueueType,
pub max_queue_size: usize,
pub persistence: bool,
pub dead_letter_queue: bool,
pub message_ttl_seconds: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageQueueType {
InMemory,
Redis,
RabbitMQ,
Kafka,
MQTT,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeoutConfig {
pub connection_timeout_seconds: u64,
pub read_timeout_seconds: u64,
pub write_timeout_seconds: u64,
pub keepalive_timeout_seconds: u64,
}
#[derive(Debug)]
pub struct CommunicationManager {
config: CommunicationProtocolConfig,
transport_security: TransportSecurityManager,
compression_config: CompressionConfig,
active_connections: HashMap<String, ConnectionInfo>,
message_queue: VecDeque<Message>,
bandwidth_monitor: BandwidthMonitor,
}
#[derive(Debug, Clone)]
pub struct ConnectionInfo {
pub id: String,
pub address: String,
pub state: ConnectionState,
pub last_activity: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Connecting,
Connected,
Disconnecting,
Disconnected,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub id: String,
pub sender_id: String,
pub recipient_id: String,
pub message_type: MessageType,
pub payload: Vec<u8>,
pub timestamp: u64,
pub priority: QoSPriority,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageType {
ModelUpdate,
AggregatedModel,
TrainingConfig,
ParticipantRegistration,
RoundSync,
Heartbeat,
Error,
}
#[derive(Debug)]
pub struct TransportSecurityManager {
config: TransportSecurityConfig,
certificates: HashMap<String, Vec<u8>>,
private_keys: HashMap<String, Vec<u8>>,
}
impl TransportSecurityManager {
pub fn new(config: TransportSecurityConfig) -> Self {
Self {
config,
certificates: HashMap::new(),
private_keys: HashMap::new(),
}
}
pub fn load_certificate(&mut self, name: String, certificate: Vec<u8>) -> Result<()> {
self.certificates.insert(name, certificate);
Ok(())
}
pub fn load_private_key(&mut self, name: String, private_key: Vec<u8>) -> Result<()> {
self.private_keys.insert(name, private_key);
Ok(())
}
pub fn establish_secure_connection(&self, address: &str) -> Result<String> {
match self.config.protocol {
TransportSecurity::TLS13 => self.establish_tls13_connection(address),
TransportSecurity::DTLS => self.establish_dtls_connection(address),
TransportSecurity::CustomEncryption => self.establish_custom_connection(address),
}
}
fn establish_tls13_connection(&self, address: &str) -> Result<String> {
let connection_id = format!("tls13_{}", address.replace([':', '.'], "_"));
Ok(connection_id)
}
fn establish_dtls_connection(&self, address: &str) -> Result<String> {
let connection_id = format!("dtls_{}", address.replace([':', '.'], "_"));
Ok(connection_id)
}
fn establish_custom_connection(&self, address: &str) -> Result<String> {
let connection_id = format!("custom_{}", address.replace([':', '.'], "_"));
Ok(connection_id)
}
}
impl CommunicationManager {
pub fn new(config: CommunicationProtocolConfig) -> Result<Self> {
Ok(Self {
transport_security: TransportSecurityManager::new(config.transport_security.clone()),
compression_config: config.compression.clone(),
bandwidth_monitor: BandwidthMonitor {
current_bandwidth_mbps: 100.0,
bandwidth_history: VecDeque::new(),
adaptation_strategy: config.bandwidth_management.adaptation_strategy,
congestion_state: CongestionState::NoCongestion,
},
config,
active_connections: HashMap::new(),
message_queue: VecDeque::new(),
})
}
pub fn connect(&mut self, participant_id: &str, address: &str) -> Result<()> {
let connection_id = self.transport_security.establish_secure_connection(address)?;
let connection_info = ConnectionInfo {
id: connection_id,
address: address.to_string(),
state: ConnectionState::Connected,
last_activity: self.get_current_timestamp(),
bytes_sent: 0,
bytes_received: 0,
};
self.active_connections.insert(participant_id.to_string(), connection_info);
Ok(())
}
pub fn disconnect(&mut self, participant_id: &str) -> Result<()> {
if let Some(mut connection) = self.active_connections.get_mut(participant_id) {
connection.state = ConnectionState::Disconnected;
}
self.active_connections.remove(participant_id);
Ok(())
}
pub fn send_message(&mut self, message: Message) -> Result<()> {
let compressed_message = self.compress_message(&message)?;
self.message_queue.push_back(compressed_message);
let message_size = message.payload.len() as f64 / 1024.0 / 1024.0; self.bandwidth_monitor.current_bandwidth_mbps += message_size;
if let Some(connection) = self.active_connections.get_mut(&message.recipient_id) {
connection.bytes_sent += message.payload.len() as u64;
connection.last_activity = self.get_current_timestamp();
}
Ok(())
}
pub fn receive_message(&mut self) -> Option<Message> {
if let Some(message) = self.message_queue.pop_front() {
match self.decompress_message(&message) {
Ok(decompressed) => {
if let Some(connection) = self.active_connections.get_mut(&message.sender_id) {
connection.bytes_received += message.payload.len() as u64;
connection.last_activity = self.get_current_timestamp();
}
Some(decompressed)
}
Err(_) => None,
}
} else {
None
}
}
fn compress_message(&self, message: &Message) -> Result<Message> {
if message.payload.len() < self.compression_config.min_size_for_compression {
return Ok(message.clone());
}
let compressed_payload = match self.compression_config.algorithm {
CompressionAlgorithm::None => message.payload.clone(),
CompressionAlgorithm::GZIP => self.gzip_compress(&message.payload)?,
CompressionAlgorithm::LZ4 => self.lz4_compress(&message.payload)?,
CompressionAlgorithm::Brotli => self.brotli_compress(&message.payload)?,
CompressionAlgorithm::Custom => self.custom_compress(&message.payload)?,
};
let mut compressed_message = message.clone();
compressed_message.payload = compressed_payload;
Ok(compressed_message)
}
fn decompress_message(&self, message: &Message) -> Result<Message> {
let decompressed_payload = match self.compression_config.algorithm {
CompressionAlgorithm::None => message.payload.clone(),
CompressionAlgorithm::GZIP => self.gzip_decompress(&message.payload)?,
CompressionAlgorithm::LZ4 => self.lz4_decompress(&message.payload)?,
CompressionAlgorithm::Brotli => self.brotli_decompress(&message.payload)?,
CompressionAlgorithm::Custom => self.custom_decompress(&message.payload)?,
};
let mut decompressed_message = message.clone();
decompressed_message.payload = decompressed_payload;
Ok(decompressed_message)
}
fn gzip_compress(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut compressed = vec![0x1f, 0x8b]; compressed.extend_from_slice(data);
Ok(compressed)
}
fn gzip_decompress(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() < 2 || data[0] != 0x1f || data[1] != 0x8b {
return Err(TrustformersError::InvalidConfiguration("Invalid GZIP data".to_string()).into());
}
Ok(data[2..].to_vec())
}
fn lz4_compress(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut compressed = vec![0x04, 0x22, 0x4d, 0x18]; compressed.extend_from_slice(data);
Ok(compressed)
}
fn lz4_decompress(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() < 4 {
return Err(TrustformersError::InvalidConfiguration("Invalid LZ4 data".to_string()).into());
}
Ok(data[4..].to_vec())
}
fn brotli_compress(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut compressed = vec![0xce, 0xb2, 0xcf, 0x81]; compressed.extend_from_slice(data);
Ok(compressed)
}
fn brotli_decompress(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() < 4 {
return Err(TrustformersError::InvalidConfiguration("Invalid Brotli data".to_string()).into());
}
Ok(data[4..].to_vec())
}
fn custom_compress(&self, data: &[u8]) -> Result<Vec<u8>> {
Ok(data.to_vec())
}
fn custom_decompress(&self, data: &[u8]) -> Result<Vec<u8>> {
Ok(data.to_vec())
}
pub fn update_bandwidth_monitoring(&mut self) {
self.bandwidth_monitor.bandwidth_history.push_back(self.bandwidth_monitor.current_bandwidth_mbps);
if self.bandwidth_monitor.bandwidth_history.len() > 100 {
self.bandwidth_monitor.bandwidth_history.pop_front();
}
let avg_bandwidth: f64 = self.bandwidth_monitor.bandwidth_history.iter().sum::<f64>()
/ self.bandwidth_monitor.bandwidth_history.len() as f64;
self.bandwidth_monitor.congestion_state = if avg_bandwidth > 80.0 {
CongestionState::HeavyCongestion
} else if avg_bandwidth > 60.0 {
CongestionState::ModerateCongestion
} else if avg_bandwidth > 40.0 {
CongestionState::LightCongestion
} else {
CongestionState::NoCongestion
};
match self.bandwidth_monitor.adaptation_strategy {
BandwidthAdaptationStrategy::Conservative => {
if self.bandwidth_monitor.congestion_state != CongestionState::NoCongestion {
self.bandwidth_monitor.current_bandwidth_mbps *= 0.9;
}
}
BandwidthAdaptationStrategy::Aggressive => {
match self.bandwidth_monitor.congestion_state {
CongestionState::NoCongestion => {
self.bandwidth_monitor.current_bandwidth_mbps *= 1.1;
}
_ => {
self.bandwidth_monitor.current_bandwidth_mbps *= 0.8;
}
}
}
BandwidthAdaptationStrategy::Hybrid => {
match self.bandwidth_monitor.congestion_state {
CongestionState::NoCongestion => {
self.bandwidth_monitor.current_bandwidth_mbps *= 1.05;
}
CongestionState::LightCongestion => {
self.bandwidth_monitor.current_bandwidth_mbps *= 0.95;
}
_ => {
self.bandwidth_monitor.current_bandwidth_mbps *= 0.85;
}
}
}
}
self.bandwidth_monitor.current_bandwidth_mbps = 0.0;
}
pub fn get_connection_statistics(&self) -> HashMap<String, ConnectionInfo> {
self.active_connections.clone()
}
pub fn get_bandwidth_monitor(&self) -> &BandwidthMonitor {
&self.bandwidth_monitor
}
fn get_current_timestamp(&self) -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
pub fn check_connection_health(&mut self) -> Result<()> {
let current_time = self.get_current_timestamp();
let timeout = self.config.timeout_config.keepalive_timeout_seconds;
let mut stale_connections = Vec::new();
for (participant_id, connection) in &self.active_connections {
if current_time - connection.last_activity > timeout {
stale_connections.push(participant_id.clone().into());
}
}
for participant_id in stale_connections {
if let Some(connection) = self.active_connections.get_mut(&participant_id) {
connection.state = ConnectionState::Error;
}
}
Ok(())
}
}
impl Default for CommunicationProtocolConfig {
fn default() -> Self {
Self {
protocol: CommunicationProtocol::default(),
transport_security: TransportSecurityConfig::default(),
compression: CompressionConfig::default(),
bandwidth_management: BandwidthManagementConfig::default(),
message_queue: MessageQueueConfig::default(),
timeout_config: TimeoutConfig::default(),
}
}
}
impl Default for TransportSecurityConfig {
fn default() -> Self {
Self {
protocol: TransportSecurity::default(),
certificate_validation: true,
mutual_tls: false,
cipher_suites: vec!["TLS_AES_256_GCM_SHA384".to_string()],
protocol_versions: vec!["TLSv1.3".to_string()],
}
}
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
algorithm: CompressionAlgorithm::default(),
compression_level: 6,
min_size_for_compression: 1024,
adaptive_compression: true,
}
}
}
impl Default for BandwidthManagementConfig {
fn default() -> Self {
Self {
max_bandwidth_mbps: 100.0,
adaptation_strategy: BandwidthAdaptationStrategy::default(),
congestion_control: true,
qos_priority: QoSPriority::Medium,
rate_limiting: true,
}
}
}
impl Default for MessageQueueConfig {
fn default() -> Self {
Self {
queue_type: MessageQueueType::InMemory,
max_queue_size: 10000,
persistence: false,
dead_letter_queue: false,
message_ttl_seconds: 3600,
}
}
}
impl Default for TimeoutConfig {
fn default() -> Self {
Self {
connection_timeout_seconds: 30,
read_timeout_seconds: 60,
write_timeout_seconds: 60,
keepalive_timeout_seconds: 300,
}
}
}
impl Default for QoSPriority {
fn default() -> Self {
Self::Medium
}
}
impl Default for MessageQueueType {
fn default() -> Self {
Self::InMemory
}
}