use std::collections::BTreeMap;
use std::collections::HashMap;
use std::ops::Deref;
use std::sync::Arc;
use async_trait::async_trait;
use futures::stream::BoxStream;
use tokio::sync::watch;
use crate::error::DynamoError;
pub use dynamo_llm::kv_router::publisher::KvEventPublisher;
pub use dynamo_llm::protocols::common::llm_backend::LLMEngineOutput;
pub use dynamo_llm::protocols::common::preprocessor::{
BootstrapInfo, PrefillResult, PreprocessedRequest,
};
pub use dynamo_llm::protocols::common::{
FinishReason, OutputOptions, SamplingOptions, StopConditions,
};
pub use dynamo_protocols::types::CompletionUsage;
pub use dynamo_runtime::engine::AsyncEngineContext;
pub struct GenerateContext {
inner: Arc<dyn AsyncEngineContext>,
first_token: Option<watch::Sender<bool>>,
metadata: BTreeMap<String, String>,
}
impl GenerateContext {
pub fn new(
inner: Arc<dyn AsyncEngineContext>,
first_token: Option<watch::Sender<bool>>,
) -> Self {
Self {
inner,
first_token,
metadata: BTreeMap::new(),
}
}
pub fn with_metadata(
inner: Arc<dyn AsyncEngineContext>,
first_token: Option<watch::Sender<bool>>,
metadata: BTreeMap<String, String>,
) -> Self {
Self {
inner,
first_token,
metadata,
}
}
pub fn inner_arc(&self) -> Arc<dyn AsyncEngineContext> {
self.inner.clone()
}
pub fn notify_first_token(&self) {
if let Some(tx) = &self.first_token {
let _ = tx.send(true);
}
}
pub fn first_token_sender(&self) -> Option<&watch::Sender<bool>> {
self.first_token.as_ref()
}
pub fn metadata(&self) -> &BTreeMap<String, String> {
&self.metadata
}
}
impl Deref for GenerateContext {
type Target = dyn AsyncEngineContext;
fn deref(&self) -> &Self::Target {
&*self.inner
}
}
#[derive(Clone, Debug, Default)]
pub struct EngineConfig {
pub model: String,
pub served_model_name: Option<String>,
pub context_length: Option<u32>,
pub kv_cache_block_size: Option<u32>,
pub total_kv_blocks: Option<u64>,
pub max_num_seqs: Option<u64>,
pub max_num_batched_tokens: Option<u64>,
pub data_parallel_size: Option<u32>,
pub data_parallel_start_rank: Option<u32>,
pub bootstrap_host: Option<String>,
pub bootstrap_port: Option<u16>,
pub runtime_data: HashMap<String, serde_json::Value>,
}
#[async_trait]
pub trait LLMEngine: Send + Sync + 'static {
async fn start(&self, worker_id: u64) -> Result<EngineConfig, DynamoError>;
async fn generate(
&self,
request: PreprocessedRequest,
ctx: GenerateContext,
) -> Result<BoxStream<'static, Result<LLMEngineOutput, DynamoError>>, DynamoError>;
async fn abort(&self, _ctx: Arc<dyn AsyncEngineContext>) {}
async fn drain(&self) -> Result<(), DynamoError> {
Ok(())
}
async fn cleanup(&self) -> Result<(), DynamoError>;
async fn kv_event_sources(&self) -> Result<Vec<KvEventSource>, DynamoError> {
Ok(Vec::new())
}
async fn setup_metrics(&self, _ctx: MetricsCtx<'_>) -> Result<MetricsBindings, DynamoError> {
Ok(MetricsBindings::default())
}
async fn health_check_payload(&self) -> Result<Option<serde_json::Value>, DynamoError> {
Ok(None)
}
async fn supported_controls(&self) -> Result<Vec<String>, DynamoError> {
Ok(Vec::new())
}
async fn engine_control(
&self,
control: String,
_body: serde_json::Value,
) -> Result<serde_json::Value, DynamoError> {
Ok(serde_json::json!({
"status": "error",
"message": format!("unsupported engine control: {control}"),
}))
}
}
pub const HEALTH_CHECK_KEY: &str = "_HEALTH_CHECK";
pub type OnPublisherReady =
Box<dyn FnOnce(Arc<KvEventPublisher>) -> Result<(), DynamoError> + Send + 'static>;
pub enum KvEventSource {
Zmq {
endpoint: String,
topic: String,
dp_rank: u32,
},
Push {
on_ready: OnPublisherReady,
dp_rank: u32,
},
}
impl KvEventSource {
pub fn dp_rank(&self) -> u32 {
match self {
KvEventSource::Zmq { dp_rank, .. } | KvEventSource::Push { dp_rank, .. } => *dp_rank,
}
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct Metrics {
pub kv_used_blocks: Option<u64>,
}
#[derive(Clone, Copy, Debug, Default, PartialEq)]
pub struct ComponentSnapshot {
pub kv_used_blocks: u64,
pub kv_total_blocks: u64,
pub gpu_cache_usage: f32,
pub kv_cache_hit_rate: Option<f32>,
pub dp_rank: u32,
}
pub struct MetricsCtx<'a> {
pub model: &'a str,
pub component: &'a str,
pub model_load_time_seconds: f64,
pub metrics: &'a crate::metrics::EngineMetrics,
}
pub type OnSnapshotPublisherReady = Box<
dyn FnOnce(Arc<crate::snapshot_publisher::SnapshotPublisher>) -> Result<(), DynamoError>
+ Send
+ 'static,
>;
#[derive(Default)]
pub struct MetricsBindings {
pub dp_ranks: Vec<u32>,
pub on_publisher_ready: Option<OnSnapshotPublisherReady>,
}
pub mod chunk {
use super::LLMEngineOutput;
pub fn token(id: u32) -> LLMEngineOutput {
LLMEngineOutput {
token_ids: vec![id],
..Default::default()
}
}
}
pub trait LLMEngineOutputExt: Sized {
fn with_tokens(self, tokens: Vec<u32>) -> Self;
fn with_usage(self, usage: CompletionUsage) -> Self;
}
impl LLMEngineOutputExt for LLMEngineOutput {
fn with_tokens(mut self, tokens: Vec<u32>) -> Self {
self.token_ids = tokens;
self
}
fn with_usage(mut self, usage: CompletionUsage) -> Self {
self.completion_usage = Some(usage);
self
}
}
pub fn usage(prompt_tokens: u32, completion_tokens: u32) -> CompletionUsage {
CompletionUsage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens.saturating_add(completion_tokens),
prompt_tokens_details: None,
completion_tokens_details: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chunk_token_sets_only_token_ids() {
let c = chunk::token(42);
assert_eq!(c.token_ids, vec![42]);
assert!(c.finish_reason.is_none());
assert!(c.completion_usage.is_none());
}
#[test]
fn ext_with_tokens_and_with_usage() {
let terminal = LLMEngineOutput::length()
.with_tokens(vec![1, 2, 3])
.with_usage(usage(10, 3));
assert_eq!(terminal.token_ids, vec![1, 2, 3]);
assert!(matches!(terminal.finish_reason, Some(FinishReason::Length)));
assert_eq!(terminal.completion_usage.unwrap().total_tokens, 13);
}
#[test]
fn usage_sums_totals() {
let u = usage(7, 11);
assert_eq!(u.total_tokens, 18);
}
#[test]
fn usage_saturates_on_overflow() {
let u = usage(u32::MAX, 10);
assert_eq!(u.total_tokens, u32::MAX);
}
}