use std::ops::Deref;
use std::sync::Arc;
use async_trait::async_trait;
use futures::stream::BoxStream;
use tokio::sync::watch;
use crate::error::DynamoError;
pub use dynamo_llm::protocols::common::llm_backend::LLMEngineOutput;
pub use dynamo_llm::protocols::common::preprocessor::{
BootstrapInfo, PrefillResult, PreprocessedRequest,
};
pub use dynamo_llm::protocols::common::{
FinishReason, OutputOptions, SamplingOptions, StopConditions,
};
pub use dynamo_protocols::types::CompletionUsage;
pub use dynamo_runtime::engine::AsyncEngineContext;
pub struct GenerateContext {
inner: Arc<dyn AsyncEngineContext>,
first_token: Option<watch::Sender<bool>>,
}
impl GenerateContext {
pub fn new(
inner: Arc<dyn AsyncEngineContext>,
first_token: Option<watch::Sender<bool>>,
) -> Self {
Self { inner, first_token }
}
pub fn inner_arc(&self) -> Arc<dyn AsyncEngineContext> {
self.inner.clone()
}
pub fn notify_first_token(&self) {
if let Some(tx) = &self.first_token {
let _ = tx.send(true);
}
}
pub fn first_token_sender(&self) -> Option<&watch::Sender<bool>> {
self.first_token.as_ref()
}
}
impl Deref for GenerateContext {
type Target = dyn AsyncEngineContext;
fn deref(&self) -> &Self::Target {
&*self.inner
}
}
#[derive(Clone, Debug, Default)]
pub struct EngineConfig {
pub model: String,
pub served_model_name: Option<String>,
pub context_length: Option<u32>,
pub kv_cache_block_size: Option<u32>,
pub total_kv_blocks: Option<u64>,
pub max_num_seqs: Option<u64>,
pub max_num_batched_tokens: Option<u64>,
pub bootstrap_host: Option<String>,
pub bootstrap_port: Option<u16>,
}
#[async_trait]
pub trait LLMEngine: Send + Sync + 'static {
async fn start(&self, worker_id: u64) -> Result<EngineConfig, DynamoError>;
async fn generate(
&self,
request: PreprocessedRequest,
ctx: GenerateContext,
) -> Result<BoxStream<'static, Result<LLMEngineOutput, DynamoError>>, DynamoError>;
async fn abort(&self, _ctx: Arc<dyn AsyncEngineContext>) {}
async fn drain(&self) -> Result<(), DynamoError> {
Ok(())
}
async fn cleanup(&self) -> Result<(), DynamoError>;
}
pub mod chunk {
use super::LLMEngineOutput;
pub fn token(id: u32) -> LLMEngineOutput {
LLMEngineOutput {
token_ids: vec![id],
..Default::default()
}
}
}
pub trait LLMEngineOutputExt: Sized {
fn with_tokens(self, tokens: Vec<u32>) -> Self;
fn with_usage(self, usage: CompletionUsage) -> Self;
}
impl LLMEngineOutputExt for LLMEngineOutput {
fn with_tokens(mut self, tokens: Vec<u32>) -> Self {
self.token_ids = tokens;
self
}
fn with_usage(mut self, usage: CompletionUsage) -> Self {
self.completion_usage = Some(usage);
self
}
}
pub fn usage(prompt_tokens: u32, completion_tokens: u32) -> CompletionUsage {
CompletionUsage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens.saturating_add(completion_tokens),
prompt_tokens_details: None,
completion_tokens_details: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chunk_token_sets_only_token_ids() {
let c = chunk::token(42);
assert_eq!(c.token_ids, vec![42]);
assert!(c.finish_reason.is_none());
assert!(c.completion_usage.is_none());
}
#[test]
fn ext_with_tokens_and_with_usage() {
let terminal = LLMEngineOutput::length()
.with_tokens(vec![1, 2, 3])
.with_usage(usage(10, 3));
assert_eq!(terminal.token_ids, vec![1, 2, 3]);
assert!(matches!(terminal.finish_reason, Some(FinishReason::Length)));
assert_eq!(terminal.completion_usage.unwrap().total_tokens, 13);
}
#[test]
fn usage_sums_totals() {
let u = usage(7, 11);
assert_eq!(u.total_tokens, 18);
}
#[test]
fn usage_saturates_on_overflow() {
let u = usage(u32::MAX, 10);
assert_eq!(u.total_tokens, u32::MAX);
}
}