use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum GpuMode {
#[default]
Auto,
WebGpu,
CudaWasm,
CpuOnly,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum PowerPreference {
LowPower,
#[default]
HighPerformance,
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuConfig {
pub mode: GpuMode,
pub power_preference: PowerPreference,
pub max_memory: u64,
pub workgroup_size: u32,
pub async_compute: bool,
pub min_batch_size: usize,
pub min_dimension: usize,
pub cache_shaders: bool,
pub enable_profiling: bool,
pub fallback_to_cpu: bool,
pub device_index: u32,
}
impl Default for GpuConfig {
fn default() -> Self {
Self {
mode: GpuMode::Auto,
power_preference: PowerPreference::HighPerformance,
max_memory: 0, workgroup_size: 256,
async_compute: true,
min_batch_size: 16,
min_dimension: 128,
cache_shaders: true,
enable_profiling: false,
fallback_to_cpu: true,
device_index: 0,
}
}
}
impl GpuConfig {
pub fn auto() -> Self {
Self::default()
}
pub fn high_performance() -> Self {
Self {
mode: GpuMode::Auto,
power_preference: PowerPreference::HighPerformance,
workgroup_size: 512,
async_compute: true,
min_batch_size: 8,
min_dimension: 64,
..Default::default()
}
}
pub fn low_power() -> Self {
Self {
mode: GpuMode::Auto,
power_preference: PowerPreference::LowPower,
workgroup_size: 128,
async_compute: false,
min_batch_size: 32,
min_dimension: 256,
..Default::default()
}
}
pub fn cpu_only() -> Self {
Self {
mode: GpuMode::CpuOnly,
..Default::default()
}
}
pub fn webgpu() -> Self {
Self {
mode: GpuMode::WebGpu,
..Default::default()
}
}
#[cfg(feature = "cuda-wasm")]
pub fn cuda_wasm() -> Self {
Self {
mode: GpuMode::CudaWasm,
workgroup_size: 256,
..Default::default()
}
}
pub fn with_mode(mut self, mode: GpuMode) -> Self {
self.mode = mode;
self
}
pub fn with_power_preference(mut self, pref: PowerPreference) -> Self {
self.power_preference = pref;
self
}
pub fn with_max_memory(mut self, bytes: u64) -> Self {
self.max_memory = bytes;
self
}
pub fn with_workgroup_size(mut self, size: u32) -> Self {
self.workgroup_size = size;
self
}
pub fn with_min_batch_size(mut self, size: usize) -> Self {
self.min_batch_size = size;
self
}
pub fn with_min_dimension(mut self, dim: usize) -> Self {
self.min_dimension = dim;
self
}
pub fn with_profiling(mut self, enable: bool) -> Self {
self.enable_profiling = enable;
self
}
pub fn with_fallback(mut self, enable: bool) -> Self {
self.fallback_to_cpu = enable;
self
}
pub fn with_device(mut self, index: u32) -> Self {
self.device_index = index;
self
}
pub fn should_use_gpu(&self, batch_size: usize, dimension: usize) -> bool {
self.mode != GpuMode::CpuOnly
&& batch_size >= self.min_batch_size
&& dimension >= self.min_dimension
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GpuMemoryStats {
pub total: u64,
pub used: u64,
pub free: u64,
pub peak: u64,
}
impl GpuMemoryStats {
pub fn usage_percent(&self) -> f32 {
if self.total > 0 {
(self.used as f32 / self.total as f32) * 100.0
} else {
0.0
}
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GpuProfilingData {
pub operations: u64,
pub gpu_time_us: u64,
pub cpu_time_us: u64,
pub speedup: f32,
pub memory_transferred: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = GpuConfig::default();
assert_eq!(config.mode, GpuMode::Auto);
assert_eq!(config.power_preference, PowerPreference::HighPerformance);
assert!(config.fallback_to_cpu);
}
#[test]
fn test_should_use_gpu() {
let config = GpuConfig::default()
.with_min_batch_size(16)
.with_min_dimension(128);
assert!(!config.should_use_gpu(8, 384)); assert!(!config.should_use_gpu(32, 64)); assert!(config.should_use_gpu(32, 384)); }
#[test]
fn test_cpu_only() {
let config = GpuConfig::cpu_only();
assert!(!config.should_use_gpu(1000, 1000));
}
#[test]
fn test_builder() {
let config = GpuConfig::auto()
.with_mode(GpuMode::WebGpu)
.with_max_memory(1024 * 1024 * 1024)
.with_workgroup_size(512)
.with_profiling(true);
assert_eq!(config.mode, GpuMode::WebGpu);
assert_eq!(config.max_memory, 1024 * 1024 * 1024);
assert_eq!(config.workgroup_size, 512);
assert!(config.enable_profiling);
}
}