use super::gpu_shader::MOE_SHADER;
use bytemuck::{Pod, Zeroable};
use std::sync::mpsc::TryRecvError;
use std::sync::OnceLock;
use std::time::{Duration, Instant};
use wgpu::util::DeviceExt;
const GPU_BATCH_THRESHOLD: usize = 64;
const DEFAULT_GPU_MOE_TIMEOUT_MS: u64 = 30_000;
const INPUT_DIM: usize = crate::ml_scorer::NUM_FEATURES;
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
struct GpuParams {
batch_size: u32,
_pad: [u32; 3],
}
pub(super) struct GpuContext {
device_queue: std::sync::Arc<(wgpu::Device, wgpu::Queue)>,
adapter_info: wgpu::AdapterInfo,
device_limits: wgpu::Limits,
pipeline: wgpu::ComputePipeline,
weights_buf: wgpu::Buffer,
params_buf: wgpu::Buffer,
bind_group_layout: wgpu::BindGroupLayout,
}
impl GpuContext {
pub fn vram_mb(&self) -> Option<u64> {
const SANE_CAP_MB: u64 = 256 * 1024;
Some((self.device_limits.max_buffer_size / (1024 * 1024)).min(SANE_CAP_MB))
}
pub fn gpu_name(&self) -> &str {
&self.adapter_info.name
}
#[inline]
fn device(&self) -> &wgpu::Device {
&self.device_queue.0
}
#[inline]
fn queue(&self) -> &wgpu::Queue {
&self.device_queue.1
}
}
static GPU: OnceLock<Option<GpuContext>> = OnceLock::new();
fn gpu_moe_timeout() -> Duration {
static TIMEOUT: OnceLock<Duration> = OnceLock::new();
*TIMEOUT.get_or_init(|| {
let ms = std::env::var("KEYHOG_GPU_MOE_TIMEOUT_MS")
.ok()
.and_then(|value| value.parse::<u64>().ok())
.filter(|&value| value > 0)
.unwrap_or(DEFAULT_GPU_MOE_TIMEOUT_MS);
Duration::from_millis(ms)
})
}
fn init_gpu() -> Result<GpuContext, Box<dyn std::error::Error + Send + Sync>> {
let vyre_backend = vyre_driver_wgpu::WgpuBackend::shared()
.map_err(|e| format!("vyre WgpuBackend unavailable: {e}"))?;
let adapter_info = vyre_backend.adapter_info().clone();
if adapter_info.device_type == wgpu::DeviceType::Cpu {
return Err(format!(
"GPU adapter is a software fallback ({} on {:?}); refusing to use",
adapter_info.name, adapter_info.backend
)
.into());
}
let device_limits = vyre_backend.device_limits().clone();
let dq = vyre_backend.device_queue();
tracing::info!(
gpu = %adapter_info.name,
backend = ?adapter_info.backend,
device_type = ?adapter_info.device_type,
driver = %adapter_info.driver,
"GPU MoE: reusing vyre shared device"
);
let device = &dq.0;
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("moe_shader"),
source: wgpu::ShaderSource::Wgsl(MOE_SHADER.into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("moe_bgl"),
entries: &[
bgl_entry(0, true),
bgl_entry(1, true),
bgl_entry(2, false),
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("moe_pipeline_layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("moe_pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("moe_forward"),
compilation_options: Default::default(),
cache: None,
});
let all_weights = crate::ml_scorer::ml_weights::all_weights_slice();
let weights_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("weights"),
contents: bytemuck::cast_slice(all_weights),
usage: wgpu::BufferUsages::STORAGE,
});
let params_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("params"),
size: std::mem::size_of::<GpuParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
Ok(GpuContext {
device_queue: dq,
adapter_info,
device_limits,
pipeline,
weights_buf,
params_buf,
bind_group_layout,
})
}
fn bgl_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
pub fn get_gpu() -> Option<&'static GpuContext> {
GPU.get_or_init(|| match init_gpu() {
Ok(ctx) => {
tracing::info!("GPU MoE inference initialized (shared device)");
Some(ctx)
}
Err(e) => {
let no_gpu = super::env_no_gpu();
let require_gpu = std::env::var("KEYHOG_REQUIRE_GPU").as_deref() == Ok("1");
if require_gpu {
eprintln!("keyhog: KEYHOG_REQUIRE_GPU=1 but GPU MoE init failed: {e}");
std::process::exit(2);
}
let gpu_present = crate::hw_probe::probe_hardware().gpu_available;
if !no_gpu && gpu_present {
eprintln!(
"keyhog: a GPU was detected but could not be initialized; using the \
CPU/SIMD scan path. Set KEYHOG_NO_GPU=1 to silence this, or KEYHOG_REQUIRE_GPU=1 to fail instead."
);
}
tracing::debug!("GPU MoE init failed, using CPU fallback: {e}");
None
}
})
.as_ref()
}
pub fn batch_score_features(features: &[[f32; INPUT_DIM]]) -> Option<Vec<f64>> {
if features.len() < GPU_BATCH_THRESHOLD {
return None; }
let gpu = get_gpu()?;
let batch_size = features.len();
let device = gpu.device();
let queue = gpu.queue();
let input_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("input"),
contents: bytemuck::cast_slice(features),
usage: wgpu::BufferUsages::STORAGE,
});
let output_size = (batch_size * std::mem::size_of::<f32>()) as u64;
let output_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("output"),
size: output_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let staging_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("staging"),
size: output_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let params = GpuParams {
batch_size: batch_size as u32,
_pad: [0; 3],
};
queue.write_buffer(&gpu.params_buf, 0, bytemuck::bytes_of(¶ms));
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("moe_bg"),
layout: &gpu.bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: gpu.weights_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: input_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: output_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: gpu.params_buf.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("moe_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("moe_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&gpu.pipeline);
pass.set_bind_group(0, &bind_group, &[]);
let workgroups = (batch_size as u32).div_ceil(64);
pass.dispatch_workgroups(workgroups, 1, 1);
}
encoder.copy_buffer_to_buffer(&output_buf, 0, &staging_buf, 0, output_size);
queue.submit(std::iter::once(encoder.finish()));
let slice = staging_buf.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = sender.send(result);
});
let timeout = gpu_moe_timeout();
let deadline = Instant::now() + timeout;
let map_recv = loop {
match receiver.try_recv() {
Ok(result) => break result,
Err(TryRecvError::Disconnected) => {
tracing::warn!(
"GPU MoE staging-buffer callback disconnected; falling back to CPU MoE for this scan"
);
return None;
}
Err(TryRecvError::Empty) => {}
}
if Instant::now() >= deadline {
tracing::warn!(
?timeout,
"GPU MoE staging-buffer readback timed out; falling back to CPU MoE for this scan"
);
return None;
}
if let Err(error) = device.poll(wgpu::PollType::Poll) {
tracing::warn!(
?error,
"GPU MoE device.poll() failed; falling back to CPU MoE for this scan"
);
return None;
}
if let Ok(result) = receiver.try_recv() {
break result;
}
std::thread::sleep(Duration::from_millis(1));
};
if let Err(error) = map_recv {
tracing::warn!(
?error,
"GPU MoE staging-buffer map_async failed; falling back to CPU MoE for this scan"
);
return None;
}
let data = slice.get_mapped_range();
let scores: &[f32] = bytemuck::cast_slice(&data);
let result: Vec<f64> = scores
.iter()
.map(|&s| {
let v = s as f64;
if v.is_finite() {
v.clamp(0.0, 1.0)
} else {
0.5
}
})
.collect();
drop(data);
staging_buf.unmap();
Some(result)
}