abyo-speculate 0.5.0

Pure Rust Speculative Decoding library for local LLMs — vanilla SD + Medusa, Qwen2 + Llama, batch-1 optimised
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
//! EAGLE-2 (Li et al. 2024) — speculative decoding with target-hidden-state
//! conditioned drafts.
//!
//! ## What EAGLE adds vs Medusa / Vanilla SD
//!
//! Vanilla SD pairs the target with a *separate* small target-shaped draft
//! model — both run autoregressively. Medusa attaches `N` heads that
//! predict directly from the target's hidden state, no draft autoregression.
//! EAGLE sits in between: a tiny **1-layer transformer** runs autoregressively
//! over a sequence of `(target_hidden, token_embedding)` pairs, propagating
//! its own KV cache. The result is higher acceptance rates than Medusa
//! (because the draft sees real target context, not just the last hidden)
//! and lower draft cost than vanilla SD (because the draft is 1 layer, not
//! ~30).
//!
//! ## Reference checkpoint
//!
//! `yuhuili/EAGLE-LLaMA3-Instruct-8B` ships the draft for Llama 3 8B in a
//! 1.5 GB `pytorch_model.bin`. Key layout:
//!
//! ```text
//! embed_tokens.weight                       (vocab,  hidden)
//! fc.weight                                 (hidden, 2*hidden)   # concat input projection
//! layers.0.self_attn.{q,k,v,o}_proj.weight  Llama attention (no biases, GQA)
//! layers.0.mlp.{gate,up,down}_proj.weight   Llama SwiGLU MLP
//! layers.0.post_attention_layernorm.weight  RmsNorm before MLP
//! ```
//!
//! Notably absent: `input_layernorm`. The draft's input is the target's
//! last-layer norm output, so we skip the second normalisation.
//!
//! ## What's still v0.2.0 follow-up
//!
//! - Dynamic confidence-based tree expansion (this v0.1.0 lands a fixed
//!   Cartesian-product tree, like Medusa's `TreeTopology::CartesianProduct`).
//!   The dynamic version improves acceptance rates by ~10-20% on the EAGLE
//!   paper's benchmarks.
//! - EAGLE-3's multi-layer feature aggregation. The 1-layer draft sees
//!   only the target's last hidden state today.
//! - Real-GPU end-to-end speedup measurement (needs `last_hidden_state`
//!   exposed for *each* draft step, not just the most recent commit —
//!   this loop calls `last_hidden_state` once per round and grows the
//!   tree from there with the draft's own forward; bench numbers land in
//!   v0.2.0 alongside the dynamic-tree improvement).

#![allow(missing_docs)]

