use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::backend::ExecutionContext;
use crate::kv_router::publisher::{KvEventPublisher, KvEventSourceConfig, WorkerMetricsPublisher};
use crate::protocols::TokenIdType;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
use anyhow::Result;
use bytes::Bytes;
use dashmap::DashMap;
use dynamo_kv_router::protocols::{KvCacheEvent, KvCacheEventData};
use dynamo_mocker::common::bootstrap::{BootstrapServer, connect_to_prefill};
use dynamo_mocker::common::protocols::{
DirectRequest, KvCacheEventSink, KvEventPublishers, MockEngineArgs, OutputSignal, RawKvEvent,
RawKvEventSink,
};
use dynamo_mocker::common::utils::sleep_precise;
use dynamo_mocker::engine::create_engine;
use dynamo_mocker::scheduler::SchedulerHandle;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_runtime::{
component::Component,
engine::AsyncEngineContextProvider,
pipeline::{AsyncEngine, Error, ManyOut, ResponseStream, SingleIn, async_trait},
traits::DistributedRuntimeProvider,
};
use futures::StreamExt;
use rand::Rng;
use serde::Serialize;
use tokio::sync::{Notify, OnceCell, mpsc};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use crate::utils::zmq::{
bind_pub_socket, bind_router_socket, multipart_message, send_multipart, send_multipart_direct,
};
pub const MOCKER_COMPONENT: &str = "mocker";
struct KvEventSinkAdapter(KvEventPublisher);
impl KvCacheEventSink for KvEventSinkAdapter {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> {
self.0
.publish(event)
.map_err(|e| anyhow::anyhow!("Failed to send KV event: {}", e))
}
}
#[derive(Serialize)]
#[serde(tag = "type")]
enum ZmqRawKvEvent {
BlockStored {
block_hashes: Vec<u64>,
parent_block_hash: Option<u64>,
token_ids: Vec<u32>,
block_size: u32,
},
BlockRemoved {
block_hashes: Vec<u64>,
},
}
struct ZmqKvEventSink {
tx: mpsc::UnboundedSender<RawKvEvent>,
}
const REPLAY_BUFFER_CAPACITY: usize = 10_000;
impl ZmqKvEventSink {
async fn new(
port: u16,
replay_port: Option<u16>,
dp_rank: u32,
block_size: u32,
) -> Result<Self> {
let (tx, mut rx) = mpsc::unbounded_channel::<RawKvEvent>();
let endpoint = format!("tcp://0.0.0.0:{port}");
let pub_socket = bind_pub_socket(&endpoint)
.await
.map_err(|e| anyhow::anyhow!("ZMQ PUB bind to {endpoint} failed: {e}"))?;
tracing::info!("ZmqKvEventSink bound to {endpoint} for dp_rank {dp_rank}");
let mut router_socket = if let Some(rp) = replay_port {
let replay_ep = format!("tcp://0.0.0.0:{rp}");
let sock = bind_router_socket(&replay_ep)
.await
.map_err(|e| anyhow::anyhow!("ZMQ ROUTER bind to {replay_ep} failed: {e}"))?;
tracing::info!(
"ZmqKvEventSink replay ROUTER bound to {replay_ep} for dp_rank {dp_rank}"
);
Some(sock)
} else {
None
};
tokio::spawn(async move {
let mut seq_num: u64 = 0;
let mut ring_buffer: VecDeque<(u64, Bytes)> = VecDeque::new();
loop {
tokio::select! {
biased;
replay_result = async {
match router_socket.as_mut() {
Some(socket) => socket.next().await,
None => std::future::pending().await,
}
} => {
let req_msg = match replay_result {
Some(Ok(req_msg)) => multipart_message(req_msg),
Some(Err(error)) => {
tracing::warn!("Replay ROUTER recv error: {error}");
router_socket = None;
continue;
}
None => {
tracing::warn!("Replay ROUTER stream ended");
router_socket = None;
continue;
}
};
if req_msg.len() < 3 {
tracing::warn!("Unexpected replay request frame count: {}", req_msg.len());
continue;
}
let identity: Bytes = Bytes::copy_from_slice(req_msg.first().unwrap());
let start_seq_bytes = req_msg.get(2).unwrap();
if start_seq_bytes.len() != 8 {
tracing::warn!("Invalid replay start_seq length: {}", start_seq_bytes.len());
continue;
}
let start_seq = u64::from_be_bytes(start_seq_bytes[..8].try_into().unwrap());
tracing::debug!(dp_rank, start_seq, buffer_len = ring_buffer.len(), "Replay request received");
let start_idx = ring_buffer.front()
.map(|(first_seq, _)| start_seq.saturating_sub(*first_seq) as usize)
.unwrap_or(0)
.min(ring_buffer.len());
let sock = router_socket.as_mut().unwrap();
for (seq, payload) in ring_buffer.iter().skip(start_idx) {
let frames = vec![
identity.clone().to_vec(),
Vec::new(),
seq.to_be_bytes().to_vec(),
payload.to_vec(),
];
if let Err(e) = send_multipart_direct(sock, frames).await {
tracing::warn!("Replay send error: {e}");
break;
}
}
let sentinel_frames = vec![
identity.to_vec(),
Vec::new(),
(-1i64).to_be_bytes().to_vec(),
Vec::new(),
];
let _ = send_multipart_direct(sock, sentinel_frames).await;
}
msg_opt = rx.recv() => {
let Some(msg) = msg_opt else { break };
let events = convert_to_zmq_events(
&msg.event,
msg.block_token_ids.as_deref(),
block_size,
);
if events.is_empty() {
continue;
}
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64();
let batch: (f64, Vec<ZmqRawKvEvent>, Option<i32>) =
(timestamp, events, Some(dp_rank as i32));
let payload: Bytes = match rmp_serde::to_vec(&batch) {
Ok(p) => p.into(),
Err(e) => {
tracing::warn!("Failed to serialize ZMQ KV event: {e}");
continue;
}
};
if router_socket.is_some() {
if ring_buffer.len() >= REPLAY_BUFFER_CAPACITY {
ring_buffer.pop_front();
}
ring_buffer.push_back((seq_num, payload.clone()));
}
let frames = vec![
Vec::new(),
seq_num.to_be_bytes().to_vec(),
payload.to_vec(),
];
if let Err(e) = send_multipart(&pub_socket, frames).await {
tracing::warn!("Failed to send ZMQ KV event: {e}");
}
seq_num += 1;
}
}
}
});
Ok(Self { tx })
}
}
impl RawKvEventSink for ZmqKvEventSink {
fn publish(&self, event: RawKvEvent) -> anyhow::Result<()> {
self.tx
.send(event)
.map_err(|_| anyhow::anyhow!("ZMQ event sink channel closed"))
}
}
fn convert_to_zmq_events(
event: &KvCacheEvent,
block_token_ids: Option<&[Vec<u32>]>,
block_size: u32,
) -> Vec<ZmqRawKvEvent> {
match &event.data {
KvCacheEventData::Stored(store_data) => {
let block_hashes: Vec<u64> = store_data.blocks.iter().map(|b| b.block_hash.0).collect();
let parent_block_hash = store_data.parent_hash.map(|h| h.0);
let token_ids: Vec<u32> = block_token_ids
.map(|tids| tids.iter().flatten().copied().collect())
.unwrap_or_default();
assert_eq!(
token_ids.len(),
block_hashes.len() * block_size as usize,
"token_ids length ({}) must equal block_hashes.len() ({}) * block_size ({block_size})",
token_ids.len(),
block_hashes.len(),
);
vec![ZmqRawKvEvent::BlockStored {
block_hashes,
parent_block_hash,
token_ids,
block_size,
}]
}
KvCacheEventData::Removed(remove_data) => {
let block_hashes: Vec<u64> = remove_data.block_hashes.iter().map(|h| h.0).collect();
vec![ZmqRawKvEvent::BlockRemoved { block_hashes }]
}
KvCacheEventData::Cleared => vec![],
}
}
fn generate_random_token() -> TokenIdType {
let mut rng = rand::rng();
rng.random_range(1000..2000)
}
pub struct MockEngine {
active_requests: Arc<DashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>,
request_senders: OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>,
senders_ready: Notify,
engine_args: MockEngineArgs,
unset_dp_rank_counter: AtomicU32,
bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>,
_schedulers: OnceCell<Vec<Box<dyn SchedulerHandle>>>,
_fpm_publisher: OnceCell<crate::fpm_publisher::FpmDirectPublisher>,
}
impl MockEngine {
pub fn new(engine_args: MockEngineArgs) -> Self {
Self {
active_requests: Arc::new(DashMap::new()),
request_senders: OnceCell::new(),
senders_ready: Notify::new(),
engine_args,
unset_dp_rank_counter: AtomicU32::new(0),
bootstrap_server: Arc::new(OnceCell::new()),
_schedulers: OnceCell::new(),
_fpm_publisher: OnceCell::new(),
}
}
fn resolve_dp_rank(&self, request: &PreprocessedRequest) -> u32 {
if let Some(dp_rank) = request.routing.as_ref().and_then(|routing| routing.dp_rank) {
return dp_rank;
}
self.unset_dp_rank_counter.fetch_add(1, Ordering::Relaxed) % self.engine_args.dp_size
}
pub async fn start(&self, component: Component) -> Result<()> {
let cancel_token = component.drt().primary_token();
if let Some(startup_time_secs) = self.engine_args.startup_time {
tracing::info!("Simulating engine startup time: {:.2}s", startup_time_secs);
tokio::time::sleep(Duration::from_secs_f64(startup_time_secs)).await;
tracing::info!("Engine startup simulation completed");
}
if self.engine_args.is_prefill()
&& let Some(port) = self.engine_args.bootstrap_port
{
let server = BootstrapServer::start(port, cancel_token.clone()).await?;
let _ = self.bootstrap_server.set(server);
tracing::info!(port = port, "Bootstrap server started for prefill worker");
}
let kv_component = if self.engine_args.needs_kv_publisher() {
tracing::info!(
"Initializing KV event publisher with block_size {}, enable_local_indexer={}",
self.engine_args.block_size,
self.engine_args.enable_local_indexer
);
Some(&component)
} else {
None
};
let worker_id = component.drt().connection_id().to_string();
let fpm_sinks = match crate::fpm_publisher::FpmDirectPublisher::new(
component.clone(),
worker_id,
self.engine_args.dp_size,
)
.await
{
Ok((publisher, sinks)) => {
let _ = self._fpm_publisher.set(publisher);
sinks
}
Err(e) => {
tracing::error!("Failed to start FPM publisher: {e}");
(0..self.engine_args.dp_size)
.map(|_| dynamo_mocker::common::protocols::FpmPublisher::default())
.collect()
}
};
let schedulers = self
.start_schedulers(kv_component, cancel_token.clone(), fpm_sinks)
.await;
Self::start_metrics_publishing(&schedulers, component.clone(), cancel_token.clone())
.await?;
let _ = self._schedulers.set(schedulers);
Ok(())
}
pub async fn direct(&self, request: DirectRequest, dp_rank: usize) {
if let Some(senders) = self.request_senders.get() {
let _ = senders[dp_rank].send(request);
return;
}
let notified = self.senders_ready.notified();
if let Some(senders) = self.request_senders.get() {
let _ = senders[dp_rank].send(request);
return;
}
notified.await;
let senders = self
.request_senders
.get()
.expect("must be set after notify");
let _ = senders[dp_rank].send(request);
}
async fn start_schedulers(
&self,
component: Option<&Component>,
cancel_token: CancellationToken,
fpm_sinks: Vec<dynamo_mocker::common::protocols::FpmPublisher>,
) -> Vec<Box<dyn SchedulerHandle>> {
let args = &self.engine_args;
let mut schedulers = Vec::<Box<dyn SchedulerHandle>>::new();
let mut senders = Vec::with_capacity(args.dp_size as usize);
for (dp_rank, fpm_publisher) in (0..args.dp_size).zip(fpm_sinks) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let (kv_event_publishers, relay_publisher): (
KvEventPublishers,
Option<KvEventPublisher>,
) = match component {
Some(comp) if args.zmq_kv_events_port.is_some() => {
let zmq_port = args.zmq_kv_events_port.unwrap() + dp_rank as u16;
let replay_port = args.zmq_replay_port.map(|p| p + dp_rank as u16);
match ZmqKvEventSink::new(
zmq_port,
replay_port,
dp_rank,
args.block_size as u32,
)
.await
{
Ok(sink) => {
let source_config = Some(KvEventSourceConfig::Zmq {
endpoint: format!("tcp://127.0.0.1:{zmq_port}"),
topic: String::new(),
});
match KvEventPublisher::new_with_local_indexer(
comp.clone(),
args.block_size as u32,
source_config,
args.enable_local_indexer,
dp_rank,
None,
) {
Ok(publisher) => (
KvEventPublishers::new(
None,
Some(Arc::new(sink) as Arc<dyn RawKvEventSink>),
),
Some(publisher),
),
Err(e) => {
tracing::error!(
"Failed to create KV event relay for dp_rank {dp_rank}: {e}"
);
(KvEventPublishers::default(), None)
}
}
}
Err(e) => {
tracing::error!(
"Failed to create ZMQ KV event sink for dp_rank {dp_rank}: {e}"
);
(KvEventPublishers::default(), None)
}
}
}
Some(comp) => {
match KvEventPublisher::new_with_local_indexer(
comp.clone(),
args.block_size as u32,
None,
args.enable_local_indexer,
dp_rank,
None,
) {
Ok(publisher) => (
KvEventPublishers::new(
Some(Arc::new(KvEventSinkAdapter(publisher))
as Arc<dyn KvCacheEventSink>),
None,
),
None,
),
Err(e) => {
tracing::error!(
"Failed to create KV event publisher for dp_rank {dp_rank}: {e}"
);
(KvEventPublishers::default(), None)
}
}
}
None => (KvEventPublishers::default(), None),
};
let scheduler = create_engine(
args.clone(),
dp_rank,
Some(output_tx),
kv_event_publishers,
Some(cancel_token.clone()),
fpm_publisher,
);
senders.push(scheduler.request_sender());
schedulers.push(scheduler);
let active_requests_clone = self.active_requests.clone();
let cancel_token_cloned = cancel_token.clone();
tokio::spawn(async move {
let _relay_publisher = relay_publisher;
loop {
tokio::select! {
signal_result = output_rx.recv() => {
let Some(output_batch) = signal_result else {
break; };
for signal in output_batch {
if let Some(request_tx) = active_requests_clone.get(&signal.uuid) {
let _ = request_tx.send(signal);
}
}
}
_ = cancel_token_cloned.cancelled() => {
tracing::info!("Scheduler output task cancelled, clearing active requests");
active_requests_clone.clear();
break;
}
}
}
});
}
self.request_senders
.set(senders)
.expect("Already initialized");
self.senders_ready.notify_waiters();
schedulers
}
async fn start_metrics_publishing(
schedulers: &[Box<dyn SchedulerHandle>],
component: Component,
cancel_token: CancellationToken,
) -> Result<()> {
let metrics_publisher = Arc::new(WorkerMetricsPublisher::new()?);
if let Err(e) = metrics_publisher.create_endpoint(component).await {
tracing::error!("Metrics endpoint failed: {e}");
}
for scheduler in schedulers.iter() {
let mut metrics_rx = scheduler.metrics_receiver();
let publisher = metrics_publisher.clone();
let cancel_token = cancel_token.clone();
tokio::spawn(async move {
loop {
tokio::select! {
Ok(_) = metrics_rx.changed() => {
let metrics = metrics_rx.borrow().clone();
if let Err(e) = publisher.publish(
Some(metrics.dp_rank),
None,
Some(metrics.active_decode_blocks),
) {
tracing::warn!("Failed to publish metrics for DP rank {}: {e}", metrics.dp_rank);
} else {
tracing::trace!("Published metrics for DP rank {}", metrics.dp_rank);
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("Metrics publishing cancelled");
break;
}
}
}
});
}
tracing::info!("Metrics background tasks started");
Ok(())
}
}
#[async_trait]
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> for MockEngine {
async fn generate(
&self,
input: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<LLMEngineOutput>, Error> {
let (request, ctx) = input.into_parts();
let dp_rank = self.resolve_dp_rank(&request);
if dp_rank >= self.engine_args.dp_size {
return Err(Error::msg(format!(
"dp_rank {} is out of bounds for dp_size {}",
dp_rank, self.engine_args.dp_size
)));
}
let bootstrap_room = request.bootstrap_info.as_ref().map(|b| b.bootstrap_room);
if let Some(bootstrap_info) = &request.bootstrap_info
&& self.engine_args.is_decode()
{
connect_to_prefill(
&bootstrap_info.bootstrap_host,
bootstrap_info.bootstrap_port,
bootstrap_info.bootstrap_room,
)
.await
.map_err(|e| Error::msg(format!("Bootstrap connection failed: {e}")))?;
}
let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4());
let is_prefill = self.engine_args.is_prefill();
let max_output_tokens = if is_prefill {
1
} else {
request
.stop_conditions
.max_tokens
.ok_or_else(|| Error::msg("max_output_tokens must be specified for mocker"))?
as usize
};
let direct_request = DirectRequest {
tokens: request.token_ids.clone(),
max_output_tokens,
uuid: Some(request_uuid),
dp_rank,
arrival_timestamp_ms: request.request_timestamp_ms,
};
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<OutputSignal>();
self.active_requests.insert(request_uuid, request_tx);
self.direct(direct_request, dp_rank as usize).await;
let (stream_tx, stream_rx) = mpsc::unbounded_channel::<LLMEngineOutput>();
let active_requests = self.active_requests.clone();
let async_context = ctx.context();
let bootstrap_server = self.bootstrap_server.clone();
let reasoning = self.engine_args.reasoning.clone();
tokio::spawn(async move {
let mut token_count = 0;
let think_len = reasoning
.as_ref()
.map(|cfg| cfg.num_thinking_tokens(max_output_tokens))
.unwrap_or(0);
loop {
tokio::select! {
maybe_signal = request_rx.recv() => {
let Some(signal) = maybe_signal else {
let _ = stream_tx.send(LLMEngineOutput::error("All output transmitters closed".to_string()));
break;
};
let token_id = if token_count == 0 && think_len > 0 {
reasoning.as_ref().unwrap().start_thinking_token_id
} else if think_len > 0 && token_count == think_len - 1 {
reasoning.as_ref().unwrap().end_thinking_token_id
} else {
generate_random_token()
};
token_count += 1;
let output = LLMEngineOutput {
token_ids: vec![token_id],
disaggregated_params: is_prefill.then(|| serde_json::json!("dummy")),
..Default::default()
};
if signal.completed && token_count < max_output_tokens {
let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string()));
break;
}
if signal.completed {
let _ = stream_tx.send(output);
if is_prefill
&& let Some(delay_ms) = signal.handoff_delay_ms
{
sleep_precise(Duration::from_secs_f64(delay_ms / 1000.0)).await;
}
if is_prefill
&& let (Some(server), Some(room_id)) = (bootstrap_server.get(), bootstrap_room)
{
server.complete_room(room_id);
}
let _ = stream_tx.send(LLMEngineOutput::length());
break;
}
if stream_tx.send(output).is_err() {
tracing::error!("Output stream receiver closed.");
break;
}
}
_ = async_context.stopped() => {
let _ = stream_tx.send(LLMEngineOutput::cancelled());
break;
}
}
}
active_requests.remove(&request_uuid);
});
let stream = UnboundedReceiverStream::new(stream_rx);
Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
}
}
pub struct AnnotatedMockEngine {
inner: Arc<MockEngine>,
}
impl AnnotatedMockEngine {
pub fn new(
inner: MockEngine,
distributed_runtime: DistributedRuntime,
endpoint_id: dynamo_runtime::protocols::EndpointId,
) -> Self {
let inner = Arc::new(inner);
let inner_clone = inner.clone();
let cancel_token = distributed_runtime.primary_token();
tokio::spawn(async move {
let component = loop {
if cancel_token.is_cancelled() {
tracing::debug!("Mocker engine startup cancelled");
return;
}
let ready = distributed_runtime
.namespace(&endpoint_id.namespace)
.and_then(|ns| ns.component(&endpoint_id.component))
.ok();
if let Some(comp) = ready
&& let Ok(instances) = comp.list_instances().await
&& !instances.is_empty()
{
break comp;
}
tracing::debug!("Component service not available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
};
tracing::debug!("Component service is now available, starting mocker engine");
if let Err(e) = inner_clone.start(component).await {
tracing::error!("Failed to start mocker engine: {e}");
}
});
Self { inner }
}
}
#[async_trait]
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for AnnotatedMockEngine
{
async fn generate(
&self,
input: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let stream = self.inner.generate(input).await?;
let context = stream.context();
let annotated_stream = stream.map(Annotated::from_data);
Ok(ResponseStream::new(Box::pin(annotated_stream), context))
}
}
pub async fn make_mocker_engine(
distributed_runtime: DistributedRuntime,
endpoint_id: dynamo_runtime::protocols::EndpointId,
args: MockEngineArgs,
) -> Result<ExecutionContext, Error> {
tracing::info!("Creating mocker engine with config: {args:?}");
let annotated_engine =
AnnotatedMockEngine::new(MockEngine::new(args), distributed_runtime, endpoint_id);
Ok(Arc::new(annotated_engine))
}