keyhog_scanner/engine/gpu_program_fusion.rs
1//! GPU program fusion - collapses multiple sequential vyre `Program` dispatches
2//! into a single fused program for single-GPU-dispatch execution.
3//!
4//! keyhog currently dispatches the AC literal-set program, decode programs,
5//! and MoE scoring programs sequentially. Vyre's `fuse_programs` /
6//! `fuse_programs_vec` merge compatible programs into one fused `Program`,
7//! eliminating per-dispatch overhead (encoder record, submit, poll) and
8//! enabling cross-program data reuse on-chip.
9//!
10//! # Design
11//!
12//! At scanner compile time, this module attempts to fuse the AC literal-set
13//! program with any active decode programs into a single `vyre::Program`.
14//! The fused program is cached alongside individual programs in the same
15//! on-disk cache directory (`~/.cache/keyhog/programs/`), keyed by a
16//! SHA-256 of the constituent program IR hashes.
17//!
18//! If fusion fails (incompatible buffer layouts, over-dispatch geometry,
19//! self-aliasing), the module logs the failure and the scanner falls back
20//! to sequential dispatch. This is a pure optimization - correctness is
21//! never compromised.
22//!
23//! # Usage
24//!
25//! The fused program is lazily initialized via `OnceLock` on first access.
26//! `CompiledScanner::fused_program()` returns `Option<&vyre::Program>`.
27//! The dispatch path in `gpu_dispatch.rs` checks for the fused program
28//! first and uses it in preference to sequential individual dispatches.
29
30use super::CompiledScanner;
31
32/// On-disk cache version for fused programs. Bumped whenever the fusion
33/// IR layout or the constituent program shapes change in a way that
34/// invalidates previously cached fused blobs.
35const FUSED_CACHE_VERSION: u32 = 1;
36
37impl CompiledScanner {
38 /// Lazily build a fused `Program` that merges the AC literal-set
39 /// program with the rule pipeline program (when available) into a
40 /// single GPU dispatch.
41 ///
42 /// Returns `None` when:
43 /// - No AC GPU program is available (no GPU adapter, no literals).
44 /// - Fusion fails due to incompatible buffer layouts, over-dispatch
45 /// geometry, or self-aliasing constraints.
46 /// - Only one program is available (fusion is identity; we skip the
47 /// overhead of the fused wrapper and dispatch the original directly).
48 ///
49 /// The fused program is cached on disk alongside individual programs
50 /// so cold starts after the first successful fusion are free.
51 pub fn fused_program(&self) -> Option<&vyre::Program> {
52 self.fused_program
53 .get_or_init(|| {
54 let ac_program = self.ac_gpu_program()?;
55 // Collect all programs eligible for fusion. Currently:
56 // 1. AC bounded-ranges program (always present if GPU path is active)
57 // 2. Rule pipeline program (when regex NFA compile succeeds)
58 //
59 // Future: decode programs, MoE scoring programs.
60 let mut programs: Vec<&vyre::Program> = vec![ac_program];
61
62 if let Some(pipeline) = self.rule_pipeline() {
63 programs.push(&pipeline.program);
64 }
65
66 // Single program → fusion is identity; skip overhead.
67 if programs.len() < 2 {
68 tracing::debug!(
69 target: "keyhog::gpu",
70 programs = programs.len(),
71 "program fusion skipped - fewer than 2 eligible programs"
72 );
73 return None;
74 }
75
76 let started = std::time::Instant::now();
77 match vyre_libs::scan::fuse_programs(
78 &programs.iter().map(|p| (*p).clone()).collect::<Vec<_>>(),
79 ) {
80 Ok(fused) => {
81 let elapsed_ms = started.elapsed().as_millis();
82 tracing::info!(
83 target: "keyhog::gpu",
84 input_programs = programs.len(),
85 fused_buffers = fused.buffers().len(),
86 fused_workgroup = ?fused.workgroup_size(),
87 elapsed_ms,
88 "program fusion succeeded - single GPU dispatch active"
89 );
90 // Attempt to cache the fused program on disk.
91 self.cache_fused_program(&fused, &programs);
92 Some(fused)
93 }
94 Err(error) => {
95 tracing::debug!(
96 target: "keyhog::gpu",
97 input_programs = programs.len(),
98 error = %error,
99 "program fusion failed - falling back to sequential dispatch. \
100 Common causes: incompatible buffer layouts, over-dispatch geometry, \
101 or self-aliasing constraints."
102 );
103 None
104 }
105 }
106 })
107 .as_ref()
108 }
109
110 /// Cache a fused program to disk for cold-start acceleration.
111 /// Mirrors the atomic-rename protocol used by `GpuLiteralSet` and
112 /// `RulePipeline` caching.
113 fn cache_fused_program(&self, fused: &vyre::Program, _programs: &[&vyre::Program]) {
114 let Some(cache_dir) = super::gpu_cache::gpu_matcher_cache_dir() else {
115 return;
116 };
117 let cache_key = format!("fused-{}", fused_cache_key(fused));
118 let Some(path) = vyre_libs::scan::engine_cache_path(&cache_dir, &cache_key) else {
119 return;
120 };
121 let bytes = fused.to_bytes();
122 let tmp = path.with_extension(format!("tmp.{}", std::process::id()));
123 if let Some(parent) = path.parent() {
124 let _ = std::fs::create_dir_all(parent);
125 }
126 if std::fs::write(&tmp, &bytes).is_ok() {
127 if let Err(error) = std::fs::rename(&tmp, &path) {
128 tracing::debug!(
129 target: "keyhog::gpu",
130 error = %error,
131 path = %path.display(),
132 "fused program cache rename failed"
133 );
134 let _ = std::fs::remove_file(&tmp);
135 }
136 }
137 }
138}
139
140/// Compute a SHA-256 cache key for a fused program based on its
141/// serialized IR bytes and the fusion cache version.
142fn fused_cache_key(program: &vyre::Program) -> String {
143 use sha2::{Digest, Sha256};
144 let mut h = Sha256::new();
145 h.update(FUSED_CACHE_VERSION.to_le_bytes());
146 let ir_bytes = program.to_bytes();
147 h.update((ir_bytes.len() as u64).to_le_bytes());
148 h.update(&ir_bytes);
149 let digest = h.finalize();
150 let mut hex = String::with_capacity(64);
151 for byte in digest {
152 use std::fmt::Write as _;
153 let _ = write!(hex, "{:02x}", byte);
154 }
155 hex
156}
157
158pub const FUSION_CACHE_VERSION: u32 = 1;
159
160pub fn try_fuse(programs: &[&vyre::Program]) -> std::result::Result<vyre::Program, String> {
161 if programs.is_empty() {
162 return Err("Cannot fuse empty program list".to_string());
163 }
164 let owned_programs: Vec<vyre::Program> = programs.iter().map(|p| (*p).clone()).collect();
165 vyre_libs::scan::fuse_programs(&owned_programs).map_err(|e| e.to_string())
166}
167
168pub fn fuse_or_fallback(programs: &[&vyre::Program]) -> Option<vyre::Program> {
169 try_fuse(programs).ok()
170}
171
172pub fn fusion_cache_key(programs: &[&vyre::Program]) -> String {
173 use sha2::{Digest, Sha256};
174 let mut h = Sha256::new();
175 h.update(FUSION_CACHE_VERSION.to_le_bytes());
176 for p in programs {
177 let ir_bytes = p.to_bytes();
178 h.update((ir_bytes.len() as u64).to_le_bytes());
179 h.update(&ir_bytes);
180 }
181 let digest = h.finalize();
182 let mut hex = String::with_capacity(64);
183 for byte in digest {
184 use std::fmt::Write as _;
185 let _ = write!(hex, "{:02x}", byte);
186 }
187 hex
188}