use crate::tensor::{DType, DeviceBuffer};
use std::borrow::Cow;
use std::collections::{BTreeMap, VecDeque};
use std::env;
use std::num::NonZeroU64;
use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
use anyhow::Result;
use ndarray::ArrayD;
use once_cell::sync::OnceCell;
use std::sync::{Arc, Mutex};
use wgpu::util::DeviceExt;
use wide::f32x4;
pub(crate) fn fill_arange_mul(buf: &mut [f32], scale: f32) {
let n = buf.len();
let mut i = 0usize;
let mut base = 0f32;
while i + 4 <= n {
let idxs = f32x4::new([base, base + 1.0, base + 2.0, base + 3.0]);
let mul = f32x4::splat(scale) * idxs;
let arr = mul.to_array();
buf[i] = arr[0];
buf[i + 1] = arr[1];
buf[i + 2] = arr[2];
buf[i + 3] = arr[3];
i += 4;
base += 4.0;
}
while i < n {
buf[i] = base * scale;
i += 1;
base += 1.0;
}
}
static SYS_ANALYSIS_SCORE: AtomicI32 = AtomicI32::new(0);
#[allow(dead_code)]
static SYS_ANALYSIS_STARTED: OnceCell<()> = OnceCell::new();
#[cfg(feature = "system_analysis")]
#[allow(dead_code)]
fn start_system_analysis_probe_once() {
if SYS_ANALYSIS_STARTED.set(()).is_ok() {
std::thread::spawn(|| {
let mut analyzer = system_analysis::SystemAnalyzer::new();
if let Ok(profile) = pollster::block_on(async { analyzer.analyze_system().await }) {
let g = profile.gpu_score() as i32;
SYS_ANALYSIS_SCORE.store(g.min(i32::MAX), Ordering::SeqCst);
}
});
}
}
#[cfg(not(feature = "system_analysis"))]
#[allow(dead_code)]
fn start_system_analysis_probe_once() {
}
const MAX_PER_BUCKET: usize = 4;
type StagingPool = Arc<Mutex<BTreeMap<NonZeroU64, VecDeque<StagingEntry>>>>;
pub struct StagingGuard {
pub key: NonZeroU64,
pub entry: Option<StagingEntry>,
pub pool: StagingPool,
pub created_mapped: bool,
}
impl StagingGuard {
pub fn buffer(&self) -> &wgpu::Buffer {
&self.entry.as_ref().expect("entry present").buffer
}
pub fn created_mapped(&self) -> bool {
self.created_mapped
}
}
pub struct StagingMappedWriteGuard {
guard: Option<StagingGuard>,
mapped: Option<Vec<u8>>,
device: Option<Arc<GpuDevice>>,
}
impl StagingMappedWriteGuard {
pub fn as_mut_slice(&mut self) -> &mut [u8] {
self.mapped.as_deref_mut().expect("mapped range present")
}
pub fn take_guard(mut self) -> StagingGuard {
if let (Some(data), Some(g), Some(dev)) = (
self.mapped.take(),
self.guard.as_ref(),
self.device.as_ref(),
) {
dev.queue.write_buffer(g.buffer(), 0, &data);
dev.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).unwrap();
}
self.guard.take().expect("guard present")
}
}
pub struct StagingMappedReadGuard {
guard: Option<StagingGuard>,
mapped_data: Option<Vec<u8>>,
}
impl StagingMappedReadGuard {
pub fn as_slice(&self) -> &[u8] {
self.mapped_data.as_deref().expect("mapped range present")
}
pub fn take_guard(mut self) -> StagingGuard {
self.guard.take().expect("guard present")
}
}
impl Drop for StagingMappedReadGuard {
fn drop(&mut self) {
self.mapped_data.take();
self.guard.take();
}
}
impl Drop for StagingMappedWriteGuard {
fn drop(&mut self) {
if let (Some(data), Some(g), Some(dev)) = (
self.mapped.take(),
self.guard.as_ref(),
self.device.as_ref(),
) {
dev.queue.write_buffer(g.buffer(), 0, &data);
dev.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).unwrap();
}
self.guard.take();
}
}
impl Drop for StagingGuard {
fn drop(&mut self) {
if let Some(e) = self.entry.take() {
e.in_use.store(false, Ordering::SeqCst);
return_entry_to_pool(&self.pool, self.key, e);
}
}
}
fn return_entry_to_pool(pool: &StagingPool, key: NonZeroU64, e: StagingEntry) {
if let Ok(mut pool) = pool.lock() {
let deq = pool.entry(key).or_default();
deq.push_back(e);
while deq.len() > MAX_PER_BUCKET {
deq.pop_front();
}
}
}
pub struct GpuDevice {
pub device: wgpu::Device,
pub queue: wgpu::Queue,
pub staging_pool:
std::sync::Arc<std::sync::Mutex<BTreeMap<NonZeroU64, VecDeque<StagingEntry>>>>,
}
pub struct StagingEntry {
pub buffer: wgpu::Buffer,
pub in_use: AtomicBool,
pub usage: wgpu::BufferUsages,
}
impl GpuDevice {
fn acquire_staging_entry(
&self,
key: NonZeroU64,
initial_contents: Option<&[u8]>,
map_for_write: bool,
) -> StagingGuard {
if let Ok(mut pool) = self.staging_pool.lock() {
let deq = pool.entry(key).or_default();
if let Some(pos) = deq.iter().position(|e| {
if !e.in_use.load(Ordering::SeqCst) {
if map_for_write {
return e
.usage
.contains(wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::COPY_SRC);
} else {
return e
.usage
.contains(wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST);
}
}
false
}) {
let e = deq.remove(pos).unwrap();
e.in_use.store(true, Ordering::SeqCst);
return StagingGuard {
key,
entry: Some(e),
pool: self.staging_pool.clone(),
created_mapped: false,
};
}
}
let (entry, created_mapped) = if let Some(bytes) = initial_contents {
if map_for_write {
let usage = wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST;
let b = self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: key.get(),
usage,
mapped_at_creation: false,
});
self.queue.write_buffer(&b, 0, bytes);
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).unwrap();
(
StagingEntry {
buffer: b,
in_use: AtomicBool::new(true),
usage,
},
false,
)
} else {
let usage = wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_SRC;
let b = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytes,
usage,
});
(
StagingEntry {
buffer: b,
in_use: AtomicBool::new(true),
usage,
},
true,
)
}
} else if map_for_write {
let usage = wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::COPY_SRC;
let b = self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: key.get(),
usage,
mapped_at_creation: false,
});
(
StagingEntry {
buffer: b,
in_use: AtomicBool::new(true),
usage,
},
false,
)
} else {
let usage = wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST;
let b = self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: key.get(),
usage,
mapped_at_creation: false,
});
(
StagingEntry {
buffer: b,
in_use: AtomicBool::new(true),
usage,
},
false,
)
};
StagingGuard {
key,
entry: Some(entry),
pool: self.staging_pool.clone(),
created_mapped,
}
}
fn map_staging_write_blocking(
&self,
key: NonZeroU64,
initial_contents: Option<&[u8]>,
) -> StagingMappedWriteGuard {
let guard = self.acquire_staging_entry(key, initial_contents, true);
let mapped_vec = if let Some(init) = initial_contents {
Some(init.to_vec())
} else {
Some(vec![0u8; guard.key.get() as usize])
};
StagingMappedWriteGuard {
guard: Some(guard),
mapped: mapped_vec,
device: Some(DeviceManager::global().get_or_init()),
}
}
fn map_staging_read_blocking_from_guard(&self, guard: StagingGuard) -> StagingMappedReadGuard {
let mapped_data = if !guard.created_mapped() {
let buf = guard.buffer();
let slice = buf.slice(..);
let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
tx.send(r).ok();
});
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).unwrap();
let _ = futures_executor::block_on(rx.receive());
let v = slice.get_mapped_range();
let copy = v.to_vec();
drop(v);
buf.unmap();
Some(copy)
} else {
Some(vec![0u8; guard.key.get() as usize])
};
StagingMappedReadGuard {
guard: Some(guard),
mapped_data,
}
}
#[cfg(feature = "tokio")]
async fn map_staging_write_async(
&self,
key: NonZeroU64,
initial_contents: Option<&[u8]>,
) -> StagingMappedWriteGuard {
let guard = self.acquire_staging_entry(key, initial_contents, true);
let mapped_vec = if let Some(init) = initial_contents {
Some(init.to_vec())
} else {
Some(vec![0u8; guard.key.get() as usize])
};
StagingMappedWriteGuard {
guard: Some(guard),
mapped: mapped_vec,
device: Some(DeviceManager::global().get_or_init()),
}
}
#[cfg(feature = "tokio")]
async fn map_staging_read_async_from_guard(
&self,
guard: StagingGuard,
) -> StagingMappedReadGuard {
let mapped_data = if !guard.created_mapped() {
let buf = guard.buffer();
let slice = buf.slice(..);
let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
tx.send(r).ok();
});
loop {
tokio::select! {
res = rx.receive() => {
match res {
Some(Ok(())) => break,
Some(Err(_)) => panic!("map_async failed"),
None => panic!("map_async channel closed"),
}
}
_ = tokio::time::sleep(std::time::Duration::from_millis(1)) => {
self.device.poll(wgpu::PollType::Poll).unwrap();
}
}
}
let v = slice.get_mapped_range();
let copy = v.to_vec();
drop(v);
buf.unmap();
Some(copy)
} else {
Some(vec![0u8; guard.key.get() as usize])
};
StagingMappedReadGuard {
guard: Some(guard),
mapped_data,
}
}
#[allow(dead_code)]
fn map_staging_read_blocking(&self, key: NonZeroU64) -> StagingMappedReadGuard {
let guard = self.acquire_staging_entry(key, None::<&[u8]>, false);
self.map_staging_read_blocking_from_guard(guard)
}
#[cfg(feature = "tokio")]
#[allow(dead_code)]
async fn map_staging_read_async(&self, key: NonZeroU64) -> StagingMappedReadGuard {
let guard = self.acquire_staging_entry(key, None::<&[u8]>, false);
self.map_staging_read_async_from_guard(guard).await
}
pub fn new_from_adapter(adapter: &wgpu::Adapter) -> anyhow::Result<Self> {
let (device, queue) =
pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor::default()))?;
Ok(GpuDevice {
device,
queue,
staging_pool: std::sync::Arc::new(std::sync::Mutex::new(BTreeMap::new())),
})
}
pub fn new() -> anyhow::Result<Self> {
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
flags: wgpu::InstanceFlags::default(),
backend_options: wgpu::BackendOptions {
dx12: wgpu::Dx12BackendOptions {
shader_compiler: wgpu::Dx12Compiler::Fxc,
..Default::default()
},
..Default::default()
},
memory_budget_thresholds: Default::default(),
});
if let Ok(sel) = env::var("EENN_GPU_ADAPTER") {
let adapters = instance.enumerate_adapters(wgpu::Backends::all());
if let Ok(idx) = sel.parse::<usize>()
&& let Some(adapter) = adapters.get(idx)
{
return GpuDevice::new_from_adapter(&adapter);
}
for adapter in &adapters {
if let Some(info) = adapter.get_info().name.get(..)
&& info.contains(&sel)
{
return GpuDevice::new_from_adapter(&adapter);
}
}
}
for adapter in instance.enumerate_adapters(wgpu::Backends::all()) {
if adapter_matches_requirements(&adapter) {
return GpuDevice::new_from_adapter(&adapter);
}
}
let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
}))
.map_err(|e| anyhow::anyhow!("no suitable GPU adapter found: {}", e))?;
GpuDevice::new_from_adapter(&adapter)
}
pub fn create_buffer_from_f32(&self, data: &[f32], usage: wgpu::BufferUsages) -> wgpu::Buffer {
let size = std::mem::size_of_val(data) as u64;
let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size,
usage: usage | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let bytes = bytemuck::cast_slice(data);
if bytes.len() <= 64 * 1024 {
self.queue.write_buffer(&buf, 0, bytes);
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).unwrap();
buf
} else {
let size = bytes.len() as u64;
let key = bucket_for_size(size);
let mapped_guard = self.map_staging_write_blocking(key, Some(bytes));
if mapped_guard.mapped.is_some() {
let mut mg = mapped_guard;
let slice = mg.as_mut_slice();
slice.copy_from_slice(bytes);
let guard = mg.take_guard();
let staging_entry_buf = guard.buffer();
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
encoder.copy_buffer_to_buffer(staging_entry_buf, 0, &buf, 0, size);
self.queue.submit(Some(encoder.finish()));
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).unwrap();
drop(guard);
} else {
let staging_entry_buf =
&mapped_guard.guard.as_ref().expect("guard present").buffer();
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
encoder.copy_buffer_to_buffer(staging_entry_buf, 0, &buf, 0, size);
self.queue.submit(Some(encoder.finish()));
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).unwrap();
}
buf
}
}
pub fn read_buffer_to_host(
&self,
buffer: &wgpu::Buffer,
elements: usize,
) -> anyhow::Result<Vec<f32>> {
let size = (elements * std::mem::size_of::<f32>()) as u64;
let key = bucket_for_size(size);
let guard = self.acquire_staging_entry(key, None::<&[u8]>, false);
let staging_buf = guard.buffer();
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
encoder.copy_buffer_to_buffer(buffer, 0, staging_buf, 0, size);
self.queue.submit(Some(encoder.finish()));
let mapped_read = self.map_staging_read_blocking_from_guard(guard);
let mapped_slice = mapped_read.as_slice();
let trimmed = &mapped_slice[..(size as usize)];
let vec: Vec<f32> = bytemuck::cast_slice(trimmed).to_vec();
drop(mapped_read);
Ok(vec)
}
}
#[cfg(feature = "tokio")]
impl GpuDevice {
pub async fn read_buffer_to_host_async(
&self,
buffer: &wgpu::Buffer,
elements: usize,
) -> anyhow::Result<Vec<f32>> {
let size = (elements * std::mem::size_of::<f32>()) as u64;
let key = bucket_for_size(size);
let guard = self.acquire_staging_entry(key, None::<&[u8]>, false);
let staging_buf = guard.buffer();
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
encoder.copy_buffer_to_buffer(buffer, 0, staging_buf, 0, size);
self.queue.submit(Some(encoder.finish()));
let mapped_read = self.map_staging_read_async_from_guard(guard).await;
let mapped_slice = mapped_read.as_slice();
let trimmed = &mapped_slice[..(size as usize)];
let vec: Vec<f32> = bytemuck::cast_slice(trimmed).to_vec();
drop(mapped_read);
Ok(vec)
}
}
fn adapter_matches_requirements(adapter: &wgpu::Adapter) -> bool {
let features = adapter.features();
features.contains(wgpu::Features::STORAGE_RESOURCE_BINDING_ARRAY)
}
fn adapter_score(adapter: &wgpu::Adapter) -> u32 {
let info = adapter.get_info();
let features = adapter.features();
let limits = adapter.limits();
let mut score: u32 = 0;
match info.device_type {
wgpu::DeviceType::DiscreteGpu => score += 60,
wgpu::DeviceType::IntegratedGpu => score += 30,
wgpu::DeviceType::VirtualGpu => score += 10,
_ => score += 5,
}
if info.name.to_lowercase().contains("swiftshader")
|| info.name.to_lowercase().contains("llvmpipe")
{
score = score.saturating_sub(50);
}
if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
score += 5;
}
if features.contains(wgpu::Features::STORAGE_RESOURCE_BINDING_ARRAY) {
score += 20;
}
if features.contains(wgpu::Features::PUSH_CONSTANTS) {
score += 5;
}
score += limits.max_storage_buffers_per_shader_stage.min(16) * 2;
if limits.min_storage_buffer_offset_alignment >= 256 {
score += 10;
}
let lname = info.vendor.to_string();
let vname = info.name.to_lowercase();
if vname.contains("nvidia") || lname.to_lowercase().contains("nvidia") {
score += 10;
}
let sys_score = SYS_ANALYSIS_SCORE.load(Ordering::SeqCst);
if sys_score > 0 {
score = score.saturating_add((sys_score as u32).min(20));
}
score
}
fn adapter_score_quick(adapter: &wgpu::Adapter) -> u32 {
let info = adapter.get_info();
let mut score: u32 = 0;
match info.device_type {
wgpu::DeviceType::DiscreteGpu => score += 50,
wgpu::DeviceType::IntegratedGpu => score += 25,
_ => score += 5,
}
let name = info.name.to_lowercase();
if name.contains("nvidia") {
score += 10;
} else if name.contains("amd") || name.contains("radeon") {
score += 8;
}
if name.contains("swiftshader") || name.contains("llvmpipe") {
score = score.saturating_sub(40);
}
score
}
fn bucket_for_size(size: u64) -> NonZeroU64 {
let mut bucket = 1u64;
while bucket < size {
bucket <<= 1;
}
NonZeroU64::new(bucket).unwrap()
}
impl GpuDevice {
pub async fn readback_f32_async(
&self,
buffer: &wgpu::Buffer,
elements: usize,
) -> anyhow::Result<Vec<f32>> {
self.read_buffer_to_host(buffer, elements)
}
}
#[cfg(feature = "tokio")]
impl GpuDevice {
pub async fn upload_array_f32_async(
&self,
data: Arc<ArrayD<f32>>,
usage: wgpu::BufferUsages,
) -> anyhow::Result<Arc<GpuBuffer>> {
let flat: Vec<f32> = if let Some(slice) = data.as_slice_memory_order() {
slice.to_vec()
} else {
let mut v = Vec::with_capacity(data.len());
v.extend(data.iter().cloned());
v
};
let bytes = bytemuck::cast_slice(&flat);
let size = bytes.len() as u64;
if bytes.len() <= 64 * 1024 {
let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size,
usage: usage | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
self.queue.write_buffer(&buf, 0, bytes);
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).unwrap();
let device_arc = DeviceManager::global().get_or_init();
let gb = GpuBuffer {
device: device_arc.clone(),
buffer: Arc::new(buf),
dtype: DType::F32,
shape: data.shape().to_vec(),
};
return Ok(Arc::new(gb));
}
let key = bucket_for_size(size);
let mapped_guard = self.map_staging_write_async(key, Some(bytes)).await;
if mapped_guard.mapped.is_some() {
let mut mg = mapped_guard;
mg.as_mut_slice().copy_from_slice(bytes);
let guard = mg.take_guard();
let usage = guard.entry.as_ref().unwrap().usage;
assert!(
usage.contains(wgpu::BufferUsages::COPY_SRC),
"staging entry missing COPY_SRC usage: {:?}",
usage
);
let staging_buf = guard.buffer();
let dst = self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size,
usage: usage | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
encoder.copy_buffer_to_buffer(staging_buf, 0, &dst, 0, size);
self.queue.submit(Some(encoder.finish()));
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).unwrap();
drop(guard);
let device_arc = DeviceManager::global().get_or_init();
let gb = GpuBuffer {
device: device_arc.clone(),
buffer: Arc::new(dst),
dtype: DType::F32,
shape: data.shape().to_vec(),
};
Ok(Arc::new(gb))
} else {
let guard_ref = mapped_guard.guard.as_ref().expect("guard present");
let usage = guard_ref.entry.as_ref().unwrap().usage;
assert!(
usage.contains(wgpu::BufferUsages::COPY_SRC),
"staging entry missing COPY_SRC usage: {:?}",
usage
);
let staging_buf = &guard_ref.buffer();
let dst = self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size,
usage: usage | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
encoder.copy_buffer_to_buffer(staging_buf, 0, &dst, 0, size);
self.queue.submit(Some(encoder.finish()));
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).unwrap();
let device_arc = DeviceManager::global().get_or_init();
let gb = GpuBuffer {
device: device_arc.clone(),
buffer: Arc::new(dst),
dtype: DType::F32,
shape: data.shape().to_vec(),
};
Ok(Arc::new(gb))
}
}
}
#[cfg(not(feature = "tokio"))]
impl GpuDevice {
pub async fn upload_array_f32_async(
&self,
data: Arc<ArrayD<f32>>,
usage: wgpu::BufferUsages,
) -> anyhow::Result<Arc<GpuBuffer>> {
let device_arc = DeviceManager::global().get_or_init();
let flat: Vec<f32> = if let Some(slice) = data.as_slice_memory_order() {
slice.to_vec()
} else {
let mut v = Vec::with_capacity(data.len());
v.extend(data.iter().cloned());
v
};
let buf = device_arc.create_buffer_from_f32(
&flat,
usage | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::STORAGE,
);
let gb = GpuBuffer {
device: device_arc.clone(),
buffer: Arc::new(buf),
dtype: DType::F32,
shape: data.shape().to_vec(),
};
Ok(Arc::new(gb))
}
}
use once_cell::sync::Lazy;
struct DeviceManagerInner {
devices: Vec<Arc<GpuDevice>>,
rr: usize,
}
pub struct DeviceManager {
inner: Mutex<DeviceManagerInner>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DeviceSelectionPolicy {
RoundRobin,
BestScore,
Index(usize),
}
impl DeviceManager {
fn new() -> Self {
Self {
inner: Mutex::new(DeviceManagerInner {
devices: Vec::new(),
rr: 0,
}),
}
}
pub fn global() -> &'static DeviceManager {
static INSTANCE: Lazy<DeviceManager> = Lazy::new(DeviceManager::new);
&INSTANCE
}
pub fn enumerate_adapters_info(&self) -> Vec<wgpu::AdapterInfo> {
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
flags: wgpu::InstanceFlags::default(),
backend_options: wgpu::BackendOptions {
dx12: wgpu::Dx12BackendOptions {
shader_compiler: wgpu::Dx12Compiler::Fxc,
..Default::default()
},
..Default::default()
},
memory_budget_thresholds: Default::default(),
});
instance
.enumerate_adapters(wgpu::Backends::all())
.into_iter()
.map(|a| a.get_info())
.collect()
}
pub fn select_adapter_index_by_policy(&self, policy: DeviceSelectionPolicy) -> Option<usize> {
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
flags: wgpu::InstanceFlags::default(),
backend_options: wgpu::BackendOptions {
dx12: wgpu::Dx12BackendOptions {
shader_compiler: wgpu::Dx12Compiler::Fxc,
..Default::default()
},
..Default::default()
},
memory_budget_thresholds: Default::default(),
});
let adapters: Vec<wgpu::Adapter> = instance.enumerate_adapters(wgpu::Backends::all());
if adapters.is_empty() {
return None;
}
match policy {
DeviceSelectionPolicy::RoundRobin => Some(0),
DeviceSelectionPolicy::Index(i) => Some(i % adapters.len()),
DeviceSelectionPolicy::BestScore => {
let mut scored: Vec<(u32, usize)> = adapters
.iter()
.enumerate()
.map(|(idx, a)| (adapter_score_quick(a), idx))
.collect();
scored.sort_by(|a, b| b.0.cmp(&a.0));
Some(scored.first().map(|t| t.1).unwrap_or(0))
}
}
}
fn ensure_initialized_locked(&self, guard: &mut std::sync::MutexGuard<'_, DeviceManagerInner>) {
if !guard.devices.is_empty() {
return;
}
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
flags: wgpu::InstanceFlags::default(),
backend_options: wgpu::BackendOptions {
dx12: wgpu::Dx12BackendOptions {
shader_compiler: wgpu::Dx12Compiler::Fxc,
..Default::default()
},
..Default::default()
},
memory_budget_thresholds: Default::default(),
});
let mut scored: Vec<(u32, wgpu::Adapter)> = instance
.enumerate_adapters(wgpu::Backends::all())
.into_iter()
.map(|a| (adapter_score(&a), a))
.collect();
if let Ok(sel) = env::var("EENN_GPU_ADAPTER") {
if let Ok(idx) = sel.parse::<usize>() {
if let Some((_, adapter)) = scored.get(idx) {
match GpuDevice::new_from_adapter(adapter) {
Ok(dev) => guard.devices.push(Arc::new(dev)),
Err(e) => eprintln!("warning: failed to init adapter: {}", e),
}
}
} else {
for (_, adapter) in &scored {
if adapter
.get_info()
.name
.to_lowercase()
.contains(&sel.to_lowercase())
{
match GpuDevice::new_from_adapter(adapter) {
Ok(dev) => guard.devices.push(Arc::new(dev)),
Err(e) => eprintln!("warning: failed to init adapter: {}", e),
}
break;
}
}
}
}
scored.sort_by(|a, b| b.0.cmp(&a.0));
for (_score, adapter) in scored {
match GpuDevice::new_from_adapter(&adapter) {
Ok(dev) => {
guard.devices.push(Arc::new(dev));
break;
}
Err(e) => {
eprintln!("warning: failed to init adapter: {}", e);
}
}
}
if guard.devices.is_empty() {
let maybe_adapter =
pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
}));
match maybe_adapter {
Ok(adapter) => {
match GpuDevice::new_from_adapter(&adapter) {
Ok(dev) => guard.devices.push(Arc::new(dev)),
Err(e) => panic!("failed to init any gpu device: {}", e),
}
}
Err(e) => {
panic!("no gpu adapters found: {}", e);
}
}
}
}
pub fn get_device_round_robin(&self) -> Arc<GpuDevice> {
let mut guard = self.inner.lock().unwrap();
self.ensure_initialized_locked(&mut guard);
let idx = guard.rr % guard.devices.len();
guard.rr = guard.rr.wrapping_add(1);
guard.devices[idx].clone()
}
pub fn get_device_by_policy(&self, policy: DeviceSelectionPolicy) -> Arc<GpuDevice> {
let mut guard = self.inner.lock().unwrap();
self.ensure_initialized_locked(&mut guard);
match policy {
DeviceSelectionPolicy::RoundRobin => self.get_device_round_robin(),
DeviceSelectionPolicy::BestScore => {
guard
.devices
.first()
.cloned()
.expect("no devices available")
}
DeviceSelectionPolicy::Index(i) => {
if guard.devices.is_empty() {
panic!("no devices available");
}
let idx = i % guard.devices.len();
guard.devices[idx].clone()
}
}
}
pub fn get_or_init(&self) -> Arc<GpuDevice> {
self.get_device_round_robin()
}
}
pub struct GpuBuffer {
pub device: Arc<GpuDevice>,
pub buffer: Arc<wgpu::Buffer>,
pub dtype: DType,
pub shape: Vec<usize>,
}
impl DeviceBuffer for GpuBuffer {
fn dtype(&self) -> DType {
self.dtype
}
fn shape(&self) -> Vec<usize> {
self.shape.clone()
}
fn to_host_f32(&self) -> anyhow::Result<Vec<f32>> {
let size = self.shape.iter().product::<usize>() * std::mem::size_of::<f32>();
let buffer: &wgpu::Buffer = self.buffer.as_ref();
let staging = self.device.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: size as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder = self
.device
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size as u64);
self.device.queue.submit(Some(encoder.finish()));
let buffer_slice = staging.slice(..);
let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |r| {
tx.send(r).ok();
});
self.device
.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
let _ = futures_executor::block_on(rx.receive());
let data = buffer_slice.get_mapped_range();
let vec: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
staging.unmap();
Ok(vec)
}
fn box_clone(&self) -> Box<dyn DeviceBuffer> {
Box::new(GpuBuffer {
device: self.device.clone(),
buffer: self.buffer.clone(),
dtype: self.dtype,
shape: self.shape.clone(),
})
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl GpuBuffer {
pub fn raw_buffer(&self) -> &wgpu::Buffer {
self.buffer.as_ref()
}
}
pub fn buffer_from_array(a: Arc<ArrayD<f32>>) -> anyhow::Result<Box<dyn DeviceBuffer>> {
let device = DeviceManager::global().get_or_init();
let flat: Vec<f32> = if let Some(slice) = a.as_slice_memory_order() {
slice.to_vec()
} else {
let mut v = Vec::with_capacity(a.len());
v.extend(a.iter().cloned());
v
};
let buf = device.create_buffer_from_f32(
&flat,
wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::STORAGE,
);
let gb = GpuBuffer {
device: device.clone(),
buffer: Arc::new(buf),
dtype: DType::F32,
shape: a.shape().to_vec(),
};
Ok(Box::new(gb))
}
#[cfg(all(test, feature = "gpu"))]
mod gpu_tests {
use super::*;
#[test]
fn smoke_selection_api() {
let manager = DeviceManager::global();
let infos = manager.enumerate_adapters_info();
let _ = infos.len();
let _idx0 = manager.select_adapter_index_by_policy(DeviceSelectionPolicy::Index(0));
let _best = manager.select_adapter_index_by_policy(DeviceSelectionPolicy::BestScore);
let _rr = manager.select_adapter_index_by_policy(DeviceSelectionPolicy::RoundRobin);
}
#[test]
fn staging_mapped_read_guard_lifecycle() {
let device = DeviceManager::global().get_or_init();
let size = 1024u64;
let mut bucket = 1u64;
while bucket < size {
bucket <<= 1;
}
let key = NonZeroU64::new(bucket).unwrap();
let guard = device.acquire_staging_entry(key, None::<&[u8]>, false);
let mapped = device.map_staging_read_blocking_from_guard(guard);
assert_eq!(mapped.as_slice().len() as u64, bucket);
drop(mapped);
if let Ok(pool) = device.staging_pool.lock() {
if let Some(deq) = pool.get(&key) {
assert!(deq.iter().any(|e| !e.in_use.load(Ordering::SeqCst)));
} else {
panic!("expected pool bucket present");
}
} else {
panic!("failed to lock staging_pool");
}
}
#[test]
fn staging_mapped_write_guard_lifecycle() {
let device = DeviceManager::global().get_or_init();
let size = 1024u64;
let key = bucket_for_size(size);
let mut mapped = device.map_staging_write_blocking(key, None::<&[u8]>);
if mapped.mapped.as_mut().is_some() {
let slice = mapped.as_mut_slice();
for (i, b) in slice.iter_mut().enumerate().take(16) {
*b = (i & 0xff) as u8;
}
}
drop(mapped);
if let Ok(pool) = device.staging_pool.lock() {
if let Some(deq) = pool.get(&key) {
assert!(deq.iter().any(|e| !e.in_use.load(Ordering::SeqCst)));
} else {
panic!("expected pool bucket present");
}
} else {
panic!("failed to lock staging_pool");
}
}
}
#[cfg(all(test, feature = "gpu", feature = "tokio"))]
mod gpu_tokio_tests {
use super::*;
use ndarray::ArrayD;
#[tokio::test]
async fn async_small_upload_roundtrip() -> anyhow::Result<()> {
let device = DeviceManager::global().get_or_init();
let a = Arc::new(ArrayD::from_elem(vec![4usize], std::f32::consts::PI));
let gb = device
.upload_array_f32_async(a.clone(), wgpu::BufferUsages::STORAGE)
.await?;
let elements = a.len();
let host = device
.read_buffer_to_host_async(gb.raw_buffer(), elements)
.await?;
assert_eq!(host.len(), elements);
Ok(())
}
#[tokio::test]
async fn async_large_upload_roundtrip() -> anyhow::Result<()> {
let device = DeviceManager::global().get_or_init();
let mut vec = vec![0.0f32; 20000];
fill_arange_mul(&mut vec, 0.125f32);
let a = Arc::new(ArrayD::from_shape_vec(vec![20000usize], vec.clone()).unwrap());
let gb = device
.upload_array_f32_async(a.clone(), wgpu::BufferUsages::STORAGE)
.await?;
let host = device
.read_buffer_to_host_async(gb.raw_buffer(), a.len())
.await?;
assert_eq!(host.len(), a.len());
assert_eq!(host[0], vec[0]);
assert_eq!(host[1234], vec[1234]);
assert_eq!(host[19999], vec[19999]);
Ok(())
}
#[tokio::test]
async fn async_strict_exact_roundtrip() -> anyhow::Result<()> {
let device = DeviceManager::global().get_or_init();
let num_f32 = 30000usize; let size_bytes = (num_f32 * std::mem::size_of::<f32>()) as u64;
let key = bucket_for_size(size_bytes);
let mut mapped = device.map_staging_write_async(key, None::<&[u8]>).await;
if mapped.mapped.as_ref().is_some() {
let slice = mapped.as_mut_slice();
for (i, b) in slice.iter_mut().enumerate() {
*b = (i & 0xFF) as u8;
}
}
let guard = mapped.take_guard();
let dst = device.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: size_bytes,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let staging_buf = guard.buffer();
let mut encoder = device
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
encoder.copy_buffer_to_buffer(staging_buf, 0, &dst, 0, size_bytes);
device.queue.submit(Some(encoder.finish()));
device
.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
drop(guard);
let host = device.read_buffer_to_host_async(&dst, num_f32).await?;
let mut expected_bytes = Vec::with_capacity(num_f32 * 4);
for i in 0..(num_f32 * 4) {
expected_bytes.push((i & 0xFF) as u8);
}
let host_bytes = bytemuck::cast_slice::<f32, u8>(&host);
if host_bytes != expected_bytes.as_slice() {
eprintln!(
"mismatch: host len={} expected len={}",
host_bytes.len(),
expected_bytes.len()
);
eprintln!(
"host first 64: {:?}",
&host_bytes[..host_bytes.len().min(64)]
);
eprintln!(
"exp first 64: {:?}",
&expected_bytes[..expected_bytes.len().min(64)]
);
}
assert_eq!(host_bytes, expected_bytes.as_slice());
Ok(())
}
#[tokio::test]
async fn staging_pool_reuse_and_cap() -> anyhow::Result<()> {
let device = DeviceManager::global().get_or_init();
let mut handles = Vec::new();
for _ in 0..12 {
let mut vec = vec![0.0f32; 20000];
fill_arange_mul(&mut vec, 1.0f32);
let a = Arc::new(ArrayD::from_shape_vec(vec![20000usize], vec.clone()).unwrap());
let h = device.upload_array_f32_async(a.clone(), wgpu::BufferUsages::STORAGE);
handles.push(h);
}
for h in handles {
let _ = h.await?;
}
let size = (20000 * std::mem::size_of::<f32>()) as u64;
let key = bucket_for_size(size);
if let Ok(pool) = device.staging_pool.lock()
&& let Some(deq) = pool.get(&key)
{
assert!(deq.len() <= 4, "pool exceeded cap");
}
Ok(())
}
}
impl GpuDevice {
pub fn dispatch_wgsl_copy(
&self,
wgsl_source: &str,
src: &wgpu::Buffer,
dst: &wgpu::Buffer,
elements: u32,
) -> Result<()> {
let module = self
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(wgsl_source)),
});
let bind_layout = self
.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = self
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bind_layout],
push_constant_ranges: &[],
});
let pipeline = self
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&pipeline_layout),
module: &module,
entry_point: Some("main"),
cache: None,
compilation_options: Default::default(),
});
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: src.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: dst.as_entire_binding(),
},
],
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
let wg_x = elements.div_ceil(64);
cpass.dispatch_workgroups(wg_x, 1, 1);
}
self.queue.submit(Some(encoder.finish()));
self.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
Ok(())
}
pub fn dispatch_spirv_copy(
&self,
spirv_words: &[u32],
src: &wgpu::Buffer,
dst: &wgpu::Buffer,
elements: u32,
) -> Result<()> {
let bytes: &[u8] = bytemuck::cast_slice(spirv_words);
let module = self
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::util::make_spirv(bytes),
});
let bind_layout = self
.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = self
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bind_layout],
push_constant_ranges: &[],
});
let pipeline = self
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&pipeline_layout),
module: &module,
entry_point: Some("main"),
cache: None,
compilation_options: Default::default(),
});
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: src.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: dst.as_entire_binding(),
},
],
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
let wg_x = elements.div_ceil(64);
cpass.dispatch_workgroups(wg_x, 1, 1);
}
self.queue.submit(Some(encoder.finish()));
self.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
Ok(())
}
}