dynamo-mocker 1.1.0

Mock LLM scheduler and KV manager for testing
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::sync::Arc;
use std::time::Instant;

use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;

use crate::common::protocols::{
    DirectRequest, FpmPublisher, KvEventPublishers, MockEngineArgs, OutputSignal,
};
use crate::common::utils::sleep_until_precise;
use crate::scheduler::{
    AdmissionEvent, DeferredFpmBuffer, MockerMetrics, RouterEventVisibility, SchedulerHandle,
    capture_deferred_kv_publish_sink, publish_deferred_fpm, publish_deferred_kv_events,
};

use super::core::SglangCore;
use super::request::{SglangRequest, direct_to_sglang};

#[derive(Clone)]
pub struct SglangScheduler {
    request_tx: mpsc::UnboundedSender<DirectRequest>,
    metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
    _cancel_guard: Arc<CancelGuard>,
}

struct CancelGuard(CancellationToken);

impl Drop for CancelGuard {
    fn drop(&mut self) {
        self.0.cancel();
    }
}

impl SglangScheduler {
    pub fn new(
        args: MockEngineArgs,
        dp_rank: u32,
        output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
        kv_event_publishers: KvEventPublishers,
        cancellation_token: Option<CancellationToken>,
        fpm_publisher: FpmPublisher,
    ) -> Self {
        Self::new_internal(
            args,
            dp_rank,
            output_tx,
            kv_event_publishers,
            cancellation_token,
            None,
            fpm_publisher,
        )
    }

    pub(crate) fn new_with_admission(
        args: MockEngineArgs,
        dp_rank: u32,
        output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
        kv_event_publishers: KvEventPublishers,
        cancellation_token: Option<CancellationToken>,
        admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
        fpm_publisher: FpmPublisher,
    ) -> Self {
        Self::new_internal(
            args,
            dp_rank,
            output_tx,
            kv_event_publishers,
            cancellation_token,
            admission_tx,
            fpm_publisher,
        )
    }

    fn new_internal(
        args: MockEngineArgs,
        dp_rank: u32,
        output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
        kv_event_publishers: KvEventPublishers,
        cancellation_token: Option<CancellationToken>,
        admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
        fpm_publisher: FpmPublisher,
    ) -> Self {
        let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
        let total_blocks = args.num_gpu_blocks as u64;
        let initial_metrics = MockerMetrics::new(dp_rank, 0, total_blocks);
        let (metrics_tx, metrics_rx) =
            tokio::sync::watch::channel::<MockerMetrics>(initial_metrics);

        let cancel_token = cancellation_token.unwrap_or_default();
        let cancel_token_clone = cancel_token.clone();
        let cancel_guard = Arc::new(CancelGuard(cancel_token));

        tokio::spawn(async move {
            let (deferred_kv_events, buffering_publishers) =
                capture_deferred_kv_publish_sink(kv_event_publishers.raw_enabled());
            let deferred_fpm = DeferredFpmBuffer::default();
            let mut core = SglangCore::new_with_sink(args, dp_rank, buffering_publishers);

            loop {
                if receive_requests(
                    &mut core.waiting,
                    &mut request_rx,
                    &cancel_token_clone,
                    &core.running,
                )
                .await
                .is_none()
                {
                    break;
                }

                let iteration_start = Instant::now();
                let pass = core.execute_pass_internal(None, 0.0);
                if let Some(admission_tx) = admission_tx.as_ref() {
                    for admission in &pass.admissions {
                        let _ = admission_tx.send(admission.clone());
                    }
                }
                if let Some(fpm) = pass.fpm {
                    deferred_fpm.push(fpm);
                }
                if pass.router_event_visibility == RouterEventVisibility::PassStart {
                    publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
                    publish_deferred_fpm(&fpm_publisher, deferred_fpm.drain());
                }
                let total_time = std::time::Duration::from_secs_f64(pass.end_ms / 1000.0);
                if total_time > std::time::Duration::ZERO {
                    sleep_until_precise(iteration_start + total_time).await;
                }
                if pass.router_event_visibility == RouterEventVisibility::PassEnd {
                    publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
                    publish_deferred_fpm(&fpm_publisher, deferred_fpm.drain());
                }
                let active_decode_blocks = pass.active_decode_blocks;
                flush_output_signals(&output_tx, pass.output_signals);
                publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
                publish_deferred_fpm(&fpm_publisher, deferred_fpm.drain());
                let _ = metrics_tx.send(MockerMetrics::new(
                    dp_rank,
                    active_decode_blocks,
                    total_blocks,
                ));
            }
        });

        Self {
            request_tx,
            metrics_rx,
            _cancel_guard: cancel_guard,
        }
    }
}

impl SchedulerHandle for SglangScheduler {
    fn receive(&self, request: DirectRequest) {
        let _ = self.request_tx.send(request);
    }

    fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
        self.request_tx.clone()
    }

    fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
        self.metrics_rx.clone()
    }
}

async fn receive_requests(
    waiting: &mut std::collections::VecDeque<SglangRequest>,
    request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
    cancel_token: &CancellationToken,
    running: &[SglangRequest],
) -> Option<()> {
    if cancel_token.is_cancelled() {
        return None;
    }

    if waiting.is_empty() && running.is_empty() {
        tokio::select! {
            biased;
            _ = cancel_token.cancelled() => return None,
            result = request_rx.recv() => {
                let request = result?;
                waiting.push_back(direct_to_sglang(request));
            }
        }
    }

    while let Ok(request) = request_rx.try_recv() {
        waiting.push_back(direct_to_sglang(request));
    }

    Some(())
}

fn flush_output_signals(
    output_tx: &Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
    output_signals: Vec<OutputSignal>,
) {
    let Some(tx) = output_tx.as_ref() else {
        return;
    };

    if output_signals.is_empty() {
        return;
    }

    let _ = tx.send(output_signals);
}