cortex_rust 0.6.0

High-performance LLM inference with 4-bit quantization and Test-Time Training (TTT)
Documentation
//! Multi-GPU Detection and Management Utilities
//!
//! Provides mechanisms for GPU discovery, VRAM tracking, and multi-GPU support
//! for distributed model inference.

#[cfg(feature = "cuda")]
use anyhow::bail;
use anyhow::Result;
use candle_core::Device;
use std::collections::HashMap;

/// Represents a detected GPU device
#[derive(Debug, Clone)]
pub struct GPUDevice {
    /// Unique device ID
    pub id: usize,
    /// Total VRAM in bytes
    pub total_vram: u64,
    /// Available VRAM in bytes
    pub free_vram: u64,
    /// GPU name/identifier
    pub name: String,
    /// Compute capability (for advanced routing)
    pub compute_capability: (u32, u32),
    /// Is this device CUDA-compatible?
    pub is_cuda: bool,
}

/// GPU detection and management system
pub struct MultiGPUManager {
    /// Detected GPU devices
    devices: Vec<GPUDevice>,
    /// Current strategy for device placement
    current_strategy: DevicePlacementStrategy,
}

/// Strategy for distributing model layers across GPUs
#[derive(Debug, Clone)]
pub enum DevicePlacementStrategy {
    /// All layers on the first GPU
    SingleGPU,
    /// Fixed number of layers per GPU
    Distributed { layers_per_gpu: usize },
    /// Dynamic load balancing
    Adaptive,
}

impl MultiGPUManager {
    /// Detect available CUDA-capable GPUs
    pub fn detect_gpus() -> Result<Vec<GPUDevice>> {
        // Note: This is a placeholder. Real implementation requires CUDA/GPU-specific calls.
        #[cfg(feature = "cuda")]
        {
            let cuda_devices = Self::detect_cuda_devices()?;
            Ok(cuda_devices)
        }

        #[cfg(not(feature = "cuda"))]
        {
            let primary_device = GPUDevice {
                id: 0,
                total_vram: 4 * 1024 * 1024 * 1024, // Default 4GB
                free_vram: 3 * 1024 * 1024 * 1024,  // Default 3GB free
                name: "Default Device".to_string(),
                compute_capability: (7, 5), // Default compute capability
                is_cuda: false,
            };
            Ok(vec![primary_device])
        }
    }

    /// Detect CUDA devices (CUDA-enabled implementation)
    #[cfg(feature = "cuda")]
    fn detect_cuda_devices() -> Result<Vec<GPUDevice>> {
        // Uses cuda_runtime_sys or similar for detection
        // This is a mock implementation - replace with actual CUDA device enumeration
        use cuda_runtime_sys as cuda;
        use cuda_runtime_sys::cudaError::cudaSuccess;

        let mut devices = Vec::new();
        let mut device_count: std::os::raw::c_int = 0;

        unsafe {
            // Get number of CUDA devices
            if cuda::cudaGetDeviceCount(&mut device_count) != cudaSuccess {
                bail!("Failed to get CUDA device count");
            }

            for device_id in 0..device_count {
                let mut props: cuda::cudaDeviceProp = std::mem::zeroed();
                if cuda::cudaGetDeviceProperties(&mut props, device_id) != cudaSuccess {
                    continue; // Skip this device if properties can't be retrieved
                }

                // Get memory info
                let mut free_memory: usize = 0;
                let mut total_memory: usize = 0;
                cuda::cudaMemGetInfo(&mut free_memory, &mut total_memory);

                devices.push(GPUDevice {
                    id: device_id as usize,
                    total_vram: total_memory as u64,
                    free_vram: free_memory as u64,
                    name: std::ffi::CStr::from_ptr(props.name.as_ptr())
                        .to_str()
                        .unwrap_or("Unknown GPU")
                        .to_string(),
                    compute_capability: (props.major as u32, props.minor as u32),
                    is_cuda: true,
                });
            }
        }

        Ok(devices)
    }

    /// Create a new MultiGPU manager
    pub fn new() -> Result<Self> {
        let devices = Self::detect_gpus()?;

        // Default strategy: Single GPU or distributed based on count
        let current_strategy = match devices.len() {
            0 => DevicePlacementStrategy::SingleGPU,
            1 => DevicePlacementStrategy::SingleGPU,
            _n => DevicePlacementStrategy::Distributed {
                layers_per_gpu: 4, // Default 4 layers per GPU
            },
        };

        Ok(Self {
            devices,
            current_strategy,
        })
    }

    /// Get the total number of detected GPUs
    pub fn gpu_count(&self) -> usize {
        self.devices.len()
    }

    /// Get a list of GPU IDs
    pub fn gpu_ids(&self) -> Vec<usize> {
        self.devices.iter().map(|d| d.id).collect()
    }

    /// Get detailed GPU information
    pub fn gpu_info(&self) -> &Vec<GPUDevice> {
        &self.devices
    }

