use ahash::{HashMap, HashMapExt};
use std::fs::File;
use std::io::{Read, Seek, SeekFrom};
use std::path::Path;
use anyhow::{Context, Result, bail};
use super::tensor::{Tensor1D, Tensor2D};
#[derive(Debug, Clone)]
struct TensorMeta {
shape: Vec<usize>,
offset_start: usize,
offset_end: usize,
}
pub struct WeightTensor {
data: Vec<f32>,
shape: Vec<usize>,
}
impl WeightTensor {
#[inline]
pub fn data(&self) -> &[f32] {
&self.data
}
#[inline]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[inline]
pub fn numel(&self) -> usize {
self.data.len()
}
#[inline]
pub fn as_1d(&self) -> &[f32] {
&self.data
}
pub fn as_2d(&self, rows: usize, cols: usize) -> impl Iterator<Item = &[f32]> {
debug_assert_eq!(rows * cols, self.data.len());
(0..rows).map(move |r| &self.data[r * cols..(r + 1) * cols])
}
}
pub struct Weights {
pub(crate) tensors: HashMap<String, WeightTensor>,
}
impl Weights {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
let mut file = File::open(path)
.with_context(|| format!("Failed to open weights file: {}", path.display()))?;
let mut header_len_bytes = [0u8; 8];
file.read_exact(&mut header_len_bytes)?;
let header_len = u64::from_le_bytes(header_len_bytes) as usize;
let mut header_bytes = vec![0u8; header_len];
file.read_exact(&mut header_bytes)?;
let header_str =
std::str::from_utf8(&header_bytes).context("Invalid UTF-8 in safetensors header")?;
let metas = parse_safetensors_header(header_str)?;
let data_offset = 8 + header_len;
let mut tensors = HashMap::with_capacity(metas.len());
for (name, meta) in metas {
file.seek(SeekFrom::Start((data_offset + meta.offset_start) as u64))?;
let byte_len = meta.offset_end - meta.offset_start;
let mut raw_bytes = vec![0u8; byte_len];
file.read_exact(&mut raw_bytes)?;
let data = bytes_to_f32(&raw_bytes);
tensors.insert(
name,
WeightTensor {
data,
shape: meta.shape,
},
);
}
Ok(Self { tensors })
}
#[inline]
pub fn get(&self, name: &str) -> Option<&WeightTensor> {
self.tensors.get(name)
}
pub fn require(&self, name: &str) -> Result<&WeightTensor> {
self.tensors
.get(name)
.with_context(|| format!("Missing required tensor: {}", name))
}
pub fn get_1d(&self, name: &str) -> Result<Tensor1D> {
let t = self.require(name)?;
Ok(Tensor1D::from_vec(t.data.clone()))
}
pub fn get_2d(&self, name: &str) -> Result<Tensor2D> {
let t = self.require(name)?;
match t.shape.len() {
1 => Ok(Tensor2D::from_vec(t.data.clone(), 1, t.shape[0])),
2 => Ok(Tensor2D::from_vec(t.data.clone(), t.shape[0], t.shape[1])),
_ => bail!(
"Expected 1D or 2D tensor for '{}', got shape {:?}",
name,
t.shape
),
}
}
pub fn tensor_names(&self) -> impl Iterator<Item = &str> {
self.tensors.keys().map(|s| s.as_str())
}
pub fn print_summary(&self) {
let mut names: Vec<_> = self.tensors.keys().collect();
names.sort();
for name in names {
let t = &self.tensors[name];
println!(" {} {:?} ({} params)", name, t.shape, t.numel());
}
}
}
fn parse_safetensors_header(json: &str) -> Result<HashMap<String, TensorMeta>> {
let bytes = json.as_bytes();
let mut pos = 0;
let mut metas = HashMap::new();
skip_whitespace(bytes, &mut pos);
expect_char(bytes, &mut pos, b'{')?;
loop {
skip_whitespace(bytes, &mut pos);
if pos < bytes.len() && bytes[pos] == b'}' {
break;
}
if pos < bytes.len() && bytes[pos] == b',' {
pos += 1;
skip_whitespace(bytes, &mut pos);
}
let name = parse_string(bytes, &mut pos)?;
if name == "__metadata__" {
skip_whitespace(bytes, &mut pos);
expect_char(bytes, &mut pos, b':')?;
skip_json_value(bytes, &mut pos)?;
continue;
}
skip_whitespace(bytes, &mut pos);
expect_char(bytes, &mut pos, b':')?;
skip_whitespace(bytes, &mut pos);
let meta = parse_tensor_info(bytes, &mut pos)?;
metas.insert(name, meta);
}
Ok(metas)
}
fn parse_tensor_info(bytes: &[u8], pos: &mut usize) -> Result<TensorMeta> {
expect_char(bytes, pos, b'{')?;
let mut shape: Option<Vec<usize>> = None;
let mut offset_start: Option<usize> = None;
let mut offset_end: Option<usize> = None;
loop {
skip_whitespace(bytes, pos);
if *pos < bytes.len() && bytes[*pos] == b'}' {
*pos += 1;
break;
}
if *pos < bytes.len() && bytes[*pos] == b',' {
*pos += 1;
skip_whitespace(bytes, pos);
}
let key = parse_string(bytes, pos)?;
skip_whitespace(bytes, pos);
expect_char(bytes, pos, b':')?;
skip_whitespace(bytes, pos);
match key.as_str() {
"shape" => {
shape = Some(parse_int_array(bytes, pos)?);
}
"data_offsets" => {
let offsets = parse_int_array(bytes, pos)?;
if offsets.len() >= 2 {
offset_start = Some(offsets[0]);
offset_end = Some(offsets[1]);
}
}
_ => {
skip_json_value(bytes, pos)?;
}
}
}
Ok(TensorMeta {
shape: shape.unwrap_or_default(),
offset_start: offset_start.unwrap_or(0),
offset_end: offset_end.unwrap_or(0),
})
}
fn parse_string(bytes: &[u8], pos: &mut usize) -> Result<String> {
expect_char(bytes, pos, b'"')?;
let start = *pos;
while *pos < bytes.len() && bytes[*pos] != b'"' {
if bytes[*pos] == b'\\' {
*pos += 1; }
*pos += 1;
}
let end = *pos;
expect_char(bytes, pos, b'"')?;
String::from_utf8(bytes[start..end].to_vec()).context("Invalid UTF-8 in JSON string")
}
fn parse_int_array(bytes: &[u8], pos: &mut usize) -> Result<Vec<usize>> {
expect_char(bytes, pos, b'[')?;
let mut result = Vec::new();
loop {
skip_whitespace(bytes, pos);
if *pos < bytes.len() && bytes[*pos] == b']' {
*pos += 1;
break;
}
if *pos < bytes.len() && bytes[*pos] == b',' {
*pos += 1;
skip_whitespace(bytes, pos);
}
result.push(parse_int(bytes, pos)?);
}
Ok(result)
}
fn parse_int(bytes: &[u8], pos: &mut usize) -> Result<usize> {
let start = *pos;
while *pos < bytes.len() && bytes[*pos].is_ascii_digit() {
*pos += 1;
}
if start == *pos {
bail!("Expected integer at position {}", *pos);
}
let s = std::str::from_utf8(&bytes[start..*pos])?;
s.parse().context("Failed to parse integer")
}
fn skip_json_value(bytes: &[u8], pos: &mut usize) -> Result<()> {
skip_whitespace(bytes, pos);
if *pos >= bytes.len() {
return Ok(());
}
match bytes[*pos] {
b'"' => {
*pos += 1;
while *pos < bytes.len() && bytes[*pos] != b'"' {
if bytes[*pos] == b'\\' {
*pos += 1;
}
*pos += 1;
}
*pos += 1; }
b'{' => {
let mut depth = 1;
*pos += 1;
while *pos < bytes.len() && depth > 0 {
match bytes[*pos] {
b'{' => depth += 1,
b'}' => depth -= 1,
b'"' => {
*pos += 1;
while *pos < bytes.len() && bytes[*pos] != b'"' {
if bytes[*pos] == b'\\' {
*pos += 1;
}
*pos += 1;
}
}
_ => {}
}
*pos += 1;
}
}
b'[' => {
let mut depth = 1;
*pos += 1;
while *pos < bytes.len() && depth > 0 {
match bytes[*pos] {
b'[' => depth += 1,
b']' => depth -= 1,
b'"' => {
*pos += 1;
while *pos < bytes.len() && bytes[*pos] != b'"' {
if bytes[*pos] == b'\\' {
*pos += 1;
}
*pos += 1;
}
}
_ => {}
}
*pos += 1;
}
}
_ => {
while *pos < bytes.len() && !matches!(bytes[*pos], b',' | b'}' | b']') {
*pos += 1;
}
}
}
Ok(())
}
#[inline]
fn skip_whitespace(bytes: &[u8], pos: &mut usize) {
while *pos < bytes.len() && bytes[*pos].is_ascii_whitespace() {
*pos += 1;
}
}
#[inline]
fn expect_char(bytes: &[u8], pos: &mut usize, expected: u8) -> Result<()> {
if *pos >= bytes.len() || bytes[*pos] != expected {
bail!(
"Expected '{}' at position {}, found '{}'",
expected as char,
*pos,
bytes.get(*pos).map(|&b| b as char).unwrap_or('\0')
);
}
*pos += 1;
Ok(())
}
#[inline]
fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
let num_floats = bytes.len() / 4;
let mut result = vec![0.0f32; num_floats];
let chunks = num_floats / 4;
for i in 0..chunks {
let base = i * 16;
result[i * 4] = f32::from_le_bytes([
bytes[base],
bytes[base + 1],
bytes[base + 2],
bytes[base + 3],
]);
result[i * 4 + 1] = f32::from_le_bytes([
bytes[base + 4],
bytes[base + 5],
bytes[base + 6],
bytes[base + 7],
]);
result[i * 4 + 2] = f32::from_le_bytes([
bytes[base + 8],
bytes[base + 9],
bytes[base + 10],
bytes[base + 11],
]);
result[i * 4 + 3] = f32::from_le_bytes([
bytes[base + 12],
bytes[base + 13],
bytes[base + 14],
bytes[base + 15],
]);
}
for (offset, out) in result.iter_mut().skip(chunks * 4).enumerate() {
let base = (chunks * 4 + offset) * 4;
*out = f32::from_le_bytes([
bytes[base],
bytes[base + 1],
bytes[base + 2],
bytes[base + 3],
]);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_weights() {
let path =
std::env::var("RWKV_MODEL").unwrap_or_else(|_| "rwkv-10m.safetensors".to_string());
if !Path::new(&path).exists() {
eprintln!("Skipping test: model file not found at {}", path);
return;
}
let weights = Weights::load(&path).unwrap();
assert!(weights.get("model.embeddings.weight").is_some());
assert!(weights.get("lm_head.weight").is_some());
let emb = weights.get("model.embeddings.weight").unwrap();
assert_eq!(emb.shape.len(), 2);
assert_eq!(emb.shape[0], 256);
assert_eq!(emb.shape[1], 256);
}
#[test]
fn test_bytes_to_f32() {
let bytes = [0x00, 0x00, 0x80, 0x3F]; let result = bytes_to_f32(&bytes);
assert_eq!(result.len(), 1);
assert!((result[0] - 1.0).abs() < 1e-6);
let bytes = [
0x00, 0x00, 0x80, 0x3F, 0x00, 0x00, 0x00, 0x40, ];
let result = bytes_to_f32(&bytes);
assert_eq!(result.len(), 2);
assert!((result[0] - 1.0).abs() < 1e-6);
assert!((result[1] - 2.0).abs() < 1e-6);
}
}