hanzo_engine/speculative/
proposer.rs1use std::sync::{Arc, Mutex};
2
3use hanzo_ml::{Result, Tensor};
4use rand_isaac::Isaac64Rng;
5
6use crate::pipeline::text_models_inputs_processor::PagedAttentionMeta;
7use crate::sequence::Sequence;
8
9pub type TargetTokenEmbedder<'a> = dyn Fn(&Tensor) -> Result<Tensor> + 'a;
10
11pub enum SpeculativeKvCache<'a> {
12 Paged {
13 metadata: &'a PagedAttentionMeta,
14 kv_cache: &'a [(Tensor, Tensor)],
15 },
16}
17
18pub struct SpeculativeProposeBatchCtx<'a> {
19 pub sampled_tokens: &'a [u32],
20 pub sampled_tokens_emitted: bool,
21 pub seq_ids: &'a [usize],
22 pub base_lens: &'a [usize],
23 pub sequences: &'a [&'a Sequence],
24 pub cache: SpeculativeKvCache<'a>,
25 pub target_hiddens: Option<Tensor>,
26 pub rng: Arc<Mutex<Isaac64Rng>>,
27}
28
29#[derive(Clone, Debug)]
30pub struct SpeculativeProposal {
31 pub tokens: Vec<u32>,
32 pub logits: Option<Tensor>,
33}
34
35impl SpeculativeProposal {
36 pub fn new(tokens: Vec<u32>) -> Self {
37 Self {
38 tokens,
39 logits: None,
40 }
41 }
42
43 pub fn with_logits(tokens: Vec<u32>, logits: Tensor) -> Self {
44 Self {
45 tokens,
46 logits: Some(logits),
47 }
48 }
49
50 pub fn is_empty(&self) -> bool {
51 self.tokens.is_empty()
52 }
53}
54
55pub struct SpeculativeProposalBatch {
56 pub proposals: Vec<SpeculativeProposal>,
57}
58
59impl SpeculativeProposalBatch {
60 pub fn new(proposals: Vec<SpeculativeProposal>) -> Self {
61 Self { proposals }
62 }
63}
64
65pub trait SpeculativeProposer {
66 fn proposal_len(&self) -> usize;
67
68 fn propose(
69 &mut self,
70 ctx: SpeculativeProposeBatchCtx<'_>,
71 target_embedder: Option<&TargetTokenEmbedder<'_>>,
72 ) -> Result<SpeculativeProposalBatch>;
73}