use crate::device::{Device, DeviceType};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpKind {
MatMul,
Elementwise,
Reduce,
Other,
}
impl OpKind {
pub fn is_gpu_friendly(self) -> bool {
matches!(self, OpKind::MatMul | OpKind::Elementwise)
}
}
#[derive(Debug, Clone)]
pub struct OpDescriptor {
pub kind: OpKind,
}
pub trait DeviceSelector: Send + Sync {
fn select(&self, op: &OpDescriptor, shape: &[usize]) -> Device;
}
#[derive(Debug, Clone)]
pub struct DeviceConfig {
gpu_threshold_elems: usize,
gpu_available: bool,
gpu_index: u32,
forced: Option<ForcedDevice>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ForcedDevice {
Cpu,
Gpu(u32),
}
impl Default for DeviceConfig {
fn default() -> Self {
Self {
gpu_threshold_elems: 1_048_576, gpu_available: false,
gpu_index: 0,
forced: None,
}
}
}
impl DeviceConfig {
pub fn with_gpu_threshold(mut self, n: usize) -> Self {
self.gpu_threshold_elems = n;
self
}
pub fn with_gpu_available(mut self, avail: bool) -> Self {
self.gpu_available = avail;
self
}
pub fn force_cpu(mut self) -> Self {
self.forced = Some(ForcedDevice::Cpu);
self
}
pub fn force_gpu(mut self, idx: u32) -> Self {
self.forced = Some(ForcedDevice::Gpu(idx));
self
}
pub fn with_gpu_index(mut self, idx: u32) -> Self {
self.gpu_index = idx;
self
}
}
pub struct HeuristicSelector {
config: DeviceConfig,
}
impl HeuristicSelector {
pub fn new(config: DeviceConfig) -> Self {
Self { config }
}
}
impl DeviceSelector for HeuristicSelector {
fn select(&self, op: &OpDescriptor, shape: &[usize]) -> Device {
if let Some(forced) = self.config.forced {
return match forced {
ForcedDevice::Cpu => Device::cpu(),
ForcedDevice::Gpu(idx) => Device {
device_type: DeviceType::Cuda,
index: idx as usize,
},
};
}
let n_elems: usize = shape.iter().product();
if self.config.gpu_available
&& n_elems >= self.config.gpu_threshold_elems
&& op.kind.is_gpu_friendly()
{
Device {
device_type: DeviceType::Cuda,
index: self.config.gpu_index as usize,
}
} else {
Device::cpu()
}
}
}
pub struct DeviceManager {
selector: Box<dyn DeviceSelector>,
}
impl DeviceManager {
pub fn new(selector: impl DeviceSelector + 'static) -> Self {
Self {
selector: Box::new(selector),
}
}
pub fn with_heuristic(config: DeviceConfig) -> Self {
Self::new(HeuristicSelector::new(config))
}
pub fn select(&self, op: &OpDescriptor, shape: &[usize]) -> Device {
self.selector.select(op, shape)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn gpu_config() -> DeviceConfig {
DeviceConfig::default().with_gpu_available(true)
}
fn tiny_shape() -> [usize; 2] {
[2, 5] }
fn large_shape() -> [usize; 2] {
[1024, 2048] }
#[test]
fn test_op_kind_gpu_friendly() {
assert!(OpKind::MatMul.is_gpu_friendly());
assert!(OpKind::Elementwise.is_gpu_friendly());
assert!(!OpKind::Reduce.is_gpu_friendly());
assert!(!OpKind::Other.is_gpu_friendly());
}
#[test]
fn test_tiny_tensor_routes_to_cpu() {
let mgr = DeviceManager::with_heuristic(gpu_config());
let op = OpDescriptor {
kind: OpKind::MatMul,
};
let dev = mgr.select(&op, &tiny_shape());
assert!(dev.is_cpu(), "tiny tensor should use CPU");
}
#[test]
fn test_large_matmul_routes_to_gpu_when_available() {
let mgr = DeviceManager::with_heuristic(gpu_config());
let op = OpDescriptor {
kind: OpKind::MatMul,
};
let dev = mgr.select(&op, &large_shape());
assert!(
dev.is_gpu(),
"large MatMul with GPU available should use GPU"
);
}
#[test]
fn test_large_tensor_cpu_when_gpu_unavailable() {
let cfg = DeviceConfig::default().with_gpu_available(false);
let mgr = DeviceManager::with_heuristic(cfg);
let op = OpDescriptor {
kind: OpKind::MatMul,
};
let dev = mgr.select(&op, &large_shape());
assert!(dev.is_cpu(), "no GPU available → must stay on CPU");
}
#[test]
fn test_large_non_gpu_friendly_op_routes_to_cpu() {
let mgr = DeviceManager::with_heuristic(gpu_config());
for kind in [OpKind::Reduce, OpKind::Other] {
let op = OpDescriptor { kind };
let dev = mgr.select(&op, &large_shape());
assert!(
dev.is_cpu(),
"{kind:?} is not GPU-friendly and should run on CPU"
);
}
}
#[test]
fn test_force_cpu_overrides_gpu_eligible() {
let cfg = gpu_config().force_cpu();
let mgr = DeviceManager::with_heuristic(cfg);
let op = OpDescriptor {
kind: OpKind::MatMul,
};
let dev = mgr.select(&op, &large_shape());
assert!(dev.is_cpu(), "force_cpu must override GPU eligibility");
}
#[test]
fn test_force_gpu_overrides_cpu_config() {
let cfg = DeviceConfig::default()
.with_gpu_available(false)
.force_gpu(0);
let mgr = DeviceManager::with_heuristic(cfg);
let op = OpDescriptor {
kind: OpKind::Other,
};
let dev = mgr.select(&op, &tiny_shape());
assert!(dev.is_gpu(), "force_gpu must override all other conditions");
}
#[test]
fn test_large_elementwise_routes_to_gpu() {
let mgr = DeviceManager::with_heuristic(gpu_config());
let op = OpDescriptor {
kind: OpKind::Elementwise,
};
let dev = mgr.select(&op, &large_shape());
assert!(
dev.is_gpu(),
"large Elementwise with GPU available should use GPU"
);
}
#[test]
fn test_custom_selector_always_cpu() {
struct AlwaysCpu;
impl DeviceSelector for AlwaysCpu {
fn select(&self, _op: &OpDescriptor, _shape: &[usize]) -> Device {
Device::cpu()
}
}
let mgr = DeviceManager::new(AlwaysCpu);
let op = OpDescriptor {
kind: OpKind::MatMul,
};
let dev = mgr.select(&op, &large_shape());
assert!(dev.is_cpu(), "custom selector should override heuristic");
}
#[test]
fn test_device_config_builder_threshold() {
let cfg = DeviceConfig::default()
.with_gpu_available(true)
.with_gpu_threshold(256);
let mgr = DeviceManager::with_heuristic(cfg);
let op = OpDescriptor {
kind: OpKind::MatMul,
};
let shape = [16_usize, 32]; let dev = mgr.select(&op, &shape);
assert!(dev.is_gpu(), "512 elems > 256 threshold should use GPU");
}
#[test]
fn test_device_config_default_no_gpu() {
let mgr = DeviceManager::with_heuristic(DeviceConfig::default());
let op = OpDescriptor {
kind: OpKind::MatMul,
};
let dev = mgr.select(&op, &large_shape());
assert!(dev.is_cpu(), "default config has no GPU available");
}
}