Skip to main content

kizzasi_model/
incremental_loader.rs

1//! Incremental / streaming weight loading for large model (7B+) support.
2//!
3//! This module provides the [`WeightSource`] trait and two concrete implementations:
4//!
5//! - [`GgufFileSource`]: Streams weights from a GGUF file via seek-based lazy loading,
6//!   dequantizing quantized types to f32 on the fly.
7//! - [`SafeTensorsSource`]: Streams weights from a `.safetensors` file, converting
8//!   BF16/F16/F32 data to f32 as each tensor is requested.
9//!
10//! The [`IncrementalModelLoader`] wraps any `WeightSource` and provides layer-by-layer
11//! streaming via [`IncrementalModelLoader::load_all_streaming`], which is the primary
12//! API for loading 7B+ models without holding all weights in RAM simultaneously.
13//!
14//! # Layer Prefix Extraction
15//!
16//! Tensor names following the `"layers.N.<rest>"` convention are grouped into
17//! per-layer buckets identified by the prefix `"layers.N."`. Tensors that do not
18//! match this pattern (e.g. `"embed"`, `"lm_head"`) are grouped under a synthetic
19//! `"_misc."` prefix, which is always presented last in streaming order.
20//!
21//! # COOLJAPAN Policy Compliance
22//!
23//! - Pure Rust — no C or Fortran dependencies.
24//! - No `unwrap()` anywhere; all error paths use `?` or `ok_or_else`.
25//! - `serde_json` (workspace dep) for SafeTensors header parsing.
26//! - `half` (workspace dep) for BF16/F16 → F32 conversion.
27
28use crate::error::{ModelError, ModelResult};
29use crate::gguf::{GgufFile, GgufQuantType, GgufTensorInfo};
30use std::collections::HashMap;
31use std::io::{Read, Seek, SeekFrom};
32use std::path::Path;
33
34// ─────────────────────────────────────────────────────────────────────────────
35// WeightSource trait
36// ─────────────────────────────────────────────────────────────────────────────
37
38/// A streaming/incremental source of model weight tensors.
39///
40/// Implementations must be `Send + Sync` so that they can be passed across
41/// thread boundaries when used with multi-threaded inference runtimes.
42///
43/// The primary operations are:
44/// - [`tensor_names`](WeightSource::tensor_names): enumerate all available tensor keys.
45/// - [`load_tensor`](WeightSource::load_tensor): load and dequantize one tensor to `Vec<f32>`.
46/// - [`contains`](WeightSource::contains): membership check without loading.
47/// - [`total_bytes_estimate`](WeightSource::total_bytes_estimate): rough file-size hint for
48///   progress reporting.
49pub trait WeightSource: Send + Sync {
50    /// Return the names of all tensors available in this source.
51    fn tensor_names(&self) -> Vec<String>;
52
53    /// Load and dequantize the tensor identified by `name`, returning a flat `Vec<f32>`.
54    ///
55    /// # Errors
56    /// Returns [`ModelError`] if the tensor is not found, the file cannot be read,
57    /// or dequantization fails.
58    fn load_tensor(&mut self, name: &str) -> ModelResult<Vec<f32>>;
59
60    /// Return `true` if the source contains a tensor with the given `name`.
61    fn contains(&self, name: &str) -> bool;
62
63    /// Return a rough estimate of the total number of bytes occupied by all
64    /// tensor data in the underlying file (used for progress reporting).
65    fn total_bytes_estimate(&self) -> u64;
66}
67
68// ─────────────────────────────────────────────────────────────────────────────
69// GgufTensorMeta — lightweight per-tensor metadata stored by GgufFileSource
70// ─────────────────────────────────────────────────────────────────────────────
71
72/// Compact metadata record stored by [`GgufFileSource`] for each tensor.
73#[derive(Debug, Clone)]
74struct GgufTensorMeta {
75    /// Absolute byte offset within the file where this tensor's data begins.
76    data_offset: u64,
77    /// Quantization type (determines dequantization path and byte size).
78    quant_type: GgufQuantType,
79    /// Total number of scalar elements in the tensor.
80    n_elements: usize,
81    /// Raw byte length of the tensor's on-disk data.
82    byte_len: usize,
83}
84
85impl GgufTensorMeta {
86    /// Build from a parsed [`GgufTensorInfo`] descriptor.
87    fn from_info(info: &GgufTensorInfo) -> ModelResult<Self> {
88        let n_elements = info.n_elements() as usize;
89        let byte_len = compute_gguf_byte_len(&info.quant_type, n_elements, &info.name)?;
90        Ok(Self {
91            data_offset: info.data_offset,
92            quant_type: info.quant_type,
93            n_elements,
94            byte_len,
95        })
96    }
97}
98
99// ─────────────────────────────────────────────────────────────────────────────
100// GgufFileSource
101// ─────────────────────────────────────────────────────────────────────────────
102
103/// [`WeightSource`] backed by a GGUF file, using seek-based lazy tensor loading.
104///
105/// Each call to [`load_tensor`](WeightSource::load_tensor) seeks to the tensor's
106/// data region, reads only the required bytes, and dequantizes them. This means
107/// only one tensor's worth of raw data is in memory at a time.
108pub struct GgufFileSource {
109    /// Open file handle used for all seeks and reads.
110    file: std::fs::File,
111    /// Tensor metadata indexed by tensor name.
112    tensor_infos: HashMap<String, GgufTensorMeta>,
113    /// Total size of the underlying file in bytes.
114    file_size: u64,
115}
116
117impl GgufFileSource {
118    /// Open a GGUF file and parse its header + tensor index.
119    ///
120    /// The file is kept open after construction for subsequent lazy reads.
121    pub fn open(path: &Path) -> ModelResult<Self> {
122        // Parse header using GgufFile (reads entire file into memory temporarily
123        // to parse the header; after this call the header data is dropped).
124        let gguf = GgufFile::open(path)?;
125
126        let file_size = std::fs::metadata(path)
127            .map_err(|e| {
128                ModelError::simple_load_error(format!("Failed to stat GGUF file {:?}: {}", path, e))
129            })?
130            .len();
131
132        // Build metadata index
133        let mut tensor_infos = HashMap::with_capacity(gguf.tensors.len());
134        for info in &gguf.tensors {
135            let meta = GgufTensorMeta::from_info(info)?;
136            tensor_infos.insert(info.name.clone(), meta);
137        }
138
139        // Open file handle for subsequent reads
140        let file = std::fs::File::open(path).map_err(|e| {
141            ModelError::simple_load_error(format!("Failed to open GGUF file {:?}: {}", path, e))
142        })?;
143
144        Ok(Self {
145            file,
146            tensor_infos,
147            file_size,
148        })
149    }
150}
151
152impl WeightSource for GgufFileSource {
153    fn tensor_names(&self) -> Vec<String> {
154        let mut names: Vec<String> = self.tensor_infos.keys().cloned().collect();
155        names.sort();
156        names
157    }
158
159    fn load_tensor(&mut self, name: &str) -> ModelResult<Vec<f32>> {
160        let meta = self.tensor_infos.get(name).ok_or_else(|| {
161            ModelError::simple_load_error(format!("GgufFileSource: tensor '{}' not found", name))
162        })?;
163
164        // Copy fields to avoid borrow conflicts with self.file
165        let data_offset = meta.data_offset;
166        let quant_type = meta.quant_type;
167        let n_elements = meta.n_elements;
168        let byte_len = meta.byte_len;
169
170        self.file.seek(SeekFrom::Start(data_offset)).map_err(|e| {
171            ModelError::simple_load_error(format!(
172                "GgufFileSource: seek to tensor '{}' at offset {} failed: {}",
173                name, data_offset, e
174            ))
175        })?;
176
177        let mut raw = vec![0u8; byte_len];
178        self.file.read_exact(&mut raw).map_err(|e| {
179            ModelError::simple_load_error(format!(
180                "GgufFileSource: read {} bytes for tensor '{}' failed: {}",
181                byte_len, name, e
182            ))
183        })?;
184
185        dequantize_gguf(&raw, &quant_type, n_elements, name)
186    }
187
188    fn contains(&self, name: &str) -> bool {
189        self.tensor_infos.contains_key(name)
190    }
191
192    fn total_bytes_estimate(&self) -> u64 {
193        self.file_size
194    }
195}
196
197// ─────────────────────────────────────────────────────────────────────────────
198// SafeTensorsSource
199// ─────────────────────────────────────────────────────────────────────────────
200
201/// Data type tag parsed from a SafeTensors header.
202#[derive(Debug, Clone, PartialEq, Eq)]
203enum SafeTensorDtype {
204    F32,
205    F16,
206    Bf16,
207    F64,
208}
209
210impl SafeTensorDtype {
211    /// Parse dtype string as it appears in the SafeTensors JSON header.
212    fn from_str(s: &str) -> ModelResult<Self> {
213        match s {
214            "F32" => Ok(Self::F32),
215            "F16" => Ok(Self::F16),
216            "BF16" => Ok(Self::Bf16),
217            "F64" => Ok(Self::F64),
218            other => Err(ModelError::simple_load_error(format!(
219                "SafeTensorsSource: unsupported dtype '{}'",
220                other
221            ))),
222        }
223    }
224
225    /// Number of bytes per scalar element.
226    fn bytes_per_element(&self) -> usize {
227        match self {
228            Self::F32 => 4,
229            Self::F16 | Self::Bf16 => 2,
230            Self::F64 => 8,
231        }
232    }
233}
234
235/// Per-tensor metadata as parsed from the SafeTensors JSON header.
236#[derive(Debug, Clone)]
237struct SafeTensorInfo {
238    /// Parsed dtype.
239    dtype: SafeTensorDtype,
240    /// Shape dimensions (outermost first, standard row-major order).
241    shape: Vec<usize>,
242    /// `[begin, end)` byte offsets relative to the start of the data region.
243    data_offsets: (u64, u64),
244}
245
246/// [`WeightSource`] backed by a `.safetensors` file.
247///
248/// The JSON header is parsed once at construction time; subsequent calls to
249/// [`load_tensor`](WeightSource::load_tensor) seek directly to each tensor's
250/// data region and convert it to f32.
251pub struct SafeTensorsSource {
252    /// Open file handle.
253    file: std::fs::File,
254    /// Per-tensor metadata.
255    header: HashMap<String, SafeTensorInfo>,
256    /// Absolute file offset where the raw data region begins (immediately after
257    /// the JSON header, i.e. at byte `8 + header_size`).
258    data_start_offset: u64,
259    /// Total file size in bytes.
260    file_size: u64,
261}
262
263impl SafeTensorsSource {
264    /// Open a `.safetensors` file and parse its JSON header.
265    ///
266    /// The binary layout is:
267    /// ```text
268    /// [0..8]           header_size   — u64 LE: length of the JSON string
269    /// [8..8+header_size] JSON header  — tensor metadata
270    /// [8+header_size..]  raw data     — tensor bytes, BF16/F16/F32
271    /// ```
272    pub fn open(path: &Path) -> ModelResult<Self> {
273        let mut file = std::fs::File::open(path).map_err(|e| {
274            ModelError::simple_load_error(format!(
275                "SafeTensorsSource: cannot open {:?}: {}",
276                path, e
277            ))
278        })?;
279
280        let file_size = file
281            .seek(SeekFrom::End(0))
282            .map_err(|e| ModelError::simple_load_error(format!("seek to end failed: {}", e)))?;
283
284        // Rewind to start
285        file.seek(SeekFrom::Start(0))
286            .map_err(|e| ModelError::simple_load_error(format!("seek to start failed: {}", e)))?;
287
288        // Read 8-byte header size
289        let mut size_buf = [0u8; 8];
290        file.read_exact(&mut size_buf).map_err(|e| {
291            ModelError::simple_load_error(format!(
292                "SafeTensorsSource: failed to read header size: {}",
293                e
294            ))
295        })?;
296        let header_size = u64::from_le_bytes(size_buf);
297
298        // Read JSON header
299        let mut json_buf = vec![0u8; header_size as usize];
300        file.read_exact(&mut json_buf).map_err(|e| {
301            ModelError::simple_load_error(format!(
302                "SafeTensorsSource: failed to read {} bytes of JSON header: {}",
303                header_size, e
304            ))
305        })?;
306
307        let data_start_offset = 8 + header_size;
308
309        // Parse JSON
310        let json_str = std::str::from_utf8(&json_buf).map_err(|e| {
311            ModelError::simple_load_error(format!(
312                "SafeTensorsSource: JSON header is not valid UTF-8: {}",
313                e
314            ))
315        })?;
316
317        let root: serde_json::Value = serde_json::from_str(json_str).map_err(|e| {
318            ModelError::simple_load_error(format!(
319                "SafeTensorsSource: failed to parse JSON header: {}",
320                e
321            ))
322        })?;
323
324        let obj = root.as_object().ok_or_else(|| {
325            ModelError::simple_load_error("SafeTensorsSource: JSON root is not an object")
326        })?;
327
328        let mut header = HashMap::with_capacity(obj.len());
329        for (key, val) in obj {
330            // Skip the special `__metadata__` key
331            if key == "__metadata__" {
332                continue;
333            }
334
335            let dtype_str = val.get("dtype").and_then(|v| v.as_str()).ok_or_else(|| {
336                ModelError::simple_load_error(format!(
337                    "SafeTensorsSource: tensor '{}' missing 'dtype'",
338                    key
339                ))
340            })?;
341
342            let dtype = SafeTensorDtype::from_str(dtype_str)?;
343
344            let shape_arr = val.get("shape").and_then(|v| v.as_array()).ok_or_else(|| {
345                ModelError::simple_load_error(format!(
346                    "SafeTensorsSource: tensor '{}' missing 'shape'",
347                    key
348                ))
349            })?;
350
351            let shape = shape_arr
352                .iter()
353                .map(|v| {
354                    v.as_u64().ok_or_else(|| {
355                        ModelError::simple_load_error(format!(
356                            "SafeTensorsSource: tensor '{}' shape element is not a u64",
357                            key
358                        ))
359                    })
360                })
361                .collect::<ModelResult<Vec<u64>>>()?
362                .into_iter()
363                .map(|d| d as usize)
364                .collect();
365
366            let offsets_arr = val
367                .get("data_offsets")
368                .and_then(|v| v.as_array())
369                .ok_or_else(|| {
370                    ModelError::simple_load_error(format!(
371                        "SafeTensorsSource: tensor '{}' missing 'data_offsets'",
372                        key
373                    ))
374                })?;
375
376            if offsets_arr.len() != 2 {
377                return Err(ModelError::simple_load_error(format!(
378                    "SafeTensorsSource: tensor '{}' data_offsets must have 2 elements, got {}",
379                    key,
380                    offsets_arr.len()
381                )));
382            }
383
384            let begin = offsets_arr[0].as_u64().ok_or_else(|| {
385                ModelError::simple_load_error(format!(
386                    "SafeTensorsSource: tensor '{}' data_offsets[0] is not a u64",
387                    key
388                ))
389            })?;
390
391            let end = offsets_arr[1].as_u64().ok_or_else(|| {
392                ModelError::simple_load_error(format!(
393                    "SafeTensorsSource: tensor '{}' data_offsets[1] is not a u64",
394                    key
395                ))
396            })?;
397
398            header.insert(
399                key.clone(),
400                SafeTensorInfo {
401                    dtype,
402                    shape,
403                    data_offsets: (begin, end),
404                },
405            );
406        }
407
408        Ok(Self {
409            file,
410            header,
411            data_start_offset,
412            file_size,
413        })
414    }
415}
416
417impl WeightSource for SafeTensorsSource {
418    fn tensor_names(&self) -> Vec<String> {
419        let mut names: Vec<String> = self.header.keys().cloned().collect();
420        names.sort();
421        names
422    }
423
424    fn load_tensor(&mut self, name: &str) -> ModelResult<Vec<f32>> {
425        let info = self.header.get(name).ok_or_else(|| {
426            ModelError::simple_load_error(format!("SafeTensorsSource: tensor '{}' not found", name))
427        })?;
428
429        let (begin, end) = info.data_offsets;
430        let byte_len = (end - begin) as usize;
431        let dtype = info.dtype.clone();
432        let n_elements: usize = if info.shape.is_empty() {
433            1
434        } else {
435            info.shape.iter().product()
436        };
437
438        // Validate
439        let expected_bytes = n_elements * dtype.bytes_per_element();
440        if byte_len != expected_bytes {
441            return Err(ModelError::simple_load_error(format!(
442                "SafeTensorsSource: tensor '{}' byte range [{}, {}) has {} bytes, expected {} (shape={:?}, dtype={:?})",
443                name, begin, end, byte_len, expected_bytes, info.shape, dtype
444            )));
445        }
446
447        let abs_offset = self.data_start_offset + begin;
448        self.file.seek(SeekFrom::Start(abs_offset)).map_err(|e| {
449            ModelError::simple_load_error(format!(
450                "SafeTensorsSource: seek to tensor '{}' at {} failed: {}",
451                name, abs_offset, e
452            ))
453        })?;
454
455        let mut raw = vec![0u8; byte_len];
456        self.file.read_exact(&mut raw).map_err(|e| {
457            ModelError::simple_load_error(format!(
458                "SafeTensorsSource: read {} bytes for tensor '{}' failed: {}",
459                byte_len, name, e
460            ))
461        })?;
462
463        convert_safetensors_bytes_to_f32(&raw, &dtype, n_elements, name)
464    }
465
466    fn contains(&self, name: &str) -> bool {
467        self.header.contains_key(name)
468    }
469
470    fn total_bytes_estimate(&self) -> u64 {
471        self.file_size
472    }
473}
474
475// ─────────────────────────────────────────────────────────────────────────────
476// IncrementalModelLoader
477// ─────────────────────────────────────────────────────────────────────────────
478
479/// Layer prefix extracted from a tensor name.
480///
481/// Tensor names matching `"layers.<N>.<rest>"` get prefix `"layers.<N>."`.
482/// All other tensor names are grouped under the synthetic `"_misc."` prefix.
483const MISC_PREFIX: &str = "_misc.";
484
485/// Wraps a [`WeightSource`] and provides layer-by-layer streaming iteration.
486///
487/// On construction, all tensor names are scanned to extract unique `"layers.N."`
488/// prefixes. The [`load_all_streaming`](Self::load_all_streaming) method then
489/// iterates over these prefixes in sorted order, loading one layer at a time
490/// and invoking a user-supplied callback. Miscellaneous tensors (not matching
491/// the `layers.N.` pattern) are presented last, grouped under `"_misc."`.
492pub struct IncrementalModelLoader<S: WeightSource> {
493    source: S,
494    /// Sorted unique layer prefixes (e.g. `["layers.0.", "layers.1.", …, "_misc."]`).
495    layer_prefixes: Vec<String>,
496}
497
498impl<S: WeightSource> IncrementalModelLoader<S> {
499    /// Construct a new loader from a weight source.
500    ///
501    /// Tensor names are scanned once to build the list of unique layer prefixes.
502    pub fn new(source: S) -> Self {
503        let names = source.tensor_names();
504        let mut prefixes: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
505        let mut has_misc = false;
506
507        for name in &names {
508            if let Some(prefix) = extract_layer_prefix(name) {
509                prefixes.insert(prefix);
510            } else {
511                has_misc = true;
512            }
513        }
514
515        let mut layer_prefixes: Vec<String> = prefixes.into_iter().collect();
516        if has_misc {
517            layer_prefixes.push(MISC_PREFIX.to_string());
518        }
519
520        Self {
521            source,
522            layer_prefixes,
523        }
524    }
525
526    /// Load all tensors whose names start with `prefix`.
527    ///
528    /// The special prefix `"_misc."` loads all tensors that do **not** match
529    /// the `"layers.N."` pattern.
530    ///
531    /// Returns a `HashMap<tensor_name, Vec<f32>>` for the group.
532    pub fn load_layer(&mut self, prefix: &str) -> ModelResult<HashMap<String, Vec<f32>>> {
533        let names: Vec<String> = if prefix == MISC_PREFIX {
534            // Collect tensors that do not belong to any regular layer prefix
535            self.source
536                .tensor_names()
537                .into_iter()
538                .filter(|n| extract_layer_prefix(n).is_none())
539                .collect()
540        } else {
541            self.source
542                .tensor_names()
543                .into_iter()
544                .filter(|n| n.starts_with(prefix))
545                .collect()
546        };
547
548        let mut result = HashMap::with_capacity(names.len());
549        for name in names {
550            let tensor = self.source.load_tensor(&name)?;
551            result.insert(name, tensor);
552        }
553        Ok(result)
554    }
555
556    /// Stream through all layers in order, invoking `callback` once per layer prefix.
557    ///
558    /// The callback receives:
559    /// - `prefix`: the layer prefix string (e.g. `"layers.0."` or `"_misc."`)
560    /// - `tensors`: a `HashMap<tensor_name, Vec<f32>>` for that layer
561    ///
562    /// If `callback` returns an `Err`, iteration stops immediately and the error
563    /// is propagated.
564    pub fn load_all_streaming<F>(&mut self, mut callback: F) -> ModelResult<()>
565    where
566        F: FnMut(&str, HashMap<String, Vec<f32>>) -> ModelResult<()>,
567    {
568        let prefixes = self.layer_prefixes.clone();
569        for prefix in &prefixes {
570            let tensors = self.load_layer(prefix)?;
571            callback(prefix, tensors)?;
572        }
573        Ok(())
574    }
575
576    /// Return the list of unique layer prefixes discovered in the weight source.
577    ///
578    /// The list is sorted lexicographically; the `"_misc."` bucket (if present)
579    /// always appears last.
580    pub fn layer_prefixes(&self) -> &[String] {
581        &self.layer_prefixes
582    }
583
584    /// Return a shared reference to the underlying weight source.
585    pub fn source(&self) -> &S {
586        &self.source
587    }
588
589    /// Consume the loader and return ownership of the underlying weight source.
590    pub fn into_source(self) -> S {
591        self.source
592    }
593}
594
595// ─────────────────────────────────────────────────────────────────────────────
596// Private helpers
597// ─────────────────────────────────────────────────────────────────────────────
598
599/// Extract the `"layers.N."` prefix from a tensor name, or return `None`.
600///
601/// Matches names of the form `"layers.<decimal>.<anything>"`.
602fn extract_layer_prefix(name: &str) -> Option<String> {
603    // Fast path: must start with "layers."
604    let rest = name.strip_prefix("layers.")?;
605
606    // Find the next dot after the layer index digits
607    let dot_pos = rest.find('.')?;
608    let idx_str = &rest[..dot_pos];
609
610    // Validate that the index is all decimal digits
611    if idx_str.is_empty() || !idx_str.chars().all(|c| c.is_ascii_digit()) {
612        return None;
613    }
614
615    Some(format!("layers.{}.", idx_str))
616}
617
618/// Dequantize raw GGUF bytes to f32 using the appropriate scheme.
619fn dequantize_gguf(
620    raw: &[u8],
621    quant_type: &GgufQuantType,
622    n_elements: usize,
623    tensor_name: &str,
624) -> ModelResult<Vec<f32>> {
625    use crate::gguf::dequant;
626    dequant::dequantize(raw, quant_type, n_elements).map_err(|e| {
627        ModelError::simple_load_error(format!(
628            "GgufFileSource: dequantize failed for tensor '{}': {}",
629            tensor_name, e
630        ))
631    })
632}
633
634/// Convert raw SafeTensors bytes to a flat `Vec<f32>`.
635fn convert_safetensors_bytes_to_f32(
636    raw: &[u8],
637    dtype: &SafeTensorDtype,
638    n_elements: usize,
639    tensor_name: &str,
640) -> ModelResult<Vec<f32>> {
641    match dtype {
642        SafeTensorDtype::F32 => {
643            if raw.len() != n_elements * 4 {
644                return Err(ModelError::simple_load_error(format!(
645                    "SafeTensorsSource: F32 tensor '{}' has {} bytes, expected {}",
646                    tensor_name,
647                    raw.len(),
648                    n_elements * 4
649                )));
650            }
651            Ok(raw
652                .chunks_exact(4)
653                .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
654                .collect())
655        }
656        SafeTensorDtype::F16 => {
657            if raw.len() != n_elements * 2 {
658                return Err(ModelError::simple_load_error(format!(
659                    "SafeTensorsSource: F16 tensor '{}' has {} bytes, expected {}",
660                    tensor_name,
661                    raw.len(),
662                    n_elements * 2
663                )));
664            }
665            Ok(raw
666                .chunks_exact(2)
667                .map(|b| {
668                    let bits = u16::from_le_bytes([b[0], b[1]]);
669                    half::f16::from_bits(bits).to_f32()
670                })
671                .collect())
672        }
673        SafeTensorDtype::Bf16 => {
674            if raw.len() != n_elements * 2 {
675                return Err(ModelError::simple_load_error(format!(
676                    "SafeTensorsSource: BF16 tensor '{}' has {} bytes, expected {}",
677                    tensor_name,
678                    raw.len(),
679                    n_elements * 2
680                )));
681            }
682            Ok(raw
683                .chunks_exact(2)
684                .map(|b| {
685                    let bits = u16::from_le_bytes([b[0], b[1]]);
686                    half::bf16::from_bits(bits).to_f32()
687                })
688                .collect())
689        }
690        SafeTensorDtype::F64 => {
691            if raw.len() != n_elements * 8 {
692                return Err(ModelError::simple_load_error(format!(
693                    "SafeTensorsSource: F64 tensor '{}' has {} bytes, expected {}",
694                    tensor_name,
695                    raw.len(),
696                    n_elements * 8
697                )));
698            }
699            Ok(raw
700                .chunks_exact(8)
701                .map(|b| {
702                    f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
703                })
704                .collect())
705        }
706    }
707}
708
709/// Compute the number of raw bytes occupied by a GGUF tensor on disk.
710fn compute_gguf_byte_len(
711    quant_type: &GgufQuantType,
712    n_elements: usize,
713    tensor_name: &str,
714) -> ModelResult<usize> {
715    // Helper for block-aligned types
716    let block_check = |block_elems: usize, block_bytes: usize| -> ModelResult<usize> {
717        if n_elements == 0 || !n_elements.is_multiple_of(block_elems) {
718            return Err(ModelError::simple_load_error(format!(
719                "GgufFileSource: tensor '{}' has {} elements, not a multiple of {}",
720                tensor_name, n_elements, block_elems
721            )));
722        }
723        Ok((n_elements / block_elems) * block_bytes)
724    };
725
726    match quant_type {
727        GgufQuantType::F32 => Ok(n_elements * 4),
728        GgufQuantType::F16 | GgufQuantType::BF16 => Ok(n_elements * 2),
729        GgufQuantType::Q4_0 => block_check(32, 18),
730        GgufQuantType::Q4_1 => block_check(32, 20),
731        GgufQuantType::Q5_0 => block_check(32, 22),
732        GgufQuantType::Q5_1 => block_check(32, 24),
733        GgufQuantType::Q8_0 => block_check(32, 34),
734        GgufQuantType::Q8_1 => block_check(32, 36),
735        GgufQuantType::Q2K => block_check(256, 84),
736        GgufQuantType::Q3K => block_check(256, 110),
737        GgufQuantType::Q4K => block_check(256, 144),
738        GgufQuantType::Q5K => block_check(256, 176),
739        GgufQuantType::Q6K => block_check(256, 210),
740        GgufQuantType::Q8K => block_check(256, 292),
741        qt => Err(ModelError::simple_load_error(format!(
742            "GgufFileSource: cannot compute byte size for unsupported quant type {:?} (tensor '{}')",
743            qt, tensor_name
744        ))),
745    }
746}
747
748// ─────────────────────────────────────────────────────────────────────────────
749// Tests
750// ─────────────────────────────────────────────────────────────────────────────
751
752#[cfg(test)]
753mod tests {
754    use super::*;
755
756    // ── SafeTensors binary builder ────────────────────────────────────────────
757
758    /// Build a valid `.safetensors` binary from a slice of `(name, f32_data)` pairs.
759    ///
760    /// All tensors are stored as 1-D F32 arrays. The resulting bytes conform to
761    /// the SafeTensors format specification:
762    /// - 8 bytes: JSON header length (u64 LE)
763    /// - `header_len` bytes: JSON header
764    /// - remaining bytes: raw tensor data (F32 little-endian)
765    fn make_synthetic_safetensors(tensors: &[(&str, Vec<f32>)]) -> Vec<u8> {
766        // Build per-tensor data and accumulate byte offsets
767        let mut data_bytes: Vec<u8> = Vec::new();
768        let mut tensor_metas: Vec<(&str, usize, usize, usize)> = Vec::new(); // name, begin, end, len
769
770        for (name, vals) in tensors {
771            let begin = data_bytes.len();
772            for v in vals.iter() {
773                data_bytes.extend_from_slice(&v.to_le_bytes());
774            }
775            let end = data_bytes.len();
776            tensor_metas.push((name, begin, end, vals.len()));
777        }
778
779        // Build JSON header using serde_json
780        let mut header_map = serde_json::Map::new();
781        for (name, begin, end, n) in &tensor_metas {
782            let entry = serde_json::json!({
783                "dtype": "F32",
784                "shape": [n],
785                "data_offsets": [begin, end]
786            });
787            header_map.insert((*name).to_string(), entry);
788        }
789        let header_json = serde_json::Value::Object(header_map).to_string();
790        let header_bytes = header_json.as_bytes();
791        let header_len = header_bytes.len() as u64;
792
793        // Assemble file
794        let mut out: Vec<u8> = Vec::new();
795        out.extend_from_slice(&header_len.to_le_bytes());
796        out.extend_from_slice(header_bytes);
797        out.extend_from_slice(&data_bytes);
798        out
799    }
800
801    // ── Minimal GGUF binary builder ───────────────────────────────────────────
802
803    /// Build a minimal GGUF binary containing a single F32 tensor.
804    ///
805    /// The file has the minimal valid structure:
806    /// - 4 magic bytes `b"GGUF"`
807    /// - u32 version = 2
808    /// - u64 tensor_count = 1
809    /// - u64 kv_count = 0
810    /// - tensor info: name (u64 len + bytes), shape (u32 ndims=1, u64 dim), quant=0 (F32), offset=0
811    /// - 32-byte alignment padding
812    /// - tensor data: n * 4 bytes of f32 LE
813    fn make_synthetic_gguf_f32(tensor_name: &str, values: &[f32]) -> Vec<u8> {
814        let mut buf: Vec<u8> = Vec::new();
815
816        // Magic
817        buf.extend_from_slice(b"GGUF");
818        // Version = 2 (u32 LE)
819        buf.extend_from_slice(&2u32.to_le_bytes());
820        // tensor_count (u64)
821        buf.extend_from_slice(&1u64.to_le_bytes());
822        // kv_count (u64)
823        buf.extend_from_slice(&0u64.to_le_bytes());
824
825        // No KV metadata entries
826
827        // Tensor info
828        let name_bytes = tensor_name.as_bytes();
829        // name length (u64)
830        buf.extend_from_slice(&(name_bytes.len() as u64).to_le_bytes());
831        // name bytes
832        buf.extend_from_slice(name_bytes);
833        // n_dims (u32)
834        buf.extend_from_slice(&1u32.to_le_bytes());
835        // dim[0] (u64)
836        buf.extend_from_slice(&(values.len() as u64).to_le_bytes());
837        // quant type = 0 (F32, u32)
838        buf.extend_from_slice(&0u32.to_le_bytes());
839        // offset within data section = 0 (u64)
840        buf.extend_from_slice(&0u64.to_le_bytes());
841
842        // Pad header to 32-byte alignment
843        let current_len = buf.len();
844        let aligned = (current_len + 31) & !31;
845        let pad = aligned - current_len;
846        buf.extend(std::iter::repeat_n(0u8, pad));
847
848        // Data section: raw f32 LE
849        for v in values {
850            buf.extend_from_slice(&v.to_le_bytes());
851        }
852
853        buf
854    }
855
856    // ── Tests ─────────────────────────────────────────────────────────────────
857
858    #[test]
859    fn test_safetensors_source_single_tensor() {
860        let tensors = &[("weight", vec![1.0f32, 2.0, 3.0, 4.0])];
861        let data = make_synthetic_safetensors(tensors);
862        let path = std::env::temp_dir().join("kizzasi_test_safetensors_single.safetensors");
863        std::fs::write(&path, &data).expect("write test file");
864
865        let mut src = SafeTensorsSource::open(&path).expect("open SafeTensorsSource");
866        assert!(src.contains("weight"), "tensor 'weight' should be present");
867        let loaded = src.load_tensor("weight").expect("load_tensor weight");
868        assert_eq!(loaded, vec![1.0f32, 2.0, 3.0, 4.0]);
869
870        let _ = std::fs::remove_file(&path);
871    }
872
873    #[test]
874    fn test_weight_source_contains() {
875        let tensors = &[("alpha", vec![0.5f32, 1.5]), ("beta", vec![2.0f32, 3.0])];
876        let data = make_synthetic_safetensors(tensors);
877        let path = std::env::temp_dir().join("kizzasi_test_safetensors_contains.safetensors");
878        std::fs::write(&path, &data).expect("write test file");
879
880        let src = SafeTensorsSource::open(&path).expect("open");
881        assert!(src.contains("alpha"));
882        assert!(src.contains("beta"));
883        assert!(
884            !src.contains("gamma"),
885            "should not contain non-existent tensor"
886        );
887
888        let _ = std::fs::remove_file(&path);
889    }
890
891    #[test]
892    fn test_incremental_loader_layer_prefixes() {
893        let tensors = &[
894            ("layers.0.weight", vec![1.0f32, 2.0]),
895            ("layers.0.bias", vec![0.1f32]),
896            ("layers.1.weight", vec![3.0f32, 4.0]),
897            ("embed", vec![0.5f32]),
898        ];
899        let data = make_synthetic_safetensors(tensors);
900        let path = std::env::temp_dir().join("kizzasi_test_safetensors_layer_prefixes.safetensors");
901        std::fs::write(&path, &data).expect("write test file");
902
903        let src = SafeTensorsSource::open(&path).expect("open");
904        let loader = IncrementalModelLoader::new(src);
905
906        let prefixes = loader.layer_prefixes();
907        assert!(
908            prefixes.contains(&"layers.0.".to_string()),
909            "expected 'layers.0.' in prefixes, got {:?}",
910            prefixes
911        );
912        assert!(
913            prefixes.contains(&"layers.1.".to_string()),
914            "expected 'layers.1.' in prefixes, got {:?}",
915            prefixes
916        );
917        assert!(
918            prefixes.contains(&MISC_PREFIX.to_string()),
919            "expected '{}' in prefixes for 'embed', got {:?}",
920            MISC_PREFIX,
921            prefixes
922        );
923        // _misc. should be last
924        assert_eq!(
925            prefixes.last().map(String::as_str),
926            Some(MISC_PREFIX),
927            "_misc. prefix should be last"
928        );
929
930        let _ = std::fs::remove_file(&path);
931    }
932
933    #[test]
934    fn test_incremental_loader_streaming_callback() {
935        let tensors = &[
936            ("layers.0.weight", vec![1.0f32]),
937            ("layers.0.bias", vec![0.0f32]),
938            ("layers.1.weight", vec![2.0f32]),
939            ("lm_head", vec![3.0f32]),
940        ];
941        let data = make_synthetic_safetensors(tensors);
942        let path = std::env::temp_dir().join("kizzasi_test_safetensors_streaming.safetensors");
943        std::fs::write(&path, &data).expect("write test file");
944
945        let src = SafeTensorsSource::open(&path).expect("open");
946        let mut loader = IncrementalModelLoader::new(src);
947
948        let mut invocation_count = 0usize;
949        loader
950            .load_all_streaming(|_prefix, _tensors| {
951                invocation_count += 1;
952                Ok(())
953            })
954            .expect("streaming failed");
955
956        // Expect 3 callbacks: layers.0., layers.1., _misc.
957        assert_eq!(
958            invocation_count, 3,
959            "expected 3 callbacks (layers.0., layers.1., _misc.), got {}",
960            invocation_count
961        );
962
963        let _ = std::fs::remove_file(&path);
964    }
965
966    #[test]
967    fn test_gguf_file_source_lazy_load() {
968        let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
969        let data = make_synthetic_gguf_f32("test_tensor", &values);
970        let path = std::env::temp_dir().join("kizzasi_test_gguf_source.gguf");
971        std::fs::write(&path, &data).expect("write test gguf file");
972
973        let mut src = GgufFileSource::open(&path).expect("open GgufFileSource");
974        assert!(src.contains("test_tensor"), "tensor should be present");
975
976        let loaded = src.load_tensor("test_tensor").expect("load_tensor");
977        assert_eq!(loaded.len(), values.len(), "element count mismatch");
978        for (i, (&got, &expected)) in loaded.iter().zip(values.iter()).enumerate() {
979            assert!(
980                (got - expected).abs() < 1e-5,
981                "element {}: expected {}, got {}",
982                i,
983                expected,
984                got
985            );
986        }
987        assert!(
988            !src.contains("nonexistent"),
989            "nonexistent tensor should not be present"
990        );
991
992        let _ = std::fs::remove_file(&path);
993    }
994
995    #[test]
996    fn test_safetensors_source_multiple_tensors_values() {
997        let tensors = &[("a", vec![10.0f32, 20.0, 30.0]), ("b", vec![-1.0f32, -2.0])];
998        let data = make_synthetic_safetensors(tensors);
999        let path = std::env::temp_dir().join("kizzasi_test_safetensors_multi.safetensors");
1000        std::fs::write(&path, &data).expect("write test file");
1001
1002        let mut src = SafeTensorsSource::open(&path).expect("open");
1003
1004        let a = src.load_tensor("a").expect("load a");
1005        assert_eq!(a, vec![10.0f32, 20.0, 30.0]);
1006
1007        let b = src.load_tensor("b").expect("load b");
1008        assert_eq!(b, vec![-1.0f32, -2.0]);
1009
1010        let _ = std::fs::remove_file(&path);
1011    }
1012
1013    #[test]
1014    fn test_extract_layer_prefix_valid() {
1015        assert_eq!(
1016            extract_layer_prefix("layers.0.weight"),
1017            Some("layers.0.".to_string())
1018        );
1019        assert_eq!(
1020            extract_layer_prefix("layers.123.bias"),
1021            Some("layers.123.".to_string())
1022        );
1023    }
1024
1025    #[test]
1026    fn test_extract_layer_prefix_invalid() {
1027        assert_eq!(extract_layer_prefix("embed"), None);
1028        assert_eq!(extract_layer_prefix("lm_head.weight"), None);
1029        assert_eq!(extract_layer_prefix("layers_bad.0.weight"), None);
1030        assert_eq!(extract_layer_prefix("layers.abc.weight"), None);
1031    }
1032
1033    #[test]
1034    fn test_weight_source_total_bytes_estimate() {
1035        let tensors = &[("x", vec![1.0f32, 2.0])];
1036        let data = make_synthetic_safetensors(tensors);
1037        let expected_size = data.len() as u64;
1038        let path = std::env::temp_dir().join("kizzasi_test_safetensors_bytes_estimate.safetensors");
1039        std::fs::write(&path, &data).expect("write");
1040
1041        let src = SafeTensorsSource::open(&path).expect("open");
1042        assert_eq!(src.total_bytes_estimate(), expected_size);
1043
1044        let _ = std::fs::remove_file(&path);
1045    }
1046
1047    #[test]
1048    fn test_safetensors_source_missing_tensor_error() {
1049        let tensors = &[("existing", vec![1.0f32])];
1050        let data = make_synthetic_safetensors(tensors);
1051        let path = std::env::temp_dir().join("kizzasi_test_safetensors_missing.safetensors");
1052        std::fs::write(&path, &data).expect("write");
1053
1054        let mut src = SafeTensorsSource::open(&path).expect("open");
1055        assert!(src.load_tensor("nonexistent").is_err());
1056
1057        let _ = std::fs::remove_file(&path);
1058    }
1059
1060    #[test]
1061    fn test_incremental_loader_load_layer() {
1062        let tensors = &[
1063            ("layers.0.weight", vec![5.0f32, 6.0]),
1064            ("layers.0.bias", vec![0.5f32]),
1065            ("layers.1.weight", vec![7.0f32]),
1066        ];
1067        let data = make_synthetic_safetensors(tensors);
1068        let path = std::env::temp_dir().join("kizzasi_test_safetensors_load_layer.safetensors");
1069        std::fs::write(&path, &data).expect("write");
1070
1071        let src = SafeTensorsSource::open(&path).expect("open");
1072        let mut loader = IncrementalModelLoader::new(src);
1073
1074        let layer0 = loader
1075            .load_layer("layers.0.")
1076            .expect("load_layer layers.0.");
1077        assert!(layer0.contains_key("layers.0.weight"));
1078        assert!(layer0.contains_key("layers.0.bias"));
1079        assert!(!layer0.contains_key("layers.1.weight"));
1080
1081        let _ = std::fs::remove_file(&path);
1082    }
1083}