dynamo_mocker/common/
perf_model.rs1use anyhow::{Context, Result};
11use ndarray::{Array1, Array2};
12use ndarray_interp::InterpolateError;
13use ndarray_interp::interp1d::{Interp1DBuilder, Linear};
14use ndarray_interp::interp2d::{Bilinear, Interp2DBuilder};
15use std::path::Path;
16use std::sync::Arc;
17
18pub trait PrefillInterpolator: Send + Sync {
20 fn interp(&self, x: f64) -> Result<f64, InterpolateError>;
21}
22
23pub trait DecodeInterpolator: Send + Sync {
25 fn interp(&self, x: f64, y: f64) -> Result<f64, InterpolateError>;
26}
27
28struct PrefillInterp1D {
30 inner: ndarray_interp::interp1d::Interp1D<
31 ndarray::OwnedRepr<f64>,
32 ndarray::OwnedRepr<f64>,
33 ndarray::Ix1,
34 Linear,
35 >,
36}
37
38impl PrefillInterpolator for PrefillInterp1D {
39 fn interp(&self, x: f64) -> Result<f64, InterpolateError> {
40 self.inner.interp_scalar(x)
41 }
42}
43
44struct DecodeInterp2D {
46 inner: ndarray_interp::interp2d::Interp2D<
47 ndarray::OwnedRepr<f64>,
48 ndarray::OwnedRepr<f64>,
49 ndarray::OwnedRepr<f64>,
50 ndarray::Ix2,
51 Bilinear,
52 >,
53}
54
55impl DecodeInterpolator for DecodeInterp2D {
56 fn interp(&self, x: f64, y: f64) -> Result<f64, InterpolateError> {
57 self.inner.interp_scalar(x, y)
58 }
59}
60
61#[derive(Default)]
63pub enum PerfModel {
64 #[default]
66 Polynomial,
67 Interpolated {
70 prefill_interp: Arc<dyn PrefillInterpolator>,
71 decode_interp: Arc<dyn DecodeInterpolator>,
72 },
73}
74
75impl Clone for PerfModel {
76 fn clone(&self) -> Self {
77 match self {
78 PerfModel::Polynomial => PerfModel::Polynomial,
79 PerfModel::Interpolated {
80 prefill_interp,
81 decode_interp,
82 } => PerfModel::Interpolated {
83 prefill_interp: Arc::clone(prefill_interp),
84 decode_interp: Arc::clone(decode_interp),
85 },
86 }
87 }
88}
89
90impl std::fmt::Debug for PerfModel {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 match self {
93 PerfModel::Polynomial => write!(f, "PerfModel::Polynomial"),
94 PerfModel::Interpolated { .. } => write!(f, "PerfModel::Interpolated {{ .. }}"),
95 }
96 }
97}
98
99impl PerfModel {
100 pub fn from_npz(path: &Path) -> Result<Self> {
109 use ndarray_npy::NpzReader;
110 use std::fs::File;
111
112 tracing::info!("Loading performance model from NPZ file: {:?}", path);
113
114 let file =
115 File::open(path).with_context(|| format!("Failed to open NPZ file: {:?}", path))?;
116
117 let mut npz = NpzReader::new(file)
118 .with_context(|| format!("Failed to create NPZ reader for: {:?}", path))?;
119
120 let prefill_isl: Array1<f64> = npz
122 .by_name("prefill_isl")
123 .with_context(|| "Failed to load prefill_isl from NPZ")?;
124 let prefill_ttft_ms: Array1<f64> = npz
125 .by_name("prefill_ttft_ms")
126 .with_context(|| "Failed to load prefill_ttft_ms from NPZ")?;
127
128 let decode_active_kv_tokens: Array1<f64> = npz
130 .by_name("decode_active_kv_tokens")
131 .with_context(|| "Failed to load decode_active_kv_tokens from NPZ")?;
132 let decode_context_length: Array1<f64> = npz
133 .by_name("decode_context_length")
134 .with_context(|| "Failed to load decode_context_length from NPZ")?;
135 let decode_itl: Array2<f64> = npz
136 .by_name("decode_itl")
137 .with_context(|| "Failed to load decode_itl from NPZ")?;
138
139 if prefill_isl.len() != prefill_ttft_ms.len() {
141 anyhow::bail!(
142 "Prefill array length mismatch: isl={}, ttft={}",
143 prefill_isl.len(),
144 prefill_ttft_ms.len()
145 );
146 }
147
148 if decode_itl.nrows() != decode_active_kv_tokens.len()
149 || decode_itl.ncols() != decode_context_length.len()
150 {
151 anyhow::bail!(
152 "Decode array dimension mismatch: itl shape=({}, {}), active_kv={}, context={}",
153 decode_itl.nrows(),
154 decode_itl.ncols(),
155 decode_active_kv_tokens.len(),
156 decode_context_length.len()
157 );
158 }
159
160 tracing::info!(
161 "Loaded performance model: prefill_points={}, decode_grid={}x{}",
162 prefill_isl.len(),
163 decode_itl.nrows(),
164 decode_itl.ncols()
165 );
166
167 let prefill_interp = Interp1DBuilder::new(prefill_ttft_ms)
169 .x(prefill_isl)
170 .strategy(Linear::new().extrapolate(true))
171 .build()
172 .with_context(|| "Failed to build prefill interpolator")?;
173
174 let decode_interp = Interp2DBuilder::new(decode_itl)
175 .x(decode_active_kv_tokens)
176 .y(decode_context_length)
177 .strategy(Bilinear::new().extrapolate(true))
178 .build()
179 .with_context(|| "Failed to build decode interpolator")?;
180
181 Ok(PerfModel::Interpolated {
182 prefill_interp: Arc::new(PrefillInterp1D {
183 inner: prefill_interp,
184 }),
185 decode_interp: Arc::new(DecodeInterp2D {
186 inner: decode_interp,
187 }),
188 })
189 }
190
191 pub fn predict_prefill_time(&self, new_tokens: usize) -> f64 {
193 let time = match self {
194 PerfModel::Polynomial => {
195 let tokens = new_tokens as f64;
197 4.209989e-07 * tokens.powi(2) + 1.518344e-02 * tokens + 1.650142e+01
198 }
199 PerfModel::Interpolated { prefill_interp, .. } => {
200 let query = new_tokens as f64;
202 prefill_interp.interp(query).unwrap_or(0.0)
203 }
204 };
205 let result = time.max(0.0);
207 tracing::trace!("Prefill time prediction: new_tokens={new_tokens}, time={result:.2}ms");
208 result
209 }
210
211 pub fn predict_decode_time(&self, active_kv_tokens: usize, context_length: usize) -> f64 {
216 let time = match self {
217 PerfModel::Polynomial => {
218 let active_perc = active_kv_tokens as f64 / 16384.0;
220 -25.74 * active_perc.powi(2) + 54.01 * active_perc + 5.74
222 }
223 PerfModel::Interpolated { decode_interp, .. } => {
224 let query_x = active_kv_tokens as f64;
226 let query_y = context_length as f64;
227 decode_interp.interp(query_x, query_y).unwrap_or(0.0)
228 }
229 };
230 let result = time.max(0.0);
232 tracing::trace!(
233 "Decode time prediction: active_kv_tokens={active_kv_tokens}, context_length={context_length}, time={result:.2}ms"
234 );
235 result
236 }
237}