use crate::model::TreeDecoder;
use crate::{Error, Result};
use candle_core::{DType, Device, IndexOp, Module, Tensor, D};
use candle_nn::{linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
use std::path::Path;

/// Hyper-parameters for an EAGLE draft model.
#[derive(Debug, Clone)]
pub struct EagleDraftConfig {
    pub hidden_size: usize,
    pub vocab_size: usize,
    pub num_attention_heads: usize,
    pub num_key_value_heads: usize,
    pub intermediate_size: usize,
    pub rms_norm_eps: f64,
    pub rope_theta: f32,
    pub max_position_embeddings: usize,
}

impl EagleDraftConfig {
    /// Defaults for `yuhuili/EAGLE-LLaMA3-Instruct-8B`.
    pub fn eagle_llama3_8b() -> Self {
        Self {
            hidden_size: 4096,
            vocab_size: 128256,
            num_attention_heads: 32,
            num_key_value_heads: 8,
            intermediate_size: 14336,
            rms_norm_eps: 1e-5,
            rope_theta: 500_000.0, // Llama 3 uses 500k, not 10k
            max_position_embeddings: 2048,
        }
    }

    /// Defaults for `yuhuili/EAGLE-llama2-chat-7B`. Llama 2 7B is MHA
    /// (num_kv_heads = num_attention_heads), uses RoPE base 10 000, and
    /// has a 32 000-entry vocab (SentencePiece, not the Llama 3 BPE).
    pub fn eagle_llama2_chat_7b() -> Self {
        Self {
            hidden_size: 4096,
            vocab_size: 32_000,
            num_attention_heads: 32,
            num_key_value_heads: 32,
            intermediate_size: 11_008,
            rms_norm_eps: 1e-5,
            rope_theta: 10_000.0,
            max_position_embeddings: 4096,
        }
    }

    fn head_dim(&self) -> usize {
        self.hidden_size / self.num_attention_heads
    }
}

#[derive(Debug, Clone)]
struct DraftAttention {
    q_proj: Linear,
    k_proj: Linear,
    v_proj: Linear,
    o_proj: Linear,
    cos: Tensor,
    sin: Tensor,
    n_head: usize,
    n_kv_head: usize,
    head_dim: usize,
    kv_cache: Option<(Tensor, Tensor)>,
}

impl DraftAttention {
    fn load(
        cfg: &EagleDraftConfig,
        vb: VarBuilder<'_>,
        dev: &Device,
        dtype: DType,
    ) -> Result<Self> {
        let h = cfg.hidden_size;
        let n = cfg.num_attention_heads;
        let n_kv = cfg.num_key_value_heads;
        let head_dim = cfg.head_dim();

        let q_proj = linear_no_bias(h, n * head_dim, vb.pp("q_proj")).map_err(Error::Candle)?;
        let k_proj = linear_no_bias(h, n_kv * head_dim, vb.pp("k_proj")).map_err(Error::Candle)?;
        let v_proj = linear_no_bias(h, n_kv * head_dim, vb.pp("v_proj")).map_err(Error::Candle)?;
        let o_proj = linear_no_bias(n * head_dim, h, vb.pp("o_proj")).map_err(Error::Candle)?;

        // Precompute cos/sin tables.
        let inv_freq: Vec<f32> = (0..head_dim)
            .step_by(2)
            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
            .collect();
        let inv_freq_t = Tensor::from_vec(inv_freq.clone(), (1, inv_freq.len()), dev)
            .map_err(Error::Candle)?
            .to_dtype(dtype)
            .map_err(Error::Candle)?;
        let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)
            .map_err(Error::Candle)?
            .to_dtype(dtype)
            .map_err(Error::Candle)?
            .reshape((cfg.max_position_embeddings, 1))
            .map_err(Error::Candle)?;
        let freqs = t.matmul(&inv_freq_t).map_err(Error::Candle)?;

        Ok(Self {
            q_proj,
            k_proj,
            v_proj,
            o_proj,
            cos: freqs.cos().map_err(Error::Candle)?,
            sin: freqs.sin().map_err(Error::Candle)?,
            n_head: n,
            n_kv_head: n_kv,
            head_dim,
            kv_cache: None,
        })
    }

    fn forward(&mut self, xs: &Tensor, position: usize) -> Result<Tensor> {
        let (b_sz, q_len, _) = xs.dims3().map_err(Error::Candle)?;
        let q = self
            .q_proj
            .forward(xs)
            .map_err(Error::Candle)?
            .reshape((b_sz, q_len, self.n_head, self.head_dim))
            .map_err(Error::Candle)?
            .transpose(1, 2)
            .map_err(Error::Candle)?
            .contiguous()
            .map_err(Error::Candle)?;
        let k = self
            .k_proj
            .forward(xs)
            .map_err(Error::Candle)?
            .reshape((b_sz, q_len, self.n_kv_head, self.head_dim))
            .map_err(Error::Candle)?
            .transpose(1, 2)
            .map_err(Error::Candle)?
            .contiguous()
            .map_err(Error::Candle)?;
        let v = self
            .v_proj
            .forward(xs)
            .map_err(Error::Candle)?
            .reshape((b_sz, q_len, self.n_kv_head, self.head_dim))
            .map_err(Error::Candle)?
            .transpose(1, 2)
            .map_err(Error::Candle)?;

        // RoPE.
        let cos = self.cos.narrow(0, position, q_len).map_err(Error::Candle)?;
        let sin = self.sin.narrow(0, position, q_len).map_err(Error::Candle)?;
        let q = candle_nn::rotary_emb::rope(&q, &cos, &sin).map_err(Error::Candle)?;
        let k = candle_nn::rotary_emb::rope(&k, &cos, &sin).map_err(Error::Candle)?;

        // KV cache.
        let (k, v) = match &self.kv_cache {
            None => (k, v),
            Some((pk, pv)) => (
                Tensor::cat(&[pk, &k], 2).map_err(Error::Candle)?,
                Tensor::cat(&[pv, &v], 2).map_err(Error::Candle)?,
            ),
        };
        self.kv_cache = Some((k.clone(), v.clone()));

        // GQA repeat.
        let n_rep = self.n_head / self.n_kv_head;
        let k = candle_transformers::utils::repeat_kv(k, n_rep)
            .map_err(Error::Candle)?
            .contiguous()
            .map_err(Error::Candle)?;
        let v = candle_transformers::utils::repeat_kv(v, n_rep)
            .map_err(Error::Candle)?
            .contiguous()
            .map_err(Error::Candle)?;

        let scale = 1f64 / (self.head_dim as f64).sqrt();
        let attn = (q
            .matmul(&k.t().map_err(Error::Candle)?)
            .map_err(Error::Candle)?
            * scale)
            .map_err(Error::Candle)?;
        // Build a causal mask for the new positions vs the cached prefix.
        let prev_len = self
            .kv_cache
            .as_ref()
            .map(|(k, _)| k.dim(2).unwrap_or(0))
            .unwrap_or(0)
            - q_len;
        let attn = if q_len <= 1 {
            attn
        } else {
            let total = prev_len + q_len;
            let mut data = vec![0f32; q_len * total];
            for i in 0..q_len {
                for j in 0..q_len {
                    if j > i {
                        data[i * total + prev_len + j] = f32::NEG_INFINITY;
                    }
                }
            }
            let bias = Tensor::from_slice(&data, (q_len, total), xs.device())
                .map_err(Error::Candle)?
                .to_dtype(xs.dtype())
                .map_err(Error::Candle)?
                .reshape((1, 1, q_len, total))
                .map_err(Error::Candle)?;
            attn.broadcast_add(&bias).map_err(Error::Candle)?
        };
        let attn = candle_nn::ops::softmax_last_dim(&attn).map_err(Error::Candle)?;
        let y = attn.matmul(&v).map_err(Error::Candle)?;
        let y = y
            .transpose(1, 2)
            .map_err(Error::Candle)?
            .reshape((b_sz, q_len, self.n_head * self.head_dim))
            .map_err(Error::Candle)?;
        self.o_proj.forward(&y).map_err(Error::Candle)
    }

    fn clear_kv_cache(&mut self) {
        self.kv_cache = None;
    }
}

