Skip to main content

hanzo_engine/speculative/
proposer.rs

1use std::sync::{Arc, Mutex};
2
3use hanzo_ml::{Result, Tensor};
4use rand_isaac::Isaac64Rng;
5
6use crate::pipeline::text_models_inputs_processor::PagedAttentionMeta;
7use crate::sequence::Sequence;
8
9pub type TargetTokenEmbedder<'a> = dyn Fn(&Tensor) -> Result<Tensor> + 'a;
10
11pub enum SpeculativeKvCache<'a> {
12    Paged {
13        metadata: &'a PagedAttentionMeta,
14        kv_cache: &'a [(Tensor, Tensor)],
15    },
16}
17
18pub struct SpeculativeProposeBatchCtx<'a> {
19    pub sampled_tokens: &'a [u32],
20    pub sampled_tokens_emitted: bool,
21    pub seq_ids: &'a [usize],
22    pub base_lens: &'a [usize],
23    pub sequences: &'a [&'a Sequence],
24    pub cache: SpeculativeKvCache<'a>,
25    pub target_hiddens: Option<Tensor>,
26    pub rng: Arc<Mutex<Isaac64Rng>>,
27}
28
29#[derive(Clone, Debug)]
30pub struct SpeculativeProposal {
31    pub tokens: Vec<u32>,
32    pub logits: Option<Tensor>,
33}
34
35impl SpeculativeProposal {
36    pub fn new(tokens: Vec<u32>) -> Self {
37        Self {
38            tokens,
39            logits: None,
40        }
41    }
42
43    pub fn with_logits(tokens: Vec<u32>, logits: Tensor) -> Self {
44        Self {
45            tokens,
46            logits: Some(logits),
47        }
48    }
49
50    pub fn is_empty(&self) -> bool {
51        self.tokens.is_empty()
52    }
53}
54
55pub struct SpeculativeProposalBatch {
56    pub proposals: Vec<SpeculativeProposal>,
57}
58
59impl SpeculativeProposalBatch {
60    pub fn new(proposals: Vec<SpeculativeProposal>) -> Self {
61        Self { proposals }
62    }
63}
64
65pub trait SpeculativeProposer {
66    fn proposal_len(&self) -> usize;
67
68    fn propose(
69        &mut self,
70        ctx: SpeculativeProposeBatchCtx<'_>,
71        target_embedder: Option<&TargetTokenEmbedder<'_>>,
72    ) -> Result<SpeculativeProposalBatch>;
73}