Skip to main content

keyhog_scanner/
gpu.rs

1//! GPU-accelerated batch inference for the MoE classifier via wgpu compute shaders.
2//!
3//! Processes N feature vectors in a single GPU dispatch, achieving ~10-100x
4//! throughput over CPU for large batches. Falls back to CPU when no GPU is
5//! available or for batches smaller than the crossover threshold.
6//!
7//! Architecture mirrors ml_scorer.rs exactly:
8//! - Gate: Linear(41→6) + softmax
9//! - 6 experts: Linear(41→32)+ReLU → Linear(32→16)+ReLU → Linear(16→1)
10//! - Output: sigmoid(weighted sum of expert logits)
11
12#[cfg(feature = "gpu")]
13mod backend {
14    use std::sync::OnceLock;
15
16    use bytemuck::{Pod, Zeroable};
17
18    /// Minimum batch size before GPU dispatch is worthwhile.
19    /// Below this, CPU is faster due to GPU dispatch overhead.
20    const GPU_BATCH_THRESHOLD: usize = 64;
21
22    #[allow(dead_code)]
23    const INPUT_DIM: usize = 41;
24    #[allow(dead_code)]
25    const EXPERT_COUNT: usize = 6;
26    #[allow(dead_code)]
27    const HIDDEN1: usize = 32;
28    #[allow(dead_code)]
29    const HIDDEN2: usize = 16;
30
31    /// Total f32 weights: gate(41*6 + 6) + 6 experts * (41*32+32 + 32*16+16 + 16+1)
32    #[allow(dead_code)]
33    const TOTAL_WEIGHT_F32S: usize = (INPUT_DIM * EXPERT_COUNT + EXPERT_COUNT)
34        + EXPERT_COUNT
35            * (INPUT_DIM * HIDDEN1 + HIDDEN1 + HIDDEN1 * HIDDEN2 + HIDDEN2 + HIDDEN2 + 1);
36
37    #[derive(Clone, Copy, Pod, Zeroable)]
38    #[repr(C)]
39    struct GpuParams {
40        batch_size: u32,
41        _pad: [u32; 3],
42    }
43
44    pub(super) struct GpuContext {
45        device: wgpu::Device,
46        queue: wgpu::Queue,
47        adapter_info: wgpu::AdapterInfo,
48        pipeline: wgpu::ComputePipeline,
49        weights_buf: wgpu::Buffer,
50        params_buf: wgpu::Buffer,
51        bind_group_layout: wgpu::BindGroupLayout,
52    }
53
54    impl GpuContext {
55        /// Approximate GPU VRAM in MiB. Returns None when wgpu does not expose
56        /// dedicated memory metrics (common on integrated and Apple Silicon GPUs).
57        pub fn vram_mb(&self) -> Option<u64> {
58            // wgpu/WebGPU does not standardize VRAM queries. Use the maximum
59            // storage buffer binding size as a rough capability proxy.
60            let limits = self.device.limits();
61            Some((limits.max_storage_buffer_binding_size as u64) / (1024 * 1024))
62        }
63
64        /// Human-readable GPU name from the adapter.
65        pub fn gpu_name(&self) -> &str {
66            &self.adapter_info.name
67        }
68    }
69
70    static GPU: OnceLock<Option<GpuContext>> = OnceLock::new();
71
72    fn init_gpu() -> Result<GpuContext, Box<dyn std::error::Error + Send + Sync>> {
73        // Offload blocking wgpu initialization to a dedicated OS thread so we
74        // don't starve the calling thread's async runtime (e.g. tokio workers).
75        let handle = std::thread::spawn(|| {
76            let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
77                backends: wgpu::Backends::all(),
78                ..Default::default()
79            });
80
81            let adapter =
82                pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
83                    power_preference: wgpu::PowerPreference::HighPerformance,
84                    compatible_surface: None,
85                    force_fallback_adapter: false,
86                }))
87                .ok_or("No GPU adapter found")?;
88
89            let adapter_info = adapter.get_info();
90
91            let (device, queue) = pollster::block_on(adapter.request_device(
92                &wgpu::DeviceDescriptor {
93                    label: Some("keyhog-moe"),
94                    required_features: wgpu::Features::empty(),
95                    required_limits: wgpu::Limits::default(),
96                    ..Default::default()
97                },
98                None,
99            ))?;
100
101            let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
102                label: Some("moe_shader"),
103                source: wgpu::ShaderSource::Wgsl(MOE_SHADER.into()),
104            });
105
106            let bind_group_layout =
107                device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
108                    label: Some("moe_bgl"),
109                    entries: &[
110                        // Weights buffer (read-only storage)
111                        bgl_entry(0, true),
112                        // Input features buffer (read-only storage)
113                        bgl_entry(1, true),
114                        // Output scores buffer (read-write storage)
115                        bgl_entry(2, false),
116                        // Params uniform
117                        wgpu::BindGroupLayoutEntry {
118                            binding: 3,
119                            visibility: wgpu::ShaderStages::COMPUTE,
120                            ty: wgpu::BindingType::Buffer {
121                                ty: wgpu::BufferBindingType::Uniform,
122                                has_dynamic_offset: false,
123                                min_binding_size: None,
124                            },
125                            count: None,
126                        },
127                    ],
128                });
129
130            let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
131                label: Some("moe_pipeline_layout"),
132                bind_group_layouts: &[&bind_group_layout],
133                push_constant_ranges: &[],
134            });
135
136            let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
137                label: Some("moe_pipeline"),
138                layout: Some(&pipeline_layout),
139                module: &shader,
140                entry_point: Some("moe_forward"),
141                compilation_options: Default::default(),
142                cache: None,
143            });
144
145            // Upload weights once
146            let all_weights = crate::ml_scorer::ml_weights::all_weights_slice();
147            let weights_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
148                label: Some("weights"),
149                contents: bytemuck::cast_slice(all_weights),
150                usage: wgpu::BufferUsages::STORAGE,
151            });
152
153            let params_buf = device.create_buffer(&wgpu::BufferDescriptor {
154                label: Some("params"),
155                size: std::mem::size_of::<GpuParams>() as u64,
156                usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
157                mapped_at_creation: false,
158            });
159
160            Ok(GpuContext {
161                device,
162                queue,
163                adapter_info,
164                pipeline,
165                weights_buf,
166                params_buf,
167                bind_group_layout,
168            })
169        });
170        // 2-second timeout: never block startup waiting for GPU.
171        let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
172        loop {
173            if handle.is_finished() {
174                return handle.join().map_err(|_| "GPU init thread panicked")?;
175            }
176            if std::time::Instant::now() > deadline {
177                return Err("GPU init timed out — falling back to CPU".into());
178            }
179            std::thread::sleep(std::time::Duration::from_millis(50));
180        }
181    }
182
183    fn bgl_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
184        wgpu::BindGroupLayoutEntry {
185            binding,
186            visibility: wgpu::ShaderStages::COMPUTE,
187            ty: wgpu::BindingType::Buffer {
188                ty: wgpu::BufferBindingType::Storage { read_only },
189                has_dynamic_offset: false,
190                min_binding_size: None,
191            },
192            count: None,
193        }
194    }
195
196    /// Return the lazily initialized GPU context when GPU inference is available.
197    ///
198    /// # Examples
199    ///
200    /// ```rust,ignore
201    /// use keyhog_scanner::gpu::get_gpu;
202    /// let _ = get_gpu();
203    /// ```
204    pub fn get_gpu() -> Option<&'static GpuContext> {
205        GPU.get_or_init(|| match init_gpu() {
206            Ok(ctx) => {
207                tracing::info!("GPU MoE inference initialized");
208                Some(ctx)
209            }
210            Err(e) => {
211                tracing::debug!("GPU init failed, using CPU fallback: {e}");
212                None
213            }
214        })
215        .as_ref()
216    }
217
218    /// Score a batch of feature vectors on GPU. Returns one score per input.
219    /// Score a batch of precomputed feature vectors on the GPU.
220    ///
221    /// # Examples
222    ///
223    /// ```rust,ignore
224    /// use keyhog_scanner::gpu::batch_score_features;
225    /// let _ = batch_score_features(&[[0.0; 41]]);
226    /// ```
227    pub fn batch_score_features(features: &[[f32; INPUT_DIM]]) -> Option<Vec<f64>> {
228        if features.len() < GPU_BATCH_THRESHOLD {
229            return None; // Too small for GPU, caller should use CPU
230        }
231
232        let gpu = get_gpu()?;
233        let batch_size = features.len();
234
235        // Flatten features into a contiguous f32 buffer
236        let flat_features: Vec<f32> = features.iter().flat_map(|f| f.iter().copied()).collect();
237
238        let input_buf = gpu
239            .device
240            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
241                label: Some("input"),
242                contents: bytemuck::cast_slice(&flat_features),
243                usage: wgpu::BufferUsages::STORAGE,
244            });
245
246        let output_size = (batch_size * std::mem::size_of::<f32>()) as u64;
247        let output_buf = gpu.device.create_buffer(&wgpu::BufferDescriptor {
248            label: Some("output"),
249            size: output_size,
250            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
251            mapped_at_creation: false,
252        });
253
254        let staging_buf = gpu.device.create_buffer(&wgpu::BufferDescriptor {
255            label: Some("staging"),
256            size: output_size,
257            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
258            mapped_at_creation: false,
259        });
260
261        // Upload params
262        let params = GpuParams {
263            batch_size: batch_size as u32,
264            _pad: [0; 3],
265        };
266        gpu.queue
267            .write_buffer(&gpu.params_buf, 0, bytemuck::bytes_of(&params));
268
269        let bind_group = gpu.device.create_bind_group(&wgpu::BindGroupDescriptor {
270            label: Some("moe_bg"),
271            layout: &gpu.bind_group_layout,
272            entries: &[
273                wgpu::BindGroupEntry {
274                    binding: 0,
275                    resource: gpu.weights_buf.as_entire_binding(),
276                },
277                wgpu::BindGroupEntry {
278                    binding: 1,
279                    resource: input_buf.as_entire_binding(),
280                },
281                wgpu::BindGroupEntry {
282                    binding: 2,
283                    resource: output_buf.as_entire_binding(),
284                },
285                wgpu::BindGroupEntry {
286                    binding: 3,
287                    resource: gpu.params_buf.as_entire_binding(),
288                },
289            ],
290        });
291
292        let mut encoder = gpu
293            .device
294            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
295                label: Some("moe_encoder"),
296            });
297
298        {
299            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
300                label: Some("moe_pass"),
301                timestamp_writes: None,
302            });
303            pass.set_pipeline(&gpu.pipeline);
304            pass.set_bind_group(0, &bind_group, &[]);
305            // Each workgroup processes 64 items
306            let workgroups = (batch_size as u32).div_ceil(64);
307            pass.dispatch_workgroups(workgroups, 1, 1);
308        }
309
310        encoder.copy_buffer_to_buffer(&output_buf, 0, &staging_buf, 0, output_size);
311        gpu.queue.submit(std::iter::once(encoder.finish()));
312
313        // Read back results
314        let slice = staging_buf.slice(..);
315        let (sender, receiver) = std::sync::mpsc::channel();
316        slice.map_async(wgpu::MapMode::Read, move |result| {
317            let _ = sender.send(result);
318        });
319        gpu.device.poll(wgpu::Maintain::Wait);
320
321        receiver.recv().ok()?.ok()?;
322        let data = slice.get_mapped_range();
323        let scores: &[f32] = bytemuck::cast_slice(&data);
324        let result: Vec<f64> = scores.iter().map(|&s| s as f64).collect();
325        drop(data);
326        staging_buf.unmap();
327
328        Some(result)
329    }
330
331    use wgpu::util::DeviceExt;
332
333    /// WGSL compute shader implementing the full MoE forward pass.
334    const MOE_SHADER: &str = r#"
335// MoE architecture constants
336const INPUT_DIM: u32 = 41u;
337const EXPERT_COUNT: u32 = 6u;
338const HIDDEN1: u32 = 32u;
339const HIDDEN2: u32 = 16u;
340
341// Weight layout offsets (in f32 units)
342const GATE_W_OFF: u32 = 0u;
343const GATE_W_COUNT: u32 = 246u;  // 41 * 6
344const GATE_B_OFF: u32 = 246u;
345const GATE_B_COUNT: u32 = 6u;
346const EXPERTS_OFF: u32 = 252u;
347
348// Per-expert parameter counts
349const E_FC1_W: u32 = 1312u;  // 41 * 32
350const E_FC1_B: u32 = 32u;
351const E_FC2_W: u32 = 512u;   // 32 * 16
352const E_FC2_B: u32 = 16u;
353const E_FC3_W: u32 = 16u;
354const E_FC3_B: u32 = 1u;
355const EXPERT_PARAMS: u32 = 1889u;  // sum of above
356
357struct Params {
358    batch_size: u32,
359}
360
361@group(0) @binding(0) var<storage, read> weights: array<f32>;
362@group(0) @binding(1) var<storage, read> inputs: array<f32>;
363@group(0) @binding(2) var<storage, read_write> outputs: array<f32>;
364@group(0) @binding(3) var<uniform> params: Params;
365
366fn get_input(batch_idx: u32, feat_idx: u32) -> f32 {
367    return inputs[batch_idx * INPUT_DIM + feat_idx];
368}
369
370fn gate_dot(batch_idx: u32, expert_idx: u32) -> f32 {
371    var sum = weights[GATE_B_OFF + expert_idx];
372    for (var i = 0u; i < INPUT_DIM; i++) {
373        sum += weights[GATE_W_OFF + expert_idx * INPUT_DIM + i] * get_input(batch_idx, i);
374    }
375    return sum;
376}
377
378fn expert_base(expert_idx: u32) -> u32 {
379    return EXPERTS_OFF + expert_idx * EXPERT_PARAMS;
380}
381
382fn expert_forward(batch_idx: u32, expert_idx: u32) -> f32 {
383    let base = expert_base(expert_idx);
384
385    // FC1: input(41) -> hidden1(32) + ReLU
386    var h1: array<f32, 32>;
387    let fc1_w_off = base;
388    let fc1_b_off = base + E_FC1_W;
389    for (var j = 0u; j < HIDDEN1; j++) {
390        var sum = weights[fc1_b_off + j];
391        for (var i = 0u; i < INPUT_DIM; i++) {
392            sum += weights[fc1_w_off + j * INPUT_DIM + i] * get_input(batch_idx, i);
393        }
394        h1[j] = max(sum, 0.0);  // ReLU
395    }
396
397    // FC2: hidden1(32) -> hidden2(16) + ReLU
398    var h2: array<f32, 16>;
399    let fc2_w_off = base + E_FC1_W + E_FC1_B;
400    let fc2_b_off = fc2_w_off + E_FC2_W;
401    for (var j = 0u; j < HIDDEN2; j++) {
402        var sum = weights[fc2_b_off + j];
403        for (var i = 0u; i < HIDDEN1; i++) {
404            sum += weights[fc2_w_off + j * HIDDEN1 + i] * h1[i];
405        }
406        h2[j] = max(sum, 0.0);  // ReLU
407    }
408
409    // FC3: hidden2(16) -> output(1)
410    let fc3_w_off = base + E_FC1_W + E_FC1_B + E_FC2_W + E_FC2_B;
411    let fc3_b_off = fc3_w_off + E_FC3_W;
412    var out = weights[fc3_b_off];
413    for (var i = 0u; i < HIDDEN2; i++) {
414        out += weights[fc3_w_off + i] * h2[i];
415    }
416    return out;
417}
418
419@compute @workgroup_size(64)
420fn moe_forward(@builtin(global_invocation_id) gid: vec3<u32>) {
421    let idx = gid.x;
422    if (idx >= params.batch_size) {
423        return;
424    }
425
426    // Compute gate logits and softmax
427    var gate_logits: array<f32, 6>;
428    var max_logit = -1e30;
429    for (var e = 0u; e < EXPERT_COUNT; e++) {
430        gate_logits[e] = gate_dot(idx, e);
431        max_logit = max(max_logit, gate_logits[e]);
432    }
433
434    var exp_sum = 0.0;
435    var gate_probs: array<f32, 6>;
436    for (var e = 0u; e < EXPERT_COUNT; e++) {
437        gate_probs[e] = exp(gate_logits[e] - max_logit);
438        exp_sum += gate_probs[e];
439    }
440    for (var e = 0u; e < EXPERT_COUNT; e++) {
441        gate_probs[e] /= exp_sum;
442    }
443
444    // Weighted sum of expert outputs
445    var score_logit = 0.0;
446    for (var e = 0u; e < EXPERT_COUNT; e++) {
447        score_logit += gate_probs[e] * expert_forward(idx, e);
448    }
449
450    // Sigmoid
451    outputs[idx] = 1.0 / (1.0 + exp(-score_logit));
452}
453"#;
454}
455
456/// Score multiple (credential, context) pairs in a single batch.
457///
458/// Uses GPU compute shaders when available and the batch is large enough.
459/// Falls back to CPU for small batches or when no GPU is present.
460/// Score a batch of `(text, context)` candidates, using GPU when available.
461///
462/// # Examples
463///
464/// ```rust,ignore
465/// use keyhog_scanner::gpu::batch_ml_inference;
466/// use keyhog_scanner::ScannerConfig;
467/// let config = ScannerConfig::default();
468/// let scores = batch_ml_inference(&[("demo_ABC12345".into(), "API_KEY=".into())], &config);
469/// assert_eq!(scores.len(), 1);
470/// ```
471pub fn batch_ml_inference(
472    candidates: &[(String, String)],
473    config: &crate::types::ScannerConfig,
474) -> Vec<f64> {
475    if candidates.is_empty() {
476        return Vec::new();
477    }
478
479    #[cfg(feature = "ml")]
480    {
481        // Try GPU batch inference
482        #[cfg(feature = "gpu")]
483        {
484            let features: Vec<[f32; 41]> = candidates
485                .iter()
486                .map(|(text, ctx)| {
487                    crate::ml_scorer::compute_features_with_config(
488                        text,
489                        ctx,
490                        &config.known_prefixes,
491                        &config.secret_keywords,
492                        &config.test_keywords,
493                        &config.placeholder_keywords,
494                    )
495                })
496                .collect();
497
498            if let Some(scores) = backend::batch_score_features(&features) {
499                return scores;
500            }
501        }
502
503        // CPU fallback
504        candidates
505            .iter()
506            .map(|(text, ctx)| {
507                crate::ml_scorer::score_with_config(
508                    text,
509                    ctx,
510                    &config.known_prefixes,
511                    &config.secret_keywords,
512                    &config.test_keywords,
513                    &config.placeholder_keywords,
514                )
515            })
516            .collect()
517    }
518
519    #[cfg(not(feature = "ml"))]
520    {
521        let _ = candidates;
522        let _ = config;
523        Vec::new()
524    }
525}
526
527/// Check if GPU acceleration is available.
528/// Return `true` when GPU scoring support is available in this build/runtime.
529///
530/// # Examples
531///
532/// ```rust
533/// use keyhog_scanner::gpu::gpu_available;
534/// let _ = gpu_available();
535/// ```
536pub fn gpu_available() -> bool {
537    #[cfg(feature = "gpu")]
538    {
539        backend::get_gpu().is_some()
540    }
541    #[cfg(not(feature = "gpu"))]
542    {
543        false
544    }
545}
546
547/// Probe GPU availability and adapter metadata without panicking.
548#[must_use]
549pub fn gpu_probe() -> (bool, Option<String>, Option<u64>) {
550    #[cfg(feature = "gpu")]
551    {
552        if let Some(gpu) = backend::get_gpu() {
553            return (true, Some(gpu.gpu_name().to_string()), gpu.vram_mb());
554        }
555        (false, None, None)
556    }
557
558    #[cfg(not(feature = "gpu"))]
559    {
560        (false, None, None)
561    }
562}