#[derive(Debug, Clone)]
struct DraftMlp {
    gate: Linear,
    up: Linear,
    down: Linear,
}

impl DraftMlp {
    fn load(cfg: &EagleDraftConfig, vb: VarBuilder<'_>) -> Result<Self> {
        let h = cfg.hidden_size;
        let i = cfg.intermediate_size;
        Ok(Self {
            gate: linear_no_bias(h, i, vb.pp("gate_proj")).map_err(Error::Candle)?,
            up: linear_no_bias(h, i, vb.pp("up_proj")).map_err(Error::Candle)?,
            down: linear_no_bias(i, h, vb.pp("down_proj")).map_err(Error::Candle)?,
        })
    }

    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let g = candle_nn::ops::silu(&self.gate.forward(xs).map_err(Error::Candle)?)
            .map_err(Error::Candle)?;
        let u = self.up.forward(xs).map_err(Error::Candle)?;
        self.down
            .forward(&(g * u).map_err(Error::Candle)?)
            .map_err(Error::Candle)
    }
}

/// EAGLE draft model loaded from a published checkpoint.
///
/// The draft consumes `concat(target_last_hidden, token_embedding)` per
/// position, projects through `fc` to hidden, then through a single Llama
/// block (no `input_layernorm` — the target's last RmsNorm already
/// normalises the input). Output hidden goes through the *target's*
/// `lm_head` for vocab logits (no separate draft head).
pub struct EagleDraftCandle {
    config: EagleDraftConfig,
    embed_tokens: Embedding,
    fc: Linear,
    attn: DraftAttention,
    post_attention_layernorm: RmsNorm,
    mlp: DraftMlp,
}

impl std::fmt::Debug for EagleDraftCandle {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("EagleDraftCandle")
            .field("hidden_size", &self.config.hidden_size)
            .field("vocab_size", &self.config.vocab_size)
            .finish()
    }
}

impl EagleDraftCandle {
    /// Read-only view of the config.
    pub fn config(&self) -> &EagleDraftConfig {
        &self.config
    }

    /// Load from a single PyTorch pickle file
    /// (`yuhuili/EAGLE-...`'s `pytorch_model.bin`).
    pub fn from_pth(
        config: &EagleDraftConfig,
        path: impl AsRef<Path>,
        device: &Device,
        dtype: DType,
    ) -> Result<Self> {
        let vb = VarBuilder::from_pth(path.as_ref(), dtype, device).map_err(Error::Candle)?;
        Self::from_var_builder(config, vb, device, dtype)
    }

    /// Load from a caller-supplied [`VarBuilder`].
    pub fn from_var_builder(
        config: &EagleDraftConfig,
        vb: VarBuilder<'_>,
        device: &Device,
        dtype: DType,
    ) -> Result<Self> {
        let embed_tokens =
            candle_nn::embedding(config.vocab_size, config.hidden_size, vb.pp("embed_tokens"))
                .map_err(Error::Candle)?;
        let fc = linear_no_bias(2 * config.hidden_size, config.hidden_size, vb.pp("fc"))
            .map_err(Error::Candle)?;
        let attn = DraftAttention::load(config, vb.pp("layers.0.self_attn"), device, dtype)?;
        let post_attention_layernorm = rms_norm(
            config.hidden_size,
            config.rms_norm_eps,
            vb.pp("layers.0.post_attention_layernorm"),
        )
        .map_err(Error::Candle)?;
        let mlp = DraftMlp::load(config, vb.pp("layers.0.mlp"))?;
        Ok(Self {
            config: config.clone(),
            embed_tokens,
            fc,
            attn,
            post_attention_layernorm,
            mlp,
        })
    }

