dynamo-mocker 1.1.0

Mock LLM scheduler and KV manager for testing
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::collections::VecDeque;
use std::time::Duration;

use dynamo_kv_router::protocols::WorkerId;
use uuid::Uuid;

use crate::common::protocols::{DirectRequest, KvEventPublishers, MockEngineArgs, WorkerType};
use crate::kv_manager::SglangKvManager;
use crate::replay::TraceCollector;

use super::config::SglangConfig;
use super::decode::{cache_materialized_prefix, simulate_decode_step};
use super::policy::apply_schedule_policy;
use super::prefill::get_new_batch_prefill;
use super::request::{SglangRequest, direct_to_sglang};
use crate::scheduler::{
    CapturedRouterEventBuffer, EnginePassResult, RouterEventVisibility, build_fpm_snapshot,
    capture_router_event_sink,
};

pub(crate) struct SglangCore {
    pub(super) config: SglangConfig,
    pub(super) waiting: VecDeque<SglangRequest>,
    pub(super) running: Vec<SglangRequest>,
    pub(super) new_token_ratio: f64,
    pub(super) kv_manager: SglangKvManager,
    kv_event_buffer: Option<CapturedRouterEventBuffer>,
}

impl SglangCore {
    pub(crate) fn new(args: MockEngineArgs) -> Self {
        Self::new_internal(args, 0, None, KvEventPublishers::default())
    }

    pub(crate) fn new_with_kv_capture(args: MockEngineArgs, worker_id: WorkerId) -> Self {
        let (buffer, sink) = capture_router_event_sink(worker_id);
        Self::new_internal(
            args,
            worker_id as u32,
            Some(buffer),
            KvEventPublishers::new(Some(sink), None),
        )
    }

    pub(super) fn new_with_sink(
        args: MockEngineArgs,
        dp_rank: u32,
        kv_event_publishers: KvEventPublishers,
    ) -> Self {
        Self::new_internal(args, dp_rank, None, kv_event_publishers)
    }

    fn new_internal(
        args: MockEngineArgs,
        dp_rank: u32,
        kv_event_buffer: Option<CapturedRouterEventBuffer>,
        kv_event_publishers: KvEventPublishers,
    ) -> Self {
        let args = args.normalized().expect("invalid MockEngineArgs");
        let config = SglangConfig::from_args(&args);
        let total_tokens = args.num_gpu_blocks * args.block_size;

        Self {
            config,
            waiting: VecDeque::new(),
            running: Vec::new(),
            new_token_ratio: SglangConfig::from_args(&args).init_new_token_ratio,
            kv_manager: SglangKvManager::new(
                total_tokens,
                args.block_size,
                kv_event_publishers,
                dp_rank,
            ),
            kv_event_buffer,
        }
    }

    pub(crate) fn receive(&mut self, request: DirectRequest) -> Uuid {
        let request = direct_to_sglang(request);
        request.debug_assert_invariants(self.config.block_size);
        let uuid = request.uuid;
        self.waiting.push_back(request);
        uuid
    }

    pub(crate) fn is_empty(&self) -> bool {
        self.waiting.is_empty() && self.running.is_empty()
    }

    pub(crate) fn num_requests(&self) -> usize {
        self.waiting.len() + self.running.len()
    }

    pub(crate) fn execute_pass(
        &mut self,
        collector: &mut TraceCollector,
        now_ms: f64,
    ) -> EnginePassResult {
        self.execute_pass_internal(Some(collector), now_ms)
    }

    pub(crate) fn execute_hidden_pass(&mut self, now_ms: f64) -> EnginePassResult {
        self.execute_pass_internal(None, now_ms)
    }

