turboquant 0.1.1

Implementation of Google's TurboQuant algorithm for vector quantization
Documentation
use std::collections::HashMap;
use std::path::Path;

use safetensors::{Dtype, SafeTensors};
use serde::{Deserialize, Serialize};

use crate::error::{Result, TurboQuantError};

#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TraceMetadata {
    pub model: Option<String>,
    pub benchmark: Option<String>,
    pub suite: Option<String>,
    pub layer: Option<usize>,
    pub head: Option<usize>,
    pub note: Option<String>,
}

#[derive(Debug, Clone)]
pub struct TraceSample {
    pub keys: Vec<Vec<f64>>,
    pub values: Vec<Vec<f64>>,
    pub queries: Vec<Vec<f64>>,
    pub query_positions: Vec<usize>,
}

#[derive(Debug, Clone)]
pub struct KvTrace {
    pub metadata: TraceMetadata,
    pub samples: Vec<TraceSample>,
    pub dim: usize,
}

impl KvTrace {
    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
        let bytes = std::fs::read(path.as_ref()).map_err(|error| {
            TurboQuantError::Io(format!("{}: {error}", path.as_ref().display()))
        })?;
        let (_, header) = SafeTensors::read_metadata(&bytes)
            .map_err(|error| TurboQuantError::TraceFormat(error.to_string()))?;
        let metadata = TraceMetadata::from_map(header.metadata());

        let tensors = SafeTensors::deserialize(&bytes)
            .map_err(|error| TurboQuantError::TraceFormat(error.to_string()))?;

        let keys = tensor_to_f64(&tensors, "keys")?;
        let values = tensor_to_f64(&tensors, "values")?;
        let queries = tensor_to_f64(&tensors, "queries")?;
        let query_positions = match tensors.tensor("query_positions") {
            Ok(tensor) => Some(tensor_to_usize(
                "query_positions",
                tensor.dtype(),
                tensor.data(),
            )?),
            Err(_) => None,
        };

        let (num_samples, num_tokens, dim) = normalize_shape("keys", &keys.shape)?;
        let (value_samples, value_tokens, value_dim) = normalize_shape("values", &values.shape)?;
        if (value_samples, value_tokens, value_dim) != (num_samples, num_tokens, dim) {
            return Err(TurboQuantError::TraceFormat(format!(
                "values shape {:?} does not match keys shape {:?}",
                values.shape, keys.shape
            )));
        }

        let (query_samples, num_queries, query_dim) = normalize_shape("queries", &queries.shape)?;
        if query_samples != num_samples || query_dim != dim {
            return Err(TurboQuantError::TraceFormat(format!(
                "queries shape {:?} does not match keys sample/dim {:?}",
                queries.shape, keys.shape
            )));
        }
        if num_tokens == 0 && num_queries > 0 {
            return Err(TurboQuantError::TraceFormat(
                "queries require at least one key/value token".into(),
            ));
        }

        let expected_query_positions = num_samples * num_queries;
        let query_positions = if let Some(positions) = query_positions {
            if positions.len() != expected_query_positions {
                return Err(TurboQuantError::TraceFormat(format!(
                    "query_positions length {} does not match expected {}",
                    positions.len(),
                    expected_query_positions
                )));
            }
            positions
        } else if num_queries == num_tokens {
            (0..num_samples)
                .flat_map(|_| 0..num_queries)
                .collect::<Vec<_>>()
        } else {
            vec![num_tokens.saturating_sub(1); expected_query_positions]
        };
        if let Some((index, position)) = query_positions
            .iter()
            .copied()
            .enumerate()
            .find(|(_, position)| *position >= num_tokens)
        {
            return Err(TurboQuantError::TraceFormat(format!(
                "query_positions[{index}] = {position} is out of range for {num_tokens} tokens"
            )));
        }

        let mut samples = Vec::with_capacity(num_samples);
        for sample_index in 0..num_samples {
            let key_start = sample_index * num_tokens * dim;
            let value_start = sample_index * num_tokens * dim;
            let query_start = sample_index * num_queries * dim;
            let position_start = sample_index * num_queries;

            samples.push(TraceSample {
                keys: reshape_rows(&keys.data[key_start..key_start + num_tokens * dim], dim),
                values: reshape_rows(
                    &values.data[value_start..value_start + num_tokens * dim],
                    dim,
                ),
                queries: reshape_rows(
                    &queries.data[query_start..query_start + num_queries * dim],
                    dim,
                ),
                query_positions: query_positions[position_start..position_start + num_queries]
                    .to_vec(),
            });
        }

