1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
//! Core abstractions for model-level speculative decoding.
//!
//! The speculative decoding protocol of Leviathan et al. (2023) and
//! Chen et al. (2023) is driven by two cooperating models:
//!
//! * A cheap **draft model** that extends the current prefix by *k* candidate
//! tokens (cf. [`DraftModel`]).
//! * An expensive **target model** that, in a single parallel forward pass,
//! scores every prefix-continuation of the draft (cf. [`TargetModel`]).
//!
//! The engine then runs the Bernoulli acceptance test
//! `accept = min(1, p_target / p_draft)` position by position, re-sampling
//! the first rejection from the adjusted target distribution
//! `max(0, p_target - p_draft)` (see `acceptance.rs`). Because this trait
//! layer never references probabilities in linear space directly — the engine
//! converts between log-probs and probs at the call sites — we can host the
//! draft/target models on CPU or GPU, eager or graph-compiled, without
//! leaking those concerns into the acceptance math.
//!
//! ## Why full distributions, not just token log-probs
//!
//! The naive signature
//! `propose(prefix, k) -> Vec<(TokenId, LogProb)>`
//! collapses each step into a single `(token, logprob)` pair. That is
//! insufficient: the adjusted re-sampling distribution
//! `max(0, p_target - p_draft)` is defined over the **entire vocabulary**,
//! so both draft and target must return full per-position distributions.
//! The trait shapes encode this explicitly.
//!
//! ## Invariants enforced by [`DraftProposal`] / [`TargetScores`]
//!
//! * `tokens.len() == distributions.len() == k`.
//! * `distributions[i].len() == vocab_size` for the configured vocab.
//! * Every `LogProb` row is normalized (log-sum-exp ≈ 0). The traits do not
//! re-normalize — it is the implementation's responsibility.
//!
//! The engine defensively checks shapes at runtime and short-circuits with a
//! [`crate::speculative_decoding::SpeculativeDecodingError`] if anything is malformed.
use crateSpeculativeDecodingResult;
/// Vocabulary-scoped token identifier.
///
/// Matches the convention used by `rule_guided_decoder::TokenId` so the two
/// decoders can share mappers in future work.
pub type TokenId = usize;
/// Natural-log probability. We deliberately use a type alias rather than a
/// newtype so callers can freely mix with `f64` arithmetic; the engine is the
/// only place where domains matter (log vs. linear) and it converts locally.
pub type LogProb = f64;
/// Output of a single [`DraftModel::propose`] call.
///
/// Fields are aligned index-wise: `tokens[i]` is the draft's sampled token at
/// step *i*, `token_logprobs[i]` is its log-probability under the draft, and
/// `distributions[i]` is the draft's **full** log-probability row over the
/// vocabulary for that step (needed by the engine for the rejection test and
/// the adjusted re-sampling).
/// Output of a single [`TargetModel::verify`] call.
///
/// For `k` draft tokens the target must return `k + 1` distributions: the
/// first `k` at the draft-covered positions (used by the acceptance test),
/// plus one **bonus** distribution at position `k + 1` that the engine uses
/// if every draft token is accepted — see Leviathan et al. 2023 §3.2.
/// A model capable of *cheaply* extending a prefix by `k` tokens while
/// exposing full vocabulary distributions at every step.
///
/// Implementations must be deterministic w.r.t. the supplied RNG so that the
/// engine's empirical-distribution tests are reproducible.
/// A model that, given a prefix and up to `k` draft continuations, returns
/// per-position distributions (as log-probs) in a single forward pass.