    pub(super) fn execute_pass_internal(
        &mut self,
        mut collector: Option<&mut TraceCollector>,
        now_ms: f64,
    ) -> EnginePassResult {
        apply_schedule_policy(&mut self.waiting, &self.kv_manager, &self.config);

        let admit = get_new_batch_prefill(
            &mut self.waiting,
            &mut self.kv_manager,
            &self.config,
            self.new_token_ratio,
            &self.running,
        );

        if admit.oom {
            self.new_token_ratio = self.config.init_new_token_ratio;
        }

        for admission in &admit.admissions {
            if let Some(collector) = collector.as_deref_mut() {
                collector.on_admit(admission.uuid, now_ms, admission.reused_input_tokens);
            }
        }

        // Capture per-request prefill FPM data before dispersing can_run.
        let prefill_fpm = admit.prefill_fpm;

        let batch_size = admit.can_run.len();
        let mean_isl = if batch_size > 0 {
            admit.total_isl / batch_size
        } else {
            0
        };
        let mean_prefix = if batch_size > 0 {
            admit.total_prefix / batch_size
        } else {
            0
        };
        let prefill_time =
            simulate_prefill_duration(batch_size, mean_isl, mean_prefix, &self.config, true);

        for mut req in admit.can_run {
            if req.materialized_tokens < req.current_sequence_len() {
                cache_materialized_prefix(&mut req, &mut self.kv_manager, &self.config);
                self.waiting.push_front(req);
            } else {
                self.running.push(req);
            }
        }

        // Capture scheduled decode data before the decode step modifies running.
        let scheduled_decode_lens: Vec<u64> = self
            .running
            .iter()
            .map(|req| req.current_sequence_len() as u64)
            .collect();

        let decode_start_ms = now_ms + prefill_time.as_secs_f64() * 1000.0;
        let mut decode = simulate_decode_step(
            &mut self.running,
            &mut self.kv_manager,
            &self.config,
            decode_start_ms,
            true,
        );

        if let Some(collector) = collector {
            for signal in &decode.output_signals {
                collector.on_token(signal.uuid, decode.end_ms);
            }
        }

        for req in decode.requests.drain(..).rev() {
            self.waiting.push_front(req);
        }

        if decode.retracted_any {
            self.new_token_ratio = self.config.init_new_token_ratio;
        }
        self.new_token_ratio = (self.new_token_ratio - self.config.new_token_ratio_decay_step)
            .max(self.config.min_new_token_ratio);

        // Build FPM snapshot now that all state has settled.
        let fpm = build_fpm_snapshot(
            prefill_fpm.iter().map(|p| {
                (
                    p.prompt_len as u64,
                    p.prefix_tokens as u64,
                    p.tokens_computed as u64,
                )
            }),
            scheduled_decode_lens.into_iter(),
            self.waiting
                .iter()
                .filter(|req| req.output_len() == 0)
                .map(|req| req.prompt_len() as u64),
            self.waiting
                .iter()
                .filter(|req| req.output_len() > 0)
                .map(|req| req.current_sequence_len() as u64),
            (decode.end_ms - now_ms) / 1000.0,
        );

        debug_assert_sglang_scheduler_state(&self.waiting, &self.running, self.config.block_size);
        EnginePassResult {
            end_ms: decode.end_ms,
            completed_requests: decode
                .output_signals
                .iter()
                .filter(|signal| signal.completed)
                .count(),
            output_signals: decode.output_signals,
            admissions: admit.admissions,
            active_decode_blocks: self.active_kv_blocks(),
            router_event_visibility: RouterEventVisibility::PassEnd,
            kv_events: self
                .kv_event_buffer
                .as_ref()
                .map(CapturedRouterEventBuffer::drain)
                .unwrap_or_default(),
            fpm: Some(fpm),
        }
    }

    fn active_kv_blocks(&self) -> u64 {
        let active_reserved = self
            .waiting
            .iter()
            .map(SglangRequest::extra_reserved_tokens)
            .sum::<usize>()
            + self
                .running
                .iter()
                .map(SglangRequest::extra_reserved_tokens)
                .sum::<usize>();
        let actual_used =
            self.kv_manager.cache().total_tokens() - self.kv_manager.cache().available_tokens();
        (actual_used + active_reserved).div_ceil(self.config.block_size) as u64
    }
}

fn simulate_prefill_duration(
    batch_size: usize,
    mean_isl: usize,
    mean_prefix: usize,
    config: &SglangConfig,
    apply_speedup: bool,
) -> Duration {
    if batch_size == 0 || config.worker_type == WorkerType::Decode {
        return Duration::ZERO;
    }

    let prefill_time = config
        .perf_model
        .predict_prefill_time(batch_size, mean_isl, mean_prefix);
    let total_time = Duration::from_secs_f64(prefill_time / 1000.0);

    if !apply_speedup || config.speedup_ratio <= 0.0 || total_time <= Duration::ZERO {
        return total_time;
    }

    Duration::from_secs_f64(total_time.as_secs_f64() / config.speedup_ratio)
}

fn debug_assert_sglang_scheduler_state(
    _waiting: &VecDeque<SglangRequest>,
    _running: &[SglangRequest],
    _block_size: usize,
) {
    #[cfg(debug_assertions)]
    {
        let waiting = _waiting;
        let running = _running;
        let block_size = _block_size;
        let mut seen = std::collections::HashSet::new();
        for req in waiting {
            debug_assert!(
                seen.insert(req.uuid),
                "request {} appears multiple times across waiting/running queues",
                req.uuid
            );
            req.debug_assert_invariants(block_size);
        }
        for req in running {
            debug_assert!(
                seen.insert(req.uuid),
                "request {} appears multiple times across waiting/running queues",
                req.uuid
            );
            req.debug_assert_invariants(block_size);
        }
    }
}