pub mod config;
pub mod decoder;
pub mod draft;
pub mod verifier;
pub use config::SpeculativeConfig;
pub use decoder::{SpeculativeDecoder, SpeculativeError, SpeculativeResult};
pub use draft::{DraftError, DraftProvider};
pub use verifier::{verify_draft, VerificationResult};
pub mod handlers;
pub mod pipeline_sse;
pub mod plugin;
pub mod semantic;
pub use plugin::SpeculativePlugin;
pub use semantic::{SpeculativeEvent, SpeculativeFault, SpeculativeState};
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use vil_llm::message::{ChatMessage, ChatResponse, LlmError};
use vil_llm::provider::LlmProvider;
struct MockDraft {
tokens: Vec<Vec<String>>,
call_count: AtomicUsize,
}
impl MockDraft {
fn new(tokens: Vec<Vec<String>>) -> Self {
Self {
tokens,
call_count: AtomicUsize::new(0),
}
}
}
#[async_trait]
impl DraftProvider for MockDraft {
async fn draft(
&self,
_messages: &[ChatMessage],
n_tokens: usize,
) -> Result<Vec<String>, DraftError> {
let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
if idx < self.tokens.len() {
let t = &self.tokens[idx];
Ok(t[..t.len().min(n_tokens)].to_vec())
} else {
Ok(vec![])
}
}
fn model_name(&self) -> &str {
"mock-draft"
}
}
struct MockTarget {
responses: Vec<String>,
call_count: AtomicUsize,
}
impl MockTarget {
fn new(responses: Vec<String>) -> Self {
Self {
responses,
call_count: AtomicUsize::new(0),
}
}
}
#[async_trait]
impl LlmProvider for MockTarget {
async fn chat(&self, _messages: &[ChatMessage]) -> Result<ChatResponse, LlmError> {
let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
let content = if idx < self.responses.len() {
self.responses[idx].clone()
} else {
String::new()
};
Ok(ChatResponse {
content,
model: "mock-target".to_string(),
tool_calls: None,
usage: None,
finish_reason: Some("stop".to_string()),
})
}
fn model(&self) -> &str {
"mock-target"
}
fn provider_name(&self) -> &str {
"mock"
}
}
#[tokio::test]
async fn test_full_decode_with_mock() {
let draft = Arc::new(MockDraft::new(vec![vec!["Hello".into(), " world".into()]]));
let target = Arc::new(MockTarget::new(vec!["".into()]));
let decoder = SpeculativeDecoder::new(
draft,
target,
SpeculativeConfig::new()
.max_draft_tokens(2)
.max_iterations(2),
);
let result = decoder
.decode(&[ChatMessage::user("Say hello")])
.await
.unwrap();
assert!(result.draft_tokens > 0);
}
#[tokio::test]
async fn test_high_acceptance_rate() {
let draft = Arc::new(MockDraft::new(vec![vec!["The".into(), " cat".into()]]));
let target = Arc::new(MockTarget::new(vec![
"The cat".into(), ]));
let decoder = SpeculativeDecoder::new(
draft,
target,
SpeculativeConfig::new()
.max_draft_tokens(2)
.max_iterations(2),
);
let result = decoder
.decode(&[ChatMessage::user("Complete the sentence")])
.await
.unwrap();
assert_eq!(result.accepted_tokens, 2);
assert_eq!(result.draft_tokens, 2);
assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
assert!(result.content.contains("The"));
assert!(result.content.contains("cat"));
}
#[tokio::test]
async fn test_zero_acceptance_all_rejected() {
let draft = Arc::new(MockDraft::new(vec![vec!["foo".into(), "bar".into()]]));
let target = Arc::new(MockTarget::new(vec!["completely different".into()]));
let decoder = SpeculativeDecoder::new(
draft,
target,
SpeculativeConfig::new()
.max_draft_tokens(2)
.max_iterations(2),
);
let result = decoder.decode(&[ChatMessage::user("test")]).await.unwrap();
assert_eq!(result.accepted_tokens, 0);
assert_eq!(result.draft_tokens, 2);
assert!((result.acceptance_rate - 0.0).abs() < f32::EPSILON);
assert!(result.content.contains("completely different"));
}
#[tokio::test]
async fn test_empty_input() {
let draft = Arc::new(MockDraft::new(vec![vec!["Hello".into()]]));
let target = Arc::new(MockTarget::new(vec!["Hello".into()]));
let decoder = SpeculativeDecoder::new(
draft,
target,
SpeculativeConfig::new()
.max_draft_tokens(1)
.max_iterations(2),
);
let result = decoder.decode(&[]).await.unwrap();
assert!(!result.content.is_empty());
}
#[tokio::test]
async fn test_config_builder() {
let config = SpeculativeConfig::new()
.max_draft_tokens(8)
.max_total_tokens(512)
.max_iterations(50);
assert_eq!(config.max_draft_tokens, 8);
assert_eq!(config.max_total_tokens, 512);
assert_eq!(config.max_iterations, 50);
}
#[tokio::test]
async fn test_config_defaults() {
let config = SpeculativeConfig::default();
assert_eq!(config.max_draft_tokens, 5);
assert_eq!(config.max_total_tokens, 256);
assert_eq!(config.max_iterations, 100);
}
#[tokio::test]
async fn test_partial_acceptance() {
let draft = Arc::new(MockDraft::new(vec![vec![
"The".into(),
" quick".into(),
" fox".into(),
]]));
let target = Arc::new(MockTarget::new(vec![
"The quick brown".into(), ]));
let decoder = SpeculativeDecoder::new(
draft,
target,
SpeculativeConfig::new()
.max_draft_tokens(3)
.max_iterations(2),
);
let result = decoder.decode(&[ChatMessage::user("test")]).await.unwrap();
assert_eq!(result.accepted_tokens, 2);
assert_eq!(result.draft_tokens, 3);
assert!(result.content.contains("The quick"));
}
#[tokio::test]
async fn test_multi_iteration_decode() {
let draft = Arc::new(MockDraft::new(vec![
vec!["Hello".into(), " ".into()],
vec!["world".into(), "!".into()],
]));
let target = Arc::new(MockTarget::new(vec![
"Hello ".into(), "world!".into(), ]));
let decoder = SpeculativeDecoder::new(
draft,
target,
SpeculativeConfig::new()
.max_draft_tokens(2)
.max_iterations(5),
);
let result = decoder.decode(&[ChatMessage::user("greet")]).await.unwrap();
assert_eq!(result.accepted_tokens, 4);
assert_eq!(result.draft_tokens, 4);
assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
}
}