use std::sync::{Arc, OnceLock};
use wgpu;
pub struct GpuContext {
pub device: wgpu::Device,
pub queue: wgpu::Queue,
pub encode_pipeline: wgpu::ComputePipeline,
pub accumulate_pipeline: wgpu::ComputePipeline,
pub encode_bind_group_layout: wgpu::BindGroupLayout,
pub accumulate_bind_group_layout: wgpu::BindGroupLayout,
pub adapter_name: String,
}
static GPU_CTX: OnceLock<Option<Arc<GpuContext>>> = OnceLock::new();
pub fn get_context() -> Option<Arc<GpuContext>> {
GPU_CTX
.get_or_init(|| GpuContext::try_new().map(Arc::new))
.clone()
}
pub fn is_available() -> bool {
get_context().is_some()
}
impl GpuContext {
fn try_new() -> Option<Self> {
let backends = wgpu::Backends::VULKAN | wgpu::Backends::METAL | wgpu::Backends::DX12;
let mut descriptor = wgpu::InstanceDescriptor::new_without_display_handle();
descriptor.backends = backends;
let instance = wgpu::Instance::new(descriptor);
let adapter = if let Ok(name_filter) = std::env::var("SC_GPU_ADAPTER") {
let filter_lower = name_filter.to_lowercase();
pollster::block_on(instance.enumerate_adapters(backends))
.into_iter()
.find(|a| a.get_info().name.to_lowercase().contains(&filter_lower))
} else {
pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
}))
.ok()
}?;
let adapter_name = adapter.get_info().name.clone();
let (device, queue) = pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
label: Some("sc-neurocore-gpu"),
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits::default(),
..Default::default()
}))
.ok()?;
let encode_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("encode"),
source: wgpu::ShaderSource::Wgsl(include_str!("shaders/encode.wgsl").into()),
});
let accum_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("accumulate"),
source: wgpu::ShaderSource::Wgsl(include_str!("shaders/accumulate.wgsl").into()),
});
let encode_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("encode_bgl"),
entries: &[
bgl_storage_ro(0), bgl_storage_rw(1), bgl_uniform(2), ],
});
let accum_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("accum_bgl"),
entries: &[
bgl_storage_ro(0), bgl_storage_ro(1), bgl_storage_rw(2), bgl_uniform(3), ],
});
let encode_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("encode_pl"),
bind_group_layouts: &[Some(&encode_bgl)],
immediate_size: 0,
});
let encode_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("encode_pipeline"),
layout: Some(&encode_pl),
module: &encode_shader,
entry_point: Some("encode_main"),
compilation_options: Default::default(),
cache: None,
});
let accum_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("accum_pl"),
bind_group_layouts: &[Some(&accum_bgl)],
immediate_size: 0,
});
let accumulate_pipeline =
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("accum_pipeline"),
layout: Some(&accum_pl),
module: &accum_shader,
entry_point: Some("accumulate_main"),
compilation_options: Default::default(),
cache: None,
});
Some(GpuContext {
device,
queue,
encode_pipeline,
accumulate_pipeline,
encode_bind_group_layout: encode_bgl,
accumulate_bind_group_layout: accum_bgl,
adapter_name,
})
}
}
fn bgl_storage_ro(binding: u32) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
fn bgl_storage_rw(binding: u32) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
fn bgl_uniform(binding: u32) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}