use anyhow::{Context, Result};
use ndarray::{Array1, Array2};
use ndarray_interp::InterpolateError;
use ndarray_interp::interp1d::{Interp1DBuilder, Linear};
use ndarray_interp::interp2d::{Bilinear, Interp2DBuilder};
use std::path::Path;
use std::sync::Arc;
pub trait PrefillInterpolator: Send + Sync {
fn interp(&self, x: f64) -> Result<f64, InterpolateError>;
}
pub trait DecodeInterpolator: Send + Sync {
fn interp(&self, x: f64, y: f64) -> Result<f64, InterpolateError>;
}
pub trait AicCallback: Send + Sync {
fn predict_prefill(&self, batch_size: usize, effective_isl: usize, prefix: usize) -> f64;
fn predict_decode(&self, batch_size: usize, isl: usize, osl: usize) -> f64;
}
struct PrefillInterp1D {
inner: ndarray_interp::interp1d::Interp1D<
ndarray::OwnedRepr<f64>,
ndarray::OwnedRepr<f64>,
ndarray::Ix1,
Linear,
>,
}
impl PrefillInterpolator for PrefillInterp1D {
fn interp(&self, x: f64) -> Result<f64, InterpolateError> {
self.inner.interp_scalar(x)
}
}
struct DecodeInterp2D {
inner: ndarray_interp::interp2d::Interp2D<
ndarray::OwnedRepr<f64>,
ndarray::OwnedRepr<f64>,
ndarray::OwnedRepr<f64>,
ndarray::Ix2,
Bilinear,
>,
}
impl DecodeInterpolator for DecodeInterp2D {
fn interp(&self, x: f64, y: f64) -> Result<f64, InterpolateError> {
self.inner.interp_scalar(x, y)
}
}
#[derive(Default)]
pub enum PerfModel {
#[default]
Polynomial,
Interpolated {
prefill_interp: Arc<dyn PrefillInterpolator>,
decode_interp: Arc<dyn DecodeInterpolator>,
},
Aiconfigurator { callback: Arc<dyn AicCallback> },
}
impl Clone for PerfModel {
fn clone(&self) -> Self {
match self {
PerfModel::Polynomial => PerfModel::Polynomial,
PerfModel::Interpolated {
prefill_interp,
decode_interp,
} => PerfModel::Interpolated {
prefill_interp: Arc::clone(prefill_interp),
decode_interp: Arc::clone(decode_interp),
},
PerfModel::Aiconfigurator { callback } => PerfModel::Aiconfigurator {
callback: Arc::clone(callback),
},
}
}
}
impl std::fmt::Debug for PerfModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PerfModel::Polynomial => write!(f, "PerfModel::Polynomial"),
PerfModel::Interpolated { .. } => write!(f, "PerfModel::Interpolated {{ .. }}"),
PerfModel::Aiconfigurator { .. } => write!(f, "PerfModel::Aiconfigurator"),
}
}
}
impl PerfModel {
pub fn from_npz(path: &Path) -> Result<Self> {
use ndarray_npy::NpzReader;
use std::fs::File;
tracing::info!("Loading performance model from NPZ file: {:?}", path);
let file =
File::open(path).with_context(|| format!("Failed to open NPZ file: {:?}", path))?;
let mut npz = NpzReader::new(file)
.with_context(|| format!("Failed to create NPZ reader for: {:?}", path))?;
let prefill_isl: Array1<f64> = npz
.by_name("prefill_isl")
.with_context(|| "Failed to load prefill_isl from NPZ")?;
let prefill_ttft_ms: Array1<f64> = npz
.by_name("prefill_ttft_ms")
.with_context(|| "Failed to load prefill_ttft_ms from NPZ")?;
let decode_active_kv_tokens: Array1<f64> = npz
.by_name("decode_active_kv_tokens")
.with_context(|| "Failed to load decode_active_kv_tokens from NPZ")?;
let decode_context_length: Array1<f64> = npz
.by_name("decode_context_length")
.with_context(|| "Failed to load decode_context_length from NPZ")?;
let decode_itl: Array2<f64> = npz
.by_name("decode_itl")
.with_context(|| "Failed to load decode_itl from NPZ")?;
if prefill_isl.len() != prefill_ttft_ms.len() {
anyhow::bail!(
"Prefill array length mismatch: isl={}, ttft={}",
prefill_isl.len(),
prefill_ttft_ms.len()
);
}
if decode_itl.nrows() != decode_active_kv_tokens.len()
|| decode_itl.ncols() != decode_context_length.len()
{
anyhow::bail!(
"Decode array dimension mismatch: itl shape=({}, {}), active_kv={}, context={}",
decode_itl.nrows(),
decode_itl.ncols(),
decode_active_kv_tokens.len(),
decode_context_length.len()
);
}
tracing::info!(
"Loaded performance model: prefill_points={}, decode_grid={}x{}",
prefill_isl.len(),
decode_itl.nrows(),
decode_itl.ncols()
);
let prefill_interp = Interp1DBuilder::new(prefill_ttft_ms)
.x(prefill_isl)
.strategy(Linear::new().extrapolate(true))
.build()
.with_context(|| "Failed to build prefill interpolator")?;
let decode_interp = Interp2DBuilder::new(decode_itl)
.x(decode_active_kv_tokens)
.y(decode_context_length)
.strategy(Bilinear::new().extrapolate(true))
.build()
.with_context(|| "Failed to build decode interpolator")?;
Ok(PerfModel::Interpolated {
prefill_interp: Arc::new(PrefillInterp1D {
inner: prefill_interp,
}),
decode_interp: Arc::new(DecodeInterp2D {
inner: decode_interp,
}),
})
}
pub fn from_aic_callback(callback: Arc<dyn AicCallback>) -> Self {
PerfModel::Aiconfigurator { callback }
}
pub fn predict_prefill_time(&self, batch_size: usize, isl: usize, prefix: usize) -> f64 {
let new_tokens_per_req = isl.saturating_sub(prefix);
let time = match self {
PerfModel::Polynomial => {
let tokens = (batch_size * new_tokens_per_req) as f64;
4.209989e-07 * tokens.powi(2) + 1.518344e-02 * tokens + 1.650142e+01
}
PerfModel::Interpolated { prefill_interp, .. } => {
let tokens = (batch_size * new_tokens_per_req) as f64;
prefill_interp.interp(tokens).unwrap_or(0.0)
}
PerfModel::Aiconfigurator { callback } => {
callback.predict_prefill(batch_size, new_tokens_per_req, prefix)
}
};
time.max(0.0)
}
pub fn predict_decode_time(
&self,
batch_size: usize,
active_kv_tokens: usize,
context_length: usize,
total_kv_tokens: usize,
) -> f64 {
if batch_size == 0 {
return 0.0;
}
let time = match self {
PerfModel::Polynomial => {
let active_perc = if total_kv_tokens > 0 {
active_kv_tokens as f64 / total_kv_tokens as f64
} else {
tracing::warn!("Total KV tokens is 0, using 1.0 as capacity");
1.0
};
-25.74 * active_perc.powi(2) + 54.01 * active_perc + 5.74
}
PerfModel::Interpolated { decode_interp, .. } => decode_interp
.interp(active_kv_tokens as f64, context_length as f64)
.unwrap_or(0.0),
PerfModel::Aiconfigurator { callback } => {
callback.predict_decode(batch_size, context_length, 2)
}
};
let result = time.max(1.0);
tracing::trace!(
"Decode time prediction: batch_size={batch_size}, active_kv_tokens={active_kv_tokens}, context_length={context_length}, time={result:.2}ms"
);
result
}
}