        Ok(Self {
            metadata,
            samples,
            dim,
        })
    }

    pub fn sample_count(&self) -> usize {
        self.samples.len()
    }

    pub fn total_tokens(&self) -> usize {
        self.samples.iter().map(|sample| sample.keys.len()).sum()
    }

    pub fn total_queries(&self) -> usize {
        self.samples.iter().map(|sample| sample.queries.len()).sum()
    }

    pub fn uncompressed_kv_bytes(&self) -> usize {
        self.total_tokens() * self.dim * std::mem::size_of::<f32>() * 2
    }
}

#[derive(Debug)]
struct TensorDataF64 {
    shape: Vec<usize>,
    data: Vec<f64>,
}

fn normalize_shape(name: &str, shape: &[usize]) -> Result<(usize, usize, usize)> {
    match shape {
        [_, 0] | [_, _, 0] => Err(TurboQuantError::TraceFormat(format!(
            "{name} tensor must have non-zero dimension, got {shape:?}"
        ))),
        [tokens, dim] => Ok((1, *tokens, *dim)),
        [samples, tokens, dim] => Ok((*samples, *tokens, *dim)),
        _ => Err(TurboQuantError::TraceFormat(format!(
            "{name} tensor must have shape [tokens, dim] or [samples, tokens, dim], got {shape:?}"
        ))),
    }
}

fn reshape_rows(flat: &[f64], dim: usize) -> Vec<Vec<f64>> {
    flat.chunks(dim).map(|chunk| chunk.to_vec()).collect()
}

fn tensor_to_f64(tensors: &SafeTensors<'_>, name: &str) -> Result<TensorDataF64> {
    let tensor = tensors
        .tensor(name)
        .map_err(|error| TurboQuantError::TraceFormat(format!("missing tensor {name}: {error}")))?;
    let shape = tensor.shape().to_vec();
    let data = match tensor.dtype() {
        Dtype::F32 => bytemuck::cast_slice::<u8, f32>(tensor.data())
            .iter()
            .map(|value| *value as f64)
            .collect(),
        Dtype::F64 => bytemuck::cast_slice::<u8, f64>(tensor.data()).to_vec(),
        other => {
            return Err(TurboQuantError::TraceFormat(format!(
                "{name} tensor must be F32 or F64, got {other:?}"
            )))
        }
    };

    Ok(TensorDataF64 { shape, data })
}

fn tensor_to_usize(name: &str, dtype: Dtype, data: &[u8]) -> Result<Vec<usize>> {
    match dtype {
        Dtype::I32 => bytemuck::cast_slice::<u8, i32>(data)
            .iter()
            .copied()
            .enumerate()
            .map(|(index, value)| {
                usize::try_from(value).map_err(|_| {
                    TurboQuantError::TraceFormat(format!(
                        "{name}[{index}] must be non-negative, got {value}"
                    ))
                })
            })
            .collect(),
        Dtype::I64 => bytemuck::cast_slice::<u8, i64>(data)
            .iter()
            .copied()
            .enumerate()
            .map(|(index, value)| {
                usize::try_from(value).map_err(|_| {
                    TurboQuantError::TraceFormat(format!(
                        "{name}[{index}] must be non-negative, got {value}"
                    ))
                })
            })
            .collect(),
        Dtype::U32 => bytemuck::cast_slice::<u8, u32>(data)
            .iter()
            .copied()
            .enumerate()
            .map(|(index, value)| {
                usize::try_from(value).map_err(|_| {
                    TurboQuantError::TraceFormat(format!(
                        "{name}[{index}] = {value} does not fit in usize"
                    ))
                })
            })
            .collect(),
        Dtype::U64 => bytemuck::cast_slice::<u8, u64>(data)
            .iter()
            .copied()
            .enumerate()
            .map(|(index, value)| {
                usize::try_from(value).map_err(|_| {
                    TurboQuantError::TraceFormat(format!(
                        "{name}[{index}] = {value} does not fit in usize"
                    ))
                })
            })
            .collect(),
        other => Err(TurboQuantError::TraceFormat(format!(
            "query_positions tensor must be I32/I64/U32/U64, got {other:?}"
        ))),
    }
}

