use ndarray::{s, Array1, Array2, Axis};
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use crate::error::{Error, Result};
use crate::gpu::operations::{GpuMatrix, GpuVector};
use crate::gpu::{GpuConfig, GpuDeviceStatus, GpuError, GpuManager};
use crate::lock_safe;
#[cfg(cuda_available)]
use cudarc::driver::CudaContext as CudarcContext;
#[derive(Debug, Clone)]
pub struct MultiGpuConfig {
pub device_ids: Vec<i32>,
pub distribution_strategy: DistributionStrategy,
pub enable_p2p: bool,
pub memory_limit_per_device: usize,
}
impl Default for MultiGpuConfig {
fn default() -> Self {
Self {
device_ids: vec![0], distribution_strategy: DistributionStrategy::DataParallel,
enable_p2p: true,
memory_limit_per_device: 1024 * 1024 * 1024, }
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum DistributionStrategy {
DataParallel,
ModelParallel,
PipelineParallel,
Custom,
}
pub struct MultiGpuManager {
config: MultiGpuConfig,
device_managers: HashMap<i32, GpuManager>,
device_statuses: HashMap<i32, GpuDeviceStatus>,
p2p_available: bool,
}
impl MultiGpuManager {
pub fn new(config: MultiGpuConfig) -> Result<Self> {
let mut device_managers = HashMap::new();
let mut device_statuses = HashMap::new();
for &device_id in &config.device_ids {
let device_config = GpuConfig {
device_id,
memory_limit: config.memory_limit_per_device,
..GpuConfig::default()
};
let manager = GpuManager::with_config(device_config);
let status = manager.device_info();
device_managers.insert(device_id, manager);
device_statuses.insert(device_id, status);
}
let p2p_available = Self::check_p2p_support(&config.device_ids);
Ok(Self {
config,
device_managers,
device_statuses,
p2p_available,
})
}
fn check_p2p_support(device_ids: &[i32]) -> bool {
#[cfg(cuda_available)]
{
for &id1 in device_ids {
for &id2 in device_ids {
if id1 != id2 {
}
}
}
true
}
#[cfg(not(cuda_available))]
{
false
}
}
pub fn device_count(&self) -> usize {
self.config.device_ids.len()
}
pub fn get_device_statuses(&self) -> &HashMap<i32, GpuDeviceStatus> {
&self.device_statuses
}
pub fn distribute_matrix(&self, matrix: &GpuMatrix) -> Result<Vec<(i32, GpuMatrix)>> {
match self.config.distribution_strategy {
DistributionStrategy::DataParallel => self.distribute_data_parallel(matrix),
DistributionStrategy::ModelParallel => self.distribute_model_parallel(matrix),
DistributionStrategy::PipelineParallel => self.distribute_pipeline(matrix),
DistributionStrategy::Custom => self.distribute_custom(matrix),
}
}
fn distribute_data_parallel(&self, matrix: &GpuMatrix) -> Result<Vec<(i32, GpuMatrix)>> {
let num_devices = self.device_count();
let rows = matrix.data.shape()[0];
let cols = matrix.data.shape()[1];
let rows_per_device = (rows + num_devices - 1) / num_devices; let mut distributed = Vec::new();
for (i, &device_id) in self.config.device_ids.iter().enumerate() {
let start_row = i * rows_per_device;
let end_row = ((i + 1) * rows_per_device).min(rows);
if start_row < rows {
let chunk = matrix.data.slice(s![start_row..end_row, ..]).to_owned();
let gpu_chunk = GpuMatrix {
data: chunk,
on_gpu: false, };
distributed.push((device_id, gpu_chunk));
}
}
Ok(distributed)
}
fn distribute_model_parallel(&self, matrix: &GpuMatrix) -> Result<Vec<(i32, GpuMatrix)>> {
let num_devices = self.device_count();
let rows = matrix.data.shape()[0];
let cols = matrix.data.shape()[1];
let cols_per_device = (cols + num_devices - 1) / num_devices;
let mut distributed = Vec::new();
for (i, &device_id) in self.config.device_ids.iter().enumerate() {
let start_col = i * cols_per_device;
let end_col = ((i + 1) * cols_per_device).min(cols);
if start_col < cols {
let chunk = matrix.data.slice(s![.., start_col..end_col]).to_owned();
let gpu_chunk = GpuMatrix {
data: chunk,
on_gpu: false,
};
distributed.push((device_id, gpu_chunk));
}
}
Ok(distributed)
}
fn distribute_pipeline(&self, matrix: &GpuMatrix) -> Result<Vec<(i32, GpuMatrix)>> {
if let Some(&first_device) = self.config.device_ids.first() {
Ok(vec![(first_device, GpuMatrix::new(matrix.to_cpu()?))])
} else {
Err(Error::from(GpuError::DeviceError(
"No devices available".to_string(),
)))
}
}
fn distribute_custom(&self, matrix: &GpuMatrix) -> Result<Vec<(i32, GpuMatrix)>> {
self.distribute_data_parallel(matrix)
}
pub fn collect_results(&self, distributed_results: Vec<(i32, GpuMatrix)>) -> Result<GpuMatrix> {
if distributed_results.is_empty() {
return Err(Error::from(GpuError::KernelExecutionError(
"No results to collect".to_string(),
)));
}
match self.config.distribution_strategy {
DistributionStrategy::DataParallel => self.collect_data_parallel(distributed_results),
DistributionStrategy::ModelParallel => self.collect_model_parallel(distributed_results),
DistributionStrategy::PipelineParallel => self.collect_pipeline(distributed_results),
DistributionStrategy::Custom => self.collect_custom(distributed_results),
}
}
fn collect_data_parallel(
&self,
mut distributed_results: Vec<(i32, GpuMatrix)>,
) -> Result<GpuMatrix> {
distributed_results.sort_by_key(|(device_id, _)| *device_id);
let matrices: Vec<Array2<f64>> = distributed_results
.into_iter()
.map(|(_, matrix)| matrix.data)
.collect();
if matrices.is_empty() {
return Err(Error::from(GpuError::KernelExecutionError(
"No matrices to concatenate".to_string(),
)));
}
let first_shape = matrices[0].shape();
let total_rows: usize = matrices.iter().map(|m| m.shape()[0]).sum();
let cols = first_shape[1];
let mut result = Array2::zeros((total_rows, cols));
let mut current_row = 0;
for matrix in matrices {
let matrix_rows = matrix.shape()[0];
result
.slice_mut(s![current_row..current_row + matrix_rows, ..])
.assign(&matrix);
current_row += matrix_rows;
}
Ok(GpuMatrix {
data: result,
on_gpu: false,
})
}
fn collect_model_parallel(
&self,
mut distributed_results: Vec<(i32, GpuMatrix)>,
) -> Result<GpuMatrix> {
distributed_results.sort_by_key(|(device_id, _)| *device_id);
let matrices: Vec<Array2<f64>> = distributed_results
.into_iter()
.map(|(_, matrix)| matrix.data)
.collect();
if matrices.is_empty() {
return Err(Error::from(GpuError::KernelExecutionError(
"No matrices to concatenate".to_string(),
)));
}
let first_shape = matrices[0].shape();
let rows = first_shape[0];
let total_cols: usize = matrices.iter().map(|m| m.shape()[1]).sum();
let mut result = Array2::zeros((rows, total_cols));
let mut current_col = 0;
for matrix in matrices {
let matrix_cols = matrix.shape()[1];
result
.slice_mut(s![.., current_col..current_col + matrix_cols])
.assign(&matrix);
current_col += matrix_cols;
}
Ok(GpuMatrix {
data: result,
on_gpu: false,
})
}
fn collect_pipeline(&self, distributed_results: Vec<(i32, GpuMatrix)>) -> Result<GpuMatrix> {
if let Some((_, result)) = distributed_results.into_iter().last() {
Ok(result)
} else {
Err(Error::from(GpuError::KernelExecutionError(
"No pipeline result".to_string(),
)))
}
}
fn collect_custom(&self, distributed_results: Vec<(i32, GpuMatrix)>) -> Result<GpuMatrix> {
self.collect_data_parallel(distributed_results)
}
pub fn distributed_matmul(&self, a: &GpuMatrix, b: &GpuMatrix) -> Result<GpuMatrix> {
let distributed_a = self.distribute_matrix(a)?;
let mut distributed_results = Vec::new();
for (device_id, a_chunk) in distributed_a {
if let Some(manager) = self.device_managers.get(&device_id) {
let result_chunk = self.matmul_on_device(&a_chunk, b, device_id)?;
distributed_results.push((device_id, result_chunk));
}
}
self.collect_results(distributed_results)
}
fn matmul_on_device(&self, a: &GpuMatrix, b: &GpuMatrix, device_id: i32) -> Result<GpuMatrix> {
let result_data = a.data.dot(&b.data);
Ok(GpuMatrix {
data: result_data,
on_gpu: false,
})
}
pub fn synchronize_all(&self) -> Result<()> {
#[cfg(cuda_available)]
{
for &device_id in &self.config.device_ids {
if let Some(manager) = self.device_managers.get(&device_id) {
}
}
}
Ok(())
}
pub fn get_memory_usage(&self) -> HashMap<i32, (usize, usize)> {
let mut usage = HashMap::new();
for (&device_id, status) in &self.device_statuses {
let used_memory = status
.total_memory
.unwrap_or(0)
.saturating_sub(status.free_memory.unwrap_or(0));
let total_memory = status.total_memory.unwrap_or(0);
usage.insert(device_id, (used_memory, total_memory));
}
usage
}
pub fn balance_load(&mut self) -> Result<()> {
let memory_usage = self.get_memory_usage();
let mut low_util_devices = Vec::new();
let mut high_util_devices = Vec::new();
for (&device_id, &(used, total)) in &memory_usage {
let utilization = used as f64 / total as f64;
if utilization < 0.3 {
low_util_devices.push(device_id);
} else if utilization > 0.8 {
high_util_devices.push(device_id);
}
}
log::info!(
"Load balancing: {} low util devices, {} high util devices",
low_util_devices.len(),
high_util_devices.len()
);
Ok(())
}
}
static MULTI_GPU_MANAGER: OnceLock<Mutex<MultiGpuManager>> = OnceLock::new();
pub fn init_multi_gpu(config: MultiGpuConfig) -> Result<()> {
let manager = MultiGpuManager::new(config)?;
MULTI_GPU_MANAGER.set(Mutex::new(manager)).map_err(|_| {
Error::InvalidOperation("Multi-GPU manager already initialized".to_string())
})?;
Ok(())
}
pub fn get_multi_gpu_manager() -> Result<Arc<Mutex<MultiGpuManager>>> {
match MULTI_GPU_MANAGER.get() {
Some(manager) => Ok(Arc::new(Mutex::new(
lock_safe!(manager, "multi gpu manager lock")?.clone(),
))),
None => {
init_multi_gpu(MultiGpuConfig::default())?;
get_multi_gpu_manager()
}
}
}
impl Clone for MultiGpuManager {
fn clone(&self) -> Self {
Self::new(self.config.clone()).unwrap_or_else(|_| {
let fallback_config = MultiGpuConfig {
device_ids: vec![0],
..self.config.clone()
};
Self::new(fallback_config).expect("operation should succeed")
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn test_multi_gpu_manager_creation() {
let config = MultiGpuConfig {
device_ids: vec![0],
..MultiGpuConfig::default()
};
let manager = MultiGpuManager::new(config);
assert!(manager.is_ok());
let manager = manager.expect("operation should succeed");
assert_eq!(manager.device_count(), 1);
}
#[test]
fn test_data_parallel_distribution() {
let config = MultiGpuConfig {
device_ids: vec![0, 1],
distribution_strategy: DistributionStrategy::DataParallel,
..MultiGpuConfig::default()
};
let manager = MultiGpuManager::new(config).expect("operation should succeed");
let matrix_data = Array2::from_shape_vec(
(4, 3),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.expect("operation should succeed");
let matrix = GpuMatrix {
data: matrix_data,
on_gpu: false,
};
let distributed = manager
.distribute_matrix(&matrix)
.expect("operation should succeed");
assert_eq!(distributed.len(), 2);
let total_rows: usize = distributed.iter().map(|(_, m)| m.data.shape()[0]).sum();
assert_eq!(total_rows, 4);
}
#[test]
fn test_collect_data_parallel() {
let config = MultiGpuConfig {
device_ids: vec![0, 1],
distribution_strategy: DistributionStrategy::DataParallel,
..MultiGpuConfig::default()
};
let manager = MultiGpuManager::new(config).expect("operation should succeed");
let chunk1_data = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("operation should succeed");
let chunk2_data = Array2::from_shape_vec((2, 3), vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0])
.expect("operation should succeed");
let distributed_results = vec![
(
0,
GpuMatrix {
data: chunk1_data,
on_gpu: false,
},
),
(
1,
GpuMatrix {
data: chunk2_data,
on_gpu: false,
},
),
];
let result = manager
.collect_results(distributed_results)
.expect("operation should succeed");
assert_eq!(result.data.shape(), &[4, 3]);
}
}