#![allow(clippy::await_holding_lock)]
#![allow(dead_code)]
use scirs2_core::random::thread_rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, oneshot};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RdmaProtocol {
InfiniBand,
RoCEv1,
RoCEv2,
IWARP,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum RdmaOperation {
Read,
Write,
WriteImmediate,
Send,
Recv,
CompareSwap,
FetchAdd,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RdmaQoS {
BestEffort,
LowLatency,
HighBandwidth,
RealTime,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemoryRegistration {
Standard,
FastReg,
MemoryWindow,
Global,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RdmaConfig {
pub protocol: RdmaProtocol,
pub qos: RdmaQoS,
pub max_message_size: usize,
pub queue_depth: u32,
pub cq_size: u32,
pub memory_registration: MemoryRegistration,
pub hardware_checksum: bool,
pub adaptive_routing: bool,
pub connection_timeout: Duration,
pub retry_count: u8,
pub path_mtu: u32,
}
impl Default for RdmaConfig {
fn default() -> Self {
Self {
protocol: RdmaProtocol::RoCEv2,
qos: RdmaQoS::HighBandwidth,
max_message_size: 4 * 1024 * 1024, queue_depth: 256,
cq_size: 512,
memory_registration: MemoryRegistration::FastReg,
hardware_checksum: true,
adaptive_routing: true,
connection_timeout: Duration::from_secs(30),
retry_count: 7,
path_mtu: 4096,
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryRegion {
pub addr: u64,
pub size: usize,
pub rkey: u32,
pub lkey: u32,
pub access: MemoryAccess,
pub registration_type: MemoryRegistration,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MemoryAccess {
pub read: bool,
pub write: bool,
pub atomic: bool,
pub remote_read: bool,
pub remote_write: bool,
pub remote_atomic: bool,
}
impl Default for MemoryAccess {
fn default() -> Self {
Self {
read: true,
write: true,
atomic: false,
remote_read: true,
remote_write: true,
remote_atomic: false,
}
}
}
#[derive(Debug, Clone)]
pub struct RdmaEndpoint {
pub node_id: usize,
pub address: String,
pub port: u16,
pub gid: Option<[u8; 16]>,
pub lid: Option<u16>,
pub qp_num: u32,
pub psn: u32,
}
#[derive(Debug)]
pub struct WorkRequest {
pub id: u64,
pub operation: RdmaOperation,
pub local_addr: u64,
pub lkey: u32,
pub remote_addr: Option<u64>,
pub rkey: Option<u32>,
pub length: usize,
pub immediate: Option<u32>,
pub completion: oneshot::Sender<RdmaResult<WorkCompletion>>,
}
#[derive(Debug, Clone)]
pub struct WorkCompletion {
pub wr_id: u64,
pub status: CompletionStatus,
pub operation: RdmaOperation,
pub bytes_transferred: usize,
pub immediate: Option<u32>,
pub timestamp: Instant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompletionStatus {
Success,
LocalLengthError,
LocalQpOperationError,
LocalProtectionError,
WorkRequestFlushed,
MemoryManagementError,
BadResponseError,
LocalAccessError,
RemoteInvalidRequestError,
RemoteAccessError,
RemoteOperationError,
RetryExceededError,
RnrRetryExceededError,
LocalRddViolationError,
RemoteInvalidRdRequest,
RemoteAborted,
InvalidEecnError,
InvalidEecStateError,
Fatal,
}
#[derive(Debug, thiserror::Error)]
pub enum RdmaError {
#[error("Connection failed: {0}")]
ConnectionFailed(String),
#[error("Memory registration failed: {0}")]
MemoryRegistrationFailed(String),
#[error("Operation failed: {0}")]
OperationFailed(String),
#[error("Timeout: {0}")]
Timeout(String),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("Hardware error: {0}")]
HardwareError(String),
#[error("Protocol error: {0}")]
ProtocolError(String),
}
pub type RdmaResult<T> = Result<T, RdmaError>;
#[derive(Debug, Clone, Default)]
pub struct RdmaStatistics {
pub total_operations: u64,
pub operations_by_type: HashMap<RdmaOperation, u64>,
pub bytes_transferred: u64,
pub avg_latency_us: f64,
pub peak_bandwidth_gbps: f64,
pub current_bandwidth_gbps: f64,
pub error_count: u64,
pub retry_count: u64,
pub uptime: Duration,
pub cpu_usage_percent: f64,
}
pub struct RdmaMemoryPool {
regions: RwLock<HashMap<usize, Vec<MemoryRegion>>>,
config: RdmaMemoryPoolConfig,
stats: Arc<Mutex<MemoryPoolStats>>,
}
#[derive(Debug, Clone)]
pub struct RdmaMemoryPoolConfig {
pub min_pool_size: usize,
pub max_pool_size: usize,
pub region_sizes: Vec<usize>,
pub prefault: bool,
pub huge_pages: bool,
}
#[derive(Debug, Default, Clone)]
pub struct MemoryPoolStats {
allocations: u64,
deallocations: u64,
cache_hits: u64,
cache_misses: u64,
total_memory_allocated: usize,
peak_memory_usage: usize,
}
impl RdmaMemoryPool {
pub fn new(config: RdmaMemoryPoolConfig) -> RdmaResult<Self> {
let mut regions = HashMap::new();
for &size in &config.region_sizes {
let mut size_regions = Vec::new();
for _ in 0..config.min_pool_size {
let region = Self::allocate_region(size, &config)?;
size_regions.push(region);
}
regions.insert(size, size_regions);
}
Ok(Self {
regions: RwLock::new(regions),
config,
stats: Arc::new(Mutex::new(MemoryPoolStats::default())),
})
}
pub fn allocate(&self, size: usize) -> RdmaResult<MemoryRegion> {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.allocations += 1;
let region_size = self
.config
.region_sizes
.iter()
.find(|&&s| s >= size)
.copied()
.unwrap_or_else(|| {
size.next_power_of_two()
});
let mut regions = self.regions.write().expect("lock should not be poisoned");
if let Some(size_regions) = regions.get_mut(®ion_size) {
if let Some(region) = size_regions.pop() {
stats.cache_hits += 1;
return Ok(region);
}
}
stats.cache_misses += 1;
let region = Self::allocate_region(region_size, &self.config)?;
Ok(region)
}
pub fn deallocate(&self, mut region: MemoryRegion) {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.deallocations += 1;
let mut regions = self.regions.write().expect("lock should not be poisoned");
let size_regions = regions.entry(region.size).or_default();
if size_regions.len() < self.config.max_pool_size {
region.addr = 0; size_regions.push(region);
}
}
fn allocate_region(size: usize, _config: &RdmaMemoryPoolConfig) -> RdmaResult<MemoryRegion> {
Ok(MemoryRegion {
addr: 0x1000_0000, size,
rkey: thread_rng().random::<u32>(),
lkey: thread_rng().random::<u32>(),
access: MemoryAccess::default(),
registration_type: MemoryRegistration::FastReg,
})
}
pub fn statistics(&self) -> MemoryPoolStats {
(*self.stats.lock().expect("lock should not be poisoned")).clone()
}
}
pub struct RdmaConnectionManager {
connections: RwLock<HashMap<usize, RdmaConnection>>,
config: RdmaConfig,
stats: Arc<Mutex<RdmaStatistics>>,
memory_pool: Arc<RdmaMemoryPool>,
work_sender: mpsc::UnboundedSender<WorkRequest>,
}
pub struct RdmaConnection {
pub local_endpoint: RdmaEndpoint,
pub remote_endpoint: RdmaEndpoint,
pub state: ConnectionState,
pub qp_handle: u64,
pub cq_handle: u64,
pub stats: RdmaStatistics,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Disconnected,
Connecting,
Connected,
Error,
}
impl RdmaConnectionManager {
pub fn new(config: RdmaConfig) -> RdmaResult<Self> {
let memory_pool_config = RdmaMemoryPoolConfig {
min_pool_size: 16,
max_pool_size: 256,
region_sizes: vec![4096, 65536, 1048576, 16777216], prefault: true,
huge_pages: config.max_message_size > 2 * 1024 * 1024,
};
let memory_pool = Arc::new(RdmaMemoryPool::new(memory_pool_config)?);
let (work_sender, _work_receiver) = mpsc::unbounded_channel();
Ok(Self {
connections: RwLock::new(HashMap::new()),
config,
stats: Arc::new(Mutex::new(RdmaStatistics::default())),
memory_pool,
work_sender,
})
}
pub async fn connect(&self, remote_endpoint: RdmaEndpoint) -> RdmaResult<usize> {
let connection_id = remote_endpoint.node_id;
let local_endpoint = RdmaEndpoint {
node_id: 0, address: "0.0.0.0".to_string(),
port: 0,
gid: None,
lid: None,
qp_num: thread_rng().random::<u32>(),
psn: thread_rng().random::<u32>(),
};
let connection = RdmaConnection {
local_endpoint,
remote_endpoint,
state: ConnectionState::Connected,
qp_handle: thread_rng().random::<u64>(),
cq_handle: thread_rng().random::<u64>(),
stats: RdmaStatistics::default(),
};
self.connections
.write()
.expect("lock should not be poisoned")
.insert(connection_id, connection);
Ok(connection_id)
}
pub async fn rdma_read(
&self,
_connection_id: usize,
local_addr: u64,
remote_addr: u64,
length: usize,
lkey: u32,
rkey: u32,
) -> RdmaResult<WorkCompletion> {
self.submit_work_request(WorkRequest {
id: thread_rng().random::<u64>(),
operation: RdmaOperation::Read,
local_addr,
lkey,
remote_addr: Some(remote_addr),
rkey: Some(rkey),
length,
immediate: None,
completion: oneshot::channel().0,
})
.await
}
pub async fn rdma_write(
&self,
_connection_id: usize,
local_addr: u64,
remote_addr: u64,
length: usize,
lkey: u32,
rkey: u32,
) -> RdmaResult<WorkCompletion> {
self.submit_work_request(WorkRequest {
id: thread_rng().random::<u64>(),
operation: RdmaOperation::Write,
local_addr,
lkey,
remote_addr: Some(remote_addr),
rkey: Some(rkey),
length,
immediate: None,
completion: oneshot::channel().0,
})
.await
}
pub async fn atomic_compare_swap(
&self,
_connection_id: usize,
remote_addr: u64,
compare: u64,
_swap: u64,
rkey: u32,
) -> RdmaResult<u64> {
let _completion = self
.submit_work_request(WorkRequest {
id: thread_rng().random::<u64>(),
operation: RdmaOperation::CompareSwap,
local_addr: 0,
lkey: 0,
remote_addr: Some(remote_addr),
rkey: Some(rkey),
length: 8,
immediate: None,
completion: oneshot::channel().0,
})
.await?;
Ok(compare) }
async fn submit_work_request(&self, work_request: WorkRequest) -> RdmaResult<WorkCompletion> {
tokio::time::sleep(Duration::from_micros(1)).await;
let completion = WorkCompletion {
wr_id: work_request.id,
status: CompletionStatus::Success,
operation: work_request.operation,
bytes_transferred: work_request.length,
immediate: work_request.immediate,
timestamp: Instant::now(),
};
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.total_operations += 1;
*stats
.operations_by_type
.entry(work_request.operation)
.or_insert(0) += 1;
stats.bytes_transferred += work_request.length as u64;
Ok(completion)
}
pub fn statistics(&self) -> RdmaStatistics {
self.stats
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn memory_pool_statistics(&self) -> MemoryPoolStats {
self.memory_pool.statistics()
}
}
pub struct RdmaTensorScheduler {
connection_manager: Arc<RdmaConnectionManager>,
operation_queue: Arc<Mutex<Vec<TensorOperation>>>,
bandwidth_optimizer: BandwidthOptimizer,
}
#[derive(Debug)]
pub struct TensorOperation {
tensor_id: String,
operation_type: TensorOperationType,
source_node: usize,
target_nodes: Vec<usize>,
data_size: usize,
priority: OperationPriority,
deadline: Option<Instant>,
}
#[derive(Debug, Clone, Copy)]
enum TensorOperationType {
AllReduce,
AllGather,
ReduceScatter,
Broadcast,
AllToAll,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum OperationPriority {
Low,
Normal,
High,
Critical,
}
#[derive(Debug)]
struct BandwidthOptimizer {
link_bandwidth: HashMap<(usize, usize), f64>,
link_utilization: HashMap<(usize, usize), f64>,
optimization_strategy: BandwidthStrategy,
}
#[derive(Debug, Clone, Copy)]
enum BandwidthStrategy {
MinimizeLatency,
MaximizeThroughput,
BalanceLatencyThroughput,
AdaptiveDynamic,
}
impl RdmaTensorScheduler {
pub fn new(connection_manager: Arc<RdmaConnectionManager>) -> Self {
Self {
connection_manager,
operation_queue: Arc::new(Mutex::new(Vec::new())),
bandwidth_optimizer: BandwidthOptimizer {
link_bandwidth: HashMap::new(),
link_utilization: HashMap::new(),
optimization_strategy: BandwidthStrategy::AdaptiveDynamic,
},
}
}
pub async fn schedule_operation(&self, operation: TensorOperation) -> RdmaResult<()> {
self.operation_queue
.lock()
.expect("lock should not be poisoned")
.push(operation);
self.optimize_scheduling().await
}
async fn optimize_scheduling(&self) -> RdmaResult<()> {
#[allow(clippy::await_holding_lock)]
let mut queue = self
.operation_queue
.lock()
.expect("lock should not be poisoned");
queue.sort_by(|a, b| {
a.priority
.cmp(&b.priority)
.reverse()
.then_with(|| match (a.deadline, b.deadline) {
(Some(da), Some(db)) => da.cmp(&db),
(Some(_), None) => std::cmp::Ordering::Less,
(None, Some(_)) => std::cmp::Ordering::Greater,
(None, None) => std::cmp::Ordering::Equal,
})
});
if let Some(operation) = queue.pop() {
self.execute_tensor_operation(operation).await?;
}
Ok(())
}
async fn execute_tensor_operation(&self, operation: TensorOperation) -> RdmaResult<()> {
match operation.operation_type {
TensorOperationType::AllReduce => self.execute_all_reduce(&operation).await,
TensorOperationType::AllGather => self.execute_all_gather(&operation).await,
TensorOperationType::ReduceScatter => self.execute_reduce_scatter(&operation).await,
TensorOperationType::Broadcast => self.execute_broadcast(&operation).await,
TensorOperationType::AllToAll => self.execute_all_to_all(&operation).await,
}
}
async fn execute_all_reduce(&self, _operation: &TensorOperation) -> RdmaResult<()> {
Ok(())
}
async fn execute_all_gather(&self, _operation: &TensorOperation) -> RdmaResult<()> {
Ok(())
}
async fn execute_reduce_scatter(&self, _operation: &TensorOperation) -> RdmaResult<()> {
Ok(())
}
async fn execute_broadcast(&self, _operation: &TensorOperation) -> RdmaResult<()> {
Ok(())
}
async fn execute_all_to_all(&self, _operation: &TensorOperation) -> RdmaResult<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rdma_memory_pool() {
let config = RdmaMemoryPoolConfig {
min_pool_size: 2,
max_pool_size: 10,
region_sizes: vec![4096, 65536],
prefault: true,
huge_pages: false,
};
let pool = RdmaMemoryPool::new(config).unwrap();
let region1 = pool.allocate(2048).unwrap();
assert!(region1.size >= 2048);
let region2 = pool.allocate(8192).unwrap();
assert!(region2.size >= 8192);
pool.deallocate(region1);
pool.deallocate(region2);
let stats = pool.statistics();
assert_eq!(stats.allocations, 2);
assert_eq!(stats.deallocations, 2);
}
#[tokio::test]
async fn test_rdma_connection_manager() {
let config = RdmaConfig::default();
let manager = RdmaConnectionManager::new(config).unwrap();
let remote_endpoint = RdmaEndpoint {
node_id: 1,
address: "192.168.1.100".to_string(),
port: 18515,
gid: None,
lid: None,
qp_num: 12345,
psn: 67890,
};
let connection_id = manager.connect(remote_endpoint).await.unwrap();
assert_eq!(connection_id, 1);
let result = manager
.rdma_read(connection_id, 0x1000, 0x2000, 1024, 0x12345678, 0x87654321)
.await
.unwrap();
assert_eq!(result.status, CompletionStatus::Success);
assert_eq!(result.operation, RdmaOperation::Read);
assert_eq!(result.bytes_transferred, 1024);
}
#[test]
fn test_rdma_config_serialization() {
let config = RdmaConfig::default();
let serialized = serde_json::to_string(&config).unwrap();
let deserialized: RdmaConfig = serde_json::from_str(&serialized).unwrap();
assert_eq!(config.protocol, deserialized.protocol);
assert_eq!(config.qos, deserialized.qos);
assert_eq!(config.max_message_size, deserialized.max_message_size);
}
#[tokio::test]
async fn test_atomic_operations() {
let config = RdmaConfig::default();
let manager = RdmaConnectionManager::new(config).unwrap();
let remote_endpoint = RdmaEndpoint {
node_id: 1,
address: "192.168.1.100".to_string(),
port: 18515,
gid: None,
lid: None,
qp_num: 12345,
psn: 67890,
};
let connection_id = manager.connect(remote_endpoint).await.unwrap();
let previous_value = manager
.atomic_compare_swap(connection_id, 0x3000, 42, 84, 0x12345678)
.await
.unwrap();
assert_eq!(previous_value, 42);
}
}