use std::collections::HashMap;
use std::path::Path;
use anamnesis::{NpzDtype, NpzTensor, parse_npz};
use candle_core::{Device, Tensor};
use crate::error::{MIError, Result};
fn npz_tensor_to_candle(npy: &NpzTensor, device: &Device) -> Result<Tensor> {
match npy.dtype {
NpzDtype::F32 => {
#[allow(clippy::indexing_slicing)]
let f32_data: Vec<f32> = npy
.data
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
Ok(Tensor::from_vec(f32_data, &*npy.shape, device)?)
}
NpzDtype::F64 => {
#[allow(
clippy::indexing_slicing,
clippy::cast_possible_truncation,
clippy::as_conversions
)]
let f32_data: Vec<f32> = npy
.data
.chunks_exact(8)
.map(|c| {
let v = f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]);
v as f32
})
.collect();
Ok(Tensor::from_vec(f32_data, &*npy.shape, device)?)
}
other @ (NpzDtype::Bool
| NpzDtype::U8
| NpzDtype::I8
| NpzDtype::U16
| NpzDtype::I16
| NpzDtype::U32
| NpzDtype::I32
| NpzDtype::U64
| NpzDtype::I64
| NpzDtype::F16
| NpzDtype::BF16
| _) => Err(MIError::Config(format!(
"unsupported NPZ dtype {other} for SAE weights (expected F32 or F64)"
))),
}
}
pub fn load_npz(path: &Path, device: &Device) -> Result<HashMap<String, Tensor>> {
let npz = parse_npz(path)?;
let mut tensors = HashMap::with_capacity(npz.len());
for (name, npy) in &npz {
tensors.insert(name.clone(), npz_tensor_to_candle(npy, device)?);
}
Ok(tensors)
}