use crate::error::{ModelError, ModelResult};
use crate::gguf::{GgufFile, GgufQuantType, GgufTensorInfo};
use std::collections::HashMap;
use std::io::{Read, Seek, SeekFrom};
use std::path::Path;
pub trait WeightSource: Send + Sync {
fn tensor_names(&self) -> Vec<String>;
fn load_tensor(&mut self, name: &str) -> ModelResult<Vec<f32>>;
fn contains(&self, name: &str) -> bool;
fn total_bytes_estimate(&self) -> u64;
}
#[derive(Debug, Clone)]
struct GgufTensorMeta {
data_offset: u64,
quant_type: GgufQuantType,
n_elements: usize,
byte_len: usize,
}
impl GgufTensorMeta {
fn from_info(info: &GgufTensorInfo) -> ModelResult<Self> {
let n_elements = info.n_elements() as usize;
let byte_len = compute_gguf_byte_len(&info.quant_type, n_elements, &info.name)?;
Ok(Self {
data_offset: info.data_offset,
quant_type: info.quant_type,
n_elements,
byte_len,
})
}
}
pub struct GgufFileSource {
file: std::fs::File,
tensor_infos: HashMap<String, GgufTensorMeta>,
file_size: u64,
}
impl GgufFileSource {
pub fn open(path: &Path) -> ModelResult<Self> {
let gguf = GgufFile::open(path)?;
let file_size = std::fs::metadata(path)
.map_err(|e| {
ModelError::simple_load_error(format!("Failed to stat GGUF file {:?}: {}", path, e))
})?
.len();
let mut tensor_infos = HashMap::with_capacity(gguf.tensors.len());
for info in &gguf.tensors {
let meta = GgufTensorMeta::from_info(info)?;
tensor_infos.insert(info.name.clone(), meta);
}
let file = std::fs::File::open(path).map_err(|e| {
ModelError::simple_load_error(format!("Failed to open GGUF file {:?}: {}", path, e))
})?;
Ok(Self {
file,
tensor_infos,
file_size,
})
}
}
impl WeightSource for GgufFileSource {
fn tensor_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.tensor_infos.keys().cloned().collect();
names.sort();
names
}
fn load_tensor(&mut self, name: &str) -> ModelResult<Vec<f32>> {
let meta = self.tensor_infos.get(name).ok_or_else(|| {
ModelError::simple_load_error(format!("GgufFileSource: tensor '{}' not found", name))
})?;
let data_offset = meta.data_offset;
let quant_type = meta.quant_type;
let n_elements = meta.n_elements;
let byte_len = meta.byte_len;
self.file.seek(SeekFrom::Start(data_offset)).map_err(|e| {
ModelError::simple_load_error(format!(
"GgufFileSource: seek to tensor '{}' at offset {} failed: {}",
name, data_offset, e
))
})?;
let mut raw = vec![0u8; byte_len];
self.file.read_exact(&mut raw).map_err(|e| {
ModelError::simple_load_error(format!(
"GgufFileSource: read {} bytes for tensor '{}' failed: {}",
byte_len, name, e
))
})?;
dequantize_gguf(&raw, &quant_type, n_elements, name)
}
fn contains(&self, name: &str) -> bool {
self.tensor_infos.contains_key(name)
}
fn total_bytes_estimate(&self) -> u64 {
self.file_size
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum SafeTensorDtype {
F32,
F16,
Bf16,
F64,
}
impl SafeTensorDtype {
fn from_str(s: &str) -> ModelResult<Self> {
match s {
"F32" => Ok(Self::F32),
"F16" => Ok(Self::F16),
"BF16" => Ok(Self::Bf16),
"F64" => Ok(Self::F64),
other => Err(ModelError::simple_load_error(format!(
"SafeTensorsSource: unsupported dtype '{}'",
other
))),
}
}
fn bytes_per_element(&self) -> usize {
match self {
Self::F32 => 4,
Self::F16 | Self::Bf16 => 2,
Self::F64 => 8,
}
}
}
#[derive(Debug, Clone)]
struct SafeTensorInfo {
dtype: SafeTensorDtype,
shape: Vec<usize>,
data_offsets: (u64, u64),
}
pub struct SafeTensorsSource {
file: std::fs::File,
header: HashMap<String, SafeTensorInfo>,
data_start_offset: u64,
file_size: u64,
}
impl SafeTensorsSource {
pub fn open(path: &Path) -> ModelResult<Self> {
let mut file = std::fs::File::open(path).map_err(|e| {
ModelError::simple_load_error(format!(
"SafeTensorsSource: cannot open {:?}: {}",
path, e
))
})?;
let file_size = file
.seek(SeekFrom::End(0))
.map_err(|e| ModelError::simple_load_error(format!("seek to end failed: {}", e)))?;
file.seek(SeekFrom::Start(0))
.map_err(|e| ModelError::simple_load_error(format!("seek to start failed: {}", e)))?;
let mut size_buf = [0u8; 8];
file.read_exact(&mut size_buf).map_err(|e| {
ModelError::simple_load_error(format!(
"SafeTensorsSource: failed to read header size: {}",
e
))
})?;
let header_size = u64::from_le_bytes(size_buf);
let mut json_buf = vec![0u8; header_size as usize];
file.read_exact(&mut json_buf).map_err(|e| {
ModelError::simple_load_error(format!(
"SafeTensorsSource: failed to read {} bytes of JSON header: {}",
header_size, e
))
})?;
let data_start_offset = 8 + header_size;
let json_str = std::str::from_utf8(&json_buf).map_err(|e| {
ModelError::simple_load_error(format!(
"SafeTensorsSource: JSON header is not valid UTF-8: {}",
e
))
})?;
let root: serde_json::Value = serde_json::from_str(json_str).map_err(|e| {
ModelError::simple_load_error(format!(
"SafeTensorsSource: failed to parse JSON header: {}",
e
))
})?;
let obj = root.as_object().ok_or_else(|| {
ModelError::simple_load_error("SafeTensorsSource: JSON root is not an object")
})?;
let mut header = HashMap::with_capacity(obj.len());
for (key, val) in obj {
if key == "__metadata__" {
continue;
}
let dtype_str = val.get("dtype").and_then(|v| v.as_str()).ok_or_else(|| {
ModelError::simple_load_error(format!(
"SafeTensorsSource: tensor '{}' missing 'dtype'",
key
))
})?;
let dtype = SafeTensorDtype::from_str(dtype_str)?;
let shape_arr = val.get("shape").and_then(|v| v.as_array()).ok_or_else(|| {
ModelError::simple_load_error(format!(
"SafeTensorsSource: tensor '{}' missing 'shape'",
key
))
})?;
let shape = shape_arr
.iter()
.map(|v| {
v.as_u64().ok_or_else(|| {
ModelError::simple_load_error(format!(
"SafeTensorsSource: tensor '{}' shape element is not a u64",
key
))
})
})
.collect::<ModelResult<Vec<u64>>>()?
.into_iter()
.map(|d| d as usize)
.collect();
let offsets_arr = val
.get("data_offsets")
.and_then(|v| v.as_array())
.ok_or_else(|| {
ModelError::simple_load_error(format!(
"SafeTensorsSource: tensor '{}' missing 'data_offsets'",
key
))
})?;
if offsets_arr.len() != 2 {
return Err(ModelError::simple_load_error(format!(
"SafeTensorsSource: tensor '{}' data_offsets must have 2 elements, got {}",
key,
offsets_arr.len()
)));
}
let begin = offsets_arr[0].as_u64().ok_or_else(|| {
ModelError::simple_load_error(format!(
"SafeTensorsSource: tensor '{}' data_offsets[0] is not a u64",
key
))
})?;
let end = offsets_arr[1].as_u64().ok_or_else(|| {
ModelError::simple_load_error(format!(
"SafeTensorsSource: tensor '{}' data_offsets[1] is not a u64",
key
))
})?;
header.insert(
key.clone(),
SafeTensorInfo {
dtype,
shape,
data_offsets: (begin, end),
},
);
}
Ok(Self {
file,
header,
data_start_offset,
file_size,
})
}
}
impl WeightSource for SafeTensorsSource {
fn tensor_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.header.keys().cloned().collect();
names.sort();
names
}
fn load_tensor(&mut self, name: &str) -> ModelResult<Vec<f32>> {
let info = self.header.get(name).ok_or_else(|| {
ModelError::simple_load_error(format!("SafeTensorsSource: tensor '{}' not found", name))
})?;
let (begin, end) = info.data_offsets;
let byte_len = (end - begin) as usize;
let dtype = info.dtype.clone();
let n_elements: usize = if info.shape.is_empty() {
1
} else {
info.shape.iter().product()
};
let expected_bytes = n_elements * dtype.bytes_per_element();
if byte_len != expected_bytes {
return Err(ModelError::simple_load_error(format!(
"SafeTensorsSource: tensor '{}' byte range [{}, {}) has {} bytes, expected {} (shape={:?}, dtype={:?})",
name, begin, end, byte_len, expected_bytes, info.shape, dtype
)));
}
let abs_offset = self.data_start_offset + begin;
self.file.seek(SeekFrom::Start(abs_offset)).map_err(|e| {
ModelError::simple_load_error(format!(
"SafeTensorsSource: seek to tensor '{}' at {} failed: {}",
name, abs_offset, e
))
})?;
let mut raw = vec![0u8; byte_len];
self.file.read_exact(&mut raw).map_err(|e| {
ModelError::simple_load_error(format!(
"SafeTensorsSource: read {} bytes for tensor '{}' failed: {}",
byte_len, name, e
))
})?;
convert_safetensors_bytes_to_f32(&raw, &dtype, n_elements, name)
}
fn contains(&self, name: &str) -> bool {
self.header.contains_key(name)
}
fn total_bytes_estimate(&self) -> u64 {
self.file_size
}
}
const MISC_PREFIX: &str = "_misc.";
pub struct IncrementalModelLoader<S: WeightSource> {
source: S,
layer_prefixes: Vec<String>,
}
impl<S: WeightSource> IncrementalModelLoader<S> {
pub fn new(source: S) -> Self {
let names = source.tensor_names();
let mut prefixes: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
let mut has_misc = false;
for name in &names {
if let Some(prefix) = extract_layer_prefix(name) {
prefixes.insert(prefix);
} else {
has_misc = true;
}
}
let mut layer_prefixes: Vec<String> = prefixes.into_iter().collect();
if has_misc {
layer_prefixes.push(MISC_PREFIX.to_string());
}
Self {
source,
layer_prefixes,
}
}
pub fn load_layer(&mut self, prefix: &str) -> ModelResult<HashMap<String, Vec<f32>>> {
let names: Vec<String> = if prefix == MISC_PREFIX {
self.source
.tensor_names()
.into_iter()
.filter(|n| extract_layer_prefix(n).is_none())
.collect()
} else {
self.source
.tensor_names()
.into_iter()
.filter(|n| n.starts_with(prefix))
.collect()
};
let mut result = HashMap::with_capacity(names.len());
for name in names {
let tensor = self.source.load_tensor(&name)?;
result.insert(name, tensor);
}
Ok(result)
}
pub fn load_all_streaming<F>(&mut self, mut callback: F) -> ModelResult<()>
where
F: FnMut(&str, HashMap<String, Vec<f32>>) -> ModelResult<()>,
{
let prefixes = self.layer_prefixes.clone();
for prefix in &prefixes {
let tensors = self.load_layer(prefix)?;
callback(prefix, tensors)?;
}
Ok(())
}
pub fn layer_prefixes(&self) -> &[String] {
&self.layer_prefixes
}
pub fn source(&self) -> &S {
&self.source
}
pub fn into_source(self) -> S {
self.source
}
}
fn extract_layer_prefix(name: &str) -> Option<String> {
let rest = name.strip_prefix("layers.")?;
let dot_pos = rest.find('.')?;
let idx_str = &rest[..dot_pos];
if idx_str.is_empty() || !idx_str.chars().all(|c| c.is_ascii_digit()) {
return None;
}
Some(format!("layers.{}.", idx_str))
}
fn dequantize_gguf(
raw: &[u8],
quant_type: &GgufQuantType,
n_elements: usize,
tensor_name: &str,
) -> ModelResult<Vec<f32>> {
use crate::gguf::dequant;
dequant::dequantize(raw, quant_type, n_elements).map_err(|e| {
ModelError::simple_load_error(format!(
"GgufFileSource: dequantize failed for tensor '{}': {}",
tensor_name, e
))
})
}
fn convert_safetensors_bytes_to_f32(
raw: &[u8],
dtype: &SafeTensorDtype,
n_elements: usize,
tensor_name: &str,
) -> ModelResult<Vec<f32>> {
match dtype {
SafeTensorDtype::F32 => {
if raw.len() != n_elements * 4 {
return Err(ModelError::simple_load_error(format!(
"SafeTensorsSource: F32 tensor '{}' has {} bytes, expected {}",
tensor_name,
raw.len(),
n_elements * 4
)));
}
Ok(raw
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect())
}
SafeTensorDtype::F16 => {
if raw.len() != n_elements * 2 {
return Err(ModelError::simple_load_error(format!(
"SafeTensorsSource: F16 tensor '{}' has {} bytes, expected {}",
tensor_name,
raw.len(),
n_elements * 2
)));
}
Ok(raw
.chunks_exact(2)
.map(|b| {
let bits = u16::from_le_bytes([b[0], b[1]]);
half::f16::from_bits(bits).to_f32()
})
.collect())
}
SafeTensorDtype::Bf16 => {
if raw.len() != n_elements * 2 {
return Err(ModelError::simple_load_error(format!(
"SafeTensorsSource: BF16 tensor '{}' has {} bytes, expected {}",
tensor_name,
raw.len(),
n_elements * 2
)));
}
Ok(raw
.chunks_exact(2)
.map(|b| {
let bits = u16::from_le_bytes([b[0], b[1]]);
half::bf16::from_bits(bits).to_f32()
})
.collect())
}
SafeTensorDtype::F64 => {
if raw.len() != n_elements * 8 {
return Err(ModelError::simple_load_error(format!(
"SafeTensorsSource: F64 tensor '{}' has {} bytes, expected {}",
tensor_name,
raw.len(),
n_elements * 8
)));
}
Ok(raw
.chunks_exact(8)
.map(|b| {
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
})
.collect())
}
}
}
fn compute_gguf_byte_len(
quant_type: &GgufQuantType,
n_elements: usize,
tensor_name: &str,
) -> ModelResult<usize> {
let block_check = |block_elems: usize, block_bytes: usize| -> ModelResult<usize> {
if n_elements == 0 || !n_elements.is_multiple_of(block_elems) {
return Err(ModelError::simple_load_error(format!(
"GgufFileSource: tensor '{}' has {} elements, not a multiple of {}",
tensor_name, n_elements, block_elems
)));
}
Ok((n_elements / block_elems) * block_bytes)
};
match quant_type {
GgufQuantType::F32 => Ok(n_elements * 4),
GgufQuantType::F16 | GgufQuantType::BF16 => Ok(n_elements * 2),
GgufQuantType::Q4_0 => block_check(32, 18),
GgufQuantType::Q4_1 => block_check(32, 20),
GgufQuantType::Q5_0 => block_check(32, 22),
GgufQuantType::Q5_1 => block_check(32, 24),
GgufQuantType::Q8_0 => block_check(32, 34),
GgufQuantType::Q8_1 => block_check(32, 36),
GgufQuantType::Q2K => block_check(256, 84),
GgufQuantType::Q3K => block_check(256, 110),
GgufQuantType::Q4K => block_check(256, 144),
GgufQuantType::Q5K => block_check(256, 176),
GgufQuantType::Q6K => block_check(256, 210),
GgufQuantType::Q8K => block_check(256, 292),
qt => Err(ModelError::simple_load_error(format!(
"GgufFileSource: cannot compute byte size for unsupported quant type {:?} (tensor '{}')",
qt, tensor_name
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_synthetic_safetensors(tensors: &[(&str, Vec<f32>)]) -> Vec<u8> {
let mut data_bytes: Vec<u8> = Vec::new();
let mut tensor_metas: Vec<(&str, usize, usize, usize)> = Vec::new();
for (name, vals) in tensors {
let begin = data_bytes.len();
for v in vals.iter() {
data_bytes.extend_from_slice(&v.to_le_bytes());
}
let end = data_bytes.len();
tensor_metas.push((name, begin, end, vals.len()));
}
let mut header_map = serde_json::Map::new();
for (name, begin, end, n) in &tensor_metas {
let entry = serde_json::json!({
"dtype": "F32",
"shape": [n],
"data_offsets": [begin, end]
});
header_map.insert((*name).to_string(), entry);
}
let header_json = serde_json::Value::Object(header_map).to_string();
let header_bytes = header_json.as_bytes();
let header_len = header_bytes.len() as u64;
let mut out: Vec<u8> = Vec::new();
out.extend_from_slice(&header_len.to_le_bytes());
out.extend_from_slice(header_bytes);
out.extend_from_slice(&data_bytes);
out
}
fn make_synthetic_gguf_f32(tensor_name: &str, values: &[f32]) -> Vec<u8> {
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(b"GGUF");
buf.extend_from_slice(&2u32.to_le_bytes());
buf.extend_from_slice(&1u64.to_le_bytes());
buf.extend_from_slice(&0u64.to_le_bytes());
let name_bytes = tensor_name.as_bytes();
buf.extend_from_slice(&(name_bytes.len() as u64).to_le_bytes());
buf.extend_from_slice(name_bytes);
buf.extend_from_slice(&1u32.to_le_bytes());
buf.extend_from_slice(&(values.len() as u64).to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes());
buf.extend_from_slice(&0u64.to_le_bytes());
let current_len = buf.len();
let aligned = (current_len + 31) & !31;
let pad = aligned - current_len;
buf.extend(std::iter::repeat_n(0u8, pad));
for v in values {
buf.extend_from_slice(&v.to_le_bytes());
}
buf
}
#[test]
fn test_safetensors_source_single_tensor() {
let tensors = &[("weight", vec![1.0f32, 2.0, 3.0, 4.0])];
let data = make_synthetic_safetensors(tensors);
let path = std::env::temp_dir().join("kizzasi_test_safetensors_single.safetensors");
std::fs::write(&path, &data).expect("write test file");
let mut src = SafeTensorsSource::open(&path).expect("open SafeTensorsSource");
assert!(src.contains("weight"), "tensor 'weight' should be present");
let loaded = src.load_tensor("weight").expect("load_tensor weight");
assert_eq!(loaded, vec![1.0f32, 2.0, 3.0, 4.0]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_weight_source_contains() {
let tensors = &[("alpha", vec![0.5f32, 1.5]), ("beta", vec![2.0f32, 3.0])];
let data = make_synthetic_safetensors(tensors);
let path = std::env::temp_dir().join("kizzasi_test_safetensors_contains.safetensors");
std::fs::write(&path, &data).expect("write test file");
let src = SafeTensorsSource::open(&path).expect("open");
assert!(src.contains("alpha"));
assert!(src.contains("beta"));
assert!(
!src.contains("gamma"),
"should not contain non-existent tensor"
);
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_incremental_loader_layer_prefixes() {
let tensors = &[
("layers.0.weight", vec![1.0f32, 2.0]),
("layers.0.bias", vec![0.1f32]),
("layers.1.weight", vec![3.0f32, 4.0]),
("embed", vec![0.5f32]),
];
let data = make_synthetic_safetensors(tensors);
let path = std::env::temp_dir().join("kizzasi_test_safetensors_layer_prefixes.safetensors");
std::fs::write(&path, &data).expect("write test file");
let src = SafeTensorsSource::open(&path).expect("open");
let loader = IncrementalModelLoader::new(src);
let prefixes = loader.layer_prefixes();
assert!(
prefixes.contains(&"layers.0.".to_string()),
"expected 'layers.0.' in prefixes, got {:?}",
prefixes
);
assert!(
prefixes.contains(&"layers.1.".to_string()),
"expected 'layers.1.' in prefixes, got {:?}",
prefixes
);
assert!(
prefixes.contains(&MISC_PREFIX.to_string()),
"expected '{}' in prefixes for 'embed', got {:?}",
MISC_PREFIX,
prefixes
);
assert_eq!(
prefixes.last().map(String::as_str),
Some(MISC_PREFIX),
"_misc. prefix should be last"
);
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_incremental_loader_streaming_callback() {
let tensors = &[
("layers.0.weight", vec![1.0f32]),
("layers.0.bias", vec![0.0f32]),
("layers.1.weight", vec![2.0f32]),
("lm_head", vec![3.0f32]),
];
let data = make_synthetic_safetensors(tensors);
let path = std::env::temp_dir().join("kizzasi_test_safetensors_streaming.safetensors");
std::fs::write(&path, &data).expect("write test file");
let src = SafeTensorsSource::open(&path).expect("open");
let mut loader = IncrementalModelLoader::new(src);
let mut invocation_count = 0usize;
loader
.load_all_streaming(|_prefix, _tensors| {
invocation_count += 1;
Ok(())
})
.expect("streaming failed");
assert_eq!(
invocation_count, 3,
"expected 3 callbacks (layers.0., layers.1., _misc.), got {}",
invocation_count
);
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_gguf_file_source_lazy_load() {
let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let data = make_synthetic_gguf_f32("test_tensor", &values);
let path = std::env::temp_dir().join("kizzasi_test_gguf_source.gguf");
std::fs::write(&path, &data).expect("write test gguf file");
let mut src = GgufFileSource::open(&path).expect("open GgufFileSource");
assert!(src.contains("test_tensor"), "tensor should be present");
let loaded = src.load_tensor("test_tensor").expect("load_tensor");
assert_eq!(loaded.len(), values.len(), "element count mismatch");
for (i, (&got, &expected)) in loaded.iter().zip(values.iter()).enumerate() {
assert!(
(got - expected).abs() < 1e-5,
"element {}: expected {}, got {}",
i,
expected,
got
);
}
assert!(
!src.contains("nonexistent"),
"nonexistent tensor should not be present"
);
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_safetensors_source_multiple_tensors_values() {
let tensors = &[("a", vec![10.0f32, 20.0, 30.0]), ("b", vec![-1.0f32, -2.0])];
let data = make_synthetic_safetensors(tensors);
let path = std::env::temp_dir().join("kizzasi_test_safetensors_multi.safetensors");
std::fs::write(&path, &data).expect("write test file");
let mut src = SafeTensorsSource::open(&path).expect("open");
let a = src.load_tensor("a").expect("load a");
assert_eq!(a, vec![10.0f32, 20.0, 30.0]);
let b = src.load_tensor("b").expect("load b");
assert_eq!(b, vec![-1.0f32, -2.0]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_extract_layer_prefix_valid() {
assert_eq!(
extract_layer_prefix("layers.0.weight"),
Some("layers.0.".to_string())
);
assert_eq!(
extract_layer_prefix("layers.123.bias"),
Some("layers.123.".to_string())
);
}
#[test]
fn test_extract_layer_prefix_invalid() {
assert_eq!(extract_layer_prefix("embed"), None);
assert_eq!(extract_layer_prefix("lm_head.weight"), None);
assert_eq!(extract_layer_prefix("layers_bad.0.weight"), None);
assert_eq!(extract_layer_prefix("layers.abc.weight"), None);
}
#[test]
fn test_weight_source_total_bytes_estimate() {
let tensors = &[("x", vec![1.0f32, 2.0])];
let data = make_synthetic_safetensors(tensors);
let expected_size = data.len() as u64;
let path = std::env::temp_dir().join("kizzasi_test_safetensors_bytes_estimate.safetensors");
std::fs::write(&path, &data).expect("write");
let src = SafeTensorsSource::open(&path).expect("open");
assert_eq!(src.total_bytes_estimate(), expected_size);
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_safetensors_source_missing_tensor_error() {
let tensors = &[("existing", vec![1.0f32])];
let data = make_synthetic_safetensors(tensors);
let path = std::env::temp_dir().join("kizzasi_test_safetensors_missing.safetensors");
std::fs::write(&path, &data).expect("write");
let mut src = SafeTensorsSource::open(&path).expect("open");
assert!(src.load_tensor("nonexistent").is_err());
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_incremental_loader_load_layer() {
let tensors = &[
("layers.0.weight", vec![5.0f32, 6.0]),
("layers.0.bias", vec![0.5f32]),
("layers.1.weight", vec![7.0f32]),
];
let data = make_synthetic_safetensors(tensors);
let path = std::env::temp_dir().join("kizzasi_test_safetensors_load_layer.safetensors");
std::fs::write(&path, &data).expect("write");
let src = SafeTensorsSource::open(&path).expect("open");
let mut loader = IncrementalModelLoader::new(src);
let layer0 = loader
.load_layer("layers.0.")
.expect("load_layer layers.0.");
assert!(layer0.contains_key("layers.0.weight"));
assert!(layer0.contains_key("layers.0.bias"));
assert!(!layer0.contains_key("layers.1.weight"));
let _ = std::fs::remove_file(&path);
}
}