    /// Reset the draft's KV cache between rounds.
    pub fn reset(&mut self) {
        self.attn.clear_kv_cache();
    }

    /// Run one forward step.
    ///
    /// Inputs:
    /// - `target_hidden`: shape `[1, seq, hidden]` — target's last-layer
    ///   norm output for the same `seq` positions.
    /// - `token_ids`: shape `[1, seq]` — the token ids at those positions
    ///   (used for `embed_tokens`).
    /// - `position`: absolute position offset for RoPE (typically the
    ///   number of tokens already in the target prefix when starting a
    ///   round, then incremented).
    ///
    /// Returns: shape `[1, seq, hidden]` — the draft's hidden state output,
    /// suitable for feeding to the target's `lm_head` for next-token logits.
    pub fn forward(
        &mut self,
        target_hidden: &Tensor,
        token_ids: &Tensor,
        position: usize,
    ) -> Result<Tensor> {
        let token_emb = self
            .embed_tokens
            .forward(token_ids)
            .map_err(Error::Candle)?;
        // The target may dequantize to F32 (quantized_llama path) while
        // the draft was loaded in F16. Promote target_hidden to the
        // draft's dtype before concat.
        let target_hidden_owned;
        let target_hidden_use: &Tensor = if target_hidden.dtype() != token_emb.dtype() {
            target_hidden_owned = target_hidden
                .to_dtype(token_emb.dtype())
                .map_err(Error::Candle)?;
            &target_hidden_owned
        } else {
            target_hidden
        };
        let combined =
            Tensor::cat(&[target_hidden_use, &token_emb], D::Minus1).map_err(Error::Candle)?;
        let xs = self.fc.forward(&combined).map_err(Error::Candle)?;
        // EAGLE block: attention (no pre-LN) + post_attention_layernorm + mlp.
        let res = xs;
        let attn = self.attn.forward(&res, position)?;
        let xs = (attn + &res).map_err(Error::Candle)?;
        let res = &xs;
        let xs_n = self
            .post_attention_layernorm
            .forward(&xs)
            .map_err(Error::Candle)?;
        let m = self.mlp.forward(&xs_n)?;
        (m + res).map_err(Error::Candle)
    }
}

/// Run-loop config for EAGLE.
#[derive(Debug, Clone)]
pub struct EagleRunConfig {
    /// Top-`k` per draft autoregressive step. Each step's top-k forms a
    /// branching factor in the static Cartesian-product tree.
    pub top_k_per_step: usize,
    /// Number of draft autoregressive steps per round (= tree depth).
    pub draft_depth: usize,
    /// If set, prune the Cartesian-product tree (1 + Σ k^d ≈ k^depth nodes)
    /// down to this many nodes by keeping the top-N path-scored nodes plus
    /// every ancestor needed to keep them connected. v0.2.0-3 dynamic tree.
    /// `None` keeps the full Cartesian tree.
    pub max_tree_nodes: Option<usize>,
    /// **Strict mode** (v0.4.2+): when `true`, run an additional GEMV-path
    /// forward of the root token after the multi-position tree forward,
    /// and use those logits for `per_node_logits[0]`. This restores the
    /// v0.2.2 invariant `tree_logits[0] argmax == next_logits argmax`,
    /// guaranteeing the EAGLE trajectory matches the AR baseline exactly
    /// under greedy acceptance — at the cost of one extra target forward
    /// per round (~+45 ms on Llama 2 7B BF16).
    ///
    /// Default `false`: skip the extra forward, accept the tiny GEMV/GEMM
    /// precision drift on borderline-argmax prompts. Output is still
    /// semantically valid; it just may diverge from AR by a token or two.
    pub strict_root_gemv: bool,
    /// Sampling temperature applied at the target side.
    pub temperature: f32,
    /// Top-p nucleus.
    pub top_p: f32,
}

impl Default for EagleRunConfig {
    fn default() -> Self {
        Self {
            top_k_per_step: 2,
            draft_depth: 4,
            max_tree_nodes: None,
            strict_root_gemv: false,
            temperature: 0.0, // greedy by default — strictest acceptance
            top_p: 1.0,
        }
    }
}

