1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
#[cfg(feature = "gpu")]
impl ContinuousBatchScheduler {
/// Lock the slots mutex, panicking on poison.
fn lock_slots(&self) -> std::sync::MutexGuard<'_, Vec<SlotState>> {
self.slots.lock().expect("Mutex poisoned")
}
/// Lock the completed queue mutex, panicking on poison.
fn lock_completed(&self) -> std::sync::MutexGuard<'_, Vec<(u64, Vec<u32>)>> {
self.completed.lock().expect("Mutex poisoned")
}
/// Lock the KV caches mutex, panicking on poison.
fn lock_caches(&self) -> std::sync::MutexGuard<'_, Vec<OwnedQuantizedKVCache>> {
self.caches.lock().expect("Mutex poisoned")
}
/// Create scheduler with specified number of slots
///
/// # Arguments
/// * `num_slots` - Maximum concurrent requests (typically 32-64)
/// * `num_layers` - Number of transformer layers (for KV cache)
/// * `hidden_dim` - Hidden dimension (for KV cache)
/// * `max_seq_len` - Maximum sequence length (for KV cache)
pub fn new(num_slots: usize, num_layers: usize, hidden_dim: usize, max_seq_len: usize) -> Self {
let slots = vec![SlotState::Empty; num_slots];
let caches = (0..num_slots)
.map(|_| OwnedQuantizedKVCache::new(num_layers, hidden_dim, max_seq_len))
.collect();
Self {
slots: std::sync::Mutex::new(slots),
caches: std::sync::Mutex::new(caches),
num_slots,
completed: std::sync::Mutex::new(Vec::new()),
next_id: std::sync::atomic::AtomicU64::new(0),
}
}
/// Count slots matching a predicate.
fn count_slots_where(&self, predicate: fn(&SlotState) -> bool) -> usize {
self.lock_slots().iter().filter(|s| predicate(s)).count()
}
/// Submit a new request to the scheduler
///
/// Returns request ID if slot available, None if all slots full
pub fn submit(
&self,
prompt_tokens: Vec<u32>,
max_tokens: usize,
temperature: f32,
top_k: usize,
) -> Option<u64> {
let mut slots = self.lock_slots();
// Find first empty slot
let empty_idx = slots.iter().position(SlotState::is_empty)?;
let request_id = self
.next_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
slots[empty_idx] = SlotState::Active {
request_id,
prompt_tokens,
generated_tokens: Vec::new(),
max_tokens,
temperature,
top_k,
};
Some(request_id)
}
/// Get number of active slots
pub fn active_count(&self) -> usize {
self.count_slots_where(SlotState::is_active)
}
/// Get number of empty slots
pub fn empty_count(&self) -> usize {
self.count_slots_where(SlotState::is_empty)
}
/// Check if any slot has completed request
pub fn has_completed(&self) -> bool {
!self.lock_completed().is_empty()
}
/// Retrieve completed request results
pub fn poll_completed(&self) -> Vec<(u64, Vec<u32>)> {
std::mem::take(&mut *self.lock_completed())
}
/// Mark a request as completed and move to completed queue
pub fn complete_request(&self, slot_idx: usize, tokens: Vec<u32>) {
let mut slots = self.lock_slots();
let mut completed = self.lock_completed();
if slot_idx < slots.len() {
if let SlotState::Active { request_id, .. } = &slots[slot_idx] {
let id = *request_id;
// Move to completed
completed.push((id, tokens));
// Free the slot
slots[slot_idx] = SlotState::Empty;
// Reset KV cache for this slot
self.lock_caches()[slot_idx].reset();
}
}
}
/// Get active slot indices and their current positions
pub fn get_active_slots(&self) -> Vec<(usize, usize)> {
self.lock_slots()
.iter()
.enumerate()
.filter_map(|(idx, slot)| match slot {
SlotState::Active {
prompt_tokens,
generated_tokens,
..
} => {
let pos = prompt_tokens.len() + generated_tokens.len();
Some((idx, pos))
},
_ => None,
})
.collect()
}
/// Get utilization (active_slots / total_slots)
pub fn utilization(&self) -> f64 {
let active = self.active_count();
active as f64 / self.num_slots as f64
}
}
/// Speculative decoding configuration (PARITY-029)
#[cfg(feature = "gpu")]
#[derive(Debug, Clone)]
pub struct SpeculativeConfig {
/// Number of tokens to speculatively generate per step
pub speculation_length: usize,
/// Temperature for draft model (lower = more deterministic)
pub draft_temperature: f32,
/// Whether to use same model for draft (self-speculative)
pub self_speculative: bool,
}
#[cfg(feature = "gpu")]
impl Default for SpeculativeConfig {
fn default() -> Self {
Self {
speculation_length: 4,
draft_temperature: 0.0,
self_speculative: true,
}
}
}
/// Result of speculative decoding verification step
#[cfg(feature = "gpu")]
#[derive(Debug, Clone)]
pub struct VerificationResult {
/// Number of draft tokens accepted
pub accepted_count: usize,
/// Total draft tokens generated
pub draft_count: usize,
/// Accepted tokens (verified by target model)
pub accepted_tokens: Vec<u32>,
/// Whether all draft tokens were accepted
pub all_accepted: bool,
}
/// Speculative decoder for accelerated token generation (PARITY-029)
///
/// Implements speculative decoding (Leviathan et al., 2023):
/// 1. Draft model generates K candidate tokens quickly
/// 2. Target model verifies all K tokens in parallel
/// 3. Accept tokens until first rejection, then resample
///
/// This enables O(K) speedup when draft acceptance rate is high.
#[cfg(feature = "gpu")]
pub struct SpeculativeDecoder {
/// Speculative decoding configuration
pub config: SpeculativeConfig,
/// Statistics: total draft tokens generated
pub total_draft_tokens: std::sync::atomic::AtomicU64,
/// Statistics: total draft tokens accepted
pub total_accepted_tokens: std::sync::atomic::AtomicU64,
}
#[cfg(feature = "gpu")]
impl SpeculativeDecoder {
/// Create a new `AtomicU64` initialized to zero.
fn new_counter() -> std::sync::atomic::AtomicU64 {
std::sync::atomic::AtomicU64::new(0)
}
/// Load an atomic counter with Relaxed ordering.
fn load_relaxed(counter: &std::sync::atomic::AtomicU64) -> u64 {
counter.load(std::sync::atomic::Ordering::Relaxed)
}
/// Add a value to an atomic counter (Relaxed ordering).
fn add(counter: &std::sync::atomic::AtomicU64, val: u64) {
counter.fetch_add(val, std::sync::atomic::Ordering::Relaxed);
}
/// Reset an atomic counter to zero (Relaxed ordering).
fn reset_counter(counter: &std::sync::atomic::AtomicU64) {
counter.store(0, std::sync::atomic::Ordering::Relaxed);
}
/// Create new speculative decoder with default config
pub fn new() -> Self {
Self {
config: SpeculativeConfig::default(),
total_draft_tokens: Self::new_counter(),
total_accepted_tokens: Self::new_counter(),
}
}
/// Create speculative decoder with custom config
pub fn with_config(config: SpeculativeConfig) -> Self {
Self {
config,
total_draft_tokens: Self::new_counter(),
total_accepted_tokens: Self::new_counter(),
}
}
/// Get acceptance rate (accepted / total draft tokens)
pub fn acceptance_rate(&self) -> f64 {
let total = Self::load_relaxed(&self.total_draft_tokens);
let accepted = Self::load_relaxed(&self.total_accepted_tokens);
if total == 0 {
return 0.0;
}
accepted as f64 / total as f64
}
/// Verify draft tokens against target model logits
///
/// # Arguments
/// * `draft_tokens` - Candidate tokens from draft model
/// * `target_logits` - Logits from target model for each position
/// * `temperature` - Sampling temperature for rejection sampling
///
/// # Returns
/// VerificationResult with accepted tokens and statistics
pub fn verify_draft(
&self,
draft_tokens: &[u32],
target_logits: &[Vec<f32>],
temperature: f32,
) -> VerificationResult {
let mut accepted_tokens = Vec::with_capacity(draft_tokens.len());
let mut accepted_count = 0;
// Verify each draft token against target model distribution
for (i, &draft_token) in draft_tokens.iter().enumerate() {
if i >= target_logits.len() {
break;
}
let logits = &target_logits[i];
// Find target model's top token
let (target_token, _) = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((0, &0.0));
// Accept if draft matches target (greedy case)
if temperature == 0.0 {
if draft_token == target_token as u32 {
accepted_tokens.push(draft_token);
accepted_count += 1;
} else {
// Reject and use target's token instead
accepted_tokens.push(target_token as u32);
accepted_count += 1;
break; // Stop at first mismatch
}
} else {
// Rejection sampling for non-greedy decoding
// P(accept) = min(1, p_target(x) / p_draft(x))
// For simplicity, accept if draft is in top-k of target
let mut sorted_indices: Vec<usize> = (0..logits.len()).collect();
sorted_indices.sort_by(|&a, &b| {
logits[b]
.partial_cmp(&logits[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let top_k = 10; // Accept if in top-10
let in_top_k = sorted_indices
.iter()
.take(top_k)
.any(|&idx| idx == draft_token as usize);
if in_top_k {
accepted_tokens.push(draft_token);
accepted_count += 1;
} else {
// Reject, use target's sampled token
accepted_tokens.push(sorted_indices[0] as u32);
accepted_count += 1;
break;
}
}
}
// Update statistics
Self::add(&self.total_draft_tokens, draft_tokens.len() as u64);
Self::add(&self.total_accepted_tokens, accepted_count as u64);
VerificationResult {
accepted_count,
draft_count: draft_tokens.len(),
accepted_tokens,
all_accepted: accepted_count == draft_tokens.len(),
}
}
/// Calculate expected speedup based on acceptance rate
///
/// Speedup = K * acceptance_rate + 1 (always get at least 1 token)
pub fn expected_speedup(&self) -> f64 {
let k = self.config.speculation_length as f64;
let acceptance_rate = self.acceptance_rate();
k * acceptance_rate + 1.0
}
/// Reset statistics
pub fn reset_stats(&self) {
Self::reset_counter(&self.total_draft_tokens);
Self::reset_counter(&self.total_accepted_tokens);
}
}
#[cfg(feature = "gpu")]
impl Default for SpeculativeDecoder {
fn default() -> Self {
Self::new()
}
}
/// GPU Buffer Pool for zero-allocation inference (PARITY-031, IMP-309)
///
/// Pre-allocates GPU buffers during warmup to eliminate allocation overhead
/// during generation. Uses a pool of reusable buffers for each tensor type.
///
/// # Key Properties
/// - Zero GPU malloc after warmup phase
/// - Pre-allocated buffers for common tensor sizes
/// - Thread-safe buffer borrowing and return
///
/// # Buffer Types
/// - Hidden state buffers: [batch_size, hidden_dim]
/// - Intermediate buffers: [batch_size, intermediate_dim]
/// - Attention score buffers: [batch_size, num_heads, seq_len]
/// - KV cache buffers: [num_layers, seq_len, hidden_dim]
#[cfg(feature = "gpu")]
pub struct GpuBufferPool {
/// Pre-allocated hidden state buffers
hidden_buffers: std::sync::Mutex<Vec<Vec<f32>>>,
/// Pre-allocated intermediate buffers (FFN)
intermediate_buffers: std::sync::Mutex<Vec<Vec<f32>>>,
/// Pre-allocated attention score buffers
attention_buffers: std::sync::Mutex<Vec<Vec<f32>>>,
/// Buffer dimensions for validation
hidden_dim: usize,
intermediate_dim: usize,
max_seq_len: usize,
num_heads: usize,
/// Pool size per buffer type
pool_size: usize,
/// Statistics: buffers borrowed
pub borrows: std::sync::atomic::AtomicU64,
/// Statistics: buffers returned
pub returns: std::sync::atomic::AtomicU64,
/// Statistics: allocations after warmup (should be 0)
pub post_warmup_allocs: std::sync::atomic::AtomicU64,
/// Whether warmup is complete
warmed_up: std::sync::atomic::AtomicBool,
}