Skip to main content

dynamo_mocker/common/
perf_model.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Performance model for timing simulations in the mocker.
5//!
6//! This module provides two timing models:
7//! 1. Polynomial: Hardcoded polynomial formulas (default, backward compatible)
8//! 2. Interpolated: Grid-based interpolation from profiler data (loaded from NPZ files)
9
10use anyhow::{Context, Result};
11use ndarray::{Array1, Array2};
12use ndarray_interp::InterpolateError;
13use ndarray_interp::interp1d::{Interp1DBuilder, Linear};
14use ndarray_interp::interp2d::{Bilinear, Interp2DBuilder};
15use std::path::Path;
16use std::sync::Arc;
17
18/// Trait to abstract over 1D interpolation for prefill timing
19pub trait PrefillInterpolator: Send + Sync {
20    fn interp(&self, x: f64) -> Result<f64, InterpolateError>;
21}
22
23/// Trait to abstract over 2D interpolation for decode timing
24pub trait DecodeInterpolator: Send + Sync {
25    fn interp(&self, x: f64, y: f64) -> Result<f64, InterpolateError>;
26}
27
28/// Wrapper to implement PrefillInterpolator for the concrete Interp1D type
29struct PrefillInterp1D {
30    inner: ndarray_interp::interp1d::Interp1D<
31        ndarray::OwnedRepr<f64>,
32        ndarray::OwnedRepr<f64>,
33        ndarray::Ix1,
34        Linear,
35    >,
36}
37
38impl PrefillInterpolator for PrefillInterp1D {
39    fn interp(&self, x: f64) -> Result<f64, InterpolateError> {
40        self.inner.interp_scalar(x)
41    }
42}
43
44/// Wrapper to implement DecodeInterpolator for the concrete Interp2D type
45struct DecodeInterp2D {
46    inner: ndarray_interp::interp2d::Interp2D<
47        ndarray::OwnedRepr<f64>,
48        ndarray::OwnedRepr<f64>,
49        ndarray::OwnedRepr<f64>,
50        ndarray::Ix2,
51        Bilinear,
52    >,
53}
54
55impl DecodeInterpolator for DecodeInterp2D {
56    fn interp(&self, x: f64, y: f64) -> Result<f64, InterpolateError> {
57        self.inner.interp_scalar(x, y)
58    }
59}
60
61/// Performance model for predicting prefill and decode timing
62#[derive(Default)]
63pub enum PerfModel {
64    /// Default polynomial-based model using hardcoded formulas
65    #[default]
66    Polynomial,
67    /// Interpolation-based model using profiler data
68    /// Interpolators are built once and stored as trait objects
69    Interpolated {
70        prefill_interp: Arc<dyn PrefillInterpolator>,
71        decode_interp: Arc<dyn DecodeInterpolator>,
72    },
73}
74
75impl Clone for PerfModel {
76    fn clone(&self) -> Self {
77        match self {
78            PerfModel::Polynomial => PerfModel::Polynomial,
79            PerfModel::Interpolated {
80                prefill_interp,
81                decode_interp,
82            } => PerfModel::Interpolated {
83                prefill_interp: Arc::clone(prefill_interp),
84                decode_interp: Arc::clone(decode_interp),
85            },
86        }
87    }
88}
89
90impl std::fmt::Debug for PerfModel {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        match self {
93            PerfModel::Polynomial => write!(f, "PerfModel::Polynomial"),
94            PerfModel::Interpolated { .. } => write!(f, "PerfModel::Interpolated {{ .. }}"),
95        }
96    }
97}
98
99impl PerfModel {
100    /// Load performance model from NPZ file
101    ///
102    /// Expected arrays in NPZ file:
103    /// - prefill_isl: 1D array of input sequence lengths
104    /// - prefill_ttft_ms: 1D array of time to first token in milliseconds
105    /// - decode_active_kv_tokens: 1D array of active KV token counts
106    /// - decode_context_length: 1D array of context lengths
107    /// - decode_itl: 2D array of inter-token latencies in milliseconds
108    pub fn from_npz(path: &Path) -> Result<Self> {
109        use ndarray_npy::NpzReader;
110        use std::fs::File;
111
112        tracing::info!("Loading performance model from NPZ file: {:?}", path);
113
114        let file =
115            File::open(path).with_context(|| format!("Failed to open NPZ file: {:?}", path))?;
116
117        let mut npz = NpzReader::new(file)
118            .with_context(|| format!("Failed to create NPZ reader for: {:?}", path))?;
119
120        // Load prefill arrays
121        let prefill_isl: Array1<f64> = npz
122            .by_name("prefill_isl")
123            .with_context(|| "Failed to load prefill_isl from NPZ")?;
124        let prefill_ttft_ms: Array1<f64> = npz
125            .by_name("prefill_ttft_ms")
126            .with_context(|| "Failed to load prefill_ttft_ms from NPZ")?;
127
128        // Load decode arrays
129        let decode_active_kv_tokens: Array1<f64> = npz
130            .by_name("decode_active_kv_tokens")
131            .with_context(|| "Failed to load decode_active_kv_tokens from NPZ")?;
132        let decode_context_length: Array1<f64> = npz
133            .by_name("decode_context_length")
134            .with_context(|| "Failed to load decode_context_length from NPZ")?;
135        let decode_itl: Array2<f64> = npz
136            .by_name("decode_itl")
137            .with_context(|| "Failed to load decode_itl from NPZ")?;
138
139        // Validate dimensions
140        if prefill_isl.len() != prefill_ttft_ms.len() {
141            anyhow::bail!(
142                "Prefill array length mismatch: isl={}, ttft={}",
143                prefill_isl.len(),
144                prefill_ttft_ms.len()
145            );
146        }
147
148        if decode_itl.nrows() != decode_active_kv_tokens.len()
149            || decode_itl.ncols() != decode_context_length.len()
150        {
151            anyhow::bail!(
152                "Decode array dimension mismatch: itl shape=({}, {}), active_kv={}, context={}",
153                decode_itl.nrows(),
154                decode_itl.ncols(),
155                decode_active_kv_tokens.len(),
156                decode_context_length.len()
157            );
158        }
159
160        tracing::info!(
161            "Loaded performance model: prefill_points={}, decode_grid={}x{}",
162            prefill_isl.len(),
163            decode_itl.nrows(),
164            decode_itl.ncols()
165        );
166
167        // Build interpolators once during loading
168        let prefill_interp = Interp1DBuilder::new(prefill_ttft_ms)
169            .x(prefill_isl)
170            .strategy(Linear::new().extrapolate(true))
171            .build()
172            .with_context(|| "Failed to build prefill interpolator")?;
173
174        let decode_interp = Interp2DBuilder::new(decode_itl)
175            .x(decode_active_kv_tokens)
176            .y(decode_context_length)
177            .strategy(Bilinear::new().extrapolate(true))
178            .build()
179            .with_context(|| "Failed to build decode interpolator")?;
180
181        Ok(PerfModel::Interpolated {
182            prefill_interp: Arc::new(PrefillInterp1D {
183                inner: prefill_interp,
184            }),
185            decode_interp: Arc::new(DecodeInterp2D {
186                inner: decode_interp,
187            }),
188        })
189    }
190
191    /// Predict prefill time in milliseconds given the number of new tokens
192    pub fn predict_prefill_time(&self, new_tokens: usize) -> f64 {
193        let time = match self {
194            PerfModel::Polynomial => {
195                // Original polynomial formula
196                let tokens = new_tokens as f64;
197                4.209989e-07 * tokens.powi(2) + 1.518344e-02 * tokens + 1.650142e+01
198            }
199            PerfModel::Interpolated { prefill_interp, .. } => {
200                // Use pre-built interpolator
201                let query = new_tokens as f64;
202                prefill_interp.interp(query).unwrap_or(0.0)
203            }
204        };
205        // Ensure non-negative timing
206        let result = time.max(0.0);
207        tracing::trace!("Prefill time prediction: new_tokens={new_tokens}, time={result:.2}ms");
208        result
209    }
210
211    /// Predict decode time in milliseconds given active KV tokens and context length
212    ///
213    /// For the Polynomial variant, this computes active percentage as active_kv_tokens / 16384.
214    /// For the Interpolated variant, this performs 2D bilinear interpolation.
215    pub fn predict_decode_time(&self, active_kv_tokens: usize, context_length: usize) -> f64 {
216        let time = match self {
217            PerfModel::Polynomial => {
218                // Compute active percentage using default capacity
219                let active_perc = active_kv_tokens as f64 / 16384.0;
220                // Original polynomial formula
221                -25.74 * active_perc.powi(2) + 54.01 * active_perc + 5.74
222            }
223            PerfModel::Interpolated { decode_interp, .. } => {
224                // Use pre-built interpolator
225                let query_x = active_kv_tokens as f64;
226                let query_y = context_length as f64;
227                decode_interp.interp(query_x, query_y).unwrap_or(0.0)
228            }
229        };
230        // Ensure non-negative timing
231        let result = time.max(0.0);
232        tracing::trace!(
233            "Decode time prediction: active_kv_tokens={active_kv_tokens}, context_length={context_length}, time={result:.2}ms"
234        );
235        result
236    }
237}