use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use std::time::Instant;
use ort::{
memory::Allocator,
session::{Session, SessionInputValue, SessionOutputs},
tensor::TensorElementType,
value::{DynValue, Tensor, ValueType},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokenizers::Tokenizer;
use crate::error::{Result, TurboQuantError};
use crate::kv_cache::{KVCacheConfig, MultiHeadConfig, MultiHeadKVCache, QuantStrategy};
use crate::utils::{norm, normalize};
const SUPPORTED_MODEL_FILES: &[&str] = &[
"decoder_model_merged.onnx",
"decoder_with_past_model.onnx",
"decoder_model.onnx",
"model.onnx",
];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SupportedRealModel {
DistilGpt2,
SmolLlm2_135MInstruct,
}
impl SupportedRealModel {
pub fn model_id(self) -> &'static str {
match self {
Self::DistilGpt2 => "distilgpt2",
Self::SmolLlm2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
}
}
pub fn preset_name(self) -> &'static str {
match self {
Self::DistilGpt2 => "distilgpt2",
Self::SmolLlm2_135MInstruct => "smollm2-135m-instruct",
}
}
pub fn description(self) -> &'static str {
match self {
Self::DistilGpt2 => "small GPT-2-family baseline verified on CPU",
Self::SmolLlm2_135MInstruct => "small SmolLM2 instruct model verified on CPU",
}
}
pub fn all() -> &'static [SupportedRealModel] {
&[
SupportedRealModel::DistilGpt2,
SupportedRealModel::SmolLlm2_135MInstruct,
]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RealModelGenerationConfig {
pub max_new_tokens: usize,
pub stop_on_eos: bool,
}
impl Default for RealModelGenerationConfig {
fn default() -> Self {
Self {
max_new_tokens: 16,
stop_on_eos: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RealModelQuantizationConfig {
pub key_bits: u8,
pub value_bits: u8,
pub key_strategy: QuantStrategy,
pub seed: u64,
}
impl Default for RealModelQuantizationConfig {
fn default() -> Self {
Self {
key_bits: 4,
value_bits: 4,
key_strategy: QuantStrategy::Prod,
seed: 42,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KvCacheUsage {
pub num_layers: usize,
pub num_key_value_heads: usize,
pub head_dim: usize,
pub stored_tokens: usize,
pub exact_bytes: usize,
pub quantized_bytes: Option<usize>,
pub compression_ratio: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RealModelTrace {
pub model_id: String,
pub prompt: String,
pub prompt_tokens: Vec<u32>,
pub generated_tokens: Vec<u32>,
pub generated_text: String,
pub step_logits: Vec<Vec<f32>>,
pub prefill_seconds: f64,
pub decode_seconds: f64,
pub kv_cache: KvCacheUsage,
}
pub struct RealModelRunner {
model_dir: PathBuf,
spec: DecoderSpec,
tokenizer: Tokenizer,
session: Mutex<Session>,
io: DecoderIoLayout,
}
impl RealModelRunner {
pub fn load<P: AsRef<Path>>(model_dir: P) -> Result<Self> {
let model_dir = model_dir.as_ref().to_path_buf();
if !model_dir.is_dir() {
return Err(TurboQuantError::Io(format!(
"{} is not a directory",
model_dir.display()
)));
}
let spec = DecoderSpec::load(&model_dir)?;
let tokenizer_path = model_dir.join("tokenizer.json");
if !tokenizer_path.is_file() {
return Err(TurboQuantError::ModelConfig(format!(
"missing tokenizer.json in {}",
model_dir.display()
)));
}
let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|error| {
TurboQuantError::Tokenizer(format!("{}: {error}", tokenizer_path.display()))
})?;
let onnx_path = resolve_onnx_path(&model_dir)?;
let session = Session::builder()
.map_err(onnx_error)?
.with_memory_pattern(false)
.map_err(onnx_error)?
.commit_from_file(&onnx_path)
.map_err(onnx_error)?;
let io = DecoderIoLayout::infer(&session, &spec)?;
Ok(Self {
model_dir,
spec,
tokenizer,
session: Mutex::new(session),
io,
})
}
pub fn model_id(&self) -> &str {
&self.spec.model_id
}
pub fn model_dir(&self) -> &Path {
&self.model_dir
}
pub fn generate_exact(
&self,
prompt: &str,
config: &RealModelGenerationConfig,
) -> Result<RealModelTrace> {
self.generate(prompt, config, CacheMode::Exact)
}
pub fn generate_quantized(
&self,
prompt: &str,
generation: &RealModelGenerationConfig,
quantization: &RealModelQuantizationConfig,
) -> Result<RealModelTrace> {
self.generate(
prompt,
generation,
CacheMode::Quantized(quantization.clone()),
)
}
fn generate(
&self,
prompt: &str,
generation: &RealModelGenerationConfig,
cache_mode: CacheMode,
) -> Result<RealModelTrace> {
if generation.max_new_tokens == 0 {
return Err(TurboQuantError::InvalidDimension(0));
}
let prompt_tokens = self.encode_prompt(prompt)?;
let mut cache = CacheState::new(&self.spec, cache_mode)?;
let mut generated_tokens = Vec::with_capacity(generation.max_new_tokens);
let mut step_logits = Vec::with_capacity(generation.max_new_tokens);
let prefill_start = Instant::now();
let prefill = self.run_model(&prompt_tokens, 0, cache.as_past_tensors(&self.io)?)?;
let prefill_seconds = prefill_start.elapsed().as_secs_f64();
cache.ingest(&prefill, prompt_tokens.len())?;
let mut current_context_len = prompt_tokens.len();
let mut next_logits = prefill.last_step_logits()?;
let mut next_token = argmax_u32(&next_logits);
step_logits.push(std::mem::take(&mut next_logits));
generated_tokens.push(next_token);
let mut decode_seconds = 0.0;
for _ in 1..generation.max_new_tokens {
if generation.stop_on_eos && self.spec.is_eos(next_token) {
break;
}
let decode_start = Instant::now();
let outputs = self.run_model(
&[next_token],
current_context_len,
cache.as_past_tensors(&self.io)?,
)?;
decode_seconds += decode_start.elapsed().as_secs_f64();
cache.ingest(&outputs, 1)?;
current_context_len += 1;
let logits = outputs.last_step_logits()?;
next_token = argmax_u32(&logits);
step_logits.push(logits);
generated_tokens.push(next_token);
}
let generated_text = self
.tokenizer
.decode(&generated_tokens, true)
.map_err(|error| TurboQuantError::Tokenizer(error.to_string()))?;
Ok(RealModelTrace {
model_id: self.spec.model_id.clone(),
prompt: prompt.to_string(),
prompt_tokens,
generated_tokens,
generated_text,
step_logits,
prefill_seconds,
decode_seconds,
kv_cache: cache.usage(&self.spec),
})
}
fn encode_prompt(&self, prompt: &str) -> Result<Vec<u32>> {
let encoding = self
.tokenizer
.encode(prompt, true)
.map_err(|error| TurboQuantError::Tokenizer(error.to_string()))?;
if !encoding.get_ids().is_empty() {
return Ok(encoding.get_ids().to_vec());
}
if let Some(bos) = self.spec.bos_token_id {
return Ok(vec![bos]);
}
Err(TurboQuantError::Tokenizer(
"prompt encoded to zero tokens and no bos_token_id fallback is available".into(),
))
}
fn run_model(
&self,
input_ids: &[u32],
past_len: usize,
past_tensors: Vec<(String, Tensor<f32>)>,
) -> Result<ModelRunOutputs> {
let mut inputs = Vec::with_capacity(self.io.input_slots.len() + past_tensors.len());
for slot in &self.io.input_slots {
let value: SessionInputValue<'static> = match &slot.role {
InputRole::InputIds => {
let values: Vec<i64> = input_ids.iter().map(|value| *value as i64).collect();
Tensor::from_array(([1_usize, input_ids.len()], values))
.map_err(onnx_error)?
.into()
}
InputRole::AttentionMask => {
let total = past_len + input_ids.len();
let values = vec![1_i64; total];
Tensor::from_array(([1_usize, total], values))
.map_err(onnx_error)?
.into()
}
InputRole::PositionIds => {
let values: Vec<i64> = (0..input_ids.len())
.map(|offset| (past_len + offset) as i64)
.collect();
Tensor::from_array(([1_usize, input_ids.len()], values))
.map_err(onnx_error)?
.into()
}
InputRole::TokenTypeIds => {
let values = vec![0_i64; input_ids.len()];
Tensor::from_array(([1_usize, input_ids.len()], values))
.map_err(onnx_error)?
.into()
}
InputRole::CachePosition { rank } => {
let values: Vec<i64> = (0..input_ids.len())
.map(|offset| (past_len + offset) as i64)
.collect();
if *rank == 1 {
Tensor::from_array(([input_ids.len()], values))
.map_err(onnx_error)?
.into()
} else {
Tensor::from_array(([1_usize, input_ids.len()], values))
.map_err(onnx_error)?
.into()
}
}
InputRole::UseCacheBranch { datum } => {
let enabled = past_len > 0;
match datum {
ScalarDatum::Bool => Tensor::from_array(((), vec![enabled]))
.map_err(onnx_error)?
.into(),
ScalarDatum::I32 => Tensor::from_array(((), vec![enabled as i32]))
.map_err(onnx_error)?
.into(),
ScalarDatum::I64 => Tensor::from_array(((), vec![enabled as i64]))
.map_err(onnx_error)?
.into(),
}
}
};
inputs.push((slot.name.clone(), value));
}
for (name, tensor) in past_tensors {
inputs.push((name, tensor.into()));
}
let mut session = self
.session
.lock()
.map_err(|_| TurboQuantError::Internal("onnxruntime session lock poisoned".into()))?;
let outputs = session.run(inputs).map_err(onnx_error)?;
ModelRunOutputs::from_outputs(&outputs, &self.io, &self.spec)
}
}
#[derive(Debug, Clone)]
enum CacheMode {
Exact,
Quantized(RealModelQuantizationConfig),
}
#[derive(Debug)]
enum CacheState {
Exact(ExactPastKeyValues),
Quantized(QuantizedPastKeyValues),
}
impl CacheState {
fn new(spec: &DecoderSpec, mode: CacheMode) -> Result<Self> {
match mode {
CacheMode::Exact => Ok(Self::Exact(ExactPastKeyValues::new(spec))),
CacheMode::Quantized(config) => {
Ok(Self::Quantized(QuantizedPastKeyValues::new(spec, &config)?))
}
}
}
fn as_past_tensors(&self, io: &DecoderIoLayout) -> Result<Vec<(String, Tensor<f32>)>> {
match self {
Self::Exact(cache) => cache.as_past_tensors(io),
Self::Quantized(cache) => cache.as_past_tensors(io),
}
}
fn ingest(&mut self, outputs: &ModelRunOutputs, input_len: usize) -> Result<()> {
match self {
Self::Exact(cache) => cache.ingest(outputs, input_len),
Self::Quantized(cache) => cache.ingest(outputs, input_len),
}
}
fn usage(&self, spec: &DecoderSpec) -> KvCacheUsage {
match self {
Self::Exact(cache) => cache.usage(spec),
Self::Quantized(cache) => cache.usage(spec),
}
}
}
#[derive(Debug)]
struct ExactPastKeyValues {
layers: Vec<ExactLayerCache>,
}
impl ExactPastKeyValues {
fn new(spec: &DecoderSpec) -> Self {
Self {
layers: (0..spec.num_hidden_layers)
.map(|_| ExactLayerCache::default())
.collect(),
}
}
fn as_past_tensors(&self, io: &DecoderIoLayout) -> Result<Vec<(String, Tensor<f32>)>> {
io.past_inputs
.iter()
.map(|slot| {
let layer = self.layers.get(slot.layer).ok_or_else(|| {
TurboQuantError::Internal(format!(
"missing exact cache layer {} for past tensor build",
slot.layer
))
})?;
let tensor = layer.build_tensor(slot)?;
Ok((slot.name.clone(), tensor))
})
.collect()
}
fn ingest(&mut self, outputs: &ModelRunOutputs, input_len: usize) -> Result<()> {
for present in &outputs.layers {
let layer = self.layers.get_mut(present.layer).ok_or_else(|| {
TurboQuantError::Internal(format!(
"present layer {} is out of range for exact cache",
present.layer
))
})?;
layer.ingest(present, input_len)?;
}
Ok(())
}
fn usage(&self, spec: &DecoderSpec) -> KvCacheUsage {
let stored_tokens = self.layers.first().map_or(0, |layer| layer.seq_len);
KvCacheUsage {
num_layers: spec.num_hidden_layers,
num_key_value_heads: spec.num_key_value_heads,
head_dim: spec.head_dim,
stored_tokens,
exact_bytes: spec.exact_cache_bytes(stored_tokens),
quantized_bytes: None,
compression_ratio: None,
}
}
}
#[derive(Debug, Default)]
struct ExactLayerCache {
key_token_major: Vec<f32>,
value_token_major: Vec<f32>,
seq_len: usize,
observed_key_layout: Option<CacheTensorLayout>,
observed_value_layout: Option<CacheTensorLayout>,
}
impl ExactLayerCache {
fn ingest(&mut self, present: &LayerPresent, input_len: usize) -> Result<()> {
let update = PresentUpdate::resolve(self.seq_len, input_len, present.key.seq_len)?;
if present.key.seq_len != present.value.seq_len {
return Err(TurboQuantError::ModelFormat(format!(
"layer {} present key/value length mismatch ({} vs {})",
present.layer, present.key.seq_len, present.value.seq_len
)));
}
self.observed_key_layout = Some(present.key.source_layout);
self.observed_value_layout = Some(present.value.source_layout);
if update.output_is_full_prefix {
self.key_token_major = present.key.token_major.clone();
self.value_token_major = present.value.token_major.clone();
self.seq_len = present.key.seq_len;
} else {
self.key_token_major
.extend_from_slice(&present.key.token_major);
self.value_token_major
.extend_from_slice(&present.value.token_major);
self.seq_len += update.new_tokens;
}
Ok(())
}
fn build_tensor(&self, slot: &PastInputSlot) -> Result<Tensor<f32>> {
let (token_major, layout_hint) = match slot.kind {
CacheTensorKind::Key => (&self.key_token_major, self.observed_key_layout),
CacheTensorKind::Value => (&self.value_token_major, self.observed_value_layout),
};
let layout = slot
.layout
.or(layout_hint)
.unwrap_or(CacheTensorLayout::BatchHeadSeq);
token_major_to_tensor(
token_major,
self.seq_len,
slot.num_heads,
slot.head_dim,
layout,
)
}
}
#[derive(Debug)]
struct QuantizedPastKeyValues {
layers: Vec<QuantizedLayerCache>,
}
impl QuantizedPastKeyValues {
fn new(spec: &DecoderSpec, config: &RealModelQuantizationConfig) -> Result<Self> {
let layers = (0..spec.num_hidden_layers)
.map(|layer_index| {
QuantizedLayerCache::new(
spec,
config,
config.seed.wrapping_add(layer_index as u64 * 100_000),
)
})
.collect::<Result<Vec<_>>>()?;
Ok(Self { layers })
}
fn as_past_tensors(&self, io: &DecoderIoLayout) -> Result<Vec<(String, Tensor<f32>)>> {
io.past_inputs
.iter()
.map(|slot| {
let layer = self.layers.get(slot.layer).ok_or_else(|| {
TurboQuantError::Internal(format!(
"missing quantized cache layer {} for past tensor build",
slot.layer
))
})?;
let tensor = layer.build_tensor(slot)?;
Ok((slot.name.clone(), tensor))
})
.collect()
}
fn ingest(&mut self, outputs: &ModelRunOutputs, input_len: usize) -> Result<()> {
for present in &outputs.layers {
let layer = self.layers.get_mut(present.layer).ok_or_else(|| {
TurboQuantError::Internal(format!(
"present layer {} is out of range for quantized cache",
present.layer
))
})?;
layer.ingest(present, input_len)?;
}
Ok(())
}
fn usage(&self, spec: &DecoderSpec) -> KvCacheUsage {
let stored_tokens = self
.layers
.first()
.map_or(0, |layer| layer.cache.num_tokens());
let quantized_bytes = self
.layers
.iter()
.map(QuantizedLayerCache::total_bytes)
.sum::<usize>();
let exact_bytes = spec.exact_cache_bytes(stored_tokens);
KvCacheUsage {
num_layers: spec.num_hidden_layers,
num_key_value_heads: spec.num_key_value_heads,
head_dim: spec.head_dim,
stored_tokens,
exact_bytes,
quantized_bytes: Some(quantized_bytes),
compression_ratio: if quantized_bytes == 0 {
None
} else {
Some(exact_bytes as f64 / quantized_bytes as f64)
},
}
}
}
#[derive(Debug)]
struct QuantizedLayerCache {
cache: MultiHeadKVCache,
key_norms: Vec<Vec<f32>>,
value_norms: Vec<Vec<f32>>,
observed_key_layout: Option<CacheTensorLayout>,
observed_value_layout: Option<CacheTensorLayout>,
}
impl QuantizedLayerCache {
fn new(spec: &DecoderSpec, config: &RealModelQuantizationConfig, seed: u64) -> Result<Self> {
let head_config = KVCacheConfig::new(spec.head_dim)
.with_key_bits(config.key_bits)
.with_value_bits(config.value_bits)
.with_key_strategy(config.key_strategy)
.with_seed(seed);
let cache =
MultiHeadKVCache::new(MultiHeadConfig::new(spec.num_key_value_heads, head_config))?;
Ok(Self {
cache,
key_norms: vec![Vec::new(); spec.num_key_value_heads],
value_norms: vec![Vec::new(); spec.num_key_value_heads],
observed_key_layout: None,
observed_value_layout: None,
})
}
fn ingest(&mut self, present: &LayerPresent, input_len: usize) -> Result<()> {
let existing_tokens = self.cache.num_tokens();
let update = PresentUpdate::resolve(existing_tokens, input_len, present.key.seq_len)?;
if present.key.seq_len != present.value.seq_len {
return Err(TurboQuantError::ModelFormat(format!(
"layer {} present key/value length mismatch ({} vs {})",
present.layer, present.key.seq_len, present.value.seq_len
)));
}
self.observed_key_layout = Some(present.key.source_layout);
self.observed_value_layout = Some(present.value.source_layout);
let start_token = if update.output_is_full_prefix {
existing_tokens
} else {
0
};
let new_keys = extract_quantizable_rows(
&present.key.token_major,
present.key.seq_len,
present.key.num_heads,
present.key.head_dim,
start_token,
update.new_tokens,
"key",
)?;
let new_values = extract_quantizable_rows(
&present.value.token_major,
present.value.seq_len,
present.value.num_heads,
present.value.head_dim,
start_token,
update.new_tokens,
"value",
)?;
for head in 0..self.cache.num_heads() {
for vector in &new_keys[head] {
self.key_norms[head].push(norm(vector) as f32);
}
for vector in &new_values[head] {
self.value_norms[head].push(norm(vector) as f32);
}
}
let normalized_keys = normalize_rows_per_head(new_keys)?;
let normalized_values = normalize_rows_per_head(new_values)?;
self.cache
.append_all(&normalized_keys, &normalized_values)?;
Ok(())
}
fn build_tensor(&self, slot: &PastInputSlot) -> Result<Tensor<f32>> {
let layout = slot
.layout
.or(match slot.kind {
CacheTensorKind::Key => self.observed_key_layout,
CacheTensorKind::Value => self.observed_value_layout,
})
.unwrap_or(CacheTensorLayout::BatchHeadSeq);
let token_major = match slot.kind {
CacheTensorKind::Key => self.reconstruct_token_major(CacheTensorKind::Key)?,
CacheTensorKind::Value => self.reconstruct_token_major(CacheTensorKind::Value)?,
};
token_major_to_tensor(
&token_major,
self.cache.num_tokens(),
slot.num_heads,
slot.head_dim,
layout,
)
}
fn reconstruct_token_major(&self, kind: CacheTensorKind) -> Result<Vec<f32>> {
let per_head = match kind {
CacheTensorKind::Key => self.cache.reconstruct_keys_all()?,
CacheTensorKind::Value => self.cache.reconstruct_values_all()?,
};
let norms = match kind {
CacheTensorKind::Key => &self.key_norms,
CacheTensorKind::Value => &self.value_norms,
};
let token_count = self.cache.num_tokens();
let head_count = self.cache.num_heads();
let head_dim = self.cache.head_dim();
let mut token_major = Vec::with_capacity(token_count * head_count * head_dim);
for token in 0..token_count {
for head in 0..head_count {
let vector = per_head
.get(head)
.and_then(|rows| rows.get(token))
.ok_or_else(|| {
TurboQuantError::Internal(format!(
"missing reconstructed {:?} vector for head {head}, token {token}",
kind
))
})?;
let scale = norms
.get(head)
.and_then(|values| values.get(token))
.copied()
.ok_or_else(|| {
TurboQuantError::Internal(format!(
"missing stored norm for {:?} head {head}, token {token}",
kind
))
})?;
token_major.extend(vector.iter().map(|value| (*value as f32) * scale));
}
}
Ok(token_major)
}
fn total_bytes(&self) -> usize {
let cache_bytes = self.cache.stats().total_bytes;
let norm_bytes = (self.key_norms.iter().map(Vec::len).sum::<usize>()
+ self.value_norms.iter().map(Vec::len).sum::<usize>())
* std::mem::size_of::<f32>();
cache_bytes + norm_bytes
}
}
#[derive(Debug, Clone)]
struct DecoderSpec {
model_id: String,
num_hidden_layers: usize,
num_key_value_heads: usize,
head_dim: usize,
bos_token_id: Option<u32>,
eos_token_ids: Vec<u32>,
}
impl DecoderSpec {
fn load(model_dir: &Path) -> Result<Self> {
let config_path = model_dir.join("config.json");
let bytes = fs::read(&config_path)
.map_err(|error| TurboQuantError::Io(format!("{}: {error}", config_path.display())))?;
let config: HfConfig = serde_json::from_slice(&bytes).map_err(|error| {
TurboQuantError::ModelConfig(format!("{}: {error}", config_path.display()))
})?;
let num_hidden_layers = config.num_hidden_layers.ok_or_else(|| {
TurboQuantError::ModelConfig(format!(
"{} is missing num_hidden_layers",
config_path.display()
))
})?;
let num_attention_heads = config.num_attention_heads.ok_or_else(|| {
TurboQuantError::ModelConfig(format!(
"{} is missing num_attention_heads",
config_path.display()
))
})?;
let num_key_value_heads = config.num_key_value_heads.unwrap_or(num_attention_heads);
if num_hidden_layers == 0 || num_attention_heads == 0 || num_key_value_heads == 0 {
return Err(TurboQuantError::ModelConfig(format!(
"{} declares zero-sized decoder metadata",
config_path.display()
)));
}
let head_dim = if let Some(head_dim) = config.head_dim {
head_dim
} else {
let hidden_size = config.hidden_size.ok_or_else(|| {
TurboQuantError::ModelConfig(format!(
"{} is missing hidden_size/head_dim metadata",
config_path.display()
))
})?;
if hidden_size % num_attention_heads != 0 {
return Err(TurboQuantError::ModelConfig(format!(
"hidden_size {hidden_size} is not divisible by num_attention_heads {num_attention_heads}"
)));
}
hidden_size / num_attention_heads
};
let generation_config_path = model_dir.join("generation_config.json");
let generation_config = if generation_config_path.is_file() {
let bytes = fs::read(&generation_config_path).map_err(|error| {
TurboQuantError::Io(format!("{}: {error}", generation_config_path.display()))
})?;
Some(
serde_json::from_slice::<GenerationConfig>(&bytes).map_err(|error| {
TurboQuantError::ModelConfig(format!(
"{}: {error}",
generation_config_path.display()
))
})?,
)
} else {
None
};
let eos_token_ids = generation_config
.as_ref()
.and_then(|config| config.eos_token_id.as_ref())
.map(parse_token_ids)
.unwrap_or_else(|| parse_token_ids(&config.eos_token_id));
if eos_token_ids.is_empty() {
return Err(TurboQuantError::ModelConfig(format!(
"missing eos_token_id in {} or generation_config.json",
config_path.display()
)));
}
let model_id = config
.name_or_path
.or_else(|| infer_model_id(model_dir))
.unwrap_or_else(|| model_dir.display().to_string());
Ok(Self {
model_id,
num_hidden_layers,
num_key_value_heads,
head_dim,
bos_token_id: generation_config
.as_ref()
.and_then(|cfg| cfg.bos_token_id)
.or(config.bos_token_id),
eos_token_ids,
})
}
fn exact_cache_bytes(&self, stored_tokens: usize) -> usize {
self.num_hidden_layers
* self.num_key_value_heads
* stored_tokens
* self.head_dim
* std::mem::size_of::<f32>()
* 2
}
fn is_eos(&self, token_id: u32) -> bool {
self.eos_token_ids.contains(&token_id)
}
}
#[derive(Debug, Deserialize)]
struct HfConfig {
#[serde(rename = "_name_or_path")]
name_or_path: Option<String>,
#[serde(alias = "n_layer")]
num_hidden_layers: Option<usize>,
#[serde(alias = "n_head")]
num_attention_heads: Option<usize>,
num_key_value_heads: Option<usize>,
#[serde(alias = "n_embd")]
hidden_size: Option<usize>,
head_dim: Option<usize>,
bos_token_id: Option<u32>,
eos_token_id: Value,
}
#[derive(Debug, Deserialize)]
struct GenerationConfig {
bos_token_id: Option<u32>,
eos_token_id: Option<Value>,
}
#[derive(Debug, Clone)]
struct DecoderIoLayout {
input_slots: Vec<ModelInputSlot>,
past_inputs: Vec<PastInputSlot>,
logits_name: String,
present_outputs: Vec<PresentOutputSlot>,
}
impl DecoderIoLayout {
fn infer(session: &Session, spec: &DecoderSpec) -> Result<Self> {
let mut input_slots = Vec::new();
let mut past_inputs = Vec::new();
let mut unknown_inputs = Vec::new();
let mut saw_input_ids = false;
for input in &session.inputs {
let name = input.name.clone();
let role = if name == "input_ids" {
saw_input_ids = true;
Some(InputRole::InputIds)
} else if name == "attention_mask" {
Some(InputRole::AttentionMask)
} else if name == "position_ids" {
Some(InputRole::PositionIds)
} else if name == "token_type_ids" {
Some(InputRole::TokenTypeIds)
} else if name == "cache_position" {
Some(InputRole::CachePosition {
rank: tensor_rank(&input.input_type, "cache_position")?,
})
} else if name == "use_cache_branch" {
Some(InputRole::UseCacheBranch {
datum: scalar_datum(&input.input_type)?,
})
} else if let Some((layer, kind)) = parse_past_input(&name) {
past_inputs.push(PastInputSlot {
name: name.clone(),
layer,
kind,
layout: infer_cache_layout_from_value_type(
&input.input_type,
spec.num_key_value_heads,
spec.head_dim,
),
num_heads: spec.num_key_value_heads,
head_dim: spec.head_dim,
});
None
} else {
unknown_inputs.push(name.clone());
None
};
if let Some(role) = role {
input_slots.push(ModelInputSlot { name, role });
}
}
if !unknown_inputs.is_empty() {
return Err(TurboQuantError::UnsupportedModel(format!(
"unsupported ONNX inputs: {}",
unknown_inputs.join(", ")
)));
}
if !saw_input_ids {
return Err(TurboQuantError::UnsupportedModel(
"ONNX model is missing input_ids".into(),
));
}
let mut logits_name = None;
let mut present_outputs = Vec::new();
for output in &session.outputs {
let name = output.name.clone();
if name == "logits" {
logits_name = Some(name);
} else if let Some((layer, kind)) = parse_present_output(&name) {
present_outputs.push(PresentOutputSlot { name, layer, kind });
}
}
let logits_name = logits_name.ok_or_else(|| {
TurboQuantError::UnsupportedModel("ONNX model is missing logits output".into())
})?;
validate_cache_io(&past_inputs, &present_outputs, spec.num_hidden_layers)?;
present_outputs.sort_by(|left, right| left.name.cmp(&right.name));
Ok(Self {
input_slots,
past_inputs,
logits_name,
present_outputs,
})
}
}
#[derive(Debug, Clone)]
struct ModelInputSlot {
name: String,
role: InputRole,
}
#[derive(Debug, Clone)]
enum InputRole {
InputIds,
AttentionMask,
PositionIds,
TokenTypeIds,
CachePosition { rank: usize },
UseCacheBranch { datum: ScalarDatum },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ScalarDatum {
Bool,
I32,
I64,
}
#[derive(Debug, Clone)]
struct PastInputSlot {
name: String,
layer: usize,
kind: CacheTensorKind,
layout: Option<CacheTensorLayout>,
num_heads: usize,
head_dim: usize,
}
#[derive(Debug, Clone)]
struct PresentOutputSlot {
name: String,
layer: usize,
kind: CacheTensorKind,
}
#[derive(Debug)]
struct ModelRunOutputs {
logits: Vec<f32>,
logits_seq_len: usize,
vocab_size: usize,
layers: Vec<LayerPresent>,
}
impl ModelRunOutputs {
fn from_outputs(
outputs: &SessionOutputs<'_>,
io: &DecoderIoLayout,
spec: &DecoderSpec,
) -> Result<Self> {
let logits_value = outputs.get(&io.logits_name).ok_or_else(|| {
TurboQuantError::Internal(format!("missing logits output {}", io.logits_name))
})?;
let (logits, logits_seq_len, vocab_size) = extract_logits(logits_value)?;
let mut keys = vec![None; spec.num_hidden_layers];
let mut values_by_layer = vec![None; spec.num_hidden_layers];
for slot in &io.present_outputs {
let value = outputs.get(&slot.name).ok_or_else(|| {
TurboQuantError::Internal(format!("missing present output {}", slot.name))
})?;
let normalized =
normalize_cache_tensor(value, spec.num_key_value_heads, spec.head_dim)?;
match slot.kind {
CacheTensorKind::Key => keys[slot.layer] = Some(normalized),
CacheTensorKind::Value => values_by_layer[slot.layer] = Some(normalized),
}
}
let mut layers = Vec::with_capacity(spec.num_hidden_layers);
for layer in 0..spec.num_hidden_layers {
let key = keys[layer].take().ok_or_else(|| {
TurboQuantError::UnsupportedModel(format!(
"missing present key output for decoder layer {layer}"
))
})?;
let value = values_by_layer[layer].take().ok_or_else(|| {
TurboQuantError::UnsupportedModel(format!(
"missing present value output for decoder layer {layer}"
))
})?;
layers.push(LayerPresent { layer, key, value });
}
Ok(Self {
logits,
logits_seq_len,
vocab_size,
layers,
})
}
fn last_step_logits(&self) -> Result<Vec<f32>> {
if self.logits_seq_len == 0 || self.vocab_size == 0 {
return Err(TurboQuantError::ModelFormat(
"logits tensor has an empty sequence or vocabulary axis".into(),
));
}
let start = (self.logits_seq_len - 1) * self.vocab_size;
Ok(self.logits[start..start + self.vocab_size].to_vec())
}
}
#[derive(Debug)]
struct LayerPresent {
layer: usize,
key: NormalizedCacheTensor,
value: NormalizedCacheTensor,
}
#[derive(Debug, Clone)]
struct NormalizedCacheTensor {
token_major: Vec<f32>,
seq_len: usize,
num_heads: usize,
head_dim: usize,
source_layout: CacheTensorLayout,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CacheTensorKind {
Key,
Value,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CacheTensorLayout {
BatchHeadSeq,
BatchSeqHead,
HeadSeq,
SeqHead,
}
#[derive(Debug)]
struct PresentUpdate {
output_is_full_prefix: bool,
new_tokens: usize,
}
impl PresentUpdate {
fn resolve(existing_tokens: usize, input_len: usize, output_seq_len: usize) -> Result<Self> {
if output_seq_len == input_len {
return Ok(Self {
output_is_full_prefix: false,
new_tokens: input_len,
});
}
if output_seq_len == existing_tokens + input_len {
return Ok(Self {
output_is_full_prefix: true,
new_tokens: input_len,
});
}
Err(TurboQuantError::ModelFormat(format!(
"present cache length {output_seq_len} is incompatible with existing_tokens={existing_tokens} and input_len={input_len}"
)))
}
}
fn resolve_onnx_path(model_dir: &Path) -> Result<PathBuf> {
for file_name in SUPPORTED_MODEL_FILES {
let candidate = model_dir.join(file_name);
if candidate.is_file() {
return Ok(candidate);
}
}
Err(TurboQuantError::ModelConfig(format!(
"missing ONNX decoder file in {} (looked for {})",
model_dir.display(),
SUPPORTED_MODEL_FILES.join(", ")
)))
}
fn infer_model_id(model_dir: &Path) -> Option<String> {
model_dir
.file_name()
.map(|name| name.to_string_lossy().to_string())
}
fn parse_token_ids(value: &Value) -> Vec<u32> {
match value {
Value::Number(number) => number
.as_u64()
.and_then(|value| u32::try_from(value).ok())
.into_iter()
.collect(),
Value::Array(values) => values
.iter()
.filter_map(|entry| entry.as_u64())
.filter_map(|value| u32::try_from(value).ok())
.collect(),
_ => Vec::new(),
}
}
fn tensor_rank(value_type: &ValueType, name: &str) -> Result<usize> {
value_type
.tensor_shape()
.map(|shape| shape.len())
.ok_or_else(|| TurboQuantError::UnsupportedModel(format!("{name} input must be a tensor")))
}
fn scalar_datum(value_type: &ValueType) -> Result<ScalarDatum> {
match value_type.tensor_type() {
Some(TensorElementType::Bool) => Ok(ScalarDatum::Bool),
Some(TensorElementType::Int32) => Ok(ScalarDatum::I32),
Some(TensorElementType::Int64) => Ok(ScalarDatum::I64),
Some(other) => Err(TurboQuantError::UnsupportedModel(format!(
"unsupported use_cache_branch input type {other}"
))),
None => Err(TurboQuantError::UnsupportedModel(
"use_cache_branch input must be a tensor".into(),
)),
}
}
fn infer_cache_layout_from_value_type(
value_type: &ValueType,
expected_heads: usize,
head_dim: usize,
) -> Option<CacheTensorLayout> {
let shape = value_type.tensor_shape()?;
let dims: &[i64] = shape;
let expected_heads = expected_heads as i64;
let head_dim = head_dim as i64;
match dims {
[_, heads, _, dim] if *heads == expected_heads && *dim == head_dim => {
Some(CacheTensorLayout::BatchHeadSeq)
}
[_, _, heads, dim] if *heads == expected_heads && *dim == head_dim => {
Some(CacheTensorLayout::BatchSeqHead)
}
[heads, _, dim] if *heads == expected_heads && *dim == head_dim => {
Some(CacheTensorLayout::HeadSeq)
}
[_, heads, dim] if *heads == expected_heads && *dim == head_dim => {
Some(CacheTensorLayout::SeqHead)
}
_ => None,
}
}
fn validate_cache_io(
past_inputs: &[PastInputSlot],
present_outputs: &[PresentOutputSlot],
layer_count: usize,
) -> Result<()> {
for layer in 0..layer_count {
let past_key = past_inputs
.iter()
.any(|slot| slot.layer == layer && slot.kind == CacheTensorKind::Key);
let past_value = past_inputs
.iter()
.any(|slot| slot.layer == layer && slot.kind == CacheTensorKind::Value);
let present_key = present_outputs
.iter()
.any(|slot| slot.layer == layer && slot.kind == CacheTensorKind::Key);
let present_value = present_outputs
.iter()
.any(|slot| slot.layer == layer && slot.kind == CacheTensorKind::Value);
if !(past_key && past_value && present_key && present_value) {
return Err(TurboQuantError::UnsupportedModel(format!(
"decoder layer {layer} is missing past/present key-value I/O"
)));
}
}
Ok(())
}
fn parse_past_input(name: &str) -> Option<(usize, CacheTensorKind)> {
if !name.contains("past_key_values.") {
return None;
}
let layer = parse_layer_index(name, "past_key_values.")?;
let kind = parse_cache_tensor_kind(name)?;
Some((layer, kind))
}
fn parse_present_output(name: &str) -> Option<(usize, CacheTensorKind)> {
if let Some(layer) = parse_layer_index(name, "present.") {
return parse_cache_tensor_kind(name).map(|kind| (layer, kind));
}
if let Some(layer) = parse_layer_index(name, "present_key_values.") {
return parse_cache_tensor_kind(name).map(|kind| (layer, kind));
}
None
}
fn parse_layer_index(name: &str, prefix: &str) -> Option<usize> {
let start = name.find(prefix)? + prefix.len();
let digits = name[start..]
.chars()
.take_while(|char| char.is_ascii_digit())
.collect::<String>();
if digits.is_empty() {
None
} else {
digits.parse::<usize>().ok()
}
}
fn parse_cache_tensor_kind(name: &str) -> Option<CacheTensorKind> {
if name.ends_with(".key") || name.ends_with("_key") || name.contains(".decoder.key") {
Some(CacheTensorKind::Key)
} else if name.ends_with(".value")
|| name.ends_with("_value")
|| name.contains(".decoder.value")
{
Some(CacheTensorKind::Value)
} else {
None
}
}
fn extract_logits(tensor: &DynValue) -> Result<(Vec<f32>, usize, usize)> {
let (shape, data) = tensor.try_extract_tensor::<f32>().map_err(onnx_error)?;
let dims: &[i64] = shape;
match dims {
[1, seq_len, vocab_size] => Ok((
data.to_vec(),
dim_to_usize(*seq_len, "logits sequence axis")?,
dim_to_usize(*vocab_size, "logits vocabulary axis")?,
)),
[seq_len, vocab_size] => Ok((
data.to_vec(),
dim_to_usize(*seq_len, "logits sequence axis")?,
dim_to_usize(*vocab_size, "logits vocabulary axis")?,
)),
_ => Err(TurboQuantError::ModelFormat(format!(
"logits tensor must have shape [1, seq, vocab] or [seq, vocab], got {dims:?}"
))),
}
}
fn normalize_cache_tensor(
tensor: &DynValue,
expected_heads: usize,
head_dim: usize,
) -> Result<NormalizedCacheTensor> {
let (shape, data) = tensor.try_extract_tensor::<f32>().map_err(onnx_error)?;
let dims: &[i64] = shape;
let (layout, seq_len) = match dims {
[1, heads, seq_len, dim]
if dim_to_usize(*heads, "cache heads axis")? == expected_heads
&& dim_to_usize(*dim, "cache head_dim axis")? == head_dim =>
{
(
CacheTensorLayout::BatchHeadSeq,
dim_to_usize(*seq_len, "cache sequence axis")?,
)
}
[1, seq_len, heads, dim]
if dim_to_usize(*heads, "cache heads axis")? == expected_heads
&& dim_to_usize(*dim, "cache head_dim axis")? == head_dim =>
{
(
CacheTensorLayout::BatchSeqHead,
dim_to_usize(*seq_len, "cache sequence axis")?,
)
}
[heads, seq_len, dim]
if dim_to_usize(*heads, "cache heads axis")? == expected_heads
&& dim_to_usize(*dim, "cache head_dim axis")? == head_dim =>
{
(
CacheTensorLayout::HeadSeq,
dim_to_usize(*seq_len, "cache sequence axis")?,
)
}
[seq_len, heads, dim]
if dim_to_usize(*heads, "cache heads axis")? == expected_heads
&& dim_to_usize(*dim, "cache head_dim axis")? == head_dim =>
{
(
CacheTensorLayout::SeqHead,
dim_to_usize(*seq_len, "cache sequence axis")?,
)
}
_ => {
return Err(TurboQuantError::ModelFormat(format!(
"cache tensor shape {dims:?} does not match expected KV layout with {expected_heads} heads and head_dim {head_dim}"
)))
}
};
let token_major = cache_data_to_token_major(data, seq_len, expected_heads, head_dim, layout);
Ok(NormalizedCacheTensor {
token_major,
seq_len,
num_heads: expected_heads,
head_dim,
source_layout: layout,
})
}
fn dim_to_usize(dim: i64, context: &str) -> Result<usize> {
usize::try_from(dim).map_err(|_| {
TurboQuantError::ModelFormat(format!(
"{context} must be a concrete non-negative dimension, got {dim}"
))
})
}
fn cache_data_to_token_major(
data: &[f32],
seq_len: usize,
num_heads: usize,
head_dim: usize,
layout: CacheTensorLayout,
) -> Vec<f32> {
let mut token_major = Vec::with_capacity(seq_len * num_heads * head_dim);
match layout {
CacheTensorLayout::BatchSeqHead | CacheTensorLayout::SeqHead => {
token_major.extend_from_slice(data);
}
CacheTensorLayout::BatchHeadSeq | CacheTensorLayout::HeadSeq => {
for token in 0..seq_len {
for head in 0..num_heads {
let start = (head * seq_len + token) * head_dim;
token_major.extend_from_slice(&data[start..start + head_dim]);
}
}
}
}
token_major
}
fn token_major_to_tensor(
token_major: &[f32],
seq_len: usize,
num_heads: usize,
head_dim: usize,
layout: CacheTensorLayout,
) -> Result<Tensor<f32>> {
let allocator = Allocator::default();
let data = match layout {
CacheTensorLayout::BatchSeqHead | CacheTensorLayout::SeqHead => token_major.to_vec(),
CacheTensorLayout::BatchHeadSeq | CacheTensorLayout::HeadSeq => {
let mut reordered = Vec::with_capacity(token_major.len());
for head in 0..num_heads {
for token in 0..seq_len {
let start = (token * num_heads + head) * head_dim;
reordered.extend_from_slice(&token_major[start..start + head_dim]);
}
}
reordered
}
};
match layout {
CacheTensorLayout::BatchHeadSeq => {
if seq_len == 0 {
Tensor::new(&allocator, [1_usize, num_heads, seq_len, head_dim]).map_err(onnx_error)
} else {
Tensor::from_array(([1_usize, num_heads, seq_len, head_dim], data))
.map_err(onnx_error)
}
}
CacheTensorLayout::BatchSeqHead => {
if seq_len == 0 {
Tensor::new(&allocator, [1_usize, seq_len, num_heads, head_dim]).map_err(onnx_error)
} else {
Tensor::from_array(([1_usize, seq_len, num_heads, head_dim], data))
.map_err(onnx_error)
}
}
CacheTensorLayout::HeadSeq => {
if seq_len == 0 {
Tensor::new(&allocator, [num_heads, seq_len, head_dim]).map_err(onnx_error)
} else {
Tensor::from_array(([num_heads, seq_len, head_dim], data)).map_err(onnx_error)
}
}
CacheTensorLayout::SeqHead => {
if seq_len == 0 {
Tensor::new(&allocator, [seq_len, num_heads, head_dim]).map_err(onnx_error)
} else {
Tensor::from_array(([seq_len, num_heads, head_dim], data)).map_err(onnx_error)
}
}
}
}
fn extract_quantizable_rows(
token_major: &[f32],
total_seq_len: usize,
num_heads: usize,
head_dim: usize,
start_token: usize,
token_count: usize,
label: &str,
) -> Result<Vec<Vec<Vec<f64>>>> {
if start_token + token_count > total_seq_len {
return Err(TurboQuantError::ModelFormat(format!(
"cannot slice {label} rows from token range [{start_token}, {}) with total_seq_len={total_seq_len}",
start_token + token_count
)));
}
let mut rows = vec![Vec::with_capacity(token_count); num_heads];
for token in start_token..start_token + token_count {
for (head_index, head_rows) in rows.iter_mut().enumerate() {
let start = (token * num_heads + head_index) * head_dim;
let slice = token_major.get(start..start + head_dim).ok_or_else(|| {
TurboQuantError::ModelFormat(format!(
"token-major {label} tensor is truncated at token={token}, head={head_index}"
))
})?;
head_rows.push(slice.iter().map(|value| *value as f64).collect());
}
}
Ok(rows)
}
fn normalize_rows_per_head(rows: Vec<Vec<Vec<f64>>>) -> Result<Vec<Vec<Vec<f64>>>> {
rows.into_iter()
.map(|head_rows| {
head_rows
.into_iter()
.map(|row| normalize(&row))
.collect::<Result<Vec<_>>>()
})
.collect()
}
fn argmax_u32(values: &[f32]) -> u32 {
values
.iter()
.enumerate()
.max_by(|(_, left), (_, right)| left.total_cmp(right))
.map(|(index, _)| index as u32)
.unwrap_or(0)
}
fn onnx_error(error: impl std::fmt::Display) -> TurboQuantError {
TurboQuantError::Onnx(error.to_string())
}
#[cfg(test)]
mod tests {
use super::{
infer_cache_layout_from_value_type, scalar_datum, CacheTensorLayout, DecoderSpec,
ScalarDatum,
};
use ort::{
tensor::{Shape, SymbolicDimensions, TensorElementType},
value::ValueType,
};
use std::fs;
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
#[test]
fn decoder_spec_accepts_gpt2_config_aliases() {
let fixture_dir = temp_fixture_dir("real-model-gpt2-config");
fs::create_dir_all(&fixture_dir).expect("fixture directory should be created");
fs::write(
fixture_dir.join("config.json"),
r#"{
"_name_or_path": "distilgpt2",
"n_layer": 6,
"n_head": 12,
"n_embd": 768,
"bos_token_id": 50256,
"eos_token_id": 50256
}
"#,
)
.expect("fixture config should be written");
let spec = DecoderSpec::load(&fixture_dir).expect("gpt2 config aliases should load");
assert_eq!(spec.model_id, "distilgpt2");
assert_eq!(spec.num_hidden_layers, 6);
assert_eq!(spec.num_key_value_heads, 12);
assert_eq!(spec.head_dim, 64);
assert_eq!(spec.bos_token_id, Some(50256));
assert_eq!(spec.eos_token_ids, vec![50256]);
fs::remove_dir_all(&fixture_dir).expect("fixture cleanup should succeed");
}
#[test]
fn infer_cache_layout_accepts_dynamic_batch_shapes() {
let value_type = ValueType::Tensor {
ty: TensorElementType::Float32,
shape: Shape::new([-1, 12, -1, 64]),
dimension_symbols: SymbolicDimensions::empty(4),
};
assert_eq!(
infer_cache_layout_from_value_type(&value_type, 12, 64),
Some(CacheTensorLayout::BatchHeadSeq)
);
}
#[test]
fn scalar_datum_accepts_bool_tensor_inputs() {
let value_type = ValueType::Tensor {
ty: TensorElementType::Bool,
shape: Shape::new([]),
dimension_symbols: SymbolicDimensions::empty(0),
};
assert_eq!(scalar_datum(&value_type).unwrap(), ScalarDatum::Bool);
}
fn temp_fixture_dir(prefix: &str) -> PathBuf {
let unique = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock should be after unix epoch")
.as_nanos();
std::env::temp_dir().join(format!("{prefix}-{}-{unique}", std::process::id()))
}
}