use std::sync::Arc;
use vil_llm::message::LlmError;
use vil_llm::{ChatMessage, ChatResponse, LlmProvider};
use vil_log::app_log;
#[derive(Debug, Clone)]
pub struct VerificationResult {
pub accepted: usize,
pub target_content: String,
}
pub async fn verify_draft(
target: &Arc<dyn LlmProvider>,
messages: &[ChatMessage],
draft_tokens: &[String],
) -> Result<VerificationResult, LlmError> {
if draft_tokens.is_empty() {
return Ok(VerificationResult {
accepted: 0,
target_content: String::new(),
});
}
let draft_text: String = draft_tokens.join("");
let mut verify_messages = messages.to_vec();
verify_messages.push(ChatMessage::assistant(&draft_text));
let response: ChatResponse = target.chat(&verify_messages).await?;
let target_content = response.content;
let accepted = count_matching_prefix(draft_tokens, &target_content);
app_log!(Debug, "speculative_verify", { draft_count: draft_tokens.len(), accepted: accepted });
Ok(VerificationResult {
accepted,
target_content,
})
}
fn count_matching_prefix(draft_tokens: &[String], target_text: &str) -> usize {
let mut pos = 0;
let mut accepted = 0;
for token in draft_tokens {
if pos + token.len() > target_text.len() {
break;
}
if &target_text[pos..pos + token.len()] == token.as_str() {
pos += token.len();
accepted += 1;
} else {
break;
}
}
accepted
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matching_prefix_all_match() {
let tokens = vec!["Hello".into(), " world".into(), "!".into()];
assert_eq!(count_matching_prefix(&tokens, "Hello world!"), 3);
}
#[test]
fn test_matching_prefix_partial() {
let tokens = vec!["Hello".into(), " world".into(), "!".into()];
assert_eq!(count_matching_prefix(&tokens, "Hello xyz"), 1);
}
#[test]
fn test_matching_prefix_none() {
let tokens = vec!["Hello".into()];
assert_eq!(count_matching_prefix(&tokens, "Goodbye"), 0);
}
#[test]
fn test_matching_prefix_empty() {
let tokens: Vec<String> = vec![];
assert_eq!(count_matching_prefix(&tokens, "anything"), 0);
}
}