use crate::error::ComputeError;
pub struct ComputeContext {
pub device: wgpu::Device,
pub queue: wgpu::Queue,
transfer_queue: Option<wgpu::Queue>,
adapter_info: wgpu::AdapterInfo,
}
impl ComputeContext {
pub fn adapter_info(&self) -> &wgpu::AdapterInfo {
&self.adapter_info
}
pub fn transfer_queue(&self) -> Option<&wgpu::Queue> {
self.transfer_queue.as_ref()
}
pub fn builder() -> ContextBuilder {
ContextBuilder::default()
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug"))]
pub fn new() -> Result<Self, ComputeError> {
ContextBuilder::default().build()
}
pub fn try_new() -> Option<Self> {
Self::new().ok()
}
pub async fn new_async() -> Result<Self, ComputeError> {
ContextBuilder::default().build_async().await
}
pub fn from_device(
device: wgpu::Device,
queue: wgpu::Queue,
adapter_info: Option<wgpu::AdapterInfo>,
) -> Self {
let adapter_info = adapter_info.unwrap_or_else(|| wgpu::AdapterInfo {
name: "external".into(),
vendor: 0,
device: 0,
device_type: wgpu::DeviceType::Other,
device_pci_bus_id: String::new(),
driver: String::new(),
driver_info: String::new(),
backend: wgpu::Backend::Noop,
subgroup_min_size: 0,
subgroup_max_size: 0,
transient_saves_memory: false,
});
ComputeContext {
device,
queue,
transfer_queue: None,
adapter_info,
}
}
pub fn dispatcher(&self) -> crate::dispatch::Dispatcher<'_> {
crate::dispatch::Dispatcher::new(self)
}
pub fn with_limits(limits: wgpu::Limits) -> ContextBuilder {
ContextBuilder::default().with_limits(limits)
}
pub fn with_features(features: wgpu::Features) -> ContextBuilder {
ContextBuilder::default().with_features(features)
}
pub fn with_power_preference(pref: wgpu::PowerPreference) -> ContextBuilder {
ContextBuilder::default().with_power_preference(pref)
}
#[cfg(feature = "hot-reload")]
pub fn watcher(&self) -> crate::hot_reload::ShaderWatcher {
crate::hot_reload::ShaderWatcher::new()
}
}
#[derive(Debug, Default)]
pub struct ContextBuilder {
power_preference: wgpu::PowerPreference,
required_features: wgpu::Features,
required_limits: Option<wgpu::Limits>,
multi_queue: bool,
}
impl ContextBuilder {
pub fn with_power_preference(mut self, pref: wgpu::PowerPreference) -> Self {
self.power_preference = pref;
self
}
pub fn with_features(mut self, features: wgpu::Features) -> Self {
self.required_features = features;
self
}
pub fn with_limits(mut self, limits: wgpu::Limits) -> Self {
self.required_limits = Some(limits);
self
}
pub fn with_multi_queue(mut self) -> Self {
self.multi_queue = true;
self
}
pub fn build(self) -> Result<ComputeContext, ComputeError> {
pollster::block_on(self.build_async())
}
pub async fn build_async(self) -> Result<ComputeContext, ComputeError> {
let instance = wgpu::Instance::default();
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: self.power_preference,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.map_err(|_| ComputeError::NoAdapter)?;
if !self.required_features.is_empty()
&& !adapter.features().contains(self.required_features)
{
return Err(ComputeError::DeviceRequest(format!(
"adapter does not support requested features: {:?}",
self.required_features
)));
}
let adapter_info = adapter.get_info();
let limits = self.required_limits.unwrap_or_default();
let (device, queue) = adapter
.request_device(&wgpu::DeviceDescriptor {
label: Some("oxiui-compute-wgpu"),
required_features: self.required_features,
required_limits: limits,
..Default::default()
})
.await
.map_err(|e| ComputeError::DeviceRequest(e.to_string()))?;
let transfer_queue: Option<wgpu::Queue> = if self.multi_queue {
None
} else {
None
};
Ok(ComputeContext {
device,
queue,
transfer_queue,
adapter_info,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn try_new_does_not_panic() {
let _ = ComputeContext::try_new();
}
#[test]
fn new_returns_result() {
match ComputeContext::new() {
Ok(_ctx) => { }
Err(ComputeError::NoAdapter) => {
}
Err(ComputeError::DeviceRequest(ref msg)) => {
panic!("unexpected DeviceRequest error: {msg}")
}
Err(e) => {
panic!("unexpected error: {e}")
}
}
}
#[test]
fn try_new_consistent_with_new() {
let via_new = ComputeContext::new();
let via_try = ComputeContext::try_new();
match (via_new, via_try) {
(Ok(_), Some(_)) | (Err(_), None) => { }
(Ok(_), None) => panic!("new() succeeded but try_new() returned None"),
(Err(e), Some(_)) => panic!("new() failed but try_new() returned Some: {e}"),
}
}
#[test]
fn builder_chain_defaults() {
let _builder = ContextBuilder::default()
.with_power_preference(wgpu::PowerPreference::HighPerformance)
.with_limits(wgpu::Limits::default())
.with_features(wgpu::Features::empty());
let _result = _builder.build();
}
#[test]
fn builder_with_multi_queue_does_not_panic() {
let _result = ContextBuilder::default().with_multi_queue().build();
}
#[test]
fn context_has_adapter_info() {
oxiui_core::require_gpu!(ctx, ComputeContext::try_new());
let info = ctx.adapter_info();
let backend_str = format!("{:?}", info.backend);
assert!(!backend_str.is_empty(), "backend string must not be empty");
}
#[test]
fn builder_with_low_power() {
oxiui_core::require_gpu!(
ctx,
ComputeContext::with_power_preference(wgpu::PowerPreference::LowPower)
.build()
.ok()
);
let _ = ctx;
}
#[test]
fn new_async_via_pollster() {
oxiui_core::require_gpu!(ctx, pollster::block_on(ComputeContext::new_async()).ok());
let _ = ctx;
}
#[test]
fn with_unsupported_features_returns_error() {
let result = ComputeContext::with_features(wgpu::Features::all()).build();
match result {
Ok(_) => { }
Err(ComputeError::NoAdapter) => { }
Err(ComputeError::DeviceRequest(_)) => { }
Err(e) => panic!("unexpected error variant: {e}"),
}
}
#[test]
fn transfer_queue_none_without_multi_queue() {
oxiui_core::require_gpu!(ctx, ComputeContext::try_new());
assert!(
ctx.transfer_queue().is_none(),
"transfer_queue must be None when multi-queue was not requested"
);
}
#[test]
fn multi_queue_context_builds() {
oxiui_core::require_gpu!(
ctx,
ComputeContext::builder().with_multi_queue().build().ok()
);
assert!(ctx.transfer_queue().is_none());
}
#[test]
fn from_device_via_real_gpu() {
oxiui_core::require_gpu!(ctx, ComputeContext::try_new());
let info = ctx.adapter_info().clone();
let ctx2 = ComputeContext::from_device(ctx.device, ctx.queue, Some(info.clone()));
assert_eq!(ctx2.adapter_info().name, info.name);
assert!(ctx2.transfer_queue().is_none());
}
}