use std::collections::HashMap;
use std::path::Path;
use safetensors::{Dtype, SafeTensors};
use serde::{Deserialize, Serialize};
use crate::error::{Result, TurboQuantError};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TraceMetadata {
pub model: Option<String>,
pub benchmark: Option<String>,
pub suite: Option<String>,
pub layer: Option<usize>,
pub head: Option<usize>,
pub note: Option<String>,
}
#[derive(Debug, Clone)]
pub struct TraceSample {
pub keys: Vec<Vec<f64>>,
pub values: Vec<Vec<f64>>,
pub queries: Vec<Vec<f64>>,
pub query_positions: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct KvTrace {
pub metadata: TraceMetadata,
pub samples: Vec<TraceSample>,
pub dim: usize,
}
impl KvTrace {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let bytes = std::fs::read(path.as_ref()).map_err(|error| {
TurboQuantError::Io(format!("{}: {error}", path.as_ref().display()))
})?;
let (_, header) = SafeTensors::read_metadata(&bytes)
.map_err(|error| TurboQuantError::TraceFormat(error.to_string()))?;
let metadata = TraceMetadata::from_map(header.metadata());
let tensors = SafeTensors::deserialize(&bytes)
.map_err(|error| TurboQuantError::TraceFormat(error.to_string()))?;
let keys = tensor_to_f64(&tensors, "keys")?;
let values = tensor_to_f64(&tensors, "values")?;
let queries = tensor_to_f64(&tensors, "queries")?;
let query_positions = match tensors.tensor("query_positions") {
Ok(tensor) => Some(tensor_to_usize(
"query_positions",
tensor.dtype(),
tensor.data(),
)?),
Err(_) => None,
};
let (num_samples, num_tokens, dim) = normalize_shape("keys", &keys.shape)?;
let (value_samples, value_tokens, value_dim) = normalize_shape("values", &values.shape)?;
if (value_samples, value_tokens, value_dim) != (num_samples, num_tokens, dim) {
return Err(TurboQuantError::TraceFormat(format!(
"values shape {:?} does not match keys shape {:?}",
values.shape, keys.shape
)));
}
let (query_samples, num_queries, query_dim) = normalize_shape("queries", &queries.shape)?;
if query_samples != num_samples || query_dim != dim {
return Err(TurboQuantError::TraceFormat(format!(
"queries shape {:?} does not match keys sample/dim {:?}",
queries.shape, keys.shape
)));
}
if num_tokens == 0 && num_queries > 0 {
return Err(TurboQuantError::TraceFormat(
"queries require at least one key/value token".into(),
));
}
let expected_query_positions = num_samples * num_queries;
let query_positions = if let Some(positions) = query_positions {
if positions.len() != expected_query_positions {
return Err(TurboQuantError::TraceFormat(format!(
"query_positions length {} does not match expected {}",
positions.len(),
expected_query_positions
)));
}
positions
} else if num_queries == num_tokens {
(0..num_samples)
.flat_map(|_| 0..num_queries)
.collect::<Vec<_>>()
} else {
vec![num_tokens.saturating_sub(1); expected_query_positions]
};
if let Some((index, position)) = query_positions
.iter()
.copied()
.enumerate()
.find(|(_, position)| *position >= num_tokens)
{
return Err(TurboQuantError::TraceFormat(format!(
"query_positions[{index}] = {position} is out of range for {num_tokens} tokens"
)));
}
let mut samples = Vec::with_capacity(num_samples);
for sample_index in 0..num_samples {
let key_start = sample_index * num_tokens * dim;
let value_start = sample_index * num_tokens * dim;
let query_start = sample_index * num_queries * dim;
let position_start = sample_index * num_queries;
samples.push(TraceSample {
keys: reshape_rows(&keys.data[key_start..key_start + num_tokens * dim], dim),
values: reshape_rows(
&values.data[value_start..value_start + num_tokens * dim],
dim,
),
queries: reshape_rows(
&queries.data[query_start..query_start + num_queries * dim],
dim,
),
query_positions: query_positions[position_start..position_start + num_queries]
.to_vec(),
});
}
Ok(Self {
metadata,
samples,
dim,
})
}
pub fn sample_count(&self) -> usize {
self.samples.len()
}
pub fn total_tokens(&self) -> usize {
self.samples.iter().map(|sample| sample.keys.len()).sum()
}
pub fn total_queries(&self) -> usize {
self.samples.iter().map(|sample| sample.queries.len()).sum()
}
pub fn uncompressed_kv_bytes(&self) -> usize {
self.total_tokens() * self.dim * std::mem::size_of::<f32>() * 2
}
}
#[derive(Debug)]
struct TensorDataF64 {
shape: Vec<usize>,
data: Vec<f64>,
}
fn normalize_shape(name: &str, shape: &[usize]) -> Result<(usize, usize, usize)> {
match shape {
[_, 0] | [_, _, 0] => Err(TurboQuantError::TraceFormat(format!(
"{name} tensor must have non-zero dimension, got {shape:?}"
))),
[tokens, dim] => Ok((1, *tokens, *dim)),
[samples, tokens, dim] => Ok((*samples, *tokens, *dim)),
_ => Err(TurboQuantError::TraceFormat(format!(
"{name} tensor must have shape [tokens, dim] or [samples, tokens, dim], got {shape:?}"
))),
}
}
fn reshape_rows(flat: &[f64], dim: usize) -> Vec<Vec<f64>> {
flat.chunks(dim).map(|chunk| chunk.to_vec()).collect()
}
fn tensor_to_f64(tensors: &SafeTensors<'_>, name: &str) -> Result<TensorDataF64> {
let tensor = tensors
.tensor(name)
.map_err(|error| TurboQuantError::TraceFormat(format!("missing tensor {name}: {error}")))?;
let shape = tensor.shape().to_vec();
let data = match tensor.dtype() {
Dtype::F32 => bytemuck::cast_slice::<u8, f32>(tensor.data())
.iter()
.map(|value| *value as f64)
.collect(),
Dtype::F64 => bytemuck::cast_slice::<u8, f64>(tensor.data()).to_vec(),
other => {
return Err(TurboQuantError::TraceFormat(format!(
"{name} tensor must be F32 or F64, got {other:?}"
)))
}
};
Ok(TensorDataF64 { shape, data })
}
fn tensor_to_usize(name: &str, dtype: Dtype, data: &[u8]) -> Result<Vec<usize>> {
match dtype {
Dtype::I32 => bytemuck::cast_slice::<u8, i32>(data)
.iter()
.copied()
.enumerate()
.map(|(index, value)| {
usize::try_from(value).map_err(|_| {
TurboQuantError::TraceFormat(format!(
"{name}[{index}] must be non-negative, got {value}"
))
})
})
.collect(),
Dtype::I64 => bytemuck::cast_slice::<u8, i64>(data)
.iter()
.copied()
.enumerate()
.map(|(index, value)| {
usize::try_from(value).map_err(|_| {
TurboQuantError::TraceFormat(format!(
"{name}[{index}] must be non-negative, got {value}"
))
})
})
.collect(),
Dtype::U32 => bytemuck::cast_slice::<u8, u32>(data)
.iter()
.copied()
.enumerate()
.map(|(index, value)| {
usize::try_from(value).map_err(|_| {
TurboQuantError::TraceFormat(format!(
"{name}[{index}] = {value} does not fit in usize"
))
})
})
.collect(),
Dtype::U64 => bytemuck::cast_slice::<u8, u64>(data)
.iter()
.copied()
.enumerate()
.map(|(index, value)| {
usize::try_from(value).map_err(|_| {
TurboQuantError::TraceFormat(format!(
"{name}[{index}] = {value} does not fit in usize"
))
})
})
.collect(),
other => Err(TurboQuantError::TraceFormat(format!(
"query_positions tensor must be I32/I64/U32/U64, got {other:?}"
))),
}
}
impl TraceMetadata {
fn from_map(map: &Option<HashMap<String, String>>) -> Self {
let Some(map) = map else {
return Self::default();
};
Self {
model: map.get("model").cloned(),
benchmark: map.get("benchmark").cloned(),
suite: map.get("suite").cloned(),
layer: map.get("layer").and_then(|value| value.parse().ok()),
head: map.get("head").and_then(|value| value.parse().ok()),
note: map.get("note").cloned(),
}
}
}
#[cfg(test)]
mod tests {
use std::collections::{BTreeMap, HashMap};
use safetensors::{serialize_to_file, tensor::TensorView, Dtype};
use super::KvTrace;
fn write_trace_file(temp_path: &std::path::Path, positions: &[i32]) {
let keys = vec![1.0f32, 0.0, 0.0, 1.0];
let values = vec![0.5f32, 0.5, 0.25, 0.75];
let queries = vec![1.0f32, 0.0, 0.0, 1.0];
let tensors = BTreeMap::from([
(
"keys".to_string(),
TensorView::new(Dtype::F32, vec![1, 2, 2], bytemuck::cast_slice(&keys)).unwrap(),
),
(
"values".to_string(),
TensorView::new(Dtype::F32, vec![1, 2, 2], bytemuck::cast_slice(&values)).unwrap(),
),
(
"queries".to_string(),
TensorView::new(Dtype::F32, vec![1, 2, 2], bytemuck::cast_slice(&queries)).unwrap(),
),
(
"query_positions".to_string(),
TensorView::new(
Dtype::I32,
vec![1, positions.len()],
bytemuck::cast_slice(positions),
)
.unwrap(),
),
]);
serialize_to_file(
tensors.iter().map(|(name, tensor)| (name.as_str(), tensor)),
Some(HashMap::from([
("model".to_string(), "toy-model".to_string()),
("layer".to_string(), "7".to_string()),
])),
temp_path,
)
.unwrap();
}
#[test]
fn load_trace_from_safetensors() {
let temp_path = std::env::temp_dir().join(format!(
"turboquant-trace-{}.safetensors",
std::process::id()
));
write_trace_file(&temp_path, &[0, 1]);
let trace = KvTrace::load(&temp_path).unwrap();
std::fs::remove_file(&temp_path).unwrap();
assert_eq!(trace.sample_count(), 1);
assert_eq!(trace.dim, 2);
assert_eq!(trace.metadata.model.as_deref(), Some("toy-model"));
assert_eq!(trace.metadata.layer, Some(7));
assert_eq!(trace.samples[0].keys.len(), 2);
assert_eq!(trace.samples[0].query_positions, vec![0, 1]);
}
#[test]
fn load_trace_rejects_negative_query_positions() {
let temp_path = std::env::temp_dir().join(format!(
"turboquant-trace-negative-{}.safetensors",
std::process::id()
));
write_trace_file(&temp_path, &[-1, 1]);
let error = KvTrace::load(&temp_path).unwrap_err();
std::fs::remove_file(&temp_path).unwrap();
assert!(matches!(
error,
crate::error::TurboQuantError::TraceFormat(_)
));
assert!(error.to_string().contains("non-negative"));
}
#[test]
fn load_trace_rejects_query_positions_past_token_count() {
let temp_path = std::env::temp_dir().join(format!(
"turboquant-trace-oob-{}.safetensors",
std::process::id()
));
write_trace_file(&temp_path, &[0, 2]);
let error = KvTrace::load(&temp_path).unwrap_err();
std::fs::remove_file(&temp_path).unwrap();
assert!(matches!(
error,
crate::error::TurboQuantError::TraceFormat(_)
));
assert!(error.to_string().contains("out of range"));
}
}