Skip to main content

keyhog_scanner/engine/
gpu_program_fusion.rs

1//! GPU program fusion - collapses multiple sequential vyre `Program` dispatches
2//! into a single fused program for single-GPU-dispatch execution.
3//!
4//! keyhog currently dispatches the AC literal-set program, decode programs,
5//! and MoE scoring programs sequentially. Vyre's `fuse_programs` /
6//! `fuse_programs_vec` merge compatible programs into one fused `Program`,
7//! eliminating per-dispatch overhead (encoder record, submit, poll) and
8//! enabling cross-program data reuse on-chip.
9//!
10//! # Design
11//!
12//! At scanner compile time, this module attempts to fuse the AC literal-set
13//! program with any active decode programs into a single `vyre::Program`.
14//! The fused program is cached alongside individual programs in the same
15//! on-disk cache directory (`~/.cache/keyhog/programs/`), keyed by a
16//! SHA-256 of the constituent program IR hashes.
17//!
18//! If fusion fails (incompatible buffer layouts, over-dispatch geometry,
19//! self-aliasing), the module logs the failure and the scanner falls back
20//! to sequential dispatch. This is a pure optimization - correctness is
21//! never compromised.
22//!
23//! # Usage
24//!
25//! The fused program is lazily initialized via `OnceLock` on first access.
26//! `CompiledScanner::fused_program()` returns `Option<&vyre::Program>`.
27//! The dispatch path in `gpu_dispatch.rs` checks for the fused program
28//! first and uses it in preference to sequential individual dispatches.
29
30use super::CompiledScanner;
31
32/// On-disk cache version for fused programs. Bumped whenever the fusion
33/// IR layout or the constituent program shapes change in a way that
34/// invalidates previously cached fused blobs.
35const FUSED_CACHE_VERSION: u32 = 1;
36
37impl CompiledScanner {
38    /// Lazily build a fused `Program` that merges the AC literal-set
39    /// program with the rule pipeline program (when available) into a
40    /// single GPU dispatch.
41    ///
42    /// Returns `None` when:
43    /// - No AC GPU program is available (no GPU adapter, no literals).
44    /// - Fusion fails due to incompatible buffer layouts, over-dispatch
45    ///   geometry, or self-aliasing constraints.
46    /// - Only one program is available (fusion is identity; we skip the
47    ///   overhead of the fused wrapper and dispatch the original directly).
48    ///
49    /// The fused program is cached on disk alongside individual programs
50    /// so cold starts after the first successful fusion are free.
51    pub fn fused_program(&self) -> Option<&vyre::Program> {
52        self.fused_program
53            .get_or_init(|| {
54                let ac_program = self.ac_gpu_program()?;
55                // Collect all programs eligible for fusion. Currently:
56                //   1. AC bounded-ranges program (always present if GPU path is active)
57                //   2. Rule pipeline program (when regex NFA compile succeeds)
58                //
59                // Future: decode programs, MoE scoring programs.
60                let mut programs: Vec<&vyre::Program> = vec![ac_program];
61
62                if let Some(pipeline) = self.rule_pipeline() {
63                    programs.push(&pipeline.program);
64                }
65
66                // Single program → fusion is identity; skip overhead.
67                if programs.len() < 2 {
68                    tracing::debug!(
69                        target: "keyhog::gpu",
70                        programs = programs.len(),
71                        "program fusion skipped - fewer than 2 eligible programs"
72                    );
73                    return None;
74                }
75
76                let started = std::time::Instant::now();
77                match vyre_libs::scan::fuse_programs(
78                    &programs.iter().map(|p| (*p).clone()).collect::<Vec<_>>(),
79                ) {
80                    Ok(fused) => {
81                        let elapsed_ms = started.elapsed().as_millis();
82                        tracing::info!(
83                            target: "keyhog::gpu",
84                            input_programs = programs.len(),
85                            fused_buffers = fused.buffers().len(),
86                            fused_workgroup = ?fused.workgroup_size(),
87                            elapsed_ms,
88                            "program fusion succeeded - single GPU dispatch active"
89                        );
90                        // Attempt to cache the fused program on disk.
91                        self.cache_fused_program(&fused, &programs);
92                        Some(fused)
93                    }
94                    Err(error) => {
95                        tracing::debug!(
96                            target: "keyhog::gpu",
97                            input_programs = programs.len(),
98                            error = %error,
99                            "program fusion failed - falling back to sequential dispatch. \
100                             Common causes: incompatible buffer layouts, over-dispatch geometry, \
101                             or self-aliasing constraints."
102                        );
103                        None
104                    }
105                }
106            })
107            .as_ref()
108    }
109
110    /// Cache a fused program to disk for cold-start acceleration.
111    /// Mirrors the atomic-rename protocol used by `GpuLiteralSet` and
112    /// `RulePipeline` caching.
113    fn cache_fused_program(&self, fused: &vyre::Program, _programs: &[&vyre::Program]) {
114        let Some(cache_dir) = super::gpu_cache::gpu_matcher_cache_dir() else {
115            return;
116        };
117        let cache_key = format!("fused-{}", fused_cache_key(fused));
118        let Some(path) = vyre_libs::scan::engine_cache_path(&cache_dir, &cache_key) else {
119            return;
120        };
121        let bytes = fused.to_bytes();
122        let tmp = path.with_extension(format!("tmp.{}", std::process::id()));
123        if let Some(parent) = path.parent() {
124            let _ = std::fs::create_dir_all(parent);
125        }
126        if std::fs::write(&tmp, &bytes).is_ok() {
127            if let Err(error) = std::fs::rename(&tmp, &path) {
128                tracing::debug!(
129                    target: "keyhog::gpu",
130                    error = %error,
131                    path = %path.display(),
132                    "fused program cache rename failed"
133                );
134                let _ = std::fs::remove_file(&tmp);
135            }
136        }
137    }
138}
139
140/// Compute a SHA-256 cache key for a fused program based on its
141/// serialized IR bytes and the fusion cache version.
142fn fused_cache_key(program: &vyre::Program) -> String {
143    use sha2::{Digest, Sha256};
144    let mut h = Sha256::new();
145    h.update(FUSED_CACHE_VERSION.to_le_bytes());
146    let ir_bytes = program.to_bytes();
147    h.update((ir_bytes.len() as u64).to_le_bytes());
148    h.update(&ir_bytes);
149    let digest = h.finalize();
150    let mut hex = String::with_capacity(64);
151    for byte in digest {
152        use std::fmt::Write as _;
153        let _ = write!(hex, "{:02x}", byte);
154    }
155    hex
156}
157
158pub const FUSION_CACHE_VERSION: u32 = 1;
159
160pub fn try_fuse(programs: &[&vyre::Program]) -> std::result::Result<vyre::Program, String> {
161    if programs.is_empty() {
162        return Err("Cannot fuse empty program list".to_string());
163    }
164    let owned_programs: Vec<vyre::Program> = programs.iter().map(|p| (*p).clone()).collect();
165    vyre_libs::scan::fuse_programs(&owned_programs).map_err(|e| e.to_string())
166}
167
168pub fn fuse_or_fallback(programs: &[&vyre::Program]) -> Option<vyre::Program> {
169    try_fuse(programs).ok()
170}
171
172pub fn fusion_cache_key(programs: &[&vyre::Program]) -> String {
173    use sha2::{Digest, Sha256};
174    let mut h = Sha256::new();
175    h.update(FUSION_CACHE_VERSION.to_le_bytes());
176    for p in programs {
177        let ir_bytes = p.to_bytes();
178        h.update((ir_bytes.len() as u64).to_le_bytes());
179        h.update(&ir_bytes);
180    }
181    let digest = h.finalize();
182    let mut hex = String::with_capacity(64);
183    for byte in digest {
184        use std::fmt::Write as _;
185        let _ = write!(hex, "{:02x}", byte);
186    }
187    hex
188}