ruvector-onnx-embeddings 0.1.0

ONNX-based embedding generation for RuVector - Reimagined embedding pipeline in pure Rust
Documentation
//! GPU Configuration for RuVector ONNX Embeddings
//!
//! Provides configuration options for GPU acceleration including
//! device selection, memory limits, and performance tuning.

use serde::{Deserialize, Serialize};

/// GPU execution mode
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum GpuMode {
    /// Automatically select best available backend
    #[default]
    Auto,
    /// Force WebGPU backend
    WebGpu,
    /// Force CUDA-WASM transpiled backend
    CudaWasm,
    /// CPU-only (disable GPU)
    CpuOnly,
}

/// Power preference for GPU device selection
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum PowerPreference {
    /// Prefer low power consumption (integrated GPU)
    LowPower,
    /// Prefer high performance (discrete GPU)
    #[default]
    HighPerformance,
    /// No preference
    None,
}

/// GPU acceleration configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuConfig {
    /// GPU execution mode
    pub mode: GpuMode,

    /// Power preference for device selection
    pub power_preference: PowerPreference,

    /// Maximum GPU memory usage (bytes, 0 = unlimited)
    pub max_memory: u64,

    /// Workgroup size for compute shaders (0 = auto)
    pub workgroup_size: u32,

    /// Enable async GPU operations
    pub async_compute: bool,

    /// Minimum batch size to use GPU (smaller batches use CPU)
    pub min_batch_size: usize,

    /// Minimum vector dimension to use GPU
    pub min_dimension: usize,

    /// Enable shader caching
    pub cache_shaders: bool,

    /// Enable profiling and timing
    pub enable_profiling: bool,

    /// Fallback to CPU on GPU error
    pub fallback_to_cpu: bool,

    /// Device index (for multi-GPU systems)
    pub device_index: u32,
}

impl Default for GpuConfig {
    fn default() -> Self {
        Self {
            mode: GpuMode::Auto,
            power_preference: PowerPreference::HighPerformance,
            max_memory: 0, // unlimited
            workgroup_size: 256,
            async_compute: true,
            min_batch_size: 16,
            min_dimension: 128,
            cache_shaders: true,
            enable_profiling: false,
            fallback_to_cpu: true,
            device_index: 0,
        }
    }
}

impl GpuConfig {
    /// Create configuration with automatic settings
    pub fn auto() -> Self {
        Self::default()
    }

    /// Create configuration for high performance
    pub fn high_performance() -> Self {
        Self {
            mode: GpuMode::Auto,
            power_preference: PowerPreference::HighPerformance,
            workgroup_size: 512,
            async_compute: true,
            min_batch_size: 8,
            min_dimension: 64,
            ..Default::default()
        }
    }

    /// Create configuration for low power usage
    pub fn low_power() -> Self {
        Self {
            mode: GpuMode::Auto,
            power_preference: PowerPreference::LowPower,
            workgroup_size: 128,
            async_compute: false,
            min_batch_size: 32,
            min_dimension: 256,
            ..Default::default()
        }
    }

    /// Create CPU-only configuration
    pub fn cpu_only() -> Self {
        Self {
            mode: GpuMode::CpuOnly,
            ..Default::default()
        }
    }

    /// Create WebGPU-specific configuration
    pub fn webgpu() -> Self {
        Self {
            mode: GpuMode::WebGpu,
            ..Default::default()
        }
    }

    /// Create CUDA-WASM specific configuration
    #[cfg(feature = "cuda-wasm")]
    pub fn cuda_wasm() -> Self {
        Self {
            mode: GpuMode::CudaWasm,
            workgroup_size: 256,
            ..Default::default()
        }
    }

    // Builder methods

    /// Set GPU mode
    pub fn with_mode(mut self, mode: GpuMode) -> Self {
        self.mode = mode;
        self
    }

    /// Set power preference
    pub fn with_power_preference(mut self, pref: PowerPreference) -> Self {
        self.power_preference = pref;
        self
    }

    /// Set maximum memory
    pub fn with_max_memory(mut self, bytes: u64) -> Self {
        self.max_memory = bytes;
        self
    }

    /// Set workgroup size
    pub fn with_workgroup_size(mut self, size: u32) -> Self {
        self.workgroup_size = size;
        self
    }

    /// Set minimum batch size for GPU usage
    pub fn with_min_batch_size(mut self, size: usize) -> Self {
        self.min_batch_size = size;
        self
    }

    /// Set minimum dimension for GPU usage
    pub fn with_min_dimension(mut self, dim: usize) -> Self {
        self.min_dimension = dim;
        self
    }

    /// Enable or disable profiling
    pub fn with_profiling(mut self, enable: bool) -> Self {
        self.enable_profiling = enable;
        self
    }

    /// Enable or disable CPU fallback
    pub fn with_fallback(mut self, enable: bool) -> Self {
        self.fallback_to_cpu = enable;
        self
    }

    /// Set device index
    pub fn with_device(mut self, index: u32) -> Self {
        self.device_index = index;
        self
    }

    /// Check if GPU should be used for given workload
    pub fn should_use_gpu(&self, batch_size: usize, dimension: usize) -> bool {
        self.mode != GpuMode::CpuOnly
            && batch_size >= self.min_batch_size
            && dimension >= self.min_dimension
    }
}

/// GPU memory statistics
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GpuMemoryStats {
    /// Total GPU memory (bytes)
    pub total: u64,
    /// Used GPU memory (bytes)
    pub used: u64,
    /// Free GPU memory (bytes)
    pub free: u64,
    /// Peak usage (bytes)
    pub peak: u64,
}

impl GpuMemoryStats {
    /// Get usage percentage
    pub fn usage_percent(&self) -> f32 {
        if self.total > 0 {
            (self.used as f32 / self.total as f32) * 100.0
        } else {
            0.0
        }
    }
}

/// GPU profiling data
#[allow(dead_code)]
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GpuProfilingData {
    /// Total operations executed
    pub operations: u64,
    /// Total GPU time (microseconds)
    pub gpu_time_us: u64,
    /// Total CPU time (microseconds)
    pub cpu_time_us: u64,
    /// GPU speedup over CPU
    pub speedup: f32,
    /// Memory transfers (bytes)
    pub memory_transferred: u64,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_default_config() {
        let config = GpuConfig::default();
        assert_eq!(config.mode, GpuMode::Auto);
        assert_eq!(config.power_preference, PowerPreference::HighPerformance);
        assert!(config.fallback_to_cpu);
    }

    #[test]
    fn test_should_use_gpu() {
        let config = GpuConfig::default()
            .with_min_batch_size(16)
            .with_min_dimension(128);

        assert!(!config.should_use_gpu(8, 384));   // batch too small
        assert!(!config.should_use_gpu(32, 64));   // dimension too small
        assert!(config.should_use_gpu(32, 384));   // both ok
    }

    #[test]
    fn test_cpu_only() {
        let config = GpuConfig::cpu_only();
        assert!(!config.should_use_gpu(1000, 1000));
    }

    #[test]
    fn test_builder() {
        let config = GpuConfig::auto()
            .with_mode(GpuMode::WebGpu)
            .with_max_memory(1024 * 1024 * 1024)
            .with_workgroup_size(512)
            .with_profiling(true);

        assert_eq!(config.mode, GpuMode::WebGpu);
        assert_eq!(config.max_memory, 1024 * 1024 * 1024);
        assert_eq!(config.workgroup_size, 512);
        assert!(config.enable_profiling);
    }
}