/// End-to-end EAGLE-2 loop.
///
/// Algorithm per round:
/// 1. Get the target's last hidden state for the most recent committed
///    token via [`TreeDecoder::last_hidden_state`].
/// 2. Run the draft forward `draft_depth` times. Each step:
///    - Input: that hidden state + the most recent token's id.
///    - Output: draft hidden → target's lm_head → vocab logits → top-k
///      (we cheat slightly by using the target's lm_head externally; full
///      EAGLE shares it through tied embeddings — close enough for a
///      static tree).
///    - For each top-k branch, the next iteration's input is the draft's
///      *own* output hidden + the candidate token's embedding.
/// 3. Build a Cartesian-product DraftTree from the per-step top-k.
/// 4. Verify the tree via `target.tree_logits` (Phase 2a tree attention).
/// 5. Walk paths, accept via greedy match (temperature 0 default).
/// 6. Commit the longest accepted prefix + bonus.
///
/// The per-step "draft logits" path uses [`TreeDecoder::apply_lm_head`]
/// directly — EAGLE shares the target's vocab head via tied embeddings, so
/// we don't keep a separate copy on the draft.
pub fn run_eagle<T, R>(
    target: &mut T,
    draft: &mut EagleDraftCandle,
    prompt: &[u32],
    max_new_tokens: usize,
    config: &EagleRunConfig,
    rng: &mut R,
) -> Result<Vec<u32>>
where
    T: TreeDecoder + ?Sized,
    R: rand::Rng + ?Sized,
{
    use crate::methods::medusa::top_k_indices;

    target.reset();
    if prompt.is_empty() {
        return Err(Error::Sampling("EAGLE requires non-empty prompt".into()));
    }
    // Observe the prompt and capture the final hidden state in one
    // forward — chained into the next round so we never call
    // `last_hidden_state` separately.
    let mut target_hidden = target.observe_returning_last_hidden(prompt)?;

    let mut generated = Vec::with_capacity(max_new_tokens);
    while generated.len() < max_new_tokens {
        let root_token = *target
            .history()
            .last()
            .ok_or_else(|| Error::Sampling("EAGLE requires non-empty prompt".into()))?;

        // Reshape stored hidden to [1, 1, hidden] for draft.forward.
        let hidden_reshaped = target_hidden
            .unsqueeze(0)
            .map_err(Error::Candle)?
            .unsqueeze(0)
            .map_err(Error::Candle)?;

        draft.reset();
        let history_len = target.history_len();

        // 2. Build Cartesian-product tree by running draft `draft_depth` times.
        let mut per_step_top_k: Vec<Vec<u32>> = Vec::with_capacity(config.draft_depth);
        let mut per_step_top_k_log_probs: Vec<Vec<f32>> = Vec::with_capacity(config.draft_depth);
        let mut current_hidden = hidden_reshaped;
        let mut current_token_ids =
            Tensor::from_slice(&[root_token], (1, 1), target_hidden.device())
                .map_err(Error::Candle)?;

        for step in 0..config.draft_depth {
            let draft_hidden =
                draft.forward(&current_hidden, &current_token_ids, history_len + step)?;
            // `target.apply_lm_head` auto-promotes the input dtype to match
            // its lm_head weight (BF16 for `LlamaDecoder`, F32 for the Q4
            // path on `LlamaQuantDecoder`), so passing the draft's hidden
            // tensor as-is is dtype-safe here.
            let logits = target.apply_lm_head(&draft_hidden)?;
            // Take the last position's logits — for a 1-token forward this
            // is just position 0.
            let last = logits
                .i((0, draft_hidden.dim(1).map_err(Error::Candle)? - 1, ..))
                .map_err(Error::Candle)?
                .to_dtype(DType::F32)
                .map_err(Error::Candle)?
                .to_vec1::<f32>()
                .map_err(Error::Candle)?;
            let top_idx: Vec<usize> = top_k_indices(&last, config.top_k_per_step);
            // Log-softmax over `last` for the kept top-k indices.
            let max_l = last.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
            let lse = last.iter().map(|&v| (v - max_l).exp()).sum::<f32>().ln() + max_l;
            let top_log_probs: Vec<f32> = top_idx.iter().map(|&i| last[i] - lse).collect();
            let top: Vec<u32> = top_idx.iter().map(|&i| i as u32).collect();
            per_step_top_k.push(top.clone());
            per_step_top_k_log_probs.push(top_log_probs);

            // For the next step, advance with the top-1 (single greedy
            // chain — Cartesian expansion happens in the tree, not the
            // draft autoregression).
            let next_id = top[0];
            current_token_ids = Tensor::from_slice(&[next_id], (1, 1), target_hidden.device())
                .map_err(Error::Candle)?;
            current_hidden = draft_hidden;
        }

        // 3. Build Cartesian-product tree, optionally pruned to top-N nodes.
        let full_tree = crate::methods::medusa::MedusaHeads::from_config(
            crate::methods::medusa::MedusaConfig {
                n_heads: config.draft_depth,
                hidden_size: draft.config.hidden_size,
                vocab_size: draft.config.vocab_size,
                residual_layers: 1,
            },
        )
        .build_draft_tree(
            root_token,
            &per_step_top_k,
            crate::methods::medusa::TreeTopology::CartesianProduct,
        )?;
        let tree = if let Some(max_n) = config.max_tree_nodes {
            prune_cartesian_tree(&full_tree, &per_step_top_k_log_probs, max_n)?
        } else {
            full_tree
        };

        // 4. Verify via the EAGLE fast-path: tree forward returns
        //    (logits, hidden) per node and leaves the KV cache populated
        //    with the tree (no restoration forward).
        // Strict mode (v0.4.2+): capture the GEMV-path root logits
        // *before* the tree forward invalidates them. They're already
        // cached on the target as a side effect of the prior `observe`
        // (initial prompt or bonus). Using them costs zero extra
        // forwards — just one `next_logits()` call (cache hit).
        let strict_root_logits: Option<Vec<f32>> = if config.strict_root_gemv {
            Some(target.next_logits()?)
        } else {
            None
        };

        let (mut per_node_logits, _per_node_hidden) =
            target.tree_logits_keep_kv(&tree)?;

        if let Some(root_gemv) = strict_root_logits {
            per_node_logits[0] = root_gemv;
        }

        // 5. Walk paths, greedy acceptance.
        let mut best_path: Vec<usize> = vec![0];
        for path in tree.paths() {
            let mut accepted = 0;
            for w in path.windows(2) {
                let parent = w[0];
                let child = w[1];
                let candidate = tree.token_at(child) as usize;
                let parent_dist = &per_node_logits[parent];
                let argmax = parent_dist
                    .iter()
                    .enumerate()
                    .fold((0usize, f32::NEG_INFINITY), |(bi, bv), (i, &v)| {
                        if v > bv {
                            (i, v)
                        } else {
                            (bi, bv)
                        }
                    })
                    .0;
                if argmax == candidate {
                    accepted += 1;
                } else {
                    break;
                }
            }
            if accepted + 1 > best_path.len() {
                best_path = path[..=accepted].to_vec();
            }
        }

        // 6. Bonus token from the deepest accepted node's distribution.
        let deepest_idx = *best_path.last().unwrap();
        let bonus = per_node_logits[deepest_idx]
            .iter()
            .enumerate()
            .fold((0usize, f32::NEG_INFINITY), |(bi, bv), (i, &v)| {
                if v > bv {
                    (i, v)
                } else {
                    (bi, bv)
                }
            })
            .0 as u32;

        // Compose the committed sequence (accepted draft tokens + bonus)
        // for the generated transcript. EOS check applies to the full
        // sequence; if hit we truncate and stop after this round.
        let mut committed: Vec<u32> = best_path
            .iter()
            .skip(1)
            .map(|&i| tree.token_at(i))
            .collect();
        committed.push(bonus);
        let eos_set = target.eos_token_ids();
        let eos_pos = committed.iter().position(|t| eos_set.contains(t));
        let stop = eos_pos.is_some();
        if let Some(p) = eos_pos {
            committed.truncate(p + 1);
        }

        // 7. Commit accepted path to KV cache via index_select reordering
        //    (no target forward — the accepted nodes' KVs are already in
        //    cache from the tree forward).
        //
        //    Optimization: if `bonus` happens to match a child of
        //    `deepest_idx` already in the tree, we can include that
        //    child's KV in the commit reorder and skip the bonus
        //    `observe` forward entirely (~45 ms saved per round on
        //    Llama 2 7B BF16). Hit rate is high for greedy acceptance —
        //    the draft's top-K predictions for a node usually contain
        //    the target's argmax.
        let path_committed: Vec<u32> = best_path
            .iter()
            .skip(1)
            .map(|&i| tree.token_at(i))
            .collect();
        let path_eos_index = path_committed.iter().position(|t| eos_set.contains(t));
        if let Some(idx) = path_eos_index {
            // EOS in accepted draft path: keep root + accepted up to and
            // including the EOS node; no bonus forward.
            let kept_path: Vec<usize> = best_path[..=idx + 1].to_vec();
            target.commit_tree_path(&tree, &kept_path)?;
            // No new target_hidden is computed; loop will exit via `stop`.
        } else {
            // Try to find `bonus` among the deepest accepted node's
            // children in the tree. If yes, fold the bonus's tree node
            // into the commit path (no observe forward) and pull its
            // hidden from the tree forward output.
            // The bonus-in-tree shortcut skips the bonus observe (~45 ms)
            // when target's argmax bonus matches an already-in-tree child of
            // the deepest accepted node. Disabled in strict mode because
            // the next round needs `last_logits` cached for the GEMV root
            // capture — and skipping observe leaves it invalidated, forcing
            // a slow-path forward there instead. Net effect of strict +
            // bonus_in_tree was *worse* throughput; keep them separate.
            let mut bonus_in_tree: Option<usize> = None;
            if !config.strict_root_gemv {
                for n in 1..tree.len() {
                    if tree.parent_of(n) == deepest_idx && tree.token_at(n) == bonus {
                        bonus_in_tree = Some(n);
                        break;
                    }
                }
            }
            if let Some(bn) = bonus_in_tree {
                let mut kept_path = best_path.clone();
                kept_path.push(bn);
                target.commit_tree_path(&tree, &kept_path)?;
                target_hidden = _per_node_hidden[bn].clone();
            } else {
                target.commit_tree_path(&tree, &best_path)?;
                target_hidden = target.observe_returning_last_hidden(&[bonus])?;
            }
        }

        generated.extend_from_slice(&committed);
        if stop {
            break;
        }
    }
    // Suppress unused-rng warning until temperature > 0 lands.
    let _ = (rng, config.temperature, config.top_p);

    generated.truncate(max_new_tokens);
    Ok(generated)
}

