use std::path::{Path, PathBuf};
use crate::checkpoint::{Checkpoint, load_checkpoint, restore_parameter};
use crate::domain::DomainId;
use crate::model::Layer;
use crate::model::parameter::Parameter;
use crate::model_arch::load_arch;
use crate::moe_model::MoEModel;
use crate::moe_model::topology::{IN_DIM, N_EXPERTS, OUT_DIM};
use crate::object::{Dim, Shape, Tensor};
pub struct ModelSession {
arch_path: PathBuf,
checkpoint_path: PathBuf,
model: MoEModel,
}
impl ModelSession {
pub fn load(arch_path: &Path, checkpoint_path: &Path) -> Result<Self, String> {
let arch =
load_arch(arch_path).map_err(|e| format!("load_arch({}): {e}", arch_path.display()))?;
if arch.input_dim != IN_DIM {
return Err(format!(
"arch input_dim {} does not match topology IN_DIM {}",
arch.input_dim, IN_DIM
));
}
if arch.output_dim != OUT_DIM {
return Err(format!(
"arch output_dim {} does not match topology OUT_DIM {}",
arch.output_dim, OUT_DIM
));
}
let mut model = arch.build().map_err(|e| format!("arch.build(): {e}"))?;
let ckpt = load_checkpoint(checkpoint_path)
.map_err(|e| format!("load_checkpoint({}): {e}", checkpoint_path.display()))?;
restore_into_model(&mut model, &ckpt)?;
Ok(Self {
arch_path: arch_path.to_path_buf(),
checkpoint_path: checkpoint_path.to_path_buf(),
model,
})
}
pub fn arch_path(&self) -> &Path {
&self.arch_path
}
pub fn checkpoint_path(&self) -> &Path {
&self.checkpoint_path
}
pub fn forward(&mut self, batch: Vec<Vec<f32>>) -> Result<InferOutput, String> {
if batch.is_empty() {
return Err("batch is empty".to_string());
}
for (i, row) in batch.iter().enumerate() {
if row.len() != IN_DIM {
return Err(format!(
"batch row {i} has length {}, expected {IN_DIM}",
row.len()
));
}
}
let b = batch.len();
let b_padded = b.next_multiple_of(16).max(16);
let pad_row = batch.last().expect("non-empty").clone();
let mut data = Vec::with_capacity(b_padded * IN_DIM);
for row in &batch {
data.extend_from_slice(row);
}
while data.len() < b_padded * IN_DIM {
data.extend_from_slice(&pad_row);
}
let input = Tensor::dense_cpu(
DomainId::new("infer_input"),
Shape::from(vec![b_padded, IN_DIM]),
data,
);
let out = self
.model
.forward(&input)
.map_err(|e| format!("model.forward: {e}"))?;
let logits = slice_rows(&out.logits, b)?;
let router_weights = slice_rows(&out.router_weights, b)?;
Ok(InferOutput {
logits,
router_weights,
})
}
}
pub struct InferOutput {
pub logits: Tensor<f32>,
pub router_weights: Tensor<f32>,
}
fn restore_into_model(model: &mut MoEModel, ckpt: &Checkpoint) -> Result<(), String> {
let mut live: Vec<(String, &mut Parameter)> = Vec::with_capacity(ckpt.params.len());
for (i, p) in model.router.parameters_mut().into_iter().enumerate() {
live.push((format!("router.param_{i}"), p));
}
for (ei, expert) in model.experts.iter_mut().enumerate() {
for (i, p) in expert.parameters_mut().into_iter().enumerate() {
live.push((format!("expert_{ei}.param_{i}"), p));
}
}
if live.len() != ckpt.params.len() {
return Err(format!(
"parameter count mismatch: model has {} live params, checkpoint has {} snapshots; \
the checkpoint was probably saved for a different arch",
live.len(),
ckpt.params.len()
));
}
for ((expected_name, param), snap) in live.into_iter().zip(ckpt.params.iter()) {
if snap.name != expected_name {
return Err(format!(
"snapshot name mismatch: expected `{expected_name}` (matching the model's \
declaration order), got `{}`; the checkpoint was probably saved for a \
different arch",
snap.name
));
}
restore_parameter(param, snap);
}
Ok(())
}
fn slice_rows(t: &Tensor<f32>, rows: usize) -> Result<Tensor<f32>, String> {
let dims = &t.meta.shape.dims;
let (cur_rows, cols) = match dims.as_slice() {
[Dim::Static(r), Dim::Static(c)] => (*r, *c),
_ => {
return Err(format!(
"slice_rows: expected 2-D static tensor, got {:?}",
dims
));
}
};
if rows > cur_rows {
return Err(format!(
"slice_rows: asked for {rows} rows but tensor only has {cur_rows}"
));
}
let mut data = Vec::with_capacity(rows * cols);
for r in 0..rows {
data.extend_from_slice(&t.data[r * cols..(r + 1) * cols]);
}
Ok(Tensor::dense_cpu(
t.meta.domain.clone(),
Shape::from(vec![rows, cols]),
data,
))
}
pub use crate::moe_model::topology::N_EXPERTS as INFER_N_EXPERTS;
pub use crate::moe_model::topology::{IN_DIM as INFER_IN_DIM, OUT_DIM as INFER_OUT_DIM};