1use crate::kv_router::publisher::WorkerMetricsPublisher;
10use crate::mocker::protocols::DirectRequest;
11use crate::mocker::protocols::{MockEngineArgs, OutputSignal};
12use crate::mocker::scheduler::Scheduler;
13use crate::protocols::TokenIdType;
14use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
15use dynamo_runtime::DistributedRuntime;
16use dynamo_runtime::protocols::annotated::Annotated;
17use tokio_util::sync::CancellationToken;
18
19use dynamo_runtime::{
20 Result,
21 component::Component,
22 engine::AsyncEngineContextProvider,
23 pipeline::{AsyncEngine, Error, ManyOut, ResponseStream, SingleIn, async_trait},
24 traits::DistributedRuntimeProvider,
25};
26
27use crate::kv_router::protocols::{KvCacheEvent, KvCacheEventData};
28use crate::kv_router::publisher::KvEventPublisher;
29use futures::StreamExt;
30use rand::Rng;
31use std::collections::HashMap;
32use std::sync::Arc;
33use std::time::Duration;
34use tokio::sync::{Mutex, OnceCell, mpsc};
35use tokio_stream::wrappers::UnboundedReceiverStream;
36use uuid::Uuid;
37
38pub const MOCKER_COMPONENT: &str = "mocker";
39
40fn generate_random_token() -> TokenIdType {
42 let mut rng = rand::rng();
43 rng.random_range(1000..5000)
44}
45
46#[derive(Clone)]
48pub struct MockVllmEngine {
49 active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
50 request_senders: Arc<OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>>,
51 engine_args: MockEngineArgs,
52}
53
54impl MockVllmEngine {
55 pub fn new(args: MockEngineArgs) -> Self {
57 Self {
58 active_requests: Arc::new(Mutex::new(HashMap::new())),
59 request_senders: Arc::new(OnceCell::new()),
60 engine_args: args,
61 }
62 }
63
64 pub async fn start(&self, component: Component) -> Result<()> {
65 let cancel_token = component.drt().runtime().child_token();
66
67 if let Some(startup_time_secs) = self.engine_args.startup_time {
69 tracing::info!("Simulating engine startup time: {:.2}s", startup_time_secs);
70 tokio::time::sleep(Duration::from_secs_f64(startup_time_secs)).await;
71 tracing::info!("Engine startup simulation completed");
72 }
73
74 let (schedulers, kv_event_receiver) = self.start_schedulers(
75 self.engine_args.clone(),
76 self.active_requests.clone(),
77 cancel_token.clone(),
78 );
79
80 Self::start_metrics_publishing(&schedulers, Some(component.clone()), cancel_token.clone())
81 .await?;
82
83 if self.engine_args.enable_prefix_caching {
85 Self::start_kv_events_publishing(
86 kv_event_receiver,
87 Some(component.clone()),
88 self.engine_args.block_size,
89 cancel_token.clone(),
90 )
91 .await?;
92 }
93
94 Ok(())
95 }
96
97 pub fn direct(&self, request: DirectRequest, dp_rank: usize) {
98 let senders = self.request_senders.get().expect("Not initialized");
99 let _ = senders[dp_rank].send(request);
100 }
101
102 fn start_schedulers(
105 &self,
106 args: MockEngineArgs,
107 active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
108 cancel_token: CancellationToken,
109 ) -> (
110 Vec<Scheduler>,
111 Vec<mpsc::UnboundedReceiver<KvCacheEventData>>,
112 ) {
113 let mut schedulers = Vec::<Scheduler>::new();
114 let mut kv_event_receivers = Vec::new();
115 let mut senders = Vec::with_capacity(args.dp_size as usize);
116
117 for dp_rank in 0..args.dp_size {
119 let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
121
122 let (kv_events_tx, kv_events_rx) = mpsc::unbounded_channel::<KvCacheEventData>();
124
125 let scheduler = Scheduler::new(
126 args.clone(),
127 Some(dp_rank),
128 Some(output_tx),
129 Some(kv_events_tx), Some(cancel_token.clone()),
131 );
132
133 senders.push(scheduler.request_sender());
134 schedulers.push(scheduler);
135 kv_event_receivers.push(kv_events_rx);
136
137 let active_requests_clone = active_requests.clone();
140 let cancel_token_cloned = cancel_token.clone();
141
142 tokio::spawn(async move {
143 loop {
144 tokio::select! {
145 signal_result = output_rx.recv() => {
146 let Some(signal) = signal_result else {
147 break; };
149
150 let active = active_requests_clone.lock().await;
152 if let Some(request_tx) = active.get(&signal.uuid) {
153 let _ = request_tx.send(signal);
154 }
155 }
156 _ = cancel_token_cloned.cancelled() => {
157 break;
158 }
159 }
160 }
161 });
162 }
163
164 self.request_senders
166 .set(senders)
167 .expect("Already initialized");
168
169 (schedulers, kv_event_receivers)
170 }
171
172 async fn start_metrics_publishing(
174 schedulers: &[Scheduler],
175 component: Option<Component>,
176 cancel_token: CancellationToken,
177 ) -> Result<()> {
178 tracing::debug!("Creating metrics publisher");
179 let metrics_publisher = Arc::new(WorkerMetricsPublisher::new()?);
180 tracing::debug!("Metrics publisher created");
181
182 if let Some(comp) = component {
183 tracing::debug!("Creating metrics endpoint");
184 tokio::spawn({
185 let publisher = metrics_publisher.clone();
186 async move {
187 if let Err(e) = publisher.create_endpoint(comp.clone(), None).await {
188 tracing::error!("Metrics endpoint failed: {e}");
189 }
190 }
191 });
192
193 tokio::time::sleep(Duration::from_millis(100)).await;
195 tracing::debug!("Metrics endpoint started (background)");
196 }
197
198 tracing::debug!("Starting metrics background tasks");
199 for (dp_rank, scheduler) in schedulers.iter().enumerate() {
200 let mut metrics_rx = scheduler.metrics_receiver();
201 let publisher = metrics_publisher.clone();
202 let dp_rank = dp_rank as u32;
203 let cancel_token = cancel_token.clone();
204
205 tokio::spawn(async move {
206 loop {
207 tokio::select! {
208 Ok(_) = metrics_rx.changed() => {
210 let metrics = metrics_rx.borrow().clone();
212
213 if let Err(e) = publisher.publish(Arc::new(metrics)) {
215 tracing::warn!("Failed to publish metrics for DP rank {dp_rank}: {e}");
216 } else {
217 tracing::trace!("Published metrics for DP rank {}", dp_rank);
218 }
219 }
220 _ = cancel_token.cancelled() => {
221 tracing::debug!("Metrics publishing cancelled for DP rank {dp_rank}");
222 break;
223 }
224 }
225 }
226 });
227 }
228 tracing::info!("Metrics background tasks started");
229 Ok(())
230 }
231
232 async fn start_kv_events_publishing(
234 kv_event_receivers: Vec<mpsc::UnboundedReceiver<KvCacheEventData>>,
235 component: Option<Component>,
236 block_size: usize,
237 cancel_token: CancellationToken,
238 ) -> Result<()> {
239 tracing::debug!("Starting KV events publishing");
240
241 let Some(comp) = component else {
243 tracing::warn!("No component provided, skipping KV events publishing");
244 return Ok(());
245 };
246 tracing::debug!("Component found for KV events publishing");
247
248 tracing::debug!("Getting worker_id");
249 let worker_id = comp
250 .drt()
251 .primary_lease()
252 .expect("Cannot publish KV events without lease") .id();
254 tracing::debug!("Worker_id set to: {worker_id}");
256
257 tracing::debug!("Creating KV event publisher");
258 let kv_event_publisher = Arc::new(KvEventPublisher::new(
259 comp.clone(),
260 worker_id,
261 block_size as u32,
262 None,
263 )?);
264 tracing::debug!("KV event publisher created");
265
266 tracing::debug!(
267 "Starting KV event background tasks for {} receivers",
268 kv_event_receivers.len()
269 );
270 for (dp_rank, mut kv_events_rx) in kv_event_receivers.into_iter().enumerate() {
271 tracing::debug!("Starting background task for DP rank {dp_rank}");
272 let publisher = kv_event_publisher.clone();
273 let dp_rank = dp_rank as u32;
274 let cancel_token = cancel_token.clone();
275
276 tokio::spawn(async move {
277 tracing::debug!("Background task started for DP rank {dp_rank}");
278 loop {
279 tokio::select! {
280 Some(event_data) = kv_events_rx.recv() => {
282 let event = KvCacheEvent {
284 event_id: Uuid::new_v4().as_u128() as u64,
285 data: event_data,
286 };
287
288 if let Err(e) = publisher.publish(event) {
290 tracing::warn!("Failed to publish KV event for DP rank {dp_rank}: {e}");
291 } else {
292 tracing::trace!("Published KV event for DP rank {dp_rank}");
293 }
294 }
295 _ = cancel_token.cancelled() => {
296 tracing::debug!("KV events publishing cancelled for DP rank {dp_rank}");
297 break;
298 }
299 }
300 }
301 });
302 }
303 tracing::info!("All KV event background tasks started");
304
305 Ok(())
306 }
307}
308
309#[async_trait]
310impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
311 for MockVllmEngine
312{
313 async fn generate(
314 &self,
315 input: SingleIn<PreprocessedRequest>,
316 ) -> Result<ManyOut<LLMEngineOutput>, Error> {
317 let (request, ctx) = input.into_parts();
318
319 let dp_rank = request
321 .annotations
322 .iter()
323 .find_map(|ann| {
324 if ann.starts_with("dp_rank:") {
325 ann.strip_prefix("dp_rank:").and_then(|s| s.parse().ok())
326 } else {
327 None
328 }
329 })
330 .unwrap_or(0);
331
332 if dp_rank >= self.engine_args.dp_size {
334 return Err(Error::msg(format!(
335 "dp_rank {} is out of bounds for dp_size {}",
336 dp_rank, self.engine_args.dp_size
337 )));
338 }
339
340 let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4());
341
342 let direct_request = DirectRequest {
344 tokens: request.token_ids.clone(),
345 max_output_tokens: request
346 .stop_conditions
347 .max_tokens
348 .expect("max_output_tokens must be specified for mocker")
349 as usize,
350 uuid: Some(request_uuid),
351 dp_rank: Some(dp_rank),
352 };
353
354 let (request_tx, mut request_rx) = mpsc::unbounded_channel::<OutputSignal>();
355 {
356 let mut active = self.active_requests.lock().await;
357 active.insert(request_uuid, request_tx);
358 }
359
360 self.direct(direct_request, dp_rank as usize);
362
363 let (stream_tx, stream_rx) = mpsc::unbounded_channel::<LLMEngineOutput>();
365
366 let active_requests = self.active_requests.clone();
367 let async_context = ctx.context();
368 let max_tokens = request.stop_conditions.max_tokens.unwrap_or(100) as usize;
369
370 tokio::spawn(async move {
372 let mut token_count = 0;
373
374 loop {
375 tokio::select! {
376 maybe_signal = request_rx.recv() => {
377 let Some(signal) = maybe_signal else {
378 let _ = stream_tx.send(LLMEngineOutput::error("All output transmitters closed".to_string()));
379 break;
380 };
381
382 let token_id = generate_random_token();
384 token_count += 1;
385
386 let output = LLMEngineOutput {
387 token_ids: vec![token_id],
388 tokens: None, text: None,
390 cum_log_probs: None,
391 log_probs: None,
392 top_logprobs: None,
393 finish_reason: None,
394 index: None,
395 };
396
397 if signal.completed && token_count < max_tokens {
398 let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string()));
399 break;
400 }
401
402 if signal.completed {
403 let _ = stream_tx.send(output);
404 let _ = stream_tx.send(LLMEngineOutput::length());
405 break;
406 }
407
408 if stream_tx.send(output).is_err() {
409 tracing::error!("Output stream receiver closed.");
410 break;
411 }
412 }
413
414 _ = async_context.stopped() => {
415 let _ = stream_tx.send(LLMEngineOutput::cancelled());
416 break;
417 }
418 }
419 }
420
421 let mut active = active_requests.lock().await;
423 active.remove(&request_uuid);
424 });
425
426 let stream = UnboundedReceiverStream::new(stream_rx);
428 Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
429 }
430}
431
432pub struct AnnotatedMockEngine {
433 inner: Arc<MockVllmEngine>,
434}
435
436impl AnnotatedMockEngine {
437 pub fn new(
438 inner: MockVllmEngine,
439 distributed_runtime: DistributedRuntime,
440 endpoint_id: dynamo_runtime::protocols::EndpointId,
441 ) -> Self {
442 let inner = Arc::new(inner);
443 let inner_clone = inner.clone();
444
445 tokio::spawn(async move {
447 loop {
448 let Ok(namespace) = distributed_runtime.namespace(&endpoint_id.namespace) else {
450 tracing::debug!("Namespace not available yet, retrying...");
451 tokio::time::sleep(Duration::from_millis(100)).await;
452 continue;
453 };
454
455 let Ok(component) = namespace.component(&endpoint_id.component) else {
456 tracing::debug!("Component not available yet, retrying...");
457 tokio::time::sleep(Duration::from_millis(100)).await;
458 continue;
459 };
460
461 let Ok(instances) = component.list_instances().await else {
463 tracing::debug!("Cannot list instances yet, retrying...");
464 tokio::time::sleep(Duration::from_millis(100)).await;
465 continue;
466 };
467
468 if instances.is_empty() {
469 tracing::debug!("No instances available yet, retrying...");
470 tokio::time::sleep(Duration::from_millis(100)).await;
471 continue;
472 }
473
474 tracing::debug!("Component service is now available, starting mocker engine");
475
476 if let Err(e) = inner_clone.start(component).await {
478 tracing::error!("Failed to start mocker engine: {e}");
479 }
480 break;
481 }
482 });
483
484 Self { inner }
485 }
486}
487
488#[async_trait]
489impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
490 for AnnotatedMockEngine
491{
492 async fn generate(
493 &self,
494 input: SingleIn<PreprocessedRequest>,
495 ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
496 let stream = self.inner.generate(input).await?;
497 let context = stream.context();
498
499 let annotated_stream = stream.map(Annotated::from_data);
501
502 Ok(ResponseStream::new(Box::pin(annotated_stream), context))
503 }
504}
505
506pub async fn make_mocker_engine(
508 distributed_runtime: DistributedRuntime,
509 endpoint_id: dynamo_runtime::protocols::EndpointId,
510 args: MockEngineArgs,
511) -> Result<crate::backend::ExecutionContext, Error> {
512 tracing::debug!("Creating mocker engine with config: {args:?}");
514 let annotated_engine =
515 AnnotatedMockEngine::new(MockVllmEngine::new(args), distributed_runtime, endpoint_id);
516
517 Ok(Arc::new(annotated_engine))
518}
519
520#[cfg(test)]
521mod integration_tests {
522 use super::*;
523 use crate::kv_router::KV_EVENT_SUBJECT;
524 use crate::kv_router::indexer::RouterEvent;
525 use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
526 use dynamo_runtime::{
527 DistributedRuntime, Worker,
528 pipeline::Context,
529 pipeline::{PushRouter, network::Ingress},
530 traits::events::EventSubscriber,
531 };
532 use futures::StreamExt;
533 use tokio::time::timeout;
534
535 #[tokio::test]
536 #[ignore] async fn test_mock_vllm_engine_full_integration() -> Result<()> {
538 const DP_SIZE: u32 = 2;
539 const TOKENS_PER_REQUEST: usize = 20;
540 const BLOCK_SIZE: usize = 2;
541
542 let worker = Worker::from_settings()?;
544 let runtime = worker.runtime();
545 let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
546 tracing::info!("✓ Runtime and distributed runtime created");
547
548 let test_component = distributed
550 .namespace("test")?
551 .component(MOCKER_COMPONENT)?
552 .service_builder()
553 .create()
554 .await?;
555 tracing::info!("✓ Test component created");
556
557 let args = MockEngineArgs::builder()
559 .speedup_ratio(10.0)
560 .dp_size(DP_SIZE)
561 .block_size(BLOCK_SIZE)
562 .build()
563 .unwrap();
564
565 let engine = MockVllmEngine::new(args);
566 engine.start(test_component.clone()).await?;
567 tokio::time::sleep(Duration::from_millis(500)).await;
568 let engine = Arc::new(engine);
569 tracing::info!("✓ MockVllmEngine created with DP_SIZE: {DP_SIZE}");
570
571 let mut kv_events_subscriber = test_component.subscribe(KV_EVENT_SUBJECT).await?;
573 tracing::info!("✓ KV events subscriber created");
574
575 let ingress = Ingress::for_engine(engine)?;
577 tracing::info!("✓ Ingress wrapper created");
578
579 let server_handle = tokio::spawn({
581 let test_component = test_component.clone();
582 async move {
583 if let Err(e) = test_component
584 .endpoint("generate")
585 .endpoint_builder()
586 .handler(ingress)
587 .start()
588 .await
589 {
590 eprintln!("❌ Generate endpoint failed: {e}");
591 }
592 }
593 });
594 tracing::info!("✓ Server started in background");
595
596 tokio::time::sleep(Duration::from_millis(500)).await;
598 tracing::info!("✓ Server startup delay completed");
599
600 match test_component.list_instances().await {
602 Ok(instances) => {
603 tracing::info!("📋 Found {} registered instances:", instances.len());
604 for instance in instances {
605 tracing::info!(
606 " • {}/{}/{} (ID: {})",
607 instance.namespace,
608 instance.component,
609 instance.endpoint,
610 instance.instance_id
611 );
612 }
613 }
614 Err(e) => {
615 tracing::error!("❌ Failed to list instances: {e}");
616 }
617 }
618
619 let client = distributed
621 .namespace("test")?
622 .component(MOCKER_COMPONENT)?
623 .endpoint("generate")
624 .client()
625 .await?;
626 tracing::info!("✓ Client created");
627
628 let router = PushRouter::from_client(client, Default::default()).await?;
629 tracing::info!("✓ Router created");
630
631 let create_request = |tokens: Vec<TokenIdType>, dp_rank: u32| {
633 PreprocessedRequest::builder()
634 .model("mock".to_string())
635 .token_ids(tokens)
636 .stop_conditions(StopConditions {
637 max_tokens: Some(TOKENS_PER_REQUEST as u32),
638 ..Default::default()
639 })
640 .sampling_options(SamplingOptions::default())
641 .output_options(OutputOptions::default())
642 .eos_token_ids(vec![])
643 .annotations(vec![format!("dp_rank:{dp_rank}")])
644 .build()
645 .unwrap()
646 };
647
648 let requests = vec![
649 create_request(vec![1, 2, 3, 4, 5], 0),
650 create_request(vec![1, 2, 3, 4, 5], 0),
651 create_request(vec![1, 2, 3, 4, 5], 1),
652 create_request(vec![1, 2, 3, 4, 5], 1),
653 ];
654 tracing::info!(
655 "✓ Test requests created ({} requests total)",
656 requests.len()
657 );
658
659 for (i, request) in requests.into_iter().enumerate() {
661 tracing::info!("Testing request {}", i + 1);
662
663 let response_stream = router.generate(Context::new(request)).await?;
664 let responses: Vec<LLMEngineOutput> = response_stream.collect().await;
665
666 assert!(
668 !responses.is_empty(),
669 "Request {} should produce at least one response",
670 i + 1
671 );
672
673 let mut total_tokens = 0;
675 let mut has_finish_reason = false;
676
677 for response in &responses {
678 total_tokens += response.token_ids.len();
679 if response.finish_reason.is_some() {
680 has_finish_reason = true;
681 }
682 }
683
684 assert!(
686 has_finish_reason,
687 "Request {} should have a finish reason",
688 i + 1
689 );
690
691 assert!(
693 total_tokens <= TOKENS_PER_REQUEST + 1, "Request {} generated {} tokens, expected at most {}",
695 i + 1,
696 total_tokens,
697 TOKENS_PER_REQUEST + 1
698 );
699
700 tracing::info!(
701 "✓ Request {} completed successfully with {} tokens",
702 i + 1,
703 total_tokens
704 );
705 }
706
707 tracing::info!("🎉 All requests completed successfully!");
708
709 tracing::info!("Waiting for KV event with 100ms timeout...");
711 let msg = timeout(Duration::from_millis(100), kv_events_subscriber.next())
712 .await
713 .map_err(|_| Error::msg("Timeout waiting for KV event"))?
714 .ok_or_else(|| Error::msg("KV events stream ended unexpectedly"))?;
715
716 match serde_json::from_slice::<RouterEvent>(&msg.payload) {
717 Ok(event) => {
718 tracing::info!("✓ Received KV event: {event:?}");
719 }
720 Err(e) => {
721 return Err(Error::msg(format!("Failed to deserialize KV event: {e}")));
722 }
723 }
724
725 let cancel_token = test_component.drt().runtime().child_token();
727 let metrics_aggregator = crate::kv_router::metrics_aggregator::KvMetricsAggregator::new(
728 test_component.clone(),
729 cancel_token,
730 )
731 .await;
732 tokio::time::sleep(Duration::from_millis(500)).await;
733
734 let processed_endpoints = metrics_aggregator.get_endpoints();
735 tracing::info!(
736 "Found {} metrics endpoints",
737 processed_endpoints.endpoints.len()
738 );
739
740 assert!(
742 !processed_endpoints.endpoints.is_empty(),
743 "Should find at least one metrics endpoint"
744 );
745 tracing::info!(
746 "✓ Successfully found {} metrics endpoints",
747 processed_endpoints.endpoints.len()
748 );
749
750 for (worker_id, endpoint) in &processed_endpoints.endpoints {
752 tracing::info!("✓ Worker {} metrics: {:?}", worker_id, endpoint.data);
753 }
754
755 tracing::info!("🎉 Event verification completed!");
756
757 distributed.shutdown();
759 server_handle.await?;
760
761 Ok(())
762 }
763}