/// Pub-facing alias of `prune_cartesian_tree` for the EAGLE-3 module.
pub fn prune_cartesian_tree_pub(
    full: &crate::tree::DraftTree,
    per_step_log_probs: &[Vec<f32>],
    max_total_nodes: usize,
) -> Result<crate::tree::DraftTree> {
    prune_cartesian_tree(full, per_step_log_probs, max_total_nodes)
}

/// Prune a Cartesian-product `DraftTree` (built by
/// [`crate::methods::medusa::TreeTopology::CartesianProduct`]) down to at
/// most `max_total_nodes` nodes by keeping the highest-cumulative-log-prob
/// non-root nodes plus every ancestor needed to keep them connected.
///
/// `per_step_log_probs[d][k]` is the log-prob of the `k`-th candidate at
/// depth `d` (depth 0 = first step from the root). The full Cartesian
/// branching factor at depth `d` is `per_step_log_probs[d].len()`.
fn prune_cartesian_tree(
    full: &crate::tree::DraftTree,
    per_step_log_probs: &[Vec<f32>],
    max_total_nodes: usize,
) -> Result<crate::tree::DraftTree> {
    use crate::tree::DraftTree;

    if full.len() <= max_total_nodes {
        // Already small enough — return as-is via reconstruction so the
        // caller always gets an owned tree.
        return clone_tree(full);
    }

    // Score every non-root node by the sum of log-probs along its path
    // from the root. We rebuild the path by walking parent pointers and
    // looking up which candidate index (within its layer) was used.
    // Layer d branches per_step_log_probs[d].len() ways. In the BFS-built
    // Cartesian tree the candidate index of a node at depth d+1 is the
    // node's child position among its parent's children — which we can
    // recover from per-parent child tracking below.

    // Build child -> candidate-index lookup.
    // For each parent, the children appear in the same order as
    // per_step_log_probs[depth_of(parent)]; so child position 0 = candidate 0, etc.
    let mut children_of: Vec<Vec<usize>> = vec![Vec::new(); full.len()];
    for n in 1..full.len() {
        let p = full.parent_of(n);
        children_of[p].push(n);
    }
    // candidate_index_of[node] = which sibling rank (0..k) this node has.
    let mut candidate_index_of = vec![0usize; full.len()];
    for siblings in &children_of {
        for (rank, &c) in siblings.iter().enumerate() {
            candidate_index_of[c] = rank;
        }
    }

    // Score each non-root by walking up and summing log-probs.
    let mut scores: Vec<(usize, f32)> = Vec::with_capacity(full.len() - 1);
    for n in 1..full.len() {
        let depth = full.depth_of(n);
        let mut s = 0f32;
        let mut cur = n;
        for d in (0..depth).rev() {
            let cand = candidate_index_of[cur];
            s += per_step_log_probs[d][cand];
            cur = full.parent_of(cur);
        }
        scores.push((n, s));
    }

    // Pick top (max_total_nodes - 1) nodes (root is always kept) by score.
    scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
    let want = max_total_nodes.saturating_sub(1).min(scores.len());
    let mut keep = vec![false; full.len()];
    keep[0] = true;
    for &(n, _) in scores.iter().take(want) {
        keep[n] = true;
    }
    // Ancestor closure.
    for n in (1..full.len()).rev() {
        if keep[n] {
            keep[full.parent_of(n)] = true;
        }
    }

    // Re-emit kept nodes in BFS order to preserve parent < self invariant.
    let mut order: Vec<usize> = (0..full.len()).filter(|&n| keep[n]).collect();
    order.sort_by_key(|&n| full.depth_of(n));
    let new_index: std::collections::HashMap<usize, usize> = order
        .iter()
        .enumerate()
        .map(|(new_i, &old_i)| (old_i, new_i))
        .collect();
    let mut entries: Vec<(usize, u32)> = Vec::with_capacity(order.len());
    for &old in &order {
        let parent_old = if old == 0 { 0 } else { full.parent_of(old) };
        let parent_new = *new_index.get(&parent_old).expect("ancestor present");
        entries.push((parent_new, full.token_at(old)));
    }
    DraftTree::from_parent_table(&entries)
        .map_err(|e| Error::Sampling(format!("pruned tree invalid: {e}")))
}

