stprobe 0.2.3

A minimal CLI for inspecting safetensors headers
Documentation
use std::collections::HashMap;
use std::fs;
use std::path::Path;

use safetensors::tensor::{Dtype, TensorView};

#[allow(dead_code)]
pub fn write_sample_file(path: &Path) {
    fs::write(path, sample_safetensors_bytes()).expect("write test safetensors fixture");
}

pub fn sample_safetensors_bytes() -> Vec<u8> {
    let weight = [1.0_f32.to_le_bytes(), 2.0_f32.to_le_bytes()].concat();
    let ids = [1_i64.to_le_bytes(), 2_i64.to_le_bytes()].concat();

    let weight_view = TensorView::new(Dtype::F32, vec![2], &weight).expect("valid weight tensor");
    let ids_view = TensorView::new(Dtype::I64, vec![2], &ids).expect("valid ids tensor");

    safetensors::serialize(
        vec![
            ("embedding.ids", ids_view),
            ("embedding.weight", weight_view),
        ],
        Some(HashMap::from([("format".to_owned(), "pt".to_owned())])),
    )
    .expect("serialize test safetensors fixture")
}

#[allow(dead_code)]
pub fn split_header(bytes: &[u8]) -> (Vec<u8>, Vec<u8>) {
    let prefix = bytes[..8].to_vec();
    let header_len = u64::from_le_bytes(prefix.as_slice().try_into().expect("8-byte prefix"));
    let header_end = 8 + header_len as usize;
    let header = bytes[8..header_end].to_vec();
    (prefix, header)
}