pub mod cpu;
#[cfg(feature = "cuda")]
pub mod cuda;
#[cfg(feature = "avx2")]
pub mod avx2;
#[cfg(feature = "vulkan")]
pub mod vulkan;
use crate::{
backend::cpu::CpuDevice,
error::GpuError,
kernel::{Kernel, KernelDispatch},
kernels::{
em_reduce::{EmReduce, EmReduceInput, EmReduceOutput},
hello_backend::{HelloBackend, HelloBackendInput, HelloBackendOutput},
},
};
pub enum DeviceBackend {
Cpu,
#[cfg(feature = "cuda")]
Cuda(cuda::CudaDevice),
#[cfg(feature = "vulkan")]
Vulkan(vulkan::VulkanDevice),
#[cfg(feature = "avx2")]
Avx2,
}
pub type GpuBackend = DeviceBackend;
#[non_exhaustive]
pub enum BackendPreference {
Auto,
Cuda,
Vulkan,
Avx2,
Cpu,
}
impl DeviceBackend {
pub fn auto_detect() -> Self {
#[cfg(feature = "cuda")]
match cuda::CudaDevice::init() {
Ok(dev) => {
tracing::info!(
device_name = %dev.name(),
vram_bytes = dev.total_vram_bytes(),
"compute backend: CUDA selected"
);
return Self::Cuda(dev);
}
Err(e) => tracing::warn!(%e, "CUDA init failed, trying Vulkan"),
}
#[cfg(feature = "vulkan")]
match vulkan::VulkanDevice::init() {
Ok(dev) => {
tracing::info!(
device_name = %dev.name(),
vram_bytes = dev.total_vram_bytes(),
"compute backend: Vulkan selected"
);
return Self::Vulkan(dev);
}
Err(e) => tracing::warn!(%e, "Vulkan init failed, trying AVX2"),
}
#[cfg(feature = "avx2")]
if is_x86_feature_detected!("avx2") {
tracing::info!("compute backend: AVX2 selected");
return Self::Avx2;
}
tracing::warn!("compute backend: scalar CPU fallback");
Self::Cpu
}
pub fn cpu() -> Self {
Self::Cpu
}
#[cfg(feature = "cuda")]
pub fn cuda() -> Result<Self, GpuError> {
Ok(Self::Cuda(cuda::CudaDevice::init()?))
}
#[cfg(feature = "vulkan")]
pub fn vulkan() -> Result<Self, GpuError> {
Ok(Self::Vulkan(vulkan::VulkanDevice::init()?))
}
#[cfg(feature = "avx2")]
pub fn avx2() -> Result<Self, GpuError> {
if is_x86_feature_detected!("avx2") {
Ok(Self::Avx2)
} else {
Err(GpuError::BackendUnavailable(
"AVX2 not supported by this CPU".into(),
))
}
}
pub fn from_preference(pref: BackendPreference) -> Result<Self, GpuError> {
match pref {
BackendPreference::Auto => Ok(Self::auto_detect()),
BackendPreference::Cpu => Ok(Self::Cpu),
BackendPreference::Cuda => {
#[cfg(feature = "cuda")]
return Ok(Self::Cuda(cuda::CudaDevice::init()?));
#[allow(unreachable_code)]
Err(GpuError::BackendUnavailable(
"CUDA backend not compiled in; rebuild with --features cuda".into(),
))
}
BackendPreference::Vulkan => {
#[cfg(feature = "vulkan")]
return Ok(Self::Vulkan(vulkan::VulkanDevice::init()?));
#[allow(unreachable_code)]
Err(GpuError::BackendUnavailable(
"Vulkan backend not compiled in; rebuild with --features vulkan".into(),
))
}
BackendPreference::Avx2 => {
#[cfg(feature = "avx2")]
{
if is_x86_feature_detected!("avx2") {
return Ok(Self::Avx2);
}
return Err(GpuError::BackendUnavailable(
"AVX2 not supported by this CPU".into(),
));
}
#[allow(unreachable_code)]
Err(GpuError::BackendUnavailable(
"AVX2 backend not compiled in; rebuild with --features avx2".into(),
))
}
}
}
pub fn run<K: Kernel>(&self, input: K::Input<'_>) -> Result<K::Output, GpuError>
where
Self: KernelDispatch<K>,
{
self.dispatch(input)
}
pub fn name(&self) -> &'static str {
match self {
Self::Cpu => "cpu",
#[cfg(feature = "cuda")]
Self::Cuda(_) => "cuda",
#[cfg(feature = "vulkan")]
Self::Vulkan(_) => "vulkan",
#[cfg(feature = "avx2")]
Self::Avx2 => "avx2",
}
}
pub fn is_gpu(&self) -> bool {
match self {
#[cfg(feature = "cuda")]
Self::Cuda(_) => true,
#[cfg(feature = "vulkan")]
Self::Vulkan(_) => true,
_ => false,
}
}
pub fn is_accelerated(&self) -> bool {
!matches!(self, Self::Cpu)
}
pub fn available_vram_bytes(&self) -> Option<u64> {
match self {
Self::Cpu => None,
#[cfg(feature = "cuda")]
Self::Cuda(dev) => dev.available_vram_bytes().ok(),
#[cfg(feature = "vulkan")]
Self::Vulkan(dev) => dev.available_vram_bytes(),
#[cfg(feature = "avx2")]
Self::Avx2 => None,
}
}
pub fn total_vram_bytes(&self) -> Option<u64> {
match self {
Self::Cpu => None,
#[cfg(feature = "cuda")]
Self::Cuda(dev) => Some(dev.total_vram_bytes()),
#[cfg(feature = "vulkan")]
Self::Vulkan(dev) => Some(dev.total_vram_bytes()),
#[cfg(feature = "avx2")]
Self::Avx2 => None,
}
}
}
#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
pub(crate) enum EmSession {
#[cfg(feature = "cuda")]
Cuda(cuda::launch::em_reduce::CudaEmSession),
#[cfg(feature = "vulkan")]
Vulkan(vulkan::launch::em_reduce::VulkanEmSession),
#[cfg(feature = "avx2")]
Avx2(avx2::launch::em_reduce::Avx2EmSession),
}
#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
impl DeviceBackend {
pub(crate) fn em_init_session(
&self,
comparison_levels: &[u32],
n_pairs: usize,
n_fields: usize,
) -> Result<EmSession, GpuError> {
match self {
#[cfg(feature = "cuda")]
Self::Cuda(dev) => dev
.em_init_session(comparison_levels, n_pairs, n_fields)
.map(EmSession::Cuda),
#[cfg(feature = "vulkan")]
Self::Vulkan(dev) => dev
.em_init_session(comparison_levels, n_pairs, n_fields)
.map(EmSession::Vulkan),
#[cfg(feature = "avx2")]
Self::Avx2 => Ok(EmSession::Avx2(avx2::device::Avx2Device::em_init_session(
comparison_levels,
n_pairs,
n_fields,
))),
_ => Err(GpuError::BackendUnavailable(
"em_init_session requires an accelerated backend".into(),
)),
}
}
pub(crate) fn em_run_iteration(
&self,
session: &mut EmSession,
weights: &[f32],
log_prior_odds: f32,
) -> Result<EmReduceOutput, GpuError> {
match (self, session) {
#[cfg(feature = "cuda")]
(Self::Cuda(dev), EmSession::Cuda(s)) => {
dev.em_run_iteration(s, weights, log_prior_odds)
}
#[cfg(feature = "vulkan")]
(Self::Vulkan(dev), EmSession::Vulkan(s)) => {
dev.em_run_iteration(s, weights, log_prior_odds)
}
#[cfg(feature = "avx2")]
(Self::Avx2, EmSession::Avx2(s)) => {
avx2::device::Avx2Device::em_run_iteration(s, weights, log_prior_odds)
}
_ => Err(GpuError::BackendUnavailable(
"em_run_iteration requires an accelerated backend".into(),
)),
}
}
pub(crate) fn em_drop_session(&self, session: EmSession) {
match (self, session) {
#[cfg(feature = "cuda")]
(Self::Cuda(_), EmSession::Cuda(_s)) => { }
#[cfg(feature = "vulkan")]
(Self::Vulkan(dev), EmSession::Vulkan(s)) => {
let mut alloc = dev.allocator.lock().unwrap();
s.destroy(&dev.device, &mut alloc);
}
#[cfg(feature = "avx2")]
(Self::Avx2, EmSession::Avx2(_s)) => { }
_ => {}
}
}
}
impl KernelDispatch<HelloBackend> for DeviceBackend {
fn dispatch(&self, input: HelloBackendInput) -> Result<HelloBackendOutput, GpuError> {
match self {
#[cfg(feature = "cuda")]
Self::Cuda(dev) => {
<cuda::CudaDevice as KernelDispatch<HelloBackend>>::dispatch(dev, input)
}
#[cfg(feature = "vulkan")]
Self::Vulkan(dev) => {
<vulkan::VulkanDevice as KernelDispatch<HelloBackend>>::dispatch(dev, input)
}
#[cfg(feature = "avx2")]
Self::Avx2 => <avx2::Avx2Device as KernelDispatch<HelloBackend>>::dispatch(
&avx2::Avx2Device,
input,
),
Self::Cpu => <CpuDevice as KernelDispatch<HelloBackend>>::dispatch(&CpuDevice, input),
}
}
}
impl KernelDispatch<EmReduce> for DeviceBackend {
fn dispatch(&self, input: EmReduceInput<'_>) -> Result<EmReduceOutput, GpuError> {
match self {
#[cfg(feature = "cuda")]
Self::Cuda(dev) => <cuda::CudaDevice as KernelDispatch<EmReduce>>::dispatch(dev, input),
#[cfg(feature = "vulkan")]
Self::Vulkan(dev) => {
<vulkan::VulkanDevice as KernelDispatch<EmReduce>>::dispatch(dev, input)
}
#[cfg(feature = "avx2")]
Self::Avx2 => {
<avx2::Avx2Device as KernelDispatch<EmReduce>>::dispatch(&avx2::Avx2Device, input)
}
Self::Cpu => <CpuDevice as KernelDispatch<EmReduce>>::dispatch(&CpuDevice, input),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auto_detect_does_not_panic() {
let backend = DeviceBackend::auto_detect();
let name = backend.name();
assert!(
matches!(name, "cpu" | "cuda" | "vulkan" | "avx2"),
"unexpected backend name: {name}"
);
}
#[test]
fn cpu_backend_has_no_vram() {
let b = DeviceBackend::cpu();
assert_eq!(b.available_vram_bytes(), None);
assert_eq!(b.total_vram_bytes(), None);
assert!(!b.is_gpu());
assert!(!b.is_accelerated());
}
#[test]
fn cpu_backend_name() {
assert_eq!(DeviceBackend::cpu().name(), "cpu");
}
#[test]
fn cpu_preference_always_succeeds() {
assert!(DeviceBackend::from_preference(BackendPreference::Cpu).is_ok());
}
#[cfg(feature = "avx2")]
#[test]
fn avx2_backend_is_accelerated_not_gpu() {
let b = DeviceBackend::Avx2;
assert!(b.is_accelerated());
assert!(!b.is_gpu());
assert_eq!(b.name(), "avx2");
assert_eq!(b.available_vram_bytes(), None);
}
}