realizar 0.8.4

Pure Rust ML inference engine built from scratch - model serving for GGUF and safetensors
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
//! Speculative decoding methods for CUDA-accelerated inference
//!
//! This module contains speculative decoding implementations:
//! - `generate_speculative_cuda`: Self-speculative decoding (same model for draft/verify)
//! - `generate_speculative_with_draft`: Two-model speculative decoding (small draft + large verify)
//!
//! # Theory
//!
//! Speculative decoding reduces memory bandwidth bottleneck by:
//! 1. Drafting K tokens quickly (smaller model or layer-skipping)
//! 2. Verifying all K tokens in one forward pass of target model
//! 3. Accepting prefix of matching tokens, rejecting rest

use super::{OwnedQuantizedKVCache, OwnedQuantizedModelCuda, QuantizedGenerateConfig};
use crate::error::{RealizarError, Result};

impl OwnedQuantizedModelCuda {
    /// PAR-100: Self-speculative decoding with layer-skipping draft
    ///
    /// Uses layer-skipping (every 2nd layer) as a fast draft model, then verifies
    /// with full model. This provides speedup without needing a separate draft model.
    ///
    /// # Arguments
    ///
    /// * `prompt` - Initial token IDs
    /// * `config` - Generation configuration
    /// * `speculation_k` - Number of tokens to draft speculatively (typically 4-8)
    ///
    /// # Returns
    ///
    /// Generated token sequence including prompt
    pub fn generate_speculative_cuda(
        &mut self,
        prompt: &[u32],
        config: &QuantizedGenerateConfig,
        speculation_k: usize,
    ) -> Result<Vec<u32>> {
        use std::time::Instant;

        if prompt.is_empty() {
            return Ok(Vec::new());
        }

        // Check architecture support
        if !self.supports_gpu_resident() {
            return Err(RealizarError::UnsupportedOperation {
                operation: "generate_speculative_cuda".to_string(),
                reason: "Model architecture not supported for GPU-resident path".to_string(),
            });
        }

        // Pre-upload all weights to GPU
        let bytes_uploaded = self.preload_weights_gpu()?;
        eprintln!(
            "PAR-100: Pre-uploaded {} MB of weights to GPU",
            bytes_uploaded / (1024 * 1024)
        );

        // PAR-100: Setup KV cache with GQA-aware dimensions
        let num_kv_heads = self.model.config.num_kv_heads;
        let head_dim = self.model.config.hidden_dim / self.model.config.num_heads;
        let kv_dim = num_kv_heads * head_dim;
        let mut cache = OwnedQuantizedKVCache::new(
            self.model.config.num_layers,
            kv_dim,
            prompt.len() + config.max_tokens + speculation_k,
        );

        // Reset GPU KV cache positions before generation
        self.executor.reset_kv_cache_gpu();

        let mut tokens = prompt.to_vec();

        // Prefill: process prompt tokens using GPU-resident path
        let prefill_start = Instant::now();
        for (pos, &token_id) in prompt.iter().enumerate() {
            if pos < prompt.len() - 1 {
                let _ = self.forward_gpu_resident(token_id, &mut cache, pos)?;
            }
        }
        let prefill_time = prefill_start.elapsed();

        // Start decode from last prompt token
        let mut position = prompt.len() - 1;
        let mut last_token = prompt[prompt.len() - 1];

        // Statistics for throughput calculation
        let decode_start = Instant::now();
        let mut accepted_tokens = 0usize;
        let mut total_drafts = 0usize;
        let mut total_speculative_batches = 0usize;

        while tokens.len() - prompt.len() < config.max_tokens {
            // Step 1: Draft k tokens greedily using GPU-resident forward
            let cache_snapshot = cache.snapshot_len();
            let mut draft_tokens = Vec::with_capacity(speculation_k);

            // Draft all k tokens using GPU-resident to_token_id (greedy argmax)
            for i in 0..speculation_k {
                let draft_pos = position + i;
                let input_token = if i == 0 {
                    last_token
                } else {
                    *draft_tokens.last().unwrap_or(&last_token)
                };

                let draft =
                    self.forward_gpu_resident_to_token_id(input_token, &mut cache, draft_pos)?;

                if config.stop_tokens.contains(&draft) {
                    if i == 0 {
                        // First draft is stop token
                        tokens.push(draft);
                    }
                    break;
                }

                draft_tokens.push(draft);
            }

            if draft_tokens.is_empty() {
                break; // Stop token on first draft
            }

            total_drafts += draft_tokens.len();

            // Step 2: Rollback cache to snapshot for verification
            cache.rollback_to(cache_snapshot, kv_dim);
            self.executor.reset_kv_cache_gpu();

            // Step 3: Verify - use single-token GPU-resident to check each draft
            // NOTE: Batched verification would be faster but requires refactoring
            // For now, verify sequentially to ensure correctness
            let mut num_accepted = 0usize;

            for (i, &draft) in draft_tokens.iter().enumerate() {
                let verify_pos = position + i;
                let input_token = if i == 0 {
                    last_token
                } else {
                    *draft_tokens.get(i - 1).unwrap_or(&last_token)
                };

                let verified =
                    self.forward_gpu_resident_to_token_id(input_token, &mut cache, verify_pos)?;

                if verified == draft {
                    // Accept this token
                    tokens.push(draft);
                    num_accepted += 1;
                } else {
                    // Reject: accept the model's correction instead
                    if !config.stop_tokens.contains(&verified) {
                        tokens.push(verified);
                        num_accepted += 1;
                    }
                    break;
                }
            }

            total_speculative_batches += 1;

            // Handle edge case: all drafts rejected
            if num_accepted == 0 && !draft_tokens.is_empty() {
                // Just generate one token normally
                cache.rollback_to(cache_snapshot, kv_dim);
                self.executor.reset_kv_cache_gpu();
                let fallback =
                    self.forward_gpu_resident_to_token_id(last_token, &mut cache, position)?;
                if config.stop_tokens.contains(&fallback) {
                    break;
                }
                tokens.push(fallback);
                num_accepted = 1;
            }

            accepted_tokens += num_accepted;

            // Step 4: Update position and last_token
            position += num_accepted;
            last_token = *tokens.last().unwrap_or(&0);

            // Rollback cache to keep only accepted entries
            let target_cache_len = cache_snapshot + num_accepted;
            cache.rollback_to(target_cache_len, kv_dim);
        }

        let decode_time = decode_start.elapsed();
        let generated_tokens = tokens.len() - prompt.len();
        let decode_tok_s = if decode_time.as_secs_f64() > 0.0 {
            generated_tokens as f64 / decode_time.as_secs_f64()
        } else {
            0.0
        };

        let acceptance_rate = if total_drafts > 0 {
            accepted_tokens as f64 / total_drafts as f64 * 100.0
        } else {
            0.0
        };

        eprintln!(
            "[PAR-100] Speculative decode: {} tokens in {:.2}ms ({:.1} tok/s)",
            generated_tokens,
            decode_time.as_secs_f64() * 1000.0,
            decode_tok_s
        );
        eprintln!(
            "[PAR-100] Prefill: {:.2}ms, Drafts: {}, Accepted: {}, Rate: {:.1}%",
            prefill_time.as_secs_f64() * 1000.0,
            total_drafts,
            accepted_tokens,
            acceptance_rate
        );
        eprintln!(
            "[PAR-100] Batched verifications: {}",
            total_speculative_batches
        );

        Ok(tokens)
    }

