1use vyre_driver::autotune_store::{AutotuneRecord, AutotuneStore};
10use vyre_driver::speculate::{
11 record_speculative_variant_race, SpeculativeVariantDecision, SpeculativeVariantKeys,
12 SpeculativeVariantRace,
13};
14use vyre_driver::speculation_substrate::{
15 decide_speculation, SpeculationObservation, SpeculationVerdict,
16};
17
18#[derive(Debug, Clone)]
20pub struct PairedSpeculationSample {
21 pub conservative_dispatch_ns: u64,
23 pub speculative_dispatch_ns: u64,
25 pub conservative_compile_ns: u64,
27 pub speculative_compile_ns: u64,
29 pub conservative_record: AutotuneRecord,
31 pub speculative_record: AutotuneRecord,
33}
34
35#[derive(Debug, Clone)]
37pub struct PairedSpeculationUpdate {
38 pub race_decision: SpeculativeVariantDecision,
40 pub verdict: SpeculationVerdict,
42 pub observation: SpeculationObservation,
44}
45
46#[derive(Debug, Default, Clone)]
48pub struct PairedSpeculationWindow {
49 conservative: RunningMean,
50 speculative: RunningMean,
51 side_compile_cost_ns: u64,
52}
53
54impl PairedSpeculationWindow {
55 #[must_use]
57 pub const fn new() -> Self {
58 Self {
59 conservative: RunningMean::new(),
60 speculative: RunningMean::new(),
61 side_compile_cost_ns: 0,
62 }
63 }
64
65 #[must_use]
67 pub fn len(&self) -> u32 {
68 self.conservative.count.min(self.speculative.count)
69 }
70
71 #[must_use]
73 pub fn is_empty(&self) -> bool {
74 self.len() == 0
75 }
76
77 #[must_use]
79 pub fn observation(&self) -> SpeculationObservation {
80 SpeculationObservation {
81 baseline_dispatches: self.conservative.count,
82 baseline_mean_ns: self.conservative.mean_ns(),
83 speculative_dispatches: self.speculative.count,
84 speculative_mean_ns: self.speculative.mean_ns(),
85 side_compile_cost_ns: self.side_compile_cost_ns,
86 }
87 }
88
89 pub fn record_sample(
92 &mut self,
93 store: &mut AutotuneStore,
94 keys: SpeculativeVariantKeys<'_>,
95 sample: PairedSpeculationSample,
96 ) -> PairedSpeculationUpdate {
97 self.conservative.record(sample.conservative_dispatch_ns);
98 self.speculative.record(sample.speculative_dispatch_ns);
99 self.side_compile_cost_ns = self
100 .side_compile_cost_ns
101 .checked_add(sample.speculative_compile_ns)
102 .unwrap_or_else(|| {
103 panic!(
104 "paired speculation side compile cost overflowed u64. Fix: reset the speculation window before accumulating more samples."
105 )
106 });
107
108 let race_decision = record_speculative_variant_race(
109 store,
110 keys,
111 SpeculativeVariantRace {
112 conservative_dispatch_ns: sample.conservative_dispatch_ns,
113 speculative_dispatch_ns: sample.speculative_dispatch_ns,
114 conservative_compile_ns: sample.conservative_compile_ns,
115 speculative_compile_ns: sample.speculative_compile_ns,
116 conservative_record: sample.conservative_record,
117 speculative_record: sample.speculative_record,
118 },
119 );
120 let observation = self.observation();
121 let verdict = decide_speculation(observation);
122 PairedSpeculationUpdate {
123 race_decision,
124 verdict,
125 observation,
126 }
127 }
128}
129
130#[derive(Debug, Default, Clone)]
131struct RunningMean {
132 count: u32,
133 total_ns: u128,
134}
135
136impl RunningMean {
137 const fn new() -> Self {
138 Self {
139 count: 0,
140 total_ns: 0,
141 }
142 }
143
144 fn record(&mut self, value_ns: u64) {
145 self.count = self.count.checked_add(1).unwrap_or_else(|| {
146 panic!(
147 "paired speculation sample count overflowed u32. Fix: reset the speculation window before accumulating more samples."
148 )
149 });
150 self.total_ns = self.total_ns.checked_add(u128::from(value_ns)).unwrap_or_else(|| {
151 panic!(
152 "paired speculation total nanoseconds overflowed u128. Fix: reset the speculation window before accumulating more samples."
153 )
154 });
155 }
156
157 fn mean_ns(&self) -> u64 {
158 if self.count == 0 {
159 return 0;
160 }
161 let mean = self.total_ns / u128::from(self.count);
162 u64::try_from(mean).unwrap_or_else(|error| {
163 panic!(
164 "paired speculation mean nanoseconds cannot fit u64: {error}. Fix: reset the speculation window before accumulating more samples."
165 )
166 })
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use vyre_driver::specialization::SpecCacheKey;
174 use vyre_driver::speculate::SpeculativeVariantKind;
175
176 fn key(id: u64) -> SpecCacheKey {
177 SpecCacheKey {
178 shader_hash: id,
179 binding_sig: id << 8,
180 workgroup_size: [64, 1, 1],
181 spec_hash: id << 16,
182 }
183 }
184
185 fn record(workgroup: u32) -> AutotuneRecord {
186 AutotuneRecord {
187 workgroup_size: [workgroup, 1, 1],
188 unroll: 1,
189 tile: [0, 0, 0],
190 recorded_at: "2026-05-02".to_string(),
191 }
192 }
193
194 fn sample(conservative_ns: u64, speculative_ns: u64) -> PairedSpeculationSample {
195 PairedSpeculationSample {
196 conservative_dispatch_ns: conservative_ns,
197 speculative_dispatch_ns: speculative_ns,
198 conservative_compile_ns: 0,
199 speculative_compile_ns: 0,
200 conservative_record: record(64),
201 speculative_record: record(128),
202 }
203 }
204
205 #[test]
206 fn paired_window_keeps_racing_under_threshold() {
207 let mut store = AutotuneStore::default();
208 let conservative = key(1);
209 let speculative = key(2);
210 let keys = SpeculativeVariantKeys {
211 conservative: &conservative,
212 speculative: &speculative,
213 adapter_id: "test-adapter",
214 };
215 let mut window = PairedSpeculationWindow::new();
216 let update = window.record_sample(&mut store, keys, sample(100_000, 50_000));
217 assert_eq!(update.verdict, SpeculationVerdict::KeepRacing);
218 assert_eq!(update.observation.baseline_dispatches, 1);
219 assert_eq!(update.observation.speculative_dispatches, 1);
220 }
221
222 #[test]
223 fn paired_window_adopts_after_sustained_win() {
224 let mut store = AutotuneStore::default();
225 let conservative = key(3);
226 let speculative = key(4);
227 let keys = SpeculativeVariantKeys {
228 conservative: &conservative,
229 speculative: &speculative,
230 adapter_id: "test-adapter",
231 };
232 let mut window = PairedSpeculationWindow::new();
233 let mut last = None;
234 for _ in 0..8 {
235 last = Some(window.record_sample(&mut store, keys, sample(100_000, 50_000)));
236 }
237 let update = last.expect("Fix: loop records at least one sample");
238 assert_eq!(update.verdict, SpeculationVerdict::Adopt);
239 assert_eq!(
240 update.race_decision.winner,
241 SpeculativeVariantKind::Speculative
242 );
243 assert_eq!(store.len(), 1);
244 }
245
246 #[test]
247 fn paired_window_rejects_sustained_loss() {
248 let mut store = AutotuneStore::default();
249 let conservative = key(5);
250 let speculative = key(6);
251 let keys = SpeculativeVariantKeys {
252 conservative: &conservative,
253 speculative: &speculative,
254 adapter_id: "test-adapter",
255 };
256 let mut window = PairedSpeculationWindow::new();
257 let mut verdict = SpeculationVerdict::KeepRacing;
258 for _ in 0..8 {
259 verdict = window
260 .record_sample(&mut store, keys, sample(50_000, 100_000))
261 .verdict;
262 }
263 assert_eq!(verdict, SpeculationVerdict::Reject);
264 }
265
266 #[test]
267 fn paired_window_amortizes_speculative_compile_cost() {
268 let mut store = AutotuneStore::default();
269 let conservative = key(7);
270 let speculative = key(8);
271 let keys = SpeculativeVariantKeys {
272 conservative: &conservative,
273 speculative: &speculative,
274 adapter_id: "test-adapter",
275 };
276 let mut window = PairedSpeculationWindow::new();
277 let mut update = None;
278 for _ in 0..8 {
279 let mut s = sample(100_000, 50_000);
280 s.speculative_compile_ns = 1_000_000;
281 update = Some(window.record_sample(&mut store, keys, s));
282 }
283 let update = update.expect("Fix: loop records at least one sample");
284 assert_eq!(update.verdict, SpeculationVerdict::Reject);
285 assert_eq!(update.observation.side_compile_cost_ns, 8_000_000);
286 }
287}