1use crate::error::{ModelError, ModelResult};
29use crate::gguf::{GgufFile, GgufQuantType, GgufTensorInfo};
30use std::collections::HashMap;
31use std::io::{Read, Seek, SeekFrom};
32use std::path::Path;
33
34pub trait WeightSource: Send + Sync {
50 fn tensor_names(&self) -> Vec<String>;
52
53 fn load_tensor(&mut self, name: &str) -> ModelResult<Vec<f32>>;
59
60 fn contains(&self, name: &str) -> bool;
62
63 fn total_bytes_estimate(&self) -> u64;
66}
67
68#[derive(Debug, Clone)]
74struct GgufTensorMeta {
75 data_offset: u64,
77 quant_type: GgufQuantType,
79 n_elements: usize,
81 byte_len: usize,
83}
84
85impl GgufTensorMeta {
86 fn from_info(info: &GgufTensorInfo) -> ModelResult<Self> {
88 let n_elements = info.n_elements() as usize;
89 let byte_len = compute_gguf_byte_len(&info.quant_type, n_elements, &info.name)?;
90 Ok(Self {
91 data_offset: info.data_offset,
92 quant_type: info.quant_type,
93 n_elements,
94 byte_len,
95 })
96 }
97}
98
99pub struct GgufFileSource {
109 file: std::fs::File,
111 tensor_infos: HashMap<String, GgufTensorMeta>,
113 file_size: u64,
115}
116
117impl GgufFileSource {
118 pub fn open(path: &Path) -> ModelResult<Self> {
122 let gguf = GgufFile::open(path)?;
125
126 let file_size = std::fs::metadata(path)
127 .map_err(|e| {
128 ModelError::simple_load_error(format!("Failed to stat GGUF file {:?}: {}", path, e))
129 })?
130 .len();
131
132 let mut tensor_infos = HashMap::with_capacity(gguf.tensors.len());
134 for info in &gguf.tensors {
135 let meta = GgufTensorMeta::from_info(info)?;
136 tensor_infos.insert(info.name.clone(), meta);
137 }
138
139 let file = std::fs::File::open(path).map_err(|e| {
141 ModelError::simple_load_error(format!("Failed to open GGUF file {:?}: {}", path, e))
142 })?;
143
144 Ok(Self {
145 file,
146 tensor_infos,
147 file_size,
148 })
149 }
150}
151
152impl WeightSource for GgufFileSource {
153 fn tensor_names(&self) -> Vec<String> {
154 let mut names: Vec<String> = self.tensor_infos.keys().cloned().collect();
155 names.sort();
156 names
157 }
158
159 fn load_tensor(&mut self, name: &str) -> ModelResult<Vec<f32>> {
160 let meta = self.tensor_infos.get(name).ok_or_else(|| {
161 ModelError::simple_load_error(format!("GgufFileSource: tensor '{}' not found", name))
162 })?;
163
164 let data_offset = meta.data_offset;
166 let quant_type = meta.quant_type;
167 let n_elements = meta.n_elements;
168 let byte_len = meta.byte_len;
169
170 self.file.seek(SeekFrom::Start(data_offset)).map_err(|e| {
171 ModelError::simple_load_error(format!(
172 "GgufFileSource: seek to tensor '{}' at offset {} failed: {}",
173 name, data_offset, e
174 ))
175 })?;
176
177 let mut raw = vec![0u8; byte_len];
178 self.file.read_exact(&mut raw).map_err(|e| {
179 ModelError::simple_load_error(format!(
180 "GgufFileSource: read {} bytes for tensor '{}' failed: {}",
181 byte_len, name, e
182 ))
183 })?;
184
185 dequantize_gguf(&raw, &quant_type, n_elements, name)
186 }
187
188 fn contains(&self, name: &str) -> bool {
189 self.tensor_infos.contains_key(name)
190 }
191
192 fn total_bytes_estimate(&self) -> u64 {
193 self.file_size
194 }
195}
196
197#[derive(Debug, Clone, PartialEq, Eq)]
203enum SafeTensorDtype {
204 F32,
205 F16,
206 Bf16,
207 F64,
208}
209
210impl SafeTensorDtype {
211 fn from_str(s: &str) -> ModelResult<Self> {
213 match s {
214 "F32" => Ok(Self::F32),
215 "F16" => Ok(Self::F16),
216 "BF16" => Ok(Self::Bf16),
217 "F64" => Ok(Self::F64),
218 other => Err(ModelError::simple_load_error(format!(
219 "SafeTensorsSource: unsupported dtype '{}'",
220 other
221 ))),
222 }
223 }
224
225 fn bytes_per_element(&self) -> usize {
227 match self {
228 Self::F32 => 4,
229 Self::F16 | Self::Bf16 => 2,
230 Self::F64 => 8,
231 }
232 }
233}
234
235#[derive(Debug, Clone)]
237struct SafeTensorInfo {
238 dtype: SafeTensorDtype,
240 shape: Vec<usize>,
242 data_offsets: (u64, u64),
244}
245
246pub struct SafeTensorsSource {
252 file: std::fs::File,
254 header: HashMap<String, SafeTensorInfo>,
256 data_start_offset: u64,
259 file_size: u64,
261}
262
263impl SafeTensorsSource {
264 pub fn open(path: &Path) -> ModelResult<Self> {
273 let mut file = std::fs::File::open(path).map_err(|e| {
274 ModelError::simple_load_error(format!(
275 "SafeTensorsSource: cannot open {:?}: {}",
276 path, e
277 ))
278 })?;
279
280 let file_size = file
281 .seek(SeekFrom::End(0))
282 .map_err(|e| ModelError::simple_load_error(format!("seek to end failed: {}", e)))?;
283
284 file.seek(SeekFrom::Start(0))
286 .map_err(|e| ModelError::simple_load_error(format!("seek to start failed: {}", e)))?;
287
288 let mut size_buf = [0u8; 8];
290 file.read_exact(&mut size_buf).map_err(|e| {
291 ModelError::simple_load_error(format!(
292 "SafeTensorsSource: failed to read header size: {}",
293 e
294 ))
295 })?;
296 let header_size = u64::from_le_bytes(size_buf);
297
298 let mut json_buf = vec![0u8; header_size as usize];
300 file.read_exact(&mut json_buf).map_err(|e| {
301 ModelError::simple_load_error(format!(
302 "SafeTensorsSource: failed to read {} bytes of JSON header: {}",
303 header_size, e
304 ))
305 })?;
306
307 let data_start_offset = 8 + header_size;
308
309 let json_str = std::str::from_utf8(&json_buf).map_err(|e| {
311 ModelError::simple_load_error(format!(
312 "SafeTensorsSource: JSON header is not valid UTF-8: {}",
313 e
314 ))
315 })?;
316
317 let root: serde_json::Value = serde_json::from_str(json_str).map_err(|e| {
318 ModelError::simple_load_error(format!(
319 "SafeTensorsSource: failed to parse JSON header: {}",
320 e
321 ))
322 })?;
323
324 let obj = root.as_object().ok_or_else(|| {
325 ModelError::simple_load_error("SafeTensorsSource: JSON root is not an object")
326 })?;
327
328 let mut header = HashMap::with_capacity(obj.len());
329 for (key, val) in obj {
330 if key == "__metadata__" {
332 continue;
333 }
334
335 let dtype_str = val.get("dtype").and_then(|v| v.as_str()).ok_or_else(|| {
336 ModelError::simple_load_error(format!(
337 "SafeTensorsSource: tensor '{}' missing 'dtype'",
338 key
339 ))
340 })?;
341
342 let dtype = SafeTensorDtype::from_str(dtype_str)?;
343
344 let shape_arr = val.get("shape").and_then(|v| v.as_array()).ok_or_else(|| {
345 ModelError::simple_load_error(format!(
346 "SafeTensorsSource: tensor '{}' missing 'shape'",
347 key
348 ))
349 })?;
350
351 let shape = shape_arr
352 .iter()
353 .map(|v| {
354 v.as_u64().ok_or_else(|| {
355 ModelError::simple_load_error(format!(
356 "SafeTensorsSource: tensor '{}' shape element is not a u64",
357 key
358 ))
359 })
360 })
361 .collect::<ModelResult<Vec<u64>>>()?
362 .into_iter()
363 .map(|d| d as usize)
364 .collect();
365
366 let offsets_arr = val
367 .get("data_offsets")
368 .and_then(|v| v.as_array())
369 .ok_or_else(|| {
370 ModelError::simple_load_error(format!(
371 "SafeTensorsSource: tensor '{}' missing 'data_offsets'",
372 key
373 ))
374 })?;
375
376 if offsets_arr.len() != 2 {
377 return Err(ModelError::simple_load_error(format!(
378 "SafeTensorsSource: tensor '{}' data_offsets must have 2 elements, got {}",
379 key,
380 offsets_arr.len()
381 )));
382 }
383
384 let begin = offsets_arr[0].as_u64().ok_or_else(|| {
385 ModelError::simple_load_error(format!(
386 "SafeTensorsSource: tensor '{}' data_offsets[0] is not a u64",
387 key
388 ))
389 })?;
390
391 let end = offsets_arr[1].as_u64().ok_or_else(|| {
392 ModelError::simple_load_error(format!(
393 "SafeTensorsSource: tensor '{}' data_offsets[1] is not a u64",
394 key
395 ))
396 })?;
397
398 header.insert(
399 key.clone(),
400 SafeTensorInfo {
401 dtype,
402 shape,
403 data_offsets: (begin, end),
404 },
405 );
406 }
407
408 Ok(Self {
409 file,
410 header,
411 data_start_offset,
412 file_size,
413 })
414 }
415}
416
417impl WeightSource for SafeTensorsSource {
418 fn tensor_names(&self) -> Vec<String> {
419 let mut names: Vec<String> = self.header.keys().cloned().collect();
420 names.sort();
421 names
422 }
423
424 fn load_tensor(&mut self, name: &str) -> ModelResult<Vec<f32>> {
425 let info = self.header.get(name).ok_or_else(|| {
426 ModelError::simple_load_error(format!("SafeTensorsSource: tensor '{}' not found", name))
427 })?;
428
429 let (begin, end) = info.data_offsets;
430 let byte_len = (end - begin) as usize;
431 let dtype = info.dtype.clone();
432 let n_elements: usize = if info.shape.is_empty() {
433 1
434 } else {
435 info.shape.iter().product()
436 };
437
438 let expected_bytes = n_elements * dtype.bytes_per_element();
440 if byte_len != expected_bytes {
441 return Err(ModelError::simple_load_error(format!(
442 "SafeTensorsSource: tensor '{}' byte range [{}, {}) has {} bytes, expected {} (shape={:?}, dtype={:?})",
443 name, begin, end, byte_len, expected_bytes, info.shape, dtype
444 )));
445 }
446
447 let abs_offset = self.data_start_offset + begin;
448 self.file.seek(SeekFrom::Start(abs_offset)).map_err(|e| {
449 ModelError::simple_load_error(format!(
450 "SafeTensorsSource: seek to tensor '{}' at {} failed: {}",
451 name, abs_offset, e
452 ))
453 })?;
454
455 let mut raw = vec![0u8; byte_len];
456 self.file.read_exact(&mut raw).map_err(|e| {
457 ModelError::simple_load_error(format!(
458 "SafeTensorsSource: read {} bytes for tensor '{}' failed: {}",
459 byte_len, name, e
460 ))
461 })?;
462
463 convert_safetensors_bytes_to_f32(&raw, &dtype, n_elements, name)
464 }
465
466 fn contains(&self, name: &str) -> bool {
467 self.header.contains_key(name)
468 }
469
470 fn total_bytes_estimate(&self) -> u64 {
471 self.file_size
472 }
473}
474
475const MISC_PREFIX: &str = "_misc.";
484
485pub struct IncrementalModelLoader<S: WeightSource> {
493 source: S,
494 layer_prefixes: Vec<String>,
496}
497
498impl<S: WeightSource> IncrementalModelLoader<S> {
499 pub fn new(source: S) -> Self {
503 let names = source.tensor_names();
504 let mut prefixes: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
505 let mut has_misc = false;
506
507 for name in &names {
508 if let Some(prefix) = extract_layer_prefix(name) {
509 prefixes.insert(prefix);
510 } else {
511 has_misc = true;
512 }
513 }
514
515 let mut layer_prefixes: Vec<String> = prefixes.into_iter().collect();
516 if has_misc {
517 layer_prefixes.push(MISC_PREFIX.to_string());
518 }
519
520 Self {
521 source,
522 layer_prefixes,
523 }
524 }
525
526 pub fn load_layer(&mut self, prefix: &str) -> ModelResult<HashMap<String, Vec<f32>>> {
533 let names: Vec<String> = if prefix == MISC_PREFIX {
534 self.source
536 .tensor_names()
537 .into_iter()
538 .filter(|n| extract_layer_prefix(n).is_none())
539 .collect()
540 } else {
541 self.source
542 .tensor_names()
543 .into_iter()
544 .filter(|n| n.starts_with(prefix))
545 .collect()
546 };
547
548 let mut result = HashMap::with_capacity(names.len());
549 for name in names {
550 let tensor = self.source.load_tensor(&name)?;
551 result.insert(name, tensor);
552 }
553 Ok(result)
554 }
555
556 pub fn load_all_streaming<F>(&mut self, mut callback: F) -> ModelResult<()>
565 where
566 F: FnMut(&str, HashMap<String, Vec<f32>>) -> ModelResult<()>,
567 {
568 let prefixes = self.layer_prefixes.clone();
569 for prefix in &prefixes {
570 let tensors = self.load_layer(prefix)?;
571 callback(prefix, tensors)?;
572 }
573 Ok(())
574 }
575
576 pub fn layer_prefixes(&self) -> &[String] {
581 &self.layer_prefixes
582 }
583
584 pub fn source(&self) -> &S {
586 &self.source
587 }
588
589 pub fn into_source(self) -> S {
591 self.source
592 }
593}
594
595fn extract_layer_prefix(name: &str) -> Option<String> {
603 let rest = name.strip_prefix("layers.")?;
605
606 let dot_pos = rest.find('.')?;
608 let idx_str = &rest[..dot_pos];
609
610 if idx_str.is_empty() || !idx_str.chars().all(|c| c.is_ascii_digit()) {
612 return None;
613 }
614
615 Some(format!("layers.{}.", idx_str))
616}
617
618fn dequantize_gguf(
620 raw: &[u8],
621 quant_type: &GgufQuantType,
622 n_elements: usize,
623 tensor_name: &str,
624) -> ModelResult<Vec<f32>> {
625 use crate::gguf::dequant;
626 dequant::dequantize(raw, quant_type, n_elements).map_err(|e| {
627 ModelError::simple_load_error(format!(
628 "GgufFileSource: dequantize failed for tensor '{}': {}",
629 tensor_name, e
630 ))
631 })
632}
633
634fn convert_safetensors_bytes_to_f32(
636 raw: &[u8],
637 dtype: &SafeTensorDtype,
638 n_elements: usize,
639 tensor_name: &str,
640) -> ModelResult<Vec<f32>> {
641 match dtype {
642 SafeTensorDtype::F32 => {
643 if raw.len() != n_elements * 4 {
644 return Err(ModelError::simple_load_error(format!(
645 "SafeTensorsSource: F32 tensor '{}' has {} bytes, expected {}",
646 tensor_name,
647 raw.len(),
648 n_elements * 4
649 )));
650 }
651 Ok(raw
652 .chunks_exact(4)
653 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
654 .collect())
655 }
656 SafeTensorDtype::F16 => {
657 if raw.len() != n_elements * 2 {
658 return Err(ModelError::simple_load_error(format!(
659 "SafeTensorsSource: F16 tensor '{}' has {} bytes, expected {}",
660 tensor_name,
661 raw.len(),
662 n_elements * 2
663 )));
664 }
665 Ok(raw
666 .chunks_exact(2)
667 .map(|b| {
668 let bits = u16::from_le_bytes([b[0], b[1]]);
669 half::f16::from_bits(bits).to_f32()
670 })
671 .collect())
672 }
673 SafeTensorDtype::Bf16 => {
674 if raw.len() != n_elements * 2 {
675 return Err(ModelError::simple_load_error(format!(
676 "SafeTensorsSource: BF16 tensor '{}' has {} bytes, expected {}",
677 tensor_name,
678 raw.len(),
679 n_elements * 2
680 )));
681 }
682 Ok(raw
683 .chunks_exact(2)
684 .map(|b| {
685 let bits = u16::from_le_bytes([b[0], b[1]]);
686 half::bf16::from_bits(bits).to_f32()
687 })
688 .collect())
689 }
690 SafeTensorDtype::F64 => {
691 if raw.len() != n_elements * 8 {
692 return Err(ModelError::simple_load_error(format!(
693 "SafeTensorsSource: F64 tensor '{}' has {} bytes, expected {}",
694 tensor_name,
695 raw.len(),
696 n_elements * 8
697 )));
698 }
699 Ok(raw
700 .chunks_exact(8)
701 .map(|b| {
702 f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
703 })
704 .collect())
705 }
706 }
707}
708
709fn compute_gguf_byte_len(
711 quant_type: &GgufQuantType,
712 n_elements: usize,
713 tensor_name: &str,
714) -> ModelResult<usize> {
715 let block_check = |block_elems: usize, block_bytes: usize| -> ModelResult<usize> {
717 if n_elements == 0 || !n_elements.is_multiple_of(block_elems) {
718 return Err(ModelError::simple_load_error(format!(
719 "GgufFileSource: tensor '{}' has {} elements, not a multiple of {}",
720 tensor_name, n_elements, block_elems
721 )));
722 }
723 Ok((n_elements / block_elems) * block_bytes)
724 };
725
726 match quant_type {
727 GgufQuantType::F32 => Ok(n_elements * 4),
728 GgufQuantType::F16 | GgufQuantType::BF16 => Ok(n_elements * 2),
729 GgufQuantType::Q4_0 => block_check(32, 18),
730 GgufQuantType::Q4_1 => block_check(32, 20),
731 GgufQuantType::Q5_0 => block_check(32, 22),
732 GgufQuantType::Q5_1 => block_check(32, 24),
733 GgufQuantType::Q8_0 => block_check(32, 34),
734 GgufQuantType::Q8_1 => block_check(32, 36),
735 GgufQuantType::Q2K => block_check(256, 84),
736 GgufQuantType::Q3K => block_check(256, 110),
737 GgufQuantType::Q4K => block_check(256, 144),
738 GgufQuantType::Q5K => block_check(256, 176),
739 GgufQuantType::Q6K => block_check(256, 210),
740 GgufQuantType::Q8K => block_check(256, 292),
741 qt => Err(ModelError::simple_load_error(format!(
742 "GgufFileSource: cannot compute byte size for unsupported quant type {:?} (tensor '{}')",
743 qt, tensor_name
744 ))),
745 }
746}
747
748#[cfg(test)]
753mod tests {
754 use super::*;
755
756 fn make_synthetic_safetensors(tensors: &[(&str, Vec<f32>)]) -> Vec<u8> {
766 let mut data_bytes: Vec<u8> = Vec::new();
768 let mut tensor_metas: Vec<(&str, usize, usize, usize)> = Vec::new(); for (name, vals) in tensors {
771 let begin = data_bytes.len();
772 for v in vals.iter() {
773 data_bytes.extend_from_slice(&v.to_le_bytes());
774 }
775 let end = data_bytes.len();
776 tensor_metas.push((name, begin, end, vals.len()));
777 }
778
779 let mut header_map = serde_json::Map::new();
781 for (name, begin, end, n) in &tensor_metas {
782 let entry = serde_json::json!({
783 "dtype": "F32",
784 "shape": [n],
785 "data_offsets": [begin, end]
786 });
787 header_map.insert((*name).to_string(), entry);
788 }
789 let header_json = serde_json::Value::Object(header_map).to_string();
790 let header_bytes = header_json.as_bytes();
791 let header_len = header_bytes.len() as u64;
792
793 let mut out: Vec<u8> = Vec::new();
795 out.extend_from_slice(&header_len.to_le_bytes());
796 out.extend_from_slice(header_bytes);
797 out.extend_from_slice(&data_bytes);
798 out
799 }
800
801 fn make_synthetic_gguf_f32(tensor_name: &str, values: &[f32]) -> Vec<u8> {
814 let mut buf: Vec<u8> = Vec::new();
815
816 buf.extend_from_slice(b"GGUF");
818 buf.extend_from_slice(&2u32.to_le_bytes());
820 buf.extend_from_slice(&1u64.to_le_bytes());
822 buf.extend_from_slice(&0u64.to_le_bytes());
824
825 let name_bytes = tensor_name.as_bytes();
829 buf.extend_from_slice(&(name_bytes.len() as u64).to_le_bytes());
831 buf.extend_from_slice(name_bytes);
833 buf.extend_from_slice(&1u32.to_le_bytes());
835 buf.extend_from_slice(&(values.len() as u64).to_le_bytes());
837 buf.extend_from_slice(&0u32.to_le_bytes());
839 buf.extend_from_slice(&0u64.to_le_bytes());
841
842 let current_len = buf.len();
844 let aligned = (current_len + 31) & !31;
845 let pad = aligned - current_len;
846 buf.extend(std::iter::repeat_n(0u8, pad));
847
848 for v in values {
850 buf.extend_from_slice(&v.to_le_bytes());
851 }
852
853 buf
854 }
855
856 #[test]
859 fn test_safetensors_source_single_tensor() {
860 let tensors = &[("weight", vec![1.0f32, 2.0, 3.0, 4.0])];
861 let data = make_synthetic_safetensors(tensors);
862 let path = std::env::temp_dir().join("kizzasi_test_safetensors_single.safetensors");
863 std::fs::write(&path, &data).expect("write test file");
864
865 let mut src = SafeTensorsSource::open(&path).expect("open SafeTensorsSource");
866 assert!(src.contains("weight"), "tensor 'weight' should be present");
867 let loaded = src.load_tensor("weight").expect("load_tensor weight");
868 assert_eq!(loaded, vec![1.0f32, 2.0, 3.0, 4.0]);
869
870 let _ = std::fs::remove_file(&path);
871 }
872
873 #[test]
874 fn test_weight_source_contains() {
875 let tensors = &[("alpha", vec![0.5f32, 1.5]), ("beta", vec![2.0f32, 3.0])];
876 let data = make_synthetic_safetensors(tensors);
877 let path = std::env::temp_dir().join("kizzasi_test_safetensors_contains.safetensors");
878 std::fs::write(&path, &data).expect("write test file");
879
880 let src = SafeTensorsSource::open(&path).expect("open");
881 assert!(src.contains("alpha"));
882 assert!(src.contains("beta"));
883 assert!(
884 !src.contains("gamma"),
885 "should not contain non-existent tensor"
886 );
887
888 let _ = std::fs::remove_file(&path);
889 }
890
891 #[test]
892 fn test_incremental_loader_layer_prefixes() {
893 let tensors = &[
894 ("layers.0.weight", vec![1.0f32, 2.0]),
895 ("layers.0.bias", vec![0.1f32]),
896 ("layers.1.weight", vec![3.0f32, 4.0]),
897 ("embed", vec![0.5f32]),
898 ];
899 let data = make_synthetic_safetensors(tensors);
900 let path = std::env::temp_dir().join("kizzasi_test_safetensors_layer_prefixes.safetensors");
901 std::fs::write(&path, &data).expect("write test file");
902
903 let src = SafeTensorsSource::open(&path).expect("open");
904 let loader = IncrementalModelLoader::new(src);
905
906 let prefixes = loader.layer_prefixes();
907 assert!(
908 prefixes.contains(&"layers.0.".to_string()),
909 "expected 'layers.0.' in prefixes, got {:?}",
910 prefixes
911 );
912 assert!(
913 prefixes.contains(&"layers.1.".to_string()),
914 "expected 'layers.1.' in prefixes, got {:?}",
915 prefixes
916 );
917 assert!(
918 prefixes.contains(&MISC_PREFIX.to_string()),
919 "expected '{}' in prefixes for 'embed', got {:?}",
920 MISC_PREFIX,
921 prefixes
922 );
923 assert_eq!(
925 prefixes.last().map(String::as_str),
926 Some(MISC_PREFIX),
927 "_misc. prefix should be last"
928 );
929
930 let _ = std::fs::remove_file(&path);
931 }
932
933 #[test]
934 fn test_incremental_loader_streaming_callback() {
935 let tensors = &[
936 ("layers.0.weight", vec![1.0f32]),
937 ("layers.0.bias", vec![0.0f32]),
938 ("layers.1.weight", vec![2.0f32]),
939 ("lm_head", vec![3.0f32]),
940 ];
941 let data = make_synthetic_safetensors(tensors);
942 let path = std::env::temp_dir().join("kizzasi_test_safetensors_streaming.safetensors");
943 std::fs::write(&path, &data).expect("write test file");
944
945 let src = SafeTensorsSource::open(&path).expect("open");
946 let mut loader = IncrementalModelLoader::new(src);
947
948 let mut invocation_count = 0usize;
949 loader
950 .load_all_streaming(|_prefix, _tensors| {
951 invocation_count += 1;
952 Ok(())
953 })
954 .expect("streaming failed");
955
956 assert_eq!(
958 invocation_count, 3,
959 "expected 3 callbacks (layers.0., layers.1., _misc.), got {}",
960 invocation_count
961 );
962
963 let _ = std::fs::remove_file(&path);
964 }
965
966 #[test]
967 fn test_gguf_file_source_lazy_load() {
968 let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
969 let data = make_synthetic_gguf_f32("test_tensor", &values);
970 let path = std::env::temp_dir().join("kizzasi_test_gguf_source.gguf");
971 std::fs::write(&path, &data).expect("write test gguf file");
972
973 let mut src = GgufFileSource::open(&path).expect("open GgufFileSource");
974 assert!(src.contains("test_tensor"), "tensor should be present");
975
976 let loaded = src.load_tensor("test_tensor").expect("load_tensor");
977 assert_eq!(loaded.len(), values.len(), "element count mismatch");
978 for (i, (&got, &expected)) in loaded.iter().zip(values.iter()).enumerate() {
979 assert!(
980 (got - expected).abs() < 1e-5,
981 "element {}: expected {}, got {}",
982 i,
983 expected,
984 got
985 );
986 }
987 assert!(
988 !src.contains("nonexistent"),
989 "nonexistent tensor should not be present"
990 );
991
992 let _ = std::fs::remove_file(&path);
993 }
994
995 #[test]
996 fn test_safetensors_source_multiple_tensors_values() {
997 let tensors = &[("a", vec![10.0f32, 20.0, 30.0]), ("b", vec![-1.0f32, -2.0])];
998 let data = make_synthetic_safetensors(tensors);
999 let path = std::env::temp_dir().join("kizzasi_test_safetensors_multi.safetensors");
1000 std::fs::write(&path, &data).expect("write test file");
1001
1002 let mut src = SafeTensorsSource::open(&path).expect("open");
1003
1004 let a = src.load_tensor("a").expect("load a");
1005 assert_eq!(a, vec![10.0f32, 20.0, 30.0]);
1006
1007 let b = src.load_tensor("b").expect("load b");
1008 assert_eq!(b, vec![-1.0f32, -2.0]);
1009
1010 let _ = std::fs::remove_file(&path);
1011 }
1012
1013 #[test]
1014 fn test_extract_layer_prefix_valid() {
1015 assert_eq!(
1016 extract_layer_prefix("layers.0.weight"),
1017 Some("layers.0.".to_string())
1018 );
1019 assert_eq!(
1020 extract_layer_prefix("layers.123.bias"),
1021 Some("layers.123.".to_string())
1022 );
1023 }
1024
1025 #[test]
1026 fn test_extract_layer_prefix_invalid() {
1027 assert_eq!(extract_layer_prefix("embed"), None);
1028 assert_eq!(extract_layer_prefix("lm_head.weight"), None);
1029 assert_eq!(extract_layer_prefix("layers_bad.0.weight"), None);
1030 assert_eq!(extract_layer_prefix("layers.abc.weight"), None);
1031 }
1032
1033 #[test]
1034 fn test_weight_source_total_bytes_estimate() {
1035 let tensors = &[("x", vec![1.0f32, 2.0])];
1036 let data = make_synthetic_safetensors(tensors);
1037 let expected_size = data.len() as u64;
1038 let path = std::env::temp_dir().join("kizzasi_test_safetensors_bytes_estimate.safetensors");
1039 std::fs::write(&path, &data).expect("write");
1040
1041 let src = SafeTensorsSource::open(&path).expect("open");
1042 assert_eq!(src.total_bytes_estimate(), expected_size);
1043
1044 let _ = std::fs::remove_file(&path);
1045 }
1046
1047 #[test]
1048 fn test_safetensors_source_missing_tensor_error() {
1049 let tensors = &[("existing", vec![1.0f32])];
1050 let data = make_synthetic_safetensors(tensors);
1051 let path = std::env::temp_dir().join("kizzasi_test_safetensors_missing.safetensors");
1052 std::fs::write(&path, &data).expect("write");
1053
1054 let mut src = SafeTensorsSource::open(&path).expect("open");
1055 assert!(src.load_tensor("nonexistent").is_err());
1056
1057 let _ = std::fs::remove_file(&path);
1058 }
1059
1060 #[test]
1061 fn test_incremental_loader_load_layer() {
1062 let tensors = &[
1063 ("layers.0.weight", vec![5.0f32, 6.0]),
1064 ("layers.0.bias", vec![0.5f32]),
1065 ("layers.1.weight", vec![7.0f32]),
1066 ];
1067 let data = make_synthetic_safetensors(tensors);
1068 let path = std::env::temp_dir().join("kizzasi_test_safetensors_load_layer.safetensors");
1069 std::fs::write(&path, &data).expect("write");
1070
1071 let src = SafeTensorsSource::open(&path).expect("open");
1072 let mut loader = IncrementalModelLoader::new(src);
1073
1074 let layer0 = loader
1075 .load_layer("layers.0.")
1076 .expect("load_layer layers.0.");
1077 assert!(layer0.contains_key("layers.0.weight"));
1078 assert!(layer0.contains_key("layers.0.bias"));
1079 assert!(!layer0.contains_key("layers.1.weight"));
1080
1081 let _ = std::fs::remove_file(&path);
1082 }
1083}