    /// Determine layer distribution across available GPUs
    pub fn distribute_layers(&self, total_layers: usize) -> Result<HashMap<usize, usize>> {
        match self.current_strategy {
            DevicePlacementStrategy::SingleGPU => {
                // All layers on first GPU (or CPU if no GPU)
                let device_id = self.devices.first().map(|d| d.id).unwrap_or(0);
                let layer_map = (0..total_layers)
                    .map(|layer_id| (layer_id, device_id))
                    .collect();
                Ok(layer_map)
            }
            DevicePlacementStrategy::Distributed { layers_per_gpu } => {
                let mut layer_map = HashMap::new();
                let devices = &self.devices;

                for layer_id in 0..total_layers {
                    // Determine GPU based on layer index and layers_per_gpu
                    let gpu_index = layer_id / layers_per_gpu;
                    let device_id = if gpu_index < devices.len() {
                        devices[gpu_index].id
                    } else {
                        // Fallback to first GPU or CPU
                        devices.first().map(|d| d.id).unwrap_or(0)
                    };

                    layer_map.insert(layer_id, device_id);
                }

                Ok(layer_map)
            }
            DevicePlacementStrategy::Adaptive => {
                // Complex adaptive strategy: consider VRAM, compute capability
                // Placeholder: simple round-robin
                let mut layer_map = HashMap::new();
                let devices = &self.devices;

                for layer_id in 0..total_layers {
                    let device_id = devices[layer_id % devices.len()].id;
                    layer_map.insert(layer_id, device_id);
                }

                Ok(layer_map)
            }
        }
    }

    /// Get target device for a specific layer
    pub fn get_layer_device(&self, layer_id: usize, total_layers: usize) -> Result<Device> {
        let layer_map = self.distribute_layers(total_layers)?;

        // Get device for this layer
        let device_id = layer_map.get(&layer_id).copied().unwrap_or(0);

        // Convert to Candle Device
        let device = if device_id == 0 {
            // Fallback to CPU
            Device::Cpu
        } else {
            #[cfg(feature = "cuda")]
            {
                Device::cuda_if_available(device_id).unwrap_or(Device::Cpu)
            }
            #[cfg(not(feature = "cuda"))]
            {
                Device::Cpu
            }
        };

        Ok(device)
    }

    /// Update strategy dynamically
    pub fn set_strategy(&mut self, strategy: DevicePlacementStrategy) {
        self.current_strategy = strategy;
    }
}

/// Utility function to get VRAM info (directly usable from isomorphic layer)
pub fn get_vram_info(gpu_device_id: usize) -> Result<(usize, usize)> {
    #[cfg(feature = "cuda")]
    {
        use cuda_runtime_sys as cuda;
        use cuda_runtime_sys::cudaError::cudaSuccess;
        let mut free_memory: usize = 0;
        let mut total_memory: usize = 0;

        unsafe {
            if cuda::cudaSetDevice(gpu_device_id as std::os::raw::c_int) != cudaSuccess {
                bail!("Failed to set CUDA device");
            }

            if cuda::cudaMemGetInfo(&mut free_memory, &mut total_memory) != cudaSuccess {
                bail!("Failed to get CUDA memory info");
            }
        }

        Ok((free_memory, total_memory))
    }

    #[cfg(not(feature = "cuda"))]
    {
        let _ = gpu_device_id; // Suppress unused warning
                               // Fallback for non-CUDA builds (2GB for 32-bit, 4GB for 64-bit)
        #[cfg(target_pointer_width = "32")]
        let mem = 2usize * 1024 * 1024 * 1024;
        #[cfg(target_pointer_width = "64")]
        let mem = 4usize * 1024 * 1024 * 1024;
        Ok((mem, mem))
    }
}

/// Unit tests for Multi-GPU utilities
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_multi_gpu_detection() {
        let manager = MultiGPUManager::new().expect("Failed to create MultiGPUManager");

        // Basic checks
        let devices = manager.gpu_info();
        assert!(
            !devices.is_empty(),
            "At least one device should be detected"
        );

        // Check device information
        for device in devices {
            assert!(device.total_vram > 0, "Total VRAM should be > 0");
            assert!(device.free_vram > 0, "Free VRAM should be > 0");
            assert!(
                device.free_vram <= device.total_vram,
                "Free VRAM cannot exceed total VRAM"
            );
        }
    }

    #[test]
    fn test_layer_distribution() {
        let manager = MultiGPUManager::new().expect("Failed to create MultiGPUManager");

        // Test layer distribution
        let total_layers = 32; // Typical LLaMA model
        let layer_map = manager
            .distribute_layers(total_layers)
            .expect("Failed to distribute layers");

        assert_eq!(
            layer_map.len(),
            total_layers,
            "All layers should be assigned"
        );
    }

    #[test]
    #[ignore = "Requires multiple GPUs"]
    fn test_layer_device_mapping() {
        let mut manager = MultiGPUManager::new().expect("Failed to create MultiGPUManager");

        // Test different strategies
        let total_layers = 32;

        // Test SingleGPU
        manager.set_strategy(DevicePlacementStrategy::SingleGPU);
        let single_gpu_map = manager
            .distribute_layers(total_layers)
            .expect("Failed to distribute layers");

        // In SingleGPU, all layers should map to same device
        let first_device = *single_gpu_map.values().next().unwrap();
        assert!(
            single_gpu_map.values().all(|&d| d == first_device),
            "All layers should map to same device in SingleGPU"
        );

        // Test Distributed
        manager.set_strategy(DevicePlacementStrategy::Distributed { layers_per_gpu: 4 });
        let distributed_map = manager
            .distribute_layers(total_layers)
            .expect("Failed to distribute layers");

        // In Distributed, layers should be spread across devices
        assert!(
            distributed_map
                .values()
                .collect::<std::collections::HashSet<_>>()
                .len()
                > 1,
            "Layers should be spread across multiple devices"
        );
    }
}