use anyhow::{Context, Result};
use ndarray::{Array1, Array2};
use ndarray_interp::InterpolateError;
use ndarray_interp::interp1d::{Interp1DBuilder, Linear};
use ndarray_interp::interp2d::{Bilinear, Interp2DBuilder};
use std::path::Path;
use std::sync::Arc;
pub trait PrefillInterpolator: Send + Sync {
fn interp(&self, x: f64) -> Result<f64, InterpolateError>;
}
pub trait DecodeInterpolator: Send + Sync {
fn interp(&self, x: f64, y: f64) -> Result<f64, InterpolateError>;
}
struct PrefillInterp1D {
inner: ndarray_interp::interp1d::Interp1D<
ndarray::OwnedRepr<f64>,
ndarray::OwnedRepr<f64>,
ndarray::Ix1,
Linear,
>,
}
impl PrefillInterpolator for PrefillInterp1D {
fn interp(&self, x: f64) -> Result<f64, InterpolateError> {
self.inner.interp_scalar(x)
}
}
struct DecodeInterp2D {
inner: ndarray_interp::interp2d::Interp2D<
ndarray::OwnedRepr<f64>,
ndarray::OwnedRepr<f64>,
ndarray::OwnedRepr<f64>,
ndarray::Ix2,
Bilinear,
>,
}
impl DecodeInterpolator for DecodeInterp2D {
fn interp(&self, x: f64, y: f64) -> Result<f64, InterpolateError> {
self.inner.interp_scalar(x, y)
}
}
#[derive(Default)]
pub enum PerfModel {
#[default]
Polynomial,
Interpolated {
prefill_interp: Arc<dyn PrefillInterpolator>,
decode_interp: Arc<dyn DecodeInterpolator>,
},
}
impl Clone for PerfModel {
fn clone(&self) -> Self {
match self {
PerfModel::Polynomial => PerfModel::Polynomial,
PerfModel::Interpolated {
prefill_interp,
decode_interp,
} => PerfModel::Interpolated {
prefill_interp: Arc::clone(prefill_interp),
decode_interp: Arc::clone(decode_interp),
},
}
}
}
impl std::fmt::Debug for PerfModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PerfModel::Polynomial => write!(f, "PerfModel::Polynomial"),
PerfModel::Interpolated { .. } => write!(f, "PerfModel::Interpolated {{ .. }}"),
}
}
}
impl PerfModel {
pub fn from_npz(path: &Path) -> Result<Self> {
use ndarray_npy::NpzReader;
use std::fs::File;
tracing::info!("Loading performance model from NPZ file: {:?}", path);
let file =
File::open(path).with_context(|| format!("Failed to open NPZ file: {:?}", path))?;
let mut npz = NpzReader::new(file)
.with_context(|| format!("Failed to create NPZ reader for: {:?}", path))?;
let prefill_isl: Array1<f64> = npz
.by_name("prefill_isl")
.with_context(|| "Failed to load prefill_isl from NPZ")?;
let prefill_ttft_ms: Array1<f64> = npz
.by_name("prefill_ttft_ms")
.with_context(|| "Failed to load prefill_ttft_ms from NPZ")?;
let decode_active_kv_tokens: Array1<f64> = npz
.by_name("decode_active_kv_tokens")
.with_context(|| "Failed to load decode_active_kv_tokens from NPZ")?;
let decode_context_length: Array1<f64> = npz
.by_name("decode_context_length")
.with_context(|| "Failed to load decode_context_length from NPZ")?;
let decode_itl: Array2<f64> = npz
.by_name("decode_itl")
.with_context(|| "Failed to load decode_itl from NPZ")?;
if prefill_isl.len() != prefill_ttft_ms.len() {
anyhow::bail!(
"Prefill array length mismatch: isl={}, ttft={}",
prefill_isl.len(),
prefill_ttft_ms.len()
);
}
if decode_itl.nrows() != decode_active_kv_tokens.len()
|| decode_itl.ncols() != decode_context_length.len()
{
anyhow::bail!(
"Decode array dimension mismatch: itl shape=({}, {}), active_kv={}, context={}",
decode_itl.nrows(),
decode_itl.ncols(),
decode_active_kv_tokens.len(),
decode_context_length.len()
);
}
tracing::info!(
"Loaded performance model: prefill_points={}, decode_grid={}x{}",
prefill_isl.len(),
decode_itl.nrows(),
decode_itl.ncols()
);
let prefill_interp = Interp1DBuilder::new(prefill_ttft_ms)
.x(prefill_isl)
.strategy(Linear::new().extrapolate(true))
.build()
.with_context(|| "Failed to build prefill interpolator")?;
let decode_interp = Interp2DBuilder::new(decode_itl)
.x(decode_active_kv_tokens)
.y(decode_context_length)
.strategy(Bilinear::new().extrapolate(true))
.build()
.with_context(|| "Failed to build decode interpolator")?;
Ok(PerfModel::Interpolated {
prefill_interp: Arc::new(PrefillInterp1D {
inner: prefill_interp,
}),
decode_interp: Arc::new(DecodeInterp2D {
inner: decode_interp,
}),
})
}
pub fn predict_prefill_time(&self, new_tokens: usize) -> f64 {
let time = match self {
PerfModel::Polynomial => {
let tokens = new_tokens as f64;
4.209989e-07 * tokens.powi(2) + 1.518344e-02 * tokens + 1.650142e+01
}
PerfModel::Interpolated { prefill_interp, .. } => {
let query = new_tokens as f64;
prefill_interp.interp(query).unwrap_or(0.0)
}
};
let result = time.max(0.0);
tracing::trace!("Prefill time prediction: new_tokens={new_tokens}, time={result:.2}ms");
result
}
pub fn predict_decode_time(&self, active_kv_tokens: usize, context_length: usize) -> f64 {
let time = match self {
PerfModel::Polynomial => {
let active_perc = active_kv_tokens as f64 / 16384.0;
-25.74 * active_perc.powi(2) + 54.01 * active_perc + 5.74
}
PerfModel::Interpolated { decode_interp, .. } => {
let query_x = active_kv_tokens as f64;
let query_y = context_length as f64;
decode_interp.interp(query_x, query_y).unwrap_or(0.0)
}
};
let result = time.max(0.0);
tracing::trace!(
"Decode time prediction: active_kv_tokens={active_kv_tokens}, context_length={context_length}, time={result:.2}ms"
);
result
}
}