dynamo-mocker 1.1.0

Mock LLM scheduler and KV manager for testing
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use anyhow::{Result, bail};

use super::{OfflineDisaggReplayConfig, ReplayArgsMode, ReplayRouterMode};
use crate::common::protocols::{MockEngineArgs, WorkerType};

pub fn validate_replay_args_mode(
    aggregated_args: Option<&MockEngineArgs>,
    prefill_args: Option<&MockEngineArgs>,
    decode_args: Option<&MockEngineArgs>,
    num_workers: usize,
    num_prefill_workers: usize,
    num_decode_workers: usize,
) -> Result<ReplayArgsMode> {
    if aggregated_args.is_some() && (prefill_args.is_some() || decode_args.is_some()) {
        bail!("extra_engine_args cannot be combined with prefill_engine_args/decode_engine_args");
    }

    match (aggregated_args, prefill_args, decode_args) {
        (Some(_), None, None) | (None, None, None) => {
            if num_prefill_workers != 1 || num_decode_workers != 1 {
                bail!(
                    "num_prefill_workers and num_decode_workers are only used for disagg replay; use num_workers for aggregated replay"
                );
            }
            Ok(ReplayArgsMode::Aggregated)
        }
        (None, Some(_), Some(_)) => {
            if num_workers != 1 {
                bail!(
                    "num_workers is only used for aggregated replay; use num_prefill_workers and num_decode_workers for disagg replay"
                );
            }
            Ok(ReplayArgsMode::Disagg)
        }
        (None, Some(_), None) | (None, None, Some(_)) => {
            bail!("prefill_engine_args and decode_engine_args must be provided together")
        }
        (Some(_), Some(_), _) | (Some(_), _, Some(_)) => unreachable!(),
    }
}

fn validate_replay_args(args: &MockEngineArgs, num_workers: usize, mode: &str) -> Result<()> {
    if num_workers == 0 {
        bail!("{mode} requires num_workers >= 1");
    }
    if args.worker_type != WorkerType::Aggregated {
        bail!(
            "{mode} only supports aggregated workers, got {:?}",
            args.worker_type,
        );
    }
    if args.dp_size != 1 {
        bail!(
            "{mode} only supports data_parallel_size=1, got {}",
            args.dp_size,
        );
    }

    Ok(())
}

fn validate_offline_router_mode(router_mode: ReplayRouterMode, num_workers: usize) -> Result<()> {
    if router_mode != ReplayRouterMode::KvRouter {
        return Ok(());
    }
    if num_workers > 1 {
        return Ok(());
    }

    bail!("offline replay only supports router_mode=kv_router when num_workers > 1");
}

pub(super) fn validate_offline_replay_args(
    args: &MockEngineArgs,
    num_workers: usize,
    router_mode: ReplayRouterMode,
) -> Result<()> {
    validate_offline_router_mode(router_mode, num_workers)?;
    validate_replay_args(args, num_workers, "trace replay")
}

pub(super) fn validate_offline_concurrency_args(
    args: &MockEngineArgs,
    num_workers: usize,
    max_in_flight: usize,
    router_mode: ReplayRouterMode,
) -> Result<()> {
    if max_in_flight == 0 {
        bail!("concurrency replay requires max_in_flight >= 1");
    }

    validate_offline_router_mode(router_mode, num_workers)?;
    validate_replay_args(args, num_workers, "concurrency replay")
}

pub(super) fn validate_online_replay_args(args: &MockEngineArgs, num_workers: usize) -> Result<()> {
    validate_replay_args(args, num_workers, "online replay")
}

pub(super) fn validate_online_concurrency_args(
    args: &MockEngineArgs,
    num_workers: usize,
    max_in_flight: usize,
) -> Result<()> {
    if max_in_flight == 0 {
        bail!("online concurrency replay requires max_in_flight >= 1");
    }

    validate_replay_args(args, num_workers, "online replay")
}

fn validate_disagg_args(config: &OfflineDisaggReplayConfig, mode: &str) -> Result<()> {
    if config.num_prefill_workers == 0 {
        bail!("{mode} requires num_prefill_workers >= 1");
    }
    if config.num_decode_workers == 0 {
        bail!("{mode} requires num_decode_workers >= 1");
    }
    if config.prefill_args.worker_type != WorkerType::Prefill {
        bail!(
            "{mode} requires prefill_engine_args.worker_type=prefill, got {:?}",
            config.prefill_args.worker_type,
        );
    }
    if config.decode_args.worker_type != WorkerType::Decode {
        bail!(
            "{mode} requires decode_engine_args.worker_type=decode, got {:?}",
            config.decode_args.worker_type,
        );
    }
    if config.prefill_args.dp_size != 1 {
        bail!(
            "{mode} only supports prefill data_parallel_size=1, got {}",
            config.prefill_args.dp_size,
        );
    }
    if config.decode_args.dp_size != 1 {
        bail!(
            "{mode} only supports decode data_parallel_size=1, got {}",
            config.decode_args.dp_size,
        );
    }
    if config.prefill_args.block_size != config.decode_args.block_size {
        bail!(
            "{mode} requires matching prefill/decode block_size, got {} and {}",
            config.prefill_args.block_size,
            config.decode_args.block_size,
        );
    }

    Ok(())
}

pub(super) fn validate_offline_disagg_replay_args(
    config: &OfflineDisaggReplayConfig,
    _router_mode: ReplayRouterMode,
) -> Result<()> {
    validate_disagg_args(config, "trace replay")
}

pub(super) fn validate_offline_disagg_concurrency_args(
    config: &OfflineDisaggReplayConfig,
    max_in_flight: usize,
    _router_mode: ReplayRouterMode,
) -> Result<()> {
    if max_in_flight == 0 {
        bail!("concurrency replay requires max_in_flight >= 1");
    }
    validate_disagg_args(config, "concurrency replay")
}