1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
//! N8 substrate: predicted-next-shape fingerprint API.
//!
//! Async dispatch path already exists (D3 / D7); the wait window
//! between submission and completion is dead CPU time. This module
//! owns the *prediction*: given recent dispatch fingerprints, what
//! is the most likely next dispatch? The runtime can then prefetch
//! the predicted pipeline cache key during the wait.
//!
//! Three prediction strategies, in order of preference:
//!
//! 1. **Repeat** - same fingerprint as the immediate predecessor
//! (covers tight loops dispatching the same kernel).
//! 2. **Cycle of length N** - fingerprint = the one N steps ago, even when
//! only a partial next cycle has been observed (covers attention's Q, K, V,
//! scale, softmax, attend cycle before the second full cycle completes).
//! 3. **None** - history too sparse to predict; runtime skips the
//! prefetch this iteration.
//!
//! Pure analysis; allocation-free after construction.
/// Fingerprint type the predictor operates over. Same shape as
/// [`crate::launch::program_vsa_fingerprint_words`] returns; the
/// callsite passes an opaque 8-word fingerprint.
pub type ShapeFingerprint = [u32; 8];
/// Bounded ring buffer of recent dispatch fingerprints. The
/// predictor looks back at most [`MAX_HISTORY`] entries.
#[derive(Debug, Clone)]
pub struct ShapeHistory {
entries: [ShapeFingerprint; MAX_HISTORY],
start: usize,
len: usize,
}
/// Maximum number of historical fingerprints retained for prediction.
/// 16 is enough to catch attention-style 6-step cycles with one
/// repeat, and small enough to scan in O(N²) at predict time
/// without a measurable cost.
pub const MAX_HISTORY: usize = 16;
impl Default for ShapeHistory {
fn default() -> Self {
Self::new()
}
}
impl ShapeHistory {
/// Empty history - no prediction is possible.
#[must_use]
pub fn new() -> Self {
Self {
entries: [[0u32; 8]; MAX_HISTORY],
start: 0,
len: 0,
}
}
/// Record a dispatch fingerprint. The predictor uses the most
/// recent [`MAX_HISTORY`] entries to predict the next.
pub fn record(&mut self, fingerprint: ShapeFingerprint) {
if self.len < MAX_HISTORY {
let idx = (self.start + self.len) % MAX_HISTORY;
self.entries[idx] = fingerprint;
self.len += 1;
} else {
self.entries[self.start] = fingerprint;
self.start = (self.start + 1) % MAX_HISTORY;
}
}
/// Most recent fingerprint, or `None` if history is empty.
#[must_use]
pub fn latest(&self) -> Option<&ShapeFingerprint> {
if self.len == 0 {
return None;
}
Some(&self.entries[(self.start + self.len - 1) % MAX_HISTORY])
}
/// Number of entries currently retained.
#[must_use]
pub fn len(&self) -> usize {
self.len
}
/// True when no entries have been recorded yet.
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
/// True when the retained history window contains `fingerprint`.
///
/// This lets backend-side prediction caches evict cloned predicted
/// programs that can no longer be predicted by the bounded history.
#[must_use]
pub fn contains(&self, fingerprint: &ShapeFingerprint) -> bool {
(0..self.len).any(|idx| self.get(idx) == *fingerprint)
}
fn get(&self, logical_idx: usize) -> ShapeFingerprint {
debug_assert!(logical_idx < self.len);
self.entries[(self.start + logical_idx) % MAX_HISTORY]
}
/// Predict the next dispatch fingerprint. Returns `None` when
/// the history is too sparse or no pattern matches.
///
/// Strategy:
/// 1. If the last two entries are equal, predict another repeat.
/// 2. Otherwise, look for the smallest cycle length `N` such that every
/// retained entry with an entry `N` positions earlier matches it.
/// This predicts partial cycles as soon as one lag agrees, e.g.
/// `A,B,C,A,B -> C`, instead of waiting for `A,B,C,A,B,C`.
/// 3. No prediction.
#[must_use]
pub fn predict_next(&self) -> Option<ShapeFingerprint> {
let n = self.len;
if n == 0 {
return None;
}
// Strategy 1: repeat.
if n >= 2 && self.get(n - 1) == self.get(n - 2) {
return Some(self.get(n - 1));
}
// Strategy 2: cycle of length 2..n. Partial-cycle detection matters
// for prefetch: after A,B,C,A,B the next useful fingerprint is C, and
// waiting for A,B,C,A,B,C loses one dispatch worth of overlap.
for cycle in 2..n {
let mut matches = true;
for i in cycle..n {
if self.get(i) != self.get(i - cycle) {
matches = false;
break;
}
}
if matches {
return Some(self.get(n - cycle));
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fp(seed: u32) -> ShapeFingerprint {
let mut a = [0u32; 8];
for (i, slot) in a.iter_mut().enumerate() {
*slot = seed.wrapping_mul(31).wrapping_add(i as u32);
}
a
}
#[test]
fn empty_history_predicts_nothing() {
let h = ShapeHistory::new();
assert!(h.predict_next().is_none());
}
#[test]
fn single_entry_history_cannot_predict() {
let mut h = ShapeHistory::new();
h.record(fp(1));
assert!(h.predict_next().is_none());
}
#[test]
fn repeated_fingerprint_predicts_repeat() {
let mut h = ShapeHistory::new();
h.record(fp(1));
h.record(fp(1));
assert_eq!(h.predict_next(), Some(fp(1)));
}
#[test]
fn two_step_cycle_is_predicted() {
let mut h = ShapeHistory::new();
h.record(fp(1));
h.record(fp(2));
h.record(fp(1));
h.record(fp(2));
assert_eq!(h.predict_next(), Some(fp(1)));
}
#[test]
fn three_step_cycle_is_predicted() {
let mut h = ShapeHistory::new();
h.record(fp(1));
h.record(fp(2));
h.record(fp(3));
h.record(fp(1));
h.record(fp(2));
h.record(fp(3));
assert_eq!(h.predict_next(), Some(fp(1)));
}
#[test]
fn partial_three_step_cycle_is_predicted_before_second_cycle_completes() {
let mut h = ShapeHistory::new();
h.record(fp(1));
h.record(fp(2));
h.record(fp(3));
h.record(fp(1));
h.record(fp(2));
assert_eq!(h.predict_next(), Some(fp(3)));
}
#[test]
fn partial_long_cycle_prefetches_next_phase() {
let mut h = ShapeHistory::new();
for seed in [10, 20, 30, 40, 10, 20, 30] {
h.record(fp(seed));
}
assert_eq!(h.predict_next(), Some(fp(40)));
}
#[test]
fn no_pattern_means_no_prediction() {
let mut h = ShapeHistory::new();
h.record(fp(1));
h.record(fp(2));
h.record(fp(3));
h.record(fp(4));
assert!(h.predict_next().is_none());
}
#[test]
fn history_caps_at_max_entries() {
let mut h = ShapeHistory::new();
for i in 0..(MAX_HISTORY + 5) {
h.record(fp(i as u32));
}
assert_eq!(h.len(), MAX_HISTORY);
// Earliest entry is fp(5), latest is fp(MAX_HISTORY+4).
assert_eq!(h.latest(), Some(&fp((MAX_HISTORY + 4) as u32)));
assert!(!h.contains(&fp(0)));
assert!(h.contains(&fp(5)));
assert!(h.contains(&fp((MAX_HISTORY + 4) as u32)));
}
}