dynamo-mocker 1.1.0

Mock LLM scheduler and KV manager for testing
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::time::Duration;

use crate::common::protocols::OutputSignal;
use crate::common::utils::compute_prefill_handoff_delay_ms;
use crate::kv_manager::SglangKvManager;

use super::config::{SglangConfig, floor_to_block};
use super::request::SglangRequest;

#[derive(Default)]
pub(super) struct DecodeResult {
    pub(super) requests: Vec<SglangRequest>,
    pub(super) output_signals: Vec<OutputSignal>,
    pub(super) retracted_any: bool,
    pub(super) end_ms: f64,
}

fn decode_capacity_state(
    running: &[SglangRequest],
    kv_manager: &SglangKvManager,
    config: &SglangConfig,
) -> (usize, usize) {
    let actual_available =
        kv_manager.cache().available_tokens() + kv_manager.cache().evictable_size;
    let reserved_tokens = running
        .iter()
        .map(SglangRequest::extra_reserved_tokens)
        .sum::<usize>();
    let logical_available = actual_available.saturating_sub(reserved_tokens);
    let page_growth_needed = running
        .iter()
        .map(|req| {
            if req.current_sequence_len() + 1 > req.allocated_tokens {
                config.block_size
            } else {
                0
            }
        })
        .sum();

    (
        actual_available,
        logical_available.saturating_sub(page_growth_needed),
    )
}

pub(super) fn cache_materialized_prefix(
    req: &mut SglangRequest,
    kv_manager: &mut SglangKvManager,
    config: &SglangConfig,
) {
    let aligned_tokens = req.page_aligned_materialized_tokens(config.block_size);
    if aligned_tokens == 0 || aligned_tokens <= req.cached_tokens {
        return;
    }

    let Some(last_node) = req.last_node else {
        return;
    };

    let sequence = req.sequence_prefix(aligned_tokens);
    let new_last =
        kv_manager.cache_unfinished_req(&sequence, &req.kv_indices[..aligned_tokens], last_node);
    req.last_node = Some(new_last);
    req.cached_tokens = aligned_tokens;
    req.debug_assert_invariants(config.block_size);
}

pub(super) fn check_decode_mem(
    running: &mut Vec<SglangRequest>,
    kv_manager: &mut SglangKvManager,
    config: &SglangConfig,
) -> Vec<SglangRequest> {
    let mut retracted = Vec::new();

    loop {
        let (actual_available, logical_available_after_growth) =
            decode_capacity_state(running, kv_manager, config);
        if actual_available >= running.len() && logical_available_after_growth > 0 {
            break;
        }
        if running.len() <= 1 {
            break;
        }

        let Some((idx, _)) = running
            .iter()
            .enumerate()
            .min_by_key(|(_, req)| req.output_len())
        else {
            break;
        };

        let mut req = running.swap_remove(idx);
        kv_manager.free_indices(&req.kv_indices[req.cached_tokens..]);
        if let Some(last_node) = req.last_node.take() {
            kv_manager.free_request(last_node);
        }
        req.reset_for_retract();
        req.debug_assert_invariants(config.block_size);
        retracted.push(req);
    }

    let available = kv_manager.cache().token_pool.available();
    let needed = running.len();
    if available < needed {
        kv_manager.evict(needed - available);
    }

    if !retracted.is_empty() {
        tracing::warn!(
            num_retracted = retracted.len(),
            remaining = running.len(),
            "SGLang decode retract requests because KV pool is full"
        );
    }

    retracted
}

pub(super) fn simulate_decode_step(
    running: &mut Vec<SglangRequest>,
    kv_manager: &mut SglangKvManager,
    config: &SglangConfig,
    current_time_ms: f64,
    apply_speedup: bool,
) -> DecodeResult {
    if running.is_empty() {
        return DecodeResult {
            end_ms: current_time_ms,
            ..DecodeResult::default()
        };
    }

    let total_context: usize = running
        .iter()
        .map(SglangRequest::current_sequence_len)
        .sum();
    let avg_context = total_context / running.len();
    let active_kv_tokens = total_context.min(config.total_kv_tokens);
    let decode_time = config.perf_model.predict_decode_time(
        running.len(),
        active_kv_tokens,
        avg_context,
        config.total_kv_tokens,
    );
    let unscaled_time = Duration::from_secs_f64(decode_time / 1000.0);
    let effective_ratio = config.speedup_ratio * config.decode_speedup_ratio;
    let total_time = if apply_speedup && effective_ratio > 0.0 && unscaled_time > Duration::ZERO {
        Duration::from_secs_f64(unscaled_time.as_secs_f64() / effective_ratio)
    } else {
        unscaled_time
    };

    let retracted = check_decode_mem(running, kv_manager, config);
    let retracted_any = !retracted.is_empty();
    let mut output_signals = Vec::with_capacity(running.len());
    let mut completed_indices = Vec::new();

    for (idx, req) in running.iter_mut().enumerate() {
        if kv_manager.cache().token_pool.available() == 0 {
            kv_manager.evict(1);
        }

        let crossing_page_boundary = req.current_sequence_len() + 1 > req.allocated_tokens;
        let last_idx = req.kv_indices.last().copied();
        let Some(new_idx) = kv_manager.allocate_decode_token(last_idx) else {
            tracing::warn!(uuid = %req.uuid, "Failed to allocate decode token, skipping output");
            continue;
        };

        req.kv_indices.push(new_idx);
        if crossing_page_boundary {
            req.allocated_tokens += config.block_size;
        }
        req.append_output_token(req.next_output_token());
        req.debug_assert_invariants(config.block_size);

        let is_complete = req.output_len() >= req.max_output_tokens;
        output_signals.push(OutputSignal {
            uuid: req.uuid,
            completed: is_complete,
            handoff_delay_ms: compute_prefill_handoff_delay_ms(
                config.worker_type,
                is_complete,
                req.prompt_len(),
                config.kv_transfer_bandwidth,
                config.kv_bytes_per_token,
            ),
        });

        if is_complete {
            let sequence = req.sequence_tokens();
            let tokens_to_cache = floor_to_block(sequence.len(), config.block_size);
            if req.kv_indices.len() > tokens_to_cache {
                kv_manager.free_indices(&req.kv_indices[tokens_to_cache..]);
            }

            if let Some(last_node) = req.last_node.take() {
                if tokens_to_cache > 0 {
                    kv_manager.cache_finished_req(
                        &sequence[..tokens_to_cache],
                        &req.kv_indices[..tokens_to_cache],
                        last_node,
                    );
                } else {
                    kv_manager.free_request(last_node);
                }
            }

            completed_indices.push(idx);
            continue;
        }

        cache_materialized_prefix(req, kv_manager, config);
        req.debug_assert_invariants(config.block_size);
    }

    for &idx in completed_indices.iter().rev() {
        running.swap_remove(idx);
    }

    DecodeResult {
        requests: retracted,
        output_signals,
        retracted_any,
        end_ms: current_time_ms + total_time.as_secs_f64() * 1000.0,
    }
}