use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::sync::{Arc, Mutex, Weak};
use crate as wgpu;
#[derive(Default)]
pub struct AdapterMap {
map: Mutex<HashMap<AdapterMapKey, Arc<ActiveAdapter>>>,
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct AdapterMapKey {
power_preference: wgpu::PowerPreference,
}
pub struct ActiveAdapter {
adapter: wgpu::Adapter,
device_map: DeviceMap,
}
#[derive(Default)]
pub struct DeviceMap {
map: Mutex<HashMap<DeviceMapKey, Weak<DeviceQueuePair>>>,
}
#[derive(Clone, Debug)]
pub struct DeviceMapKey {
descriptor: wgpu::DeviceDescriptor<'static>,
}
#[derive(Debug)]
pub struct DeviceQueuePair {
device: wgpu::Device,
queue: wgpu::Queue,
}
impl AdapterMap {
#[cfg(not(target_os = "unknown"))]
pub fn get_or_request<'a, 'b>(
&'a self,
options: wgpu::RequestAdapterOptions<'b>,
instance: &'a wgpu::Instance,
) -> Option<Arc<ActiveAdapter>> {
let rt = tokio::runtime::Handle::current();
rt.block_on(self.get_or_request_async(options, instance))
}
#[cfg(not(target_os = "unknown"))]
pub fn request<'a, 'b>(
&'a self,
options: wgpu::RequestAdapterOptions<'b>,
instance: &'a wgpu::Instance,
) -> Option<Arc<ActiveAdapter>> {
let rt = tokio::runtime::Handle::current();
rt.block_on(self.request_async(options, instance))
}
pub async fn get_or_request_async<'a, 'b>(
&'a self,
options: wgpu::RequestAdapterOptions<'b>,
instance: &'a wgpu::Instance,
) -> Option<Arc<ActiveAdapter>> {
let power_preference = options.power_preference;
let key = AdapterMapKey { power_preference };
let mut map = self
.map
.lock()
.expect("failed to acquire `AdapterMap` lock");
if let Some(adapter) = map.get(&key) {
return Some(adapter.clone());
}
if let Some(adapter) = instance.request_adapter(&options).await {
let device_map = Default::default();
let adapter = Arc::new(ActiveAdapter {
adapter,
device_map,
});
return Some(map.entry(key).or_insert(adapter).clone());
}
None
}
pub async fn request_async<'a, 'b>(
&'a self,
options: wgpu::RequestAdapterOptions<'b>,
instance: &'b wgpu::Instance,
) -> Option<Arc<ActiveAdapter>> {
let adapter = instance.request_adapter(&options).await?;
let device_map = Default::default();
let adapter = Arc::new(ActiveAdapter {
adapter,
device_map,
});
let power_preference = options.power_preference;
let key = AdapterMapKey { power_preference };
let mut map = self
.map
.lock()
.expect("failed to acquire `AdapterMap` lock");
map.insert(key, adapter.clone());
Some(adapter)
}
pub fn clear_inactive_adapters_and_devices(&self) {
let mut map = self
.map
.lock()
.expect("failed to acquire `AdapterMap` lock");
map.retain(|_, adapter| {
adapter.clear_inactive_devices();
adapter.device_count() > 0
});
}
pub(crate) fn _poll_all_devices(&self, maintain: wgpu::Maintain) {
let map = self
.map
.lock()
.expect("failed to acquire `AdapterMap` lock");
for adapter in map.values() {
adapter._poll_all_devices(maintain.clone()); }
}
}
impl ActiveAdapter {
#[cfg(not(target_os = "unknown"))]
pub fn get_or_request_device(
&self,
descriptor: wgpu::DeviceDescriptor<'static>,
) -> Arc<DeviceQueuePair> {
let rt = tokio::runtime::Handle::current();
rt.block_on(self.get_or_request_device_async(descriptor))
}
#[cfg(not(target_os = "unknown"))]
pub fn request_device(
&self,
descriptor: wgpu::DeviceDescriptor<'static>,
) -> Arc<DeviceQueuePair> {
let rt = tokio::runtime::Handle::current();
rt.block_on(self.request_device_async(descriptor))
}
pub async fn get_or_request_device_async(
&self,
descriptor: wgpu::DeviceDescriptor<'static>,
) -> Arc<DeviceQueuePair> {
let key = DeviceMapKey { descriptor };
let mut map = self
.device_map
.map
.lock()
.expect("failed to acquire `AdapterMap` lock");
if let Some(device_ref) = map.get(&key) {
if let Some(device) = device_ref.upgrade() {
return device;
}
}
let (device, queue) = self
.adapter
.request_device(&key.descriptor, None)
.await
.expect("could not get or request device");
let device = Arc::new(DeviceQueuePair { device, queue });
map.insert(key, Arc::downgrade(&device));
device
}
pub async fn request_device_async(
&self,
descriptor: wgpu::DeviceDescriptor<'static>,
) -> Arc<DeviceQueuePair> {
let (device, queue) = self
.adapter
.request_device(&descriptor, None)
.await
.expect("could not request device async");
let device = Arc::new(DeviceQueuePair { device, queue });
let key = DeviceMapKey { descriptor };
let mut map = self
.device_map
.map
.lock()
.expect("failed to acquire `DeviceMap` lock");
map.insert(key, Arc::downgrade(&device));
device
}
pub fn device_count(&self) -> usize {
let map = self
.device_map
.map
.lock()
.expect("failed to acquire `DeviceMap` lock");
map.len()
}
pub fn clear_inactive_devices(&self) {
let mut map = self
.device_map
.map
.lock()
.expect("failed to acquire `DeviceMap` lock");
map.retain(|_, pair| pair.upgrade().is_some());
}
fn _poll_all_devices(&self, maintain: wgpu::Maintain) {
let map = self
.device_map
.map
.lock()
.expect("failed to acquire `DeviceMap` lock");
for weak in map.values() {
if let Some(pair) = weak.upgrade() {
pair.device().poll(maintain.clone()); }
}
}
}
impl DeviceQueuePair {
pub fn device(&self) -> &wgpu::Device {
&self.device
}
pub fn queue(&self) -> &wgpu::Queue {
&self.queue
}
}
impl Deref for ActiveAdapter {
type Target = wgpu::Adapter;
fn deref(&self) -> &Self::Target {
&self.adapter
}
}
impl Hash for DeviceMapKey {
fn hash<H: Hasher>(&self, state: &mut H) {
hash_device_descriptor(&self.descriptor, state);
}
}
impl PartialEq for DeviceMapKey {
fn eq(&self, other: &Self) -> bool {
eq_device_descriptor(&self.descriptor, &other.descriptor)
}
}
impl Eq for DeviceMapKey {}
fn eq_device_descriptor(
a: &wgpu::DeviceDescriptor<'static>,
b: &wgpu::DeviceDescriptor<'static>,
) -> bool {
a.label == b.label && a.features == b.features && a.limits == b.limits
}
fn hash_device_descriptor<H>(desc: &wgpu::DeviceDescriptor<'static>, state: &mut H)
where
H: Hasher,
{
desc.label.hash(state);
desc.features.hash(state);
desc.limits.hash(state);
}