use crate::error::{GpuError, GpuResult};
use std::sync::Arc;
use tracing::{debug, info};
use wgpu::{
Adapter, AdapterInfo, Backend, Backends, Device, DeviceDescriptor, Features, Instance,
InstanceDescriptor, Limits, PowerPreference, Queue, RequestAdapterOptions,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackendPreference {
Vulkan,
Metal,
DX12,
WebGPU,
Auto,
All,
}
impl BackendPreference {
pub fn to_backends(&self) -> Backends {
match self {
Self::Vulkan => Backends::VULKAN,
Self::Metal => Backends::METAL,
Self::DX12 => Backends::DX12,
Self::WebGPU => Backends::BROWSER_WEBGPU,
Self::Auto => Backends::PRIMARY,
Self::All => Backends::all(),
}
}
pub fn platform_default() -> Self {
#[cfg(target_os = "macos")]
return Self::Metal;
#[cfg(target_os = "windows")]
return Self::DX12;
#[cfg(target_os = "linux")]
return Self::Vulkan;
#[cfg(target_arch = "wasm32")]
return Self::WebGPU;
#[cfg(not(any(
target_os = "macos",
target_os = "windows",
target_os = "linux",
target_arch = "wasm32"
)))]
return Self::Auto;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GpuPowerPreference {
LowPower,
HighPerformance,
Default,
}
impl From<GpuPowerPreference> for PowerPreference {
fn from(pref: GpuPowerPreference) -> Self {
match pref {
GpuPowerPreference::LowPower => PowerPreference::LowPower,
GpuPowerPreference::HighPerformance => PowerPreference::HighPerformance,
GpuPowerPreference::Default => PowerPreference::None,
}
}
}
#[derive(Debug, Clone)]
pub struct GpuContextConfig {
pub backend: BackendPreference,
pub power_preference: GpuPowerPreference,
pub required_features: Features,
pub required_limits: Option<Limits>,
pub debug: bool,
pub label: Option<String>,
}
impl Default for GpuContextConfig {
fn default() -> Self {
Self {
backend: BackendPreference::platform_default(),
power_preference: GpuPowerPreference::HighPerformance,
required_features: Features::empty(),
required_limits: None,
debug: cfg!(debug_assertions),
label: Some("OxiGDAL GPU Context".to_string()),
}
}
}
impl GpuContextConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_backend(mut self, backend: BackendPreference) -> Self {
self.backend = backend;
self
}
pub fn with_power_preference(mut self, power: GpuPowerPreference) -> Self {
self.power_preference = power;
self
}
pub fn with_features(mut self, features: Features) -> Self {
self.required_features = features;
self
}
pub fn with_limits(mut self, limits: Limits) -> Self {
self.required_limits = Some(limits);
self
}
pub fn with_debug(mut self, debug: bool) -> Self {
self.debug = debug;
self
}
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
}
#[derive(Clone)]
pub struct GpuContext {
instance: Arc<Instance>,
adapter: Arc<Adapter>,
device: Arc<Device>,
queue: Arc<Queue>,
adapter_info: AdapterInfo,
limits: Limits,
}
impl GpuContext {
pub async fn new() -> GpuResult<Self> {
Self::with_config(GpuContextConfig::default()).await
}
pub async fn with_config(config: GpuContextConfig) -> GpuResult<Self> {
info!(
"Initializing GPU context with backend: {:?}",
config.backend
);
let instance = Instance::new(InstanceDescriptor {
backends: config.backend.to_backends(),
..InstanceDescriptor::new_without_display_handle()
});
let adapter = Self::request_adapter(&instance, &config).await?;
let adapter_info = adapter.get_info();
info!(
"Selected GPU adapter: {} ({:?})",
adapter_info.name, adapter_info.backend
);
debug!("Adapter info: {:?}", adapter_info);
let adapter_limits = adapter.limits();
let limits = config
.required_limits
.unwrap_or_else(|| Self::default_limits(&adapter_limits));
if !Self::validate_limits(&limits, &adapter_limits) {
return Err(GpuError::device_request(format!(
"Requested limits exceed adapter capabilities: \
max_compute_workgroup_size_x: {} (adapter: {})",
limits.max_compute_workgroup_size_x, adapter_limits.max_compute_workgroup_size_x
)));
}
let (device, queue) = adapter
.request_device(&DeviceDescriptor {
label: config.label.as_deref(),
required_features: config.required_features,
required_limits: limits.clone(),
memory_hints: Default::default(),
experimental_features: Default::default(),
trace: Default::default(),
})
.await
.map_err(|e| GpuError::device_request(e.to_string()))?;
info!("GPU device created successfully");
debug!("Device limits: {:?}", limits);
Ok(Self {
instance: Arc::new(instance),
adapter: Arc::new(adapter),
device: Arc::new(device),
queue: Arc::new(queue),
adapter_info,
limits,
})
}
async fn request_adapter(instance: &Instance, config: &GpuContextConfig) -> GpuResult<Adapter> {
let adapter = instance
.request_adapter(&RequestAdapterOptions {
power_preference: config.power_preference.into(),
force_fallback_adapter: false,
compatible_surface: None,
})
.await;
adapter.map_err(|_| {
let backends = match config.backend {
BackendPreference::Auto => "Auto (PRIMARY)".to_string(),
BackendPreference::All => "All".to_string(),
backend => format!("{backend:?}"),
};
GpuError::no_adapter(backends)
})
}
fn default_limits(adapter_limits: &Limits) -> Limits {
Limits {
max_compute_workgroup_size_x: adapter_limits.max_compute_workgroup_size_x.min(256),
max_compute_workgroup_size_y: adapter_limits.max_compute_workgroup_size_y.min(256),
max_compute_workgroup_size_z: adapter_limits.max_compute_workgroup_size_z.min(64),
max_compute_invocations_per_workgroup: adapter_limits
.max_compute_invocations_per_workgroup
.min(256),
max_compute_workgroups_per_dimension: adapter_limits
.max_compute_workgroups_per_dimension,
..Default::default()
}
}
fn validate_limits(requested: &Limits, adapter: &Limits) -> bool {
requested.max_compute_workgroup_size_x <= adapter.max_compute_workgroup_size_x
&& requested.max_compute_workgroup_size_y <= adapter.max_compute_workgroup_size_y
&& requested.max_compute_workgroup_size_z <= adapter.max_compute_workgroup_size_z
&& requested.max_compute_invocations_per_workgroup
<= adapter.max_compute_invocations_per_workgroup
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn queue(&self) -> &Queue {
&self.queue
}
pub fn adapter(&self) -> &Adapter {
&self.adapter
}
pub fn instance(&self) -> &Instance {
&self.instance
}
pub fn adapter_info(&self) -> &AdapterInfo {
&self.adapter_info
}
pub fn limits(&self) -> &Limits {
&self.limits
}
pub fn backend(&self) -> Backend {
self.adapter_info.backend
}
pub fn supports_feature(&self, feature: Features) -> bool {
self.device.features().contains(feature)
}
pub fn max_workgroup_size(&self) -> (u32, u32, u32) {
(
self.limits.max_compute_workgroup_size_x,
self.limits.max_compute_workgroup_size_y,
self.limits.max_compute_workgroup_size_z,
)
}
pub fn poll(&self, _wait: bool) {
}
pub fn is_valid(&self) -> bool {
self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("health_check"),
size: 4,
usage: wgpu::BufferUsages::UNIFORM,
mapped_at_creation: false,
});
true
}
}
impl std::fmt::Debug for GpuContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuContext")
.field("adapter", &self.adapter_info.name)
.field("backend", &self.adapter_info.backend)
.field("device_type", &self.adapter_info.device_type)
.field("limits", &self.limits)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_gpu_context_creation() {
match GpuContext::new().await {
Ok(ctx) => {
println!("GPU Context created: {:?}", ctx);
assert!(ctx.is_valid());
}
Err(e) => {
println!("GPU not available (expected in CI): {}", e);
}
}
}
#[tokio::test]
async fn test_backend_preference() {
let config = GpuContextConfig::new()
.with_backend(BackendPreference::platform_default())
.with_power_preference(GpuPowerPreference::HighPerformance);
match GpuContext::with_config(config).await {
Ok(ctx) => {
println!("Backend: {:?}", ctx.backend());
}
Err(e) => {
println!("GPU not available: {}", e);
}
}
}
#[test]
fn test_backend_conversion() {
assert_eq!(BackendPreference::Vulkan.to_backends(), Backends::VULKAN);
assert_eq!(BackendPreference::Metal.to_backends(), Backends::METAL);
assert_eq!(BackendPreference::DX12.to_backends(), Backends::DX12);
}
#[test]
fn test_platform_default() {
let default = BackendPreference::platform_default();
println!("Platform default backend: {:?}", default);
#[cfg(target_os = "macos")]
assert_eq!(default, BackendPreference::Metal);
#[cfg(target_os = "windows")]
assert_eq!(default, BackendPreference::DX12);
#[cfg(target_os = "linux")]
assert_eq!(default, BackendPreference::Vulkan);
}
}