use std::pin::Pin;
use std::sync::Arc;
use anyhow::Result;
use dynamo_async_openai::types::{
ChatCompletionMessageContent, ChatCompletionRequestAssistantMessage,
ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage,
};
use futures::Stream;
use futures::stream::StreamExt;
use minijinja::value::Value;
use dynamo_runtime::engine::AsyncEngine;
use dynamo_runtime::pipeline::{Context as PipelineContext, Error, ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
use crate::preprocessor::prompt::{OAIChatLikeRequest, OAIPromptFormatter};
use crate::protocols::common::llm_backend::{BackendOutput, PreprocessedRequest};
use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use crate::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
};
use crate::tokenizers::traits::Tokenizer;
pub struct SpeculativePrefillRequest {
messages: Vec<ChatCompletionRequestMessage>,
}
impl SpeculativePrefillRequest {
pub fn new(messages: Vec<ChatCompletionRequestMessage>) -> Self {
Self { messages }
}
}
impl OAIChatLikeRequest for SpeculativePrefillRequest {
fn model(&self) -> String {
"speculative_prefill".to_string()
}
fn messages(&self) -> Value {
let json = serde_json::to_value(&self.messages).unwrap();
Value::from_serialize(&json)
}
fn typed_messages(&self) -> Option<&[ChatCompletionRequestMessage]> {
Some(&self.messages)
}
fn should_add_generation_prompt(&self) -> bool {
false
}
}
pub fn maybe_wrap_stream(
stream: Pin<Box<dyn Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send>>,
request: &NvCreateChatCompletionRequest,
next: &Arc<
dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
>,
formatter: &Arc<dyn OAIPromptFormatter>,
tokenizer: &Arc<dyn Tokenizer>,
) -> Pin<Box<dyn Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send>> {
let enabled = request
.nvext
.as_ref()
.and_then(|ext| ext.agent_hints.as_ref())
.and_then(|hints| hints.speculative_prefill)
.unwrap_or(false);
if !enabled {
return stream;
}
let (tx, rx) = tokio::sync::oneshot::channel::<String>();
let next = next.clone();
let formatter = formatter.clone();
let tokenizer = tokenizer.clone();
let messages = request.inner.messages.clone();
tokio::spawn(async move {
let Ok(response_text) = rx.await else {
return;
};
if let Err(e) = prefill_task(next, formatter, tokenizer, messages, response_text).await {
tracing::warn!(error = %e, "Speculative prefill failed");
}
});
let mut accumulated_text = String::new();
let mut prefill_tx = Some(tx);
Box::pin(stream.map(move |item| {
if let Some(ref resp) = item.data {
for choice in &resp.choices {
if let Some(ChatCompletionMessageContent::Text(ref text)) = choice.delta.content {
accumulated_text.push_str(text);
}
if choice.finish_reason.is_some()
&& let Some(tx) = prefill_tx.take()
{
let _ = tx.send(accumulated_text.clone());
}
}
}
item
}))
}
async fn prefill_task(
next: Arc<
dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
>,
formatter: Arc<dyn OAIPromptFormatter>,
tokenizer: Arc<dyn Tokenizer>,
original_messages: Vec<ChatCompletionRequestMessage>,
response_text: String,
) -> Result<()> {
let assistant_msg =
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
response_text,
)),
..Default::default()
});
let mut messages = original_messages;
messages.push(assistant_msg);
let prefill_request = SpeculativePrefillRequest::new(messages);
let formatted_prompt = formatter.render(&prefill_request)?;
let encoding = tokenizer.encode(&formatted_prompt)?;
let token_ids = encoding.token_ids().to_vec();
tracing::info!(
num_tokens = token_ids.len(),
"Speculative prefill: sending next-turn prefix"
);
let preprocessed = PreprocessedRequest::builder()
.model("speculative_prefill".to_string())
.token_ids(token_ids)
.stop_conditions(StopConditions {
max_tokens: Some(1),
..Default::default()
})
.sampling_options(SamplingOptions::default())
.output_options(OutputOptions::default())
.eos_token_ids(vec![])
.annotations(vec![])
.build()?;
let context = PipelineContext::with_id(preprocessed, uuid::Uuid::new_v4().to_string());
if let Ok(mut stream) = next.generate(context).await {
while stream.next().await.is_some() {}
}
Ok(())
}