use std::sync::Arc;
use vil_llm::{ChatMessage, LlmProvider};
use vil_log::app_log;
use vil_llm::message::LlmError;
use crate::config::SpeculativeConfig;
use crate::draft::{DraftProvider, DraftError};
use crate::verifier::verify_draft;
#[derive(Debug, Clone)]
pub struct SpeculativeResult {
pub content: String,
pub draft_tokens: usize,
pub accepted_tokens: usize,
pub acceptance_rate: f32,
pub speedup: f32,
}
#[derive(Debug)]
pub enum SpeculativeError {
Draft(DraftError),
Target(LlmError),
}
impl std::fmt::Display for SpeculativeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Draft(e) => write!(f, "draft error: {}", e),
Self::Target(e) => write!(f, "target error: {}", e),
}
}
}
impl std::error::Error for SpeculativeError {}
impl From<DraftError> for SpeculativeError {
fn from(e: DraftError) -> Self {
Self::Draft(e)
}
}
impl From<LlmError> for SpeculativeError {
fn from(e: LlmError) -> Self {
Self::Target(e)
}
}
pub struct SpeculativeDecoder {
pub draft: Arc<dyn DraftProvider>,
pub target: Arc<dyn LlmProvider>,
pub config: SpeculativeConfig,
}
impl SpeculativeDecoder {
pub fn new(
draft: Arc<dyn DraftProvider>,
target: Arc<dyn LlmProvider>,
config: SpeculativeConfig,
) -> Self {
Self { draft, target, config }
}
pub async fn decode(
&self,
messages: &[ChatMessage],
) -> Result<SpeculativeResult, SpeculativeError> {
let mut output = String::new();
let mut total_draft = 0usize;
let mut total_accepted = 0usize;
let mut iterations = 0usize;
loop {
if iterations >= self.config.max_iterations {
app_log!(Debug, "speculative_decode", { event: "max_iterations", max: self.config.max_iterations });
break;
}
if output.len() >= self.config.max_total_tokens {
app_log!(Debug, "speculative_decode", { event: "max_tokens", max: self.config.max_total_tokens });
break;
}
iterations += 1;
let mut context = messages.to_vec();
if !output.is_empty() {
context.push(ChatMessage::assistant(&output));
}
let draft_tokens = self
.draft
.draft(&context, self.config.max_draft_tokens)
.await?;
if draft_tokens.is_empty() {
app_log!(Debug, "speculative_decode", { event: "draft_empty" });
break;
}
let n_draft = draft_tokens.len();
total_draft += n_draft;
let verification = verify_draft(&self.target, &context, &draft_tokens).await?;
if verification.accepted > 0 {
let accepted_text: String = draft_tokens[..verification.accepted].join("");
output.push_str(&accepted_text);
total_accepted += verification.accepted;
app_log!(Debug, "speculative_decode", { accepted: verification.accepted, drafted: n_draft });
}
if verification.accepted < n_draft {
if !verification.target_content.is_empty() {
output.push_str(&verification.target_content);
}
app_log!(Debug, "speculative_decode", { event: "diverged", position: verification.accepted });
}
if verification.accepted == n_draft && verification.target_content.is_empty() {
app_log!(Debug, "speculative_decode", { event: "all_accepted_done" });
break;
}
if verification.target_content.is_empty() && verification.accepted == 0 {
app_log!(Debug, "speculative_decode", { event: "no_progress_done" });
break;
}
}
let acceptance_rate = if total_draft > 0 {
total_accepted as f32 / total_draft as f32
} else {
0.0
};
let speedup = if total_draft > 0 && total_accepted > 0 {
let tokens_generated = total_accepted as f32;
let target_calls = (total_draft as f32 / self.config.max_draft_tokens as f32).ceil();
if target_calls > 0.0 {
tokens_generated / target_calls
} else {
1.0
}
} else {
1.0
};
app_log!(Info, "speculative_decode_complete", {
content_len: output.len(),
total_draft: total_draft,
total_accepted: total_accepted
});
Ok(SpeculativeResult {
content: output,
draft_tokens: total_draft,
accepted_tokens: total_accepted,
acceptance_rate,
speedup,
})
}
}