    /// PAR-099: Speculative decoding with separate draft model
    ///
    /// Uses a smaller draft model (e.g., 0.5B) for fast token generation,
    /// then verifies with the target model (e.g., 1.5B).
    ///
    /// # Theory (Five-Whys Root Cause)
    ///
    /// WHY does draft model help?
    /// → Draft model is 3x smaller = 3x faster = 3x fewer weight reads
    /// → Verification with target model amortizes quality check
    ///
    /// Expected speedup with 0.5B draft + 1.5B target:
    /// - Draft 4 tokens: 4 × (2.5ms/3) = 3.3ms
    /// - Verify 4 tokens: 1 × 2.5ms = 2.5ms (batched)
    /// - Total: 5.8ms for ~3 accepted tokens = 517 tok/s (1.3x improvement)
    ///
    /// With k=8, 80% acceptance: theoretical ~700-800 tok/s
    ///
    /// # Arguments
    ///
    /// * `draft_model` - Smaller model for fast token drafting
    /// * `prompt` - Initial token IDs
    /// * `config` - Generation configuration
    /// * `speculation_k` - Number of tokens to draft (typically 4-8)
    ///
    /// # Returns
    ///
    /// Generated token sequence including prompt
    pub fn generate_speculative_with_draft(
        &mut self,
        draft_model: &mut OwnedQuantizedModelCuda,
        prompt: &[u32],
        config: &QuantizedGenerateConfig,
        speculation_k: usize,
    ) -> Result<Vec<u32>> {
        use std::time::Instant;

        if prompt.is_empty() {
            return Ok(Vec::new());
        }

        // Check architecture support for both models
        if !self.supports_gpu_resident() {
            return Err(RealizarError::UnsupportedOperation {
                operation: "generate_speculative_with_draft".to_string(),
                reason: "Target model architecture not supported for GPU-resident path".to_string(),
            });
        }
        if !draft_model.supports_gpu_resident() {
            return Err(RealizarError::UnsupportedOperation {
                operation: "generate_speculative_with_draft".to_string(),
                reason: "Draft model architecture not supported for GPU-resident path".to_string(),
            });
        }

        // Pre-upload weights for both models
        let target_bytes = self.preload_weights_gpu()?;
        let draft_bytes = draft_model.preload_weights_gpu()?;
        eprintln!(
            "PAR-099: Pre-uploaded {} MB (target) + {} MB (draft) to GPU",
            target_bytes / (1024 * 1024),
            draft_bytes / (1024 * 1024)
        );

        // Setup KV caches for both models
        let target_kv_dim = {
            let num_kv_heads = self.model.config.num_kv_heads;
            let head_dim = self.model.config.hidden_dim / self.model.config.num_heads;
            num_kv_heads * head_dim
        };
        let draft_kv_dim = {
            let num_kv_heads = draft_model.model.config.num_kv_heads;
            let head_dim = draft_model.model.config.hidden_dim / draft_model.model.config.num_heads;
            num_kv_heads * head_dim
        };

        let mut target_cache = OwnedQuantizedKVCache::new(
            self.model.config.num_layers,
            target_kv_dim,
            prompt.len() + config.max_tokens + speculation_k,
        );
        let mut draft_cache = OwnedQuantizedKVCache::new(
            draft_model.model.config.num_layers,
            draft_kv_dim,
            prompt.len() + config.max_tokens + speculation_k,
        );

        // Reset GPU KV cache positions
        self.executor.reset_kv_cache_gpu();
        draft_model.executor.reset_kv_cache_gpu();

        let mut tokens = prompt.to_vec();

        // Prefill both models
        let prefill_start = Instant::now();
        for (pos, &token_id) in prompt.iter().enumerate() {
            if pos < prompt.len() - 1 {
                let _ = self.forward_gpu_resident(token_id, &mut target_cache, pos)?;
                let _ = draft_model.forward_gpu_resident(token_id, &mut draft_cache, pos)?;
            }
        }
        let prefill_time = prefill_start.elapsed();

        // Start decode from last prompt token
        let mut position = prompt.len() - 1;
        let mut last_token = prompt[prompt.len() - 1];

        // Statistics
        let decode_start = Instant::now();
        let mut accepted_tokens = 0usize;
        let mut total_drafts = 0usize;
        let mut total_speculative_batches = 0usize;

        while tokens.len() - prompt.len() < config.max_tokens {
            // Step 1: Draft k tokens using DRAFT model (fast, smaller)
            let draft_cache_snapshot = draft_cache.snapshot_len();
            let target_cache_snapshot = target_cache.snapshot_len();
            let mut draft_tokens = Vec::with_capacity(speculation_k);

            // Draft using the smaller model
            for i in 0..speculation_k {
                let draft_pos = position + i;
                let input_token = if i == 0 {
                    last_token
                } else {
                    *draft_tokens.last().unwrap_or(&last_token)
                };

                let draft = draft_model.forward_gpu_resident_to_token_id(
                    input_token,
                    &mut draft_cache,
                    draft_pos,
                )?;

                if config.stop_tokens.contains(&draft) {
                    if i == 0 {
                        tokens.push(draft);
                    }
                    break;
                }

                draft_tokens.push(draft);
            }

            if draft_tokens.is_empty() {
                break;
            }

            total_drafts += draft_tokens.len();

            // Step 2: Verify using TARGET model
            // PAR-105: Rollback draft cache to snapshot position, preserving prefill history
            // RESOLVED: reset_kv_cache_gpu() was clearing ALL history, causing 1/k acceptance
            draft_cache.rollback_to(draft_cache_snapshot, draft_kv_dim);
            draft_model
                .executor
                .rollback_kv_cache_gpu(draft_cache_snapshot);

            let mut num_accepted = 0usize;

            for (i, &draft) in draft_tokens.iter().enumerate() {
                let verify_pos = position + i;
                let input_token = if i == 0 {
                    last_token
                } else {
                    *draft_tokens.get(i - 1).unwrap_or(&last_token)
                };

                // Verify with target model
                let verified = self.forward_gpu_resident_to_token_id(
                    input_token,
                    &mut target_cache,
                    verify_pos,
                )?;

                if verified == draft {
                    // Accept: also update draft cache for consistency
                    let _ = draft_model.forward_gpu_resident(
                        input_token,
                        &mut draft_cache,
                        verify_pos,
                    )?;
                    tokens.push(draft);
                    num_accepted += 1;
                } else {
                    // Reject: accept target's correction
                    if !config.stop_tokens.contains(&verified) {
                        let _ = draft_model.forward_gpu_resident(
                            input_token,
                            &mut draft_cache,
                            verify_pos,
                        )?;
                        tokens.push(verified);
                        num_accepted += 1;
                    }
                    break;
                }
            }

            total_speculative_batches += 1;

            // Handle edge case: all drafts rejected
            if num_accepted == 0 && !draft_tokens.is_empty() {
                // PAR-105: Use rollback instead of reset to preserve prefill history
                target_cache.rollback_to(target_cache_snapshot, target_kv_dim);
                draft_cache.rollback_to(draft_cache_snapshot, draft_kv_dim);
                self.executor.rollback_kv_cache_gpu(target_cache_snapshot);
                draft_model
                    .executor
                    .rollback_kv_cache_gpu(draft_cache_snapshot);

                let fallback =
                    self.forward_gpu_resident_to_token_id(last_token, &mut target_cache, position)?;
                let _ = draft_model.forward_gpu_resident(last_token, &mut draft_cache, position)?;

                if config.stop_tokens.contains(&fallback) {
                    break;
                }
                tokens.push(fallback);
                num_accepted = 1;
            }

            accepted_tokens += num_accepted;
            position += num_accepted;
            last_token = *tokens.last().unwrap_or(&0);

            // Rollback caches to accepted length (CPU AND GPU must stay in sync)
            let target_len = target_cache_snapshot + num_accepted;
            let draft_len = draft_cache_snapshot + num_accepted;
            target_cache.rollback_to(target_len, target_kv_dim);
            draft_cache.rollback_to(draft_len, draft_kv_dim);
            // PAR-105: RESOLVED - must also rollback GPU caches to match CPU
            // Without this, GPU cache has stale entries from rejected verifications
            self.executor.rollback_kv_cache_gpu(target_len);
            draft_model.executor.rollback_kv_cache_gpu(draft_len);
        }

        let decode_time = decode_start.elapsed();
        let generated_tokens = tokens.len() - prompt.len();
        let decode_tok_s = if decode_time.as_secs_f64() > 0.0 {
            generated_tokens as f64 / decode_time.as_secs_f64()
        } else {
            0.0
        };

        let acceptance_rate = if total_drafts > 0 {
            accepted_tokens as f64 / total_drafts as f64 * 100.0
        } else {
            0.0
        };

        eprintln!(
            "[PAR-099] Speculative decode (draft model): {} tokens in {:.2}ms ({:.1} tok/s)",
            generated_tokens,
            decode_time.as_secs_f64() * 1000.0,
            decode_tok_s
        );
        eprintln!(
            "[PAR-099] Prefill: {:.2}ms, Drafts: {}, Accepted: {}, Rate: {:.1}%",
            prefill_time.as_secs_f64() * 1000.0,
            total_drafts,
            accepted_tokens,
            acceptance_rate
        );
        eprintln!(
            "[PAR-099] Speculative batches: {}",
            total_speculative_batches
        );

        Ok(tokens)
    }
}