1#[cfg(feature = "gpu")]
13mod backend {
14 use std::sync::OnceLock;
15
16 use bytemuck::{Pod, Zeroable};
17
18 const GPU_BATCH_THRESHOLD: usize = 64;
21
22 #[allow(dead_code)]
23 const INPUT_DIM: usize = 41;
24 #[allow(dead_code)]
25 const EXPERT_COUNT: usize = 6;
26 #[allow(dead_code)]
27 const HIDDEN1: usize = 32;
28 #[allow(dead_code)]
29 const HIDDEN2: usize = 16;
30
31 #[allow(dead_code)]
33 const TOTAL_WEIGHT_F32S: usize = (INPUT_DIM * EXPERT_COUNT + EXPERT_COUNT)
34 + EXPERT_COUNT
35 * (INPUT_DIM * HIDDEN1 + HIDDEN1 + HIDDEN1 * HIDDEN2 + HIDDEN2 + HIDDEN2 + 1);
36
37 #[derive(Clone, Copy, Pod, Zeroable)]
38 #[repr(C)]
39 struct GpuParams {
40 batch_size: u32,
41 _pad: [u32; 3],
42 }
43
44 pub(super) struct GpuContext {
45 device: wgpu::Device,
46 queue: wgpu::Queue,
47 adapter_info: wgpu::AdapterInfo,
48 pipeline: wgpu::ComputePipeline,
49 weights_buf: wgpu::Buffer,
50 params_buf: wgpu::Buffer,
51 bind_group_layout: wgpu::BindGroupLayout,
52 }
53
54 impl GpuContext {
55 pub fn vram_mb(&self) -> Option<u64> {
58 let limits = self.device.limits();
61 Some((limits.max_storage_buffer_binding_size as u64) / (1024 * 1024))
62 }
63
64 pub fn gpu_name(&self) -> &str {
66 &self.adapter_info.name
67 }
68 }
69
70 static GPU: OnceLock<Option<GpuContext>> = OnceLock::new();
71
72 fn init_gpu() -> Result<GpuContext, Box<dyn std::error::Error + Send + Sync>> {
73 let handle = std::thread::spawn(|| {
76 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
77 backends: wgpu::Backends::all(),
78 ..Default::default()
79 });
80
81 let adapter =
82 pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
83 power_preference: wgpu::PowerPreference::HighPerformance,
84 compatible_surface: None,
85 force_fallback_adapter: false,
86 }))
87 .ok_or("No GPU adapter found")?;
88
89 let adapter_info = adapter.get_info();
90
91 let (device, queue) = pollster::block_on(adapter.request_device(
92 &wgpu::DeviceDescriptor {
93 label: Some("keyhog-moe"),
94 required_features: wgpu::Features::empty(),
95 required_limits: wgpu::Limits::default(),
96 ..Default::default()
97 },
98 None,
99 ))?;
100
101 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
102 label: Some("moe_shader"),
103 source: wgpu::ShaderSource::Wgsl(MOE_SHADER.into()),
104 });
105
106 let bind_group_layout =
107 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
108 label: Some("moe_bgl"),
109 entries: &[
110 bgl_entry(0, true),
112 bgl_entry(1, true),
114 bgl_entry(2, false),
116 wgpu::BindGroupLayoutEntry {
118 binding: 3,
119 visibility: wgpu::ShaderStages::COMPUTE,
120 ty: wgpu::BindingType::Buffer {
121 ty: wgpu::BufferBindingType::Uniform,
122 has_dynamic_offset: false,
123 min_binding_size: None,
124 },
125 count: None,
126 },
127 ],
128 });
129
130 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
131 label: Some("moe_pipeline_layout"),
132 bind_group_layouts: &[&bind_group_layout],
133 push_constant_ranges: &[],
134 });
135
136 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
137 label: Some("moe_pipeline"),
138 layout: Some(&pipeline_layout),
139 module: &shader,
140 entry_point: Some("moe_forward"),
141 compilation_options: Default::default(),
142 cache: None,
143 });
144
145 let all_weights = crate::ml_scorer::ml_weights::all_weights_slice();
147 let weights_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
148 label: Some("weights"),
149 contents: bytemuck::cast_slice(all_weights),
150 usage: wgpu::BufferUsages::STORAGE,
151 });
152
153 let params_buf = device.create_buffer(&wgpu::BufferDescriptor {
154 label: Some("params"),
155 size: std::mem::size_of::<GpuParams>() as u64,
156 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
157 mapped_at_creation: false,
158 });
159
160 Ok(GpuContext {
161 device,
162 queue,
163 adapter_info,
164 pipeline,
165 weights_buf,
166 params_buf,
167 bind_group_layout,
168 })
169 });
170 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
172 loop {
173 if handle.is_finished() {
174 return handle.join().map_err(|_| "GPU init thread panicked")?;
175 }
176 if std::time::Instant::now() > deadline {
177 return Err("GPU init timed out — falling back to CPU".into());
178 }
179 std::thread::sleep(std::time::Duration::from_millis(50));
180 }
181 }
182
183 fn bgl_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
184 wgpu::BindGroupLayoutEntry {
185 binding,
186 visibility: wgpu::ShaderStages::COMPUTE,
187 ty: wgpu::BindingType::Buffer {
188 ty: wgpu::BufferBindingType::Storage { read_only },
189 has_dynamic_offset: false,
190 min_binding_size: None,
191 },
192 count: None,
193 }
194 }
195
196 pub fn get_gpu() -> Option<&'static GpuContext> {
205 GPU.get_or_init(|| match init_gpu() {
206 Ok(ctx) => {
207 tracing::info!("GPU MoE inference initialized");
208 Some(ctx)
209 }
210 Err(e) => {
211 tracing::debug!("GPU init failed, using CPU fallback: {e}");
212 None
213 }
214 })
215 .as_ref()
216 }
217
218 pub fn batch_score_features(features: &[[f32; INPUT_DIM]]) -> Option<Vec<f64>> {
228 if features.len() < GPU_BATCH_THRESHOLD {
229 return None; }
231
232 let gpu = get_gpu()?;
233 let batch_size = features.len();
234
235 let flat_features: Vec<f32> = features.iter().flat_map(|f| f.iter().copied()).collect();
237
238 let input_buf = gpu
239 .device
240 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
241 label: Some("input"),
242 contents: bytemuck::cast_slice(&flat_features),
243 usage: wgpu::BufferUsages::STORAGE,
244 });
245
246 let output_size = (batch_size * std::mem::size_of::<f32>()) as u64;
247 let output_buf = gpu.device.create_buffer(&wgpu::BufferDescriptor {
248 label: Some("output"),
249 size: output_size,
250 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
251 mapped_at_creation: false,
252 });
253
254 let staging_buf = gpu.device.create_buffer(&wgpu::BufferDescriptor {
255 label: Some("staging"),
256 size: output_size,
257 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
258 mapped_at_creation: false,
259 });
260
261 let params = GpuParams {
263 batch_size: batch_size as u32,
264 _pad: [0; 3],
265 };
266 gpu.queue
267 .write_buffer(&gpu.params_buf, 0, bytemuck::bytes_of(¶ms));
268
269 let bind_group = gpu.device.create_bind_group(&wgpu::BindGroupDescriptor {
270 label: Some("moe_bg"),
271 layout: &gpu.bind_group_layout,
272 entries: &[
273 wgpu::BindGroupEntry {
274 binding: 0,
275 resource: gpu.weights_buf.as_entire_binding(),
276 },
277 wgpu::BindGroupEntry {
278 binding: 1,
279 resource: input_buf.as_entire_binding(),
280 },
281 wgpu::BindGroupEntry {
282 binding: 2,
283 resource: output_buf.as_entire_binding(),
284 },
285 wgpu::BindGroupEntry {
286 binding: 3,
287 resource: gpu.params_buf.as_entire_binding(),
288 },
289 ],
290 });
291
292 let mut encoder = gpu
293 .device
294 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
295 label: Some("moe_encoder"),
296 });
297
298 {
299 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
300 label: Some("moe_pass"),
301 timestamp_writes: None,
302 });
303 pass.set_pipeline(&gpu.pipeline);
304 pass.set_bind_group(0, &bind_group, &[]);
305 let workgroups = (batch_size as u32).div_ceil(64);
307 pass.dispatch_workgroups(workgroups, 1, 1);
308 }
309
310 encoder.copy_buffer_to_buffer(&output_buf, 0, &staging_buf, 0, output_size);
311 gpu.queue.submit(std::iter::once(encoder.finish()));
312
313 let slice = staging_buf.slice(..);
315 let (sender, receiver) = std::sync::mpsc::channel();
316 slice.map_async(wgpu::MapMode::Read, move |result| {
317 let _ = sender.send(result);
318 });
319 gpu.device.poll(wgpu::Maintain::Wait);
320
321 receiver.recv().ok()?.ok()?;
322 let data = slice.get_mapped_range();
323 let scores: &[f32] = bytemuck::cast_slice(&data);
324 let result: Vec<f64> = scores.iter().map(|&s| s as f64).collect();
325 drop(data);
326 staging_buf.unmap();
327
328 Some(result)
329 }
330
331 use wgpu::util::DeviceExt;
332
333 const MOE_SHADER: &str = r#"
335// MoE architecture constants
336const INPUT_DIM: u32 = 41u;
337const EXPERT_COUNT: u32 = 6u;
338const HIDDEN1: u32 = 32u;
339const HIDDEN2: u32 = 16u;
340
341// Weight layout offsets (in f32 units)
342const GATE_W_OFF: u32 = 0u;
343const GATE_W_COUNT: u32 = 246u; // 41 * 6
344const GATE_B_OFF: u32 = 246u;
345const GATE_B_COUNT: u32 = 6u;
346const EXPERTS_OFF: u32 = 252u;
347
348// Per-expert parameter counts
349const E_FC1_W: u32 = 1312u; // 41 * 32
350const E_FC1_B: u32 = 32u;
351const E_FC2_W: u32 = 512u; // 32 * 16
352const E_FC2_B: u32 = 16u;
353const E_FC3_W: u32 = 16u;
354const E_FC3_B: u32 = 1u;
355const EXPERT_PARAMS: u32 = 1889u; // sum of above
356
357struct Params {
358 batch_size: u32,
359}
360
361@group(0) @binding(0) var<storage, read> weights: array<f32>;
362@group(0) @binding(1) var<storage, read> inputs: array<f32>;
363@group(0) @binding(2) var<storage, read_write> outputs: array<f32>;
364@group(0) @binding(3) var<uniform> params: Params;
365
366fn get_input(batch_idx: u32, feat_idx: u32) -> f32 {
367 return inputs[batch_idx * INPUT_DIM + feat_idx];
368}
369
370fn gate_dot(batch_idx: u32, expert_idx: u32) -> f32 {
371 var sum = weights[GATE_B_OFF + expert_idx];
372 for (var i = 0u; i < INPUT_DIM; i++) {
373 sum += weights[GATE_W_OFF + expert_idx * INPUT_DIM + i] * get_input(batch_idx, i);
374 }
375 return sum;
376}
377
378fn expert_base(expert_idx: u32) -> u32 {
379 return EXPERTS_OFF + expert_idx * EXPERT_PARAMS;
380}
381
382fn expert_forward(batch_idx: u32, expert_idx: u32) -> f32 {
383 let base = expert_base(expert_idx);
384
385 // FC1: input(41) -> hidden1(32) + ReLU
386 var h1: array<f32, 32>;
387 let fc1_w_off = base;
388 let fc1_b_off = base + E_FC1_W;
389 for (var j = 0u; j < HIDDEN1; j++) {
390 var sum = weights[fc1_b_off + j];
391 for (var i = 0u; i < INPUT_DIM; i++) {
392 sum += weights[fc1_w_off + j * INPUT_DIM + i] * get_input(batch_idx, i);
393 }
394 h1[j] = max(sum, 0.0); // ReLU
395 }
396
397 // FC2: hidden1(32) -> hidden2(16) + ReLU
398 var h2: array<f32, 16>;
399 let fc2_w_off = base + E_FC1_W + E_FC1_B;
400 let fc2_b_off = fc2_w_off + E_FC2_W;
401 for (var j = 0u; j < HIDDEN2; j++) {
402 var sum = weights[fc2_b_off + j];
403 for (var i = 0u; i < HIDDEN1; i++) {
404 sum += weights[fc2_w_off + j * HIDDEN1 + i] * h1[i];
405 }
406 h2[j] = max(sum, 0.0); // ReLU
407 }
408
409 // FC3: hidden2(16) -> output(1)
410 let fc3_w_off = base + E_FC1_W + E_FC1_B + E_FC2_W + E_FC2_B;
411 let fc3_b_off = fc3_w_off + E_FC3_W;
412 var out = weights[fc3_b_off];
413 for (var i = 0u; i < HIDDEN2; i++) {
414 out += weights[fc3_w_off + i] * h2[i];
415 }
416 return out;
417}
418
419@compute @workgroup_size(64)
420fn moe_forward(@builtin(global_invocation_id) gid: vec3<u32>) {
421 let idx = gid.x;
422 if (idx >= params.batch_size) {
423 return;
424 }
425
426 // Compute gate logits and softmax
427 var gate_logits: array<f32, 6>;
428 var max_logit = -1e30;
429 for (var e = 0u; e < EXPERT_COUNT; e++) {
430 gate_logits[e] = gate_dot(idx, e);
431 max_logit = max(max_logit, gate_logits[e]);
432 }
433
434 var exp_sum = 0.0;
435 var gate_probs: array<f32, 6>;
436 for (var e = 0u; e < EXPERT_COUNT; e++) {
437 gate_probs[e] = exp(gate_logits[e] - max_logit);
438 exp_sum += gate_probs[e];
439 }
440 for (var e = 0u; e < EXPERT_COUNT; e++) {
441 gate_probs[e] /= exp_sum;
442 }
443
444 // Weighted sum of expert outputs
445 var score_logit = 0.0;
446 for (var e = 0u; e < EXPERT_COUNT; e++) {
447 score_logit += gate_probs[e] * expert_forward(idx, e);
448 }
449
450 // Sigmoid
451 outputs[idx] = 1.0 / (1.0 + exp(-score_logit));
452}
453"#;
454}
455
456pub fn batch_ml_inference(
472 candidates: &[(String, String)],
473 config: &crate::types::ScannerConfig,
474) -> Vec<f64> {
475 if candidates.is_empty() {
476 return Vec::new();
477 }
478
479 #[cfg(feature = "ml")]
480 {
481 #[cfg(feature = "gpu")]
483 {
484 let features: Vec<[f32; 41]> = candidates
485 .iter()
486 .map(|(text, ctx)| {
487 crate::ml_scorer::compute_features_with_config(
488 text,
489 ctx,
490 &config.known_prefixes,
491 &config.secret_keywords,
492 &config.test_keywords,
493 &config.placeholder_keywords,
494 )
495 })
496 .collect();
497
498 if let Some(scores) = backend::batch_score_features(&features) {
499 return scores;
500 }
501 }
502
503 candidates
505 .iter()
506 .map(|(text, ctx)| {
507 crate::ml_scorer::score_with_config(
508 text,
509 ctx,
510 &config.known_prefixes,
511 &config.secret_keywords,
512 &config.test_keywords,
513 &config.placeholder_keywords,
514 )
515 })
516 .collect()
517 }
518
519 #[cfg(not(feature = "ml"))]
520 {
521 let _ = candidates;
522 let _ = config;
523 Vec::new()
524 }
525}
526
527pub fn gpu_available() -> bool {
537 #[cfg(feature = "gpu")]
538 {
539 backend::get_gpu().is_some()
540 }
541 #[cfg(not(feature = "gpu"))]
542 {
543 false
544 }
545}
546
547#[must_use]
549pub fn gpu_probe() -> (bool, Option<String>, Option<u64>) {
550 #[cfg(feature = "gpu")]
551 {
552 if let Some(gpu) = backend::get_gpu() {
553 return (true, Some(gpu.gpu_name().to_string()), gpu.vram_mb());
554 }
555 (false, None, None)
556 }
557
558 #[cfg(not(feature = "gpu"))]
559 {
560 (false, None, None)
561 }
562}