1use vyre_driver::autotune_store::{AutotuneRecord, AutotuneStore};
10use vyre_driver::speculate::{
11 record_speculative_variant_race, SpeculativeVariantDecision, SpeculativeVariantKeys,
12 SpeculativeVariantRace,
13};
14use vyre_driver::speculation_substrate::{
15 decide_speculation, SpeculationObservation, SpeculationVerdict,
16};
17
18#[derive(Debug, Clone)]
20pub struct PairedSpeculationSample {
21 pub conservative_dispatch_ns: u64,
23 pub speculative_dispatch_ns: u64,
25 pub conservative_compile_ns: u64,
27 pub speculative_compile_ns: u64,
29 pub conservative_record: AutotuneRecord,
31 pub speculative_record: AutotuneRecord,
33}
34
35#[derive(Debug, Clone)]
37pub struct PairedSpeculationUpdate {
38 pub race_decision: SpeculativeVariantDecision,
40 pub verdict: SpeculationVerdict,
42 pub observation: SpeculationObservation,
44}
45
46#[derive(Debug, Default, Clone)]
48pub struct PairedSpeculationWindow {
49 conservative: RunningMean,
50 speculative: RunningMean,
51 side_compile_cost_ns: u64,
52}
53
54impl PairedSpeculationWindow {
55 #[must_use]
57 pub const fn new() -> Self {
58 Self {
59 conservative: RunningMean::new(),
60 speculative: RunningMean::new(),
61 side_compile_cost_ns: 0,
62 }
63 }
64
65 #[must_use]
67 pub fn len(&self) -> u32 {
68 self.conservative.count.min(self.speculative.count)
69 }
70
71 #[must_use]
73 pub fn is_empty(&self) -> bool {
74 self.len() == 0
75 }
76
77 #[must_use]
79 pub fn observation(&self) -> SpeculationObservation {
80 SpeculationObservation {
81 baseline_dispatches: self.conservative.count,
82 baseline_mean_ns: self.conservative.mean_ns(),
83 speculative_dispatches: self.speculative.count,
84 speculative_mean_ns: self.speculative.mean_ns(),
85 side_compile_cost_ns: self.side_compile_cost_ns,
86 }
87 }
88
89 pub fn record_sample(
92 &mut self,
93 store: &mut AutotuneStore,
94 keys: SpeculativeVariantKeys<'_>,
95 sample: PairedSpeculationSample,
96 ) -> PairedSpeculationUpdate {
97 self.conservative.record(sample.conservative_dispatch_ns);
98 self.speculative.record(sample.speculative_dispatch_ns);
99 self.side_compile_cost_ns = self
100 .side_compile_cost_ns
101 .saturating_add(sample.speculative_compile_ns);
102
103 let race_decision = record_speculative_variant_race(
104 store,
105 keys,
106 SpeculativeVariantRace {
107 conservative_dispatch_ns: sample.conservative_dispatch_ns,
108 speculative_dispatch_ns: sample.speculative_dispatch_ns,
109 conservative_compile_ns: sample.conservative_compile_ns,
110 speculative_compile_ns: sample.speculative_compile_ns,
111 conservative_record: sample.conservative_record,
112 speculative_record: sample.speculative_record,
113 },
114 );
115 let observation = self.observation();
116 let verdict = decide_speculation(observation);
117 PairedSpeculationUpdate {
118 race_decision,
119 verdict,
120 observation,
121 }
122 }
123}
124
125#[derive(Debug, Default, Clone)]
126struct RunningMean {
127 count: u32,
128 total_ns: u128,
129}
130
131impl RunningMean {
132 const fn new() -> Self {
133 Self {
134 count: 0,
135 total_ns: 0,
136 }
137 }
138
139 fn record(&mut self, value_ns: u64) {
140 self.count = self.count.saturating_add(1);
141 self.total_ns = self.total_ns.saturating_add(u128::from(value_ns));
142 }
143
144 fn mean_ns(&self) -> u64 {
145 if self.count == 0 {
146 return 0;
147 }
148 let mean = self.total_ns / u128::from(self.count);
149 match u64::try_from(mean) {
150 Ok(mean) => mean,
151 Err(_) => u64::MAX,
152 }
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use vyre_driver::specialization::SpecCacheKey;
160 use vyre_driver::speculate::SpeculativeVariantKind;
161
162 fn key(id: u64) -> SpecCacheKey {
163 SpecCacheKey {
164 shader_hash: id,
165 binding_sig: id << 8,
166 workgroup_size: [64, 1, 1],
167 spec_hash: id << 16,
168 }
169 }
170
171 fn record(workgroup: u32) -> AutotuneRecord {
172 AutotuneRecord {
173 workgroup_size: [workgroup, 1, 1],
174 unroll: 1,
175 tile: [0, 0, 0],
176 recorded_at: "2026-05-02".to_string(),
177 }
178 }
179
180 fn sample(conservative_ns: u64, speculative_ns: u64) -> PairedSpeculationSample {
181 PairedSpeculationSample {
182 conservative_dispatch_ns: conservative_ns,
183 speculative_dispatch_ns: speculative_ns,
184 conservative_compile_ns: 0,
185 speculative_compile_ns: 0,
186 conservative_record: record(64),
187 speculative_record: record(128),
188 }
189 }
190
191 #[test]
192 fn paired_window_keeps_racing_under_threshold() {
193 let mut store = AutotuneStore::default();
194 let conservative = key(1);
195 let speculative = key(2);
196 let keys = SpeculativeVariantKeys {
197 conservative: &conservative,
198 speculative: &speculative,
199 adapter_id: "test-adapter",
200 };
201 let mut window = PairedSpeculationWindow::new();
202 let update = window.record_sample(&mut store, keys, sample(100_000, 50_000));
203 assert_eq!(update.verdict, SpeculationVerdict::KeepRacing);
204 assert_eq!(update.observation.baseline_dispatches, 1);
205 assert_eq!(update.observation.speculative_dispatches, 1);
206 }
207
208 #[test]
209 fn paired_window_adopts_after_sustained_win() {
210 let mut store = AutotuneStore::default();
211 let conservative = key(3);
212 let speculative = key(4);
213 let keys = SpeculativeVariantKeys {
214 conservative: &conservative,
215 speculative: &speculative,
216 adapter_id: "test-adapter",
217 };
218 let mut window = PairedSpeculationWindow::new();
219 let mut last = None;
220 for _ in 0..8 {
221 last = Some(window.record_sample(&mut store, keys, sample(100_000, 50_000)));
222 }
223 let update = last.expect("Fix: loop records at least one sample");
224 assert_eq!(update.verdict, SpeculationVerdict::Adopt);
225 assert_eq!(
226 update.race_decision.winner,
227 SpeculativeVariantKind::Speculative
228 );
229 assert_eq!(store.len(), 1);
230 }
231
232 #[test]
233 fn paired_window_rejects_sustained_loss() {
234 let mut store = AutotuneStore::default();
235 let conservative = key(5);
236 let speculative = key(6);
237 let keys = SpeculativeVariantKeys {
238 conservative: &conservative,
239 speculative: &speculative,
240 adapter_id: "test-adapter",
241 };
242 let mut window = PairedSpeculationWindow::new();
243 let mut verdict = SpeculationVerdict::KeepRacing;
244 for _ in 0..8 {
245 verdict = window
246 .record_sample(&mut store, keys, sample(50_000, 100_000))
247 .verdict;
248 }
249 assert_eq!(verdict, SpeculationVerdict::Reject);
250 }
251
252 #[test]
253 fn paired_window_amortizes_speculative_compile_cost() {
254 let mut store = AutotuneStore::default();
255 let conservative = key(7);
256 let speculative = key(8);
257 let keys = SpeculativeVariantKeys {
258 conservative: &conservative,
259 speculative: &speculative,
260 adapter_id: "test-adapter",
261 };
262 let mut window = PairedSpeculationWindow::new();
263 let mut update = None;
264 for _ in 0..8 {
265 let mut s = sample(100_000, 50_000);
266 s.speculative_compile_ns = 1_000_000;
267 update = Some(window.record_sample(&mut store, keys, s));
268 }
269 let update = update.expect("Fix: loop records at least one sample");
270 assert_eq!(update.verdict, SpeculationVerdict::Reject);
271 assert_eq!(update.observation.side_compile_cost_ns, 8_000_000);
272 }
273}