impl TraceMetadata {
    fn from_map(map: &Option<HashMap<String, String>>) -> Self {
        let Some(map) = map else {
            return Self::default();
        };

        Self {
            model: map.get("model").cloned(),
            benchmark: map.get("benchmark").cloned(),
            suite: map.get("suite").cloned(),
            layer: map.get("layer").and_then(|value| value.parse().ok()),
            head: map.get("head").and_then(|value| value.parse().ok()),
            note: map.get("note").cloned(),
        }
    }
}

#[cfg(test)]
mod tests {
    use std::collections::{BTreeMap, HashMap};

    use safetensors::{serialize_to_file, tensor::TensorView, Dtype};

    use super::KvTrace;

    fn write_trace_file(temp_path: &std::path::Path, positions: &[i32]) {
        let keys = vec![1.0f32, 0.0, 0.0, 1.0];
        let values = vec![0.5f32, 0.5, 0.25, 0.75];
        let queries = vec![1.0f32, 0.0, 0.0, 1.0];

        let tensors = BTreeMap::from([
            (
                "keys".to_string(),
                TensorView::new(Dtype::F32, vec![1, 2, 2], bytemuck::cast_slice(&keys)).unwrap(),
            ),
            (
                "values".to_string(),
                TensorView::new(Dtype::F32, vec![1, 2, 2], bytemuck::cast_slice(&values)).unwrap(),
            ),
            (
                "queries".to_string(),
                TensorView::new(Dtype::F32, vec![1, 2, 2], bytemuck::cast_slice(&queries)).unwrap(),
            ),
            (
                "query_positions".to_string(),
                TensorView::new(
                    Dtype::I32,
                    vec![1, positions.len()],
                    bytemuck::cast_slice(positions),
                )
                .unwrap(),
            ),
        ]);

        serialize_to_file(
            tensors.iter().map(|(name, tensor)| (name.as_str(), tensor)),
            Some(HashMap::from([
                ("model".to_string(), "toy-model".to_string()),
                ("layer".to_string(), "7".to_string()),
            ])),
            temp_path,
        )
        .unwrap();
    }

    #[test]
    fn load_trace_from_safetensors() {
        let temp_path = std::env::temp_dir().join(format!(
            "turboquant-trace-{}.safetensors",
            std::process::id()
        ));
        write_trace_file(&temp_path, &[0, 1]);

        let trace = KvTrace::load(&temp_path).unwrap();
        std::fs::remove_file(&temp_path).unwrap();

        assert_eq!(trace.sample_count(), 1);
        assert_eq!(trace.dim, 2);
        assert_eq!(trace.metadata.model.as_deref(), Some("toy-model"));
        assert_eq!(trace.metadata.layer, Some(7));
        assert_eq!(trace.samples[0].keys.len(), 2);
        assert_eq!(trace.samples[0].query_positions, vec![0, 1]);
    }

    #[test]
    fn load_trace_rejects_negative_query_positions() {
        let temp_path = std::env::temp_dir().join(format!(
            "turboquant-trace-negative-{}.safetensors",
            std::process::id()
        ));
        write_trace_file(&temp_path, &[-1, 1]);

        let error = KvTrace::load(&temp_path).unwrap_err();
        std::fs::remove_file(&temp_path).unwrap();

        assert!(matches!(
            error,
            crate::error::TurboQuantError::TraceFormat(_)
        ));
        assert!(error.to_string().contains("non-negative"));
    }

    #[test]
    fn load_trace_rejects_query_positions_past_token_count() {
        let temp_path = std::env::temp_dir().join(format!(
            "turboquant-trace-oob-{}.safetensors",
            std::process::id()
        ));
        write_trace_file(&temp_path, &[0, 2]);

        let error = KvTrace::load(&temp_path).unwrap_err();
        std::fs::remove_file(&temp_path).unwrap();

        assert!(matches!(
            error,
            crate::error::TurboQuantError::TraceFormat(_)
        ));
        assert!(error.to_string().contains("out of range"));
    }
}