fn clone_tree(t: &crate::tree::DraftTree) -> Result<crate::tree::DraftTree> {
    let entries: Vec<(usize, u32)> = (0..t.len())
        .map(|i| {
            let parent = if i == 0 { 0 } else { t.parent_of(i) };
            (parent, t.token_at(i))
        })
        .collect();
    crate::tree::DraftTree::from_parent_table(&entries)
        .map_err(|e| Error::Sampling(format!("clone tree invalid: {e}")))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn config_defaults_match_eagle_llama3_8b() {
        let c = EagleDraftConfig::eagle_llama3_8b();
        assert_eq!(c.hidden_size, 4096);
        assert_eq!(c.vocab_size, 128256);
        assert_eq!(c.num_attention_heads, 32);
        assert_eq!(c.num_key_value_heads, 8);
        assert_eq!(c.head_dim(), 128);
    }

    #[test]
    fn config_clone() {
        let c = EagleDraftConfig::eagle_llama3_8b();
        let c2 = c.clone();
        assert_eq!(c.hidden_size, c2.hidden_size);
    }

    #[test]
    fn run_config_default_strict_off() {
        let c = EagleRunConfig::default();
        assert!(!c.strict_root_gemv, "fast mode is the default");
    }

    #[test]
    fn run_config_strict_toggle_compiles() {
        let c = EagleRunConfig {
            strict_root_gemv: true,
            ..Default::default()
        };
        assert!(c.strict_root_gemv);
    }

    #[test]
    fn prune_keeps_root_and_top_paths() {
        // Build a Cartesian k=2 depth=2 tree (1 + 2 + 4 = 7 nodes).
        let cart = crate::methods::medusa::MedusaHeads::from_config(
            crate::methods::medusa::MedusaConfig {
                n_heads: 2,
                hidden_size: 4,
                vocab_size: 100,
                residual_layers: 1,
            },
        )
        .build_draft_tree(
            42, // root token
            &[vec![10, 20], vec![30, 40]],
            crate::methods::medusa::TreeTopology::CartesianProduct,
        )
        .expect("build cart");
        assert_eq!(cart.len(), 7);

        // Layer 0: candidate 0 has log-prob -0.1, candidate 1 has -2.0.
        // Layer 1: candidate 0 has -0.2, candidate 1 has -3.0.
        // Best path from root: (10 -> 30) score = -0.1 + -0.2 = -0.3.
        // Worst: (20 -> 40) score = -2.0 + -3.0 = -5.0.
        let log_probs = vec![vec![-0.1f32, -2.0], vec![-0.2, -3.0]];

        // Prune to 4 nodes.
        let pruned = prune_cartesian_tree(&cart, &log_probs, 4).expect("prune");
        assert!(pruned.len() <= 4, "pruned should be ≤ 4 nodes");
        assert!(!pruned.is_empty());
        assert_eq!(pruned.token_at(0), 42, "root preserved");

        // The single best path (root, 10, 30) must be present.
        let tokens: Vec<u32> = (0..pruned.len()).map(|i| pruned.token_at(i)).collect();
        assert!(tokens.contains(&10), "best layer-0 child kept");
        assert!(tokens.contains(&30), "best layer-1 child kept");
    }

    #[test]
    fn prune_returns_full_tree_when_under_limit() {
        let cart = crate::methods::medusa::MedusaHeads::from_config(
            crate::methods::medusa::MedusaConfig {
                n_heads: 2,
                hidden_size: 4,
                vocab_size: 100,
                residual_layers: 1,
            },
        )
        .build_draft_tree(
            1,
            &[vec![2], vec![3]],
            crate::methods::medusa::TreeTopology::CartesianProduct,
        )
        .expect("build");
        assert_eq!(cart.len(), 3);
        let pruned = prune_cartesian_tree(&cart, &[vec![-0.1], vec![-0.2]], 100).expect("prune");
        assert_eq!(pruned.len(), 3);
    }
}