use std::{
collections::HashMap,
io::{Cursor, Read},
};
use serde::{Deserialize, Serialize};
use crate::error::{RealizarError, Result};
use crate::inference::simd_bf16_to_f32;
pub fn find_sibling_file(
model_path: &std::path::Path,
companion_name: &str,
) -> Option<std::path::PathBuf> {
let parent = model_path.parent()?;
let filename = model_path.file_name()?.to_str()?;
if !filename.ends_with(".index.json") {
let stem = model_path.file_stem()?.to_str()?;
let prefixed = parent.join(format!("{stem}.{companion_name}"));
if prefixed.exists() {
return Some(prefixed);
}
}
let plain = parent.join(companion_name);
if plain.exists() {
return Some(plain);
}
None
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub enum SafetensorsDtype {
F32,
F16,
BF16,
I32,
I64,
U8,
Bool,
}
impl SafetensorsDtype {
#[must_use]
pub fn size_in_bytes(&self) -> usize {
match self {
SafetensorsDtype::F32 | SafetensorsDtype::I32 => 4,
SafetensorsDtype::F16 | SafetensorsDtype::BF16 => 2,
SafetensorsDtype::I64 => 8,
SafetensorsDtype::U8 | SafetensorsDtype::Bool => 1,
}
}
}
#[derive(Debug, Deserialize)]
struct TensorMetadata {
dtype: SafetensorsDtype,
shape: Vec<usize>,
data_offsets: [usize; 2],
}
#[derive(Debug, Clone, PartialEq)]
pub struct SafetensorsTensorInfo {
pub name: String,
pub dtype: SafetensorsDtype,
pub shape: Vec<usize>,
pub data_offsets: [usize; 2],
}
impl SafetensorsTensorInfo {
#[must_use]
pub fn byte_len(&self) -> usize {
self.data_offsets[1].saturating_sub(self.data_offsets[0])
}
pub fn validate_shape_matches_bytes(&self) -> Result<()> {
let dtype_size = self.dtype.size_in_bytes();
let elem_count = self
.shape
.iter()
.try_fold(1usize, |acc, &d| acc.checked_mul(d))
.ok_or_else(|| RealizarError::UnsupportedOperation {
operation: "validate_shape_matches_bytes".to_string(),
reason: format!(
"Tensor '{}' shape {:?} overflows usize",
self.name, self.shape
),
})?;
let expected = elem_count.checked_mul(dtype_size).ok_or_else(|| {
RealizarError::UnsupportedOperation {
operation: "validate_shape_matches_bytes".to_string(),
reason: format!(
"Tensor '{}' byte size (shape {:?} * {dtype_size}) overflows usize",
self.name, self.shape
),
}
})?;
let actual = self.byte_len();
if actual != expected {
return Err(RealizarError::UnsupportedOperation {
operation: "validate_shape_matches_bytes".to_string(),
reason: format!(
"Tensor '{}' byte length {actual} contradicts declared shape {:?} \
({:?}, {dtype_size} bytes/elem => expected {expected} bytes)",
self.name, self.shape, self.dtype
),
});
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SafetensorsModel {
pub tensors: HashMap<String, SafetensorsTensorInfo>,
pub data: Vec<u8>,
}
include!("safetensors_parser.rs");
include!("mapped_model.rs");
include!("shard.rs");