kizzasi_model/
loader.rs

1//! Weight loading from safetensors format
2//!
3//! This module provides functionality to load pre-trained model weights
4//! from the safetensors format, which is safer and faster than PyTorch
5//! pickle files.
6//!
7//! # Safetensors Format
8//!
9//! Safetensors is a simple format for storing tensors safely (as opposed to pickle)
10//! and that is still fast (zero-copy). It's used by Hugging Face and other ML frameworks.
11//!
12//! # Weight Naming Conventions
13//!
14//! Kizzasi models expect specific weight naming patterns. Each model architecture
15//! has its own convention documented in the respective model module.
16//!
17//! ## Mamba Weight Format
18//!
19//! Mamba models expect the following weight structure:
20//!
21//! ```text
22//! input_proj                      [input_dim, hidden_dim]
23//! output_proj                     [hidden_dim, input_dim]
24//! layers.{i}.norm.weight          [hidden_dim]
25//! layers.{i}.norm.bias            [hidden_dim] (optional)
26//! layers.{i}.in_proj              [hidden_dim, inner_dim*2]
27//! layers.{i}.conv.weight          [out_channels, in_channels, kernel_size]
28//! layers.{i}.conv.bias            [out_channels] (optional)
29//! layers.{i}.ssm.log_a            [state_dim]
30//! layers.{i}.ssm.delta_proj       [inner_dim, inner_dim]
31//! layers.{i}.ssm.delta_bias       [inner_dim]
32//! layers.{i}.ssm.b_proj           [inner_dim, state_dim]
33//! layers.{i}.ssm.c_proj           [inner_dim, state_dim]
34//! layers.{i}.ssm.d_skip           [inner_dim]
35//! layers.{i}.out_proj             [inner_dim, hidden_dim]
36//! ```
37//!
38//! ## RWKV Weight Format
39//!
40//! RWKV v6 models expect:
41//!
42//! ```text
43//! input_proj                      [input_dim, hidden_dim]
44//! output_proj                     [hidden_dim, input_dim]
45//! layers.{i}.norm.weight          [hidden_dim]
46//! layers.{i}.time_mix.w_r         [num_heads, head_dim]
47//! layers.{i}.time_mix.w_k         [num_heads, head_dim]
48//! layers.{i}.time_mix.w_v         [num_heads, head_dim]
49//! layers.{i}.time_mix.w_g         [num_heads, head_dim]
50//! layers.{i}.time_mix.w_a         [num_heads, head_dim]
51//! layers.{i}.time_mix.w_b         [num_heads, head_dim]
52//! layers.{i}.channel_mix.w_r      [hidden_dim]
53//! layers.{i}.channel_mix.w_k      [hidden_dim]
54//! layers.{i}.channel_mix.w_v      [hidden_dim]
55//! ```
56//!
57//! ## HuggingFace Compatibility
58//!
59//! HuggingFace Mamba models use a different architecture and naming:
60//!
61//! ```text
62//! HuggingFace:                    Kizzasi:
63//! backbone.embeddings          →  input_proj
64//! backbone.layers.{i}.norm     →  layers.{i}.norm
65//! backbone.layers.{i}.mixer.in_proj → layers.{i}.in_proj
66//! backbone.layers.{i}.mixer.conv1d → layers.{i}.conv
67//! backbone.layers.{i}.mixer.x_proj → (needs splitting)
68//! backbone.layers.{i}.mixer.dt_proj → layers.{i}.ssm.delta_proj
69//! backbone.layers.{i}.mixer.A_log → layers.{i}.ssm.log_a
70//! backbone.layers.{i}.mixer.D → layers.{i}.ssm.d_skip
71//! backbone.layers.{i}.mixer.out_proj → layers.{i}.out_proj
72//! lm_head                      →  output_proj
73//! ```
74//!
75//! **Important**: HuggingFace's `x_proj` combines time_step, B, and C projections
76//! into a single matrix that must be split during conversion:
77//!
78//! ```text
79//! x_proj [intermediate_size, time_step_rank + state_size*2]
80//!   ↓ split ↓
81//! dt [time_step_rank], B [state_size], C [state_size]
82//! ```
83//!
84//! # Conversion Utilities
85//!
86//! Use `WeightLoader` for advanced loading with validation and name mapping:
87//!
88//! ```ignore
89//! use kizzasi_model::loader::{ModelLoader, WeightLoader};
90//! use kizzasi_model::ModelType;
91//!
92//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
93//! let loader = ModelLoader::new("mamba.safetensors")?;
94//! let weight_loader = WeightLoader::new(loader)
95//!     .model_type(ModelType::Mamba)
96//!     .strict(false);  // Allow missing optional weights
97//!
98//! // Inspect checkpoint structure
99//! weight_loader.print_weights();
100//!
101//! // Get suggested mappings for HuggingFace format
102//! let mappings = weight_loader.suggest_huggingface_mapping();
103//! # Ok(())
104//! # }
105//! ```
106//!
107//! # Example
108//!
109//! ```ignore
110//! use kizzasi_model::loader::ModelLoader;
111//!
112//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
113//! let loader = ModelLoader::new("model.safetensors")?;
114//! let tensor_names = loader.list_tensors();
115//! // Load specific tensors as needed
116//! # Ok(())
117//! # }
118//! ```
119
120use crate::error::{ModelError, ModelResult};
121use crate::ModelType;
122use safetensors::tensor::SafeTensors;
123use scirs2_core::ndarray::{Array1, Array2, ArrayD};
124use std::collections::HashMap;
125use std::fs::File;
126use std::io::Read;
127use std::path::Path;
128
129/// Weight loader for safetensors format
130pub struct ModelLoader {
131    /// Loaded safetensors data
132    tensors: SafeTensors<'static>,
133    /// Raw file data (kept alive for tensors)
134    _data: Vec<u8>,
135}
136
137impl ModelLoader {
138    /// Load a safetensors file from disk
139    pub fn new<P: AsRef<Path>>(path: P) -> ModelResult<Self> {
140        let mut file = File::open(path.as_ref())
141            .map_err(|e| ModelError::simple_load_error(format!("Failed to open file: {}", e)))?;
142
143        let mut data = Vec::new();
144        file.read_to_end(&mut data)
145            .map_err(|e| ModelError::simple_load_error(format!("Failed to read file: {}", e)))?;
146
147        // Leak the data to get a 'static lifetime
148        // This is safe because we keep the Vec alive in the struct
149        let data_static = Box::leak(data.clone().into_boxed_slice());
150
151        let tensors = SafeTensors::deserialize(data_static).map_err(|e| {
152            ModelError::simple_load_error(format!("Failed to parse safetensors: {}", e))
153        })?;
154
155        Ok(Self {
156            tensors,
157            _data: data,
158        })
159    }
160
161    /// Load a safetensors from bytes
162    pub fn from_bytes(data: Vec<u8>) -> ModelResult<Self> {
163        let data_static = Box::leak(data.clone().into_boxed_slice());
164
165        let tensors = SafeTensors::deserialize(data_static).map_err(|e| {
166            ModelError::simple_load_error(format!("Failed to parse safetensors: {}", e))
167        })?;
168
169        Ok(Self {
170            tensors,
171            _data: data,
172        })
173    }
174
175    /// List all available tensor names in the file
176    pub fn list_tensors(&self) -> Vec<String> {
177        self.tensors.names().iter().map(|s| s.to_string()).collect()
178    }
179
180    /// Get metadata about a specific tensor
181    pub fn tensor_info(&self, name: &str) -> Option<TensorInfo> {
182        self.tensors.tensor(name).ok().map(|view| TensorInfo {
183            name: name.to_string(),
184            shape: view.shape().to_vec(),
185            dtype: format!("{:?}", view.dtype()),
186        })
187    }
188
189    /// Load a 1D tensor (Array1<f32>)
190    pub fn load_array1(&self, name: &str) -> ModelResult<Array1<f32>> {
191        let view = self.tensors.tensor(name).map_err(|e| {
192            ModelError::simple_load_error(format!("Tensor '{}' not found: {}", name, e))
193        })?;
194
195        let shape = view.shape();
196        if shape.len() != 1 {
197            return Err(ModelError::simple_load_error(format!(
198                "Expected 1D tensor for '{}', got shape {:?}",
199                name, shape
200            )));
201        }
202
203        let data = view.data();
204        let float_data = match view.dtype() {
205            safetensors::Dtype::F32 => {
206                // Convert bytes to f32
207                data.chunks_exact(4)
208                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
209                    .collect::<Vec<_>>()
210            }
211            safetensors::Dtype::F64 => {
212                // Convert f64 to f32
213                data.chunks_exact(8)
214                    .map(|chunk| {
215                        let bytes = [
216                            chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
217                            chunk[7],
218                        ];
219                        f64::from_le_bytes(bytes) as f32
220                    })
221                    .collect::<Vec<_>>()
222            }
223            dtype => {
224                return Err(ModelError::simple_load_error(format!(
225                    "Unsupported dtype for '{}': {:?}",
226                    name, dtype
227                )));
228            }
229        };
230
231        Ok(Array1::from_vec(float_data))
232    }
233
234    /// Load a 2D tensor (Array2<f32>)
235    pub fn load_array2(&self, name: &str) -> ModelResult<Array2<f32>> {
236        let view = self.tensors.tensor(name).map_err(|e| {
237            ModelError::simple_load_error(format!("Tensor '{}' not found: {}", name, e))
238        })?;
239
240        let shape = view.shape();
241        if shape.len() != 2 {
242            return Err(ModelError::simple_load_error(format!(
243                "Expected 2D tensor for '{}', got shape {:?}",
244                name, shape
245            )));
246        }
247
248        let data = view.data();
249        let float_data = match view.dtype() {
250            safetensors::Dtype::F32 => data
251                .chunks_exact(4)
252                .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
253                .collect::<Vec<_>>(),
254            safetensors::Dtype::F64 => data
255                .chunks_exact(8)
256                .map(|chunk| {
257                    let bytes = [
258                        chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
259                        chunk[7],
260                    ];
261                    f64::from_le_bytes(bytes) as f32
262                })
263                .collect::<Vec<_>>(),
264            dtype => {
265                return Err(ModelError::simple_load_error(format!(
266                    "Unsupported dtype for '{}': {:?}",
267                    name, dtype
268                )));
269            }
270        };
271
272        Array2::from_shape_vec((shape[0], shape[1]), float_data)
273            .map_err(|e| ModelError::simple_load_error(format!("Failed to create Array2: {}", e)))
274    }
275
276    /// Load a tensor of arbitrary dimension
277    pub fn load_array(&self, name: &str) -> ModelResult<ArrayD<f32>> {
278        let view = self.tensors.tensor(name).map_err(|e| {
279            ModelError::simple_load_error(format!("Tensor '{}' not found: {}", name, e))
280        })?;
281
282        let shape = view.shape();
283        let data = view.data();
284
285        let float_data = match view.dtype() {
286            safetensors::Dtype::F32 => data
287                .chunks_exact(4)
288                .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
289                .collect::<Vec<_>>(),
290            safetensors::Dtype::F64 => data
291                .chunks_exact(8)
292                .map(|chunk| {
293                    let bytes = [
294                        chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
295                        chunk[7],
296                    ];
297                    f64::from_le_bytes(bytes) as f32
298                })
299                .collect::<Vec<_>>(),
300            safetensors::Dtype::F16 => {
301                // For F16, we need to convert to f32
302                // Note: This is a simplified conversion
303                data.chunks_exact(2)
304                    .map(|chunk| {
305                        let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
306                        half::f16::from_bits(bits).to_f32()
307                    })
308                    .collect::<Vec<_>>()
309            }
310            dtype => {
311                return Err(ModelError::simple_load_error(format!(
312                    "Unsupported dtype for '{}': {:?}",
313                    name, dtype
314                )));
315            }
316        };
317
318        ArrayD::from_shape_vec(shape, float_data)
319            .map_err(|e| ModelError::simple_load_error(format!("Failed to create ArrayD: {}", e)))
320    }
321
322    /// Load a 3D tensor as Vec<Vec<Vec<f32>>>
323    ///
324    /// This is useful for convolution weights [out_channels, in_channels, kernel_size]
325    pub fn load_array3(&self, name: &str) -> ModelResult<Vec<Vec<Vec<f32>>>> {
326        let array_d = self.load_array(name)?;
327
328        if array_d.ndim() != 3 {
329            return Err(ModelError::simple_load_error(format!(
330                "Expected 3D tensor for '{}', got {}D tensor",
331                name,
332                array_d.ndim()
333            )));
334        }
335
336        let shape = array_d.shape();
337        let dim0 = shape[0];
338        let dim1 = shape[1];
339        let dim2 = shape[2];
340
341        // Convert ArrayD to nested Vec structure
342        let mut result = Vec::with_capacity(dim0);
343        for i in 0..dim0 {
344            let mut dim1_vec = Vec::with_capacity(dim1);
345            for j in 0..dim1 {
346                let mut dim2_vec = Vec::with_capacity(dim2);
347                for k in 0..dim2 {
348                    dim2_vec.push(array_d[[i, j, k]]);
349                }
350                dim1_vec.push(dim2_vec);
351            }
352            result.push(dim1_vec);
353        }
354
355        Ok(result)
356    }
357
358    /// Check if a tensor exists
359    pub fn has_tensor(&self, name: &str) -> bool {
360        self.tensors.tensor(name).is_ok()
361    }
362
363    /// Load all tensors into a HashMap
364    pub fn load_all(&self) -> ModelResult<HashMap<String, ArrayD<f32>>> {
365        let mut result = HashMap::new();
366        for name in self.list_tensors() {
367            let array = self.load_array(&name)?;
368            result.insert(name, array);
369        }
370        Ok(result)
371    }
372
373    /// Print a summary of all tensors in the file
374    ///
375    /// This is useful for inspecting checkpoint files and understanding their structure
376    pub fn print_summary(&self) {
377        println!("SafeTensors Weight Summary");
378        println!("==========================");
379        println!("Total tensors: {}", self.list_tensors().len());
380        println!();
381
382        // Group by prefix
383        let mut prefixes: HashMap<String, Vec<String>> = HashMap::new();
384        for name in self.list_tensors() {
385            let parts: Vec<&str> = name.split('.').collect();
386            let prefix = if parts.len() > 1 {
387                parts[0..parts.len() - 1].join(".")
388            } else {
389                "root".to_string()
390            };
391            prefixes.entry(prefix).or_default().push(name);
392        }
393
394        for (prefix, tensors) in prefixes.iter() {
395            println!("\n[{}]", prefix);
396            for name in tensors {
397                if let Some(info) = self.tensor_info(name) {
398                    println!(
399                        "  {} - shape: {:?}, dtype: {}",
400                        name, info.shape, info.dtype
401                    );
402                }
403            }
404        }
405    }
406
407    /// Get statistics about tensor sizes
408    pub fn get_size_stats(&self) -> HashMap<String, usize> {
409        let mut stats = HashMap::new();
410        let mut total_params = 0usize;
411
412        for name in self.list_tensors() {
413            if let Some(info) = self.tensor_info(&name) {
414                let size: usize = info.shape.iter().product();
415                stats.insert(name.clone(), size);
416                total_params += size;
417            }
418        }
419
420        stats.insert("__total_parameters".to_string(), total_params);
421        stats
422    }
423
424    /// Search for tensors matching a pattern
425    ///
426    /// # Example
427    /// ```ignore
428    /// // Find all conv weights
429    /// let conv_tensors = loader.search_tensors("conv.weight");
430    /// ```
431    pub fn search_tensors(&self, pattern: &str) -> Vec<String> {
432        self.list_tensors()
433            .into_iter()
434            .filter(|name| name.contains(pattern))
435            .collect()
436    }
437}
438
439/// Information about a tensor in the safetensors file
440#[derive(Debug, Clone)]
441pub struct TensorInfo {
442    /// Tensor name
443    pub name: String,
444    /// Shape of the tensor
445    pub shape: Vec<usize>,
446    /// Data type as string
447    pub dtype: String,
448}
449
450/// Builder for loading model weights with validation
451pub struct WeightLoader {
452    loader: ModelLoader,
453    model_type: Option<ModelType>,
454    strict: bool,
455}
456
457impl WeightLoader {
458    /// Create a new weight loader
459    pub fn new(loader: ModelLoader) -> Self {
460        Self {
461            loader,
462            model_type: None,
463            strict: true,
464        }
465    }
466
467    /// Set the expected model type
468    pub fn model_type(mut self, model_type: ModelType) -> Self {
469        self.model_type = Some(model_type);
470        self
471    }
472
473    /// Set whether to enforce strict loading (all weights must be present)
474    pub fn strict(mut self, strict: bool) -> Self {
475        self.strict = strict;
476        self
477    }
478
479    /// Validate that all required weights are present
480    pub fn validate_weights(&self, required: &[&str]) -> ModelResult<()> {
481        if !self.strict {
482            return Ok(());
483        }
484
485        let missing: Vec<_> = required
486            .iter()
487            .filter(|&&name| !self.loader.has_tensor(name))
488            .copied()
489            .collect();
490
491        if !missing.is_empty() {
492            return Err(ModelError::simple_load_error(format!(
493                "Missing required weights: {:?}",
494                missing
495            )));
496        }
497
498        Ok(())
499    }
500
501    /// Get the underlying loader
502    pub fn loader(&self) -> &ModelLoader {
503        &self.loader
504    }
505
506    /// Create a name mapping from source format to target format
507    ///
508    /// # Example
509    /// ```ignore
510    /// let mapping = HashMap::from([
511    ///     ("backbone.layers.0.mixer.in_proj.weight", "layers.0.in_proj"),
512    ///     ("backbone.layers.0.mixer.A_log", "layers.0.ssm.log_a"),
513    /// ]);
514    /// let mapped_loader = WeightLoader::new(loader).with_name_mapping(mapping);
515    /// ```
516    pub fn with_name_mapping(self, _mapping: HashMap<String, String>) -> Self {
517        // TODO: Implement name remapping
518        // This requires storing the mapping and using it during tensor lookups
519        self
520    }
521
522    /// Print available weights and their shapes
523    ///
524    /// This is useful for understanding what weights are available in the checkpoint
525    pub fn print_weights(&self) {
526        self.loader.print_summary();
527    }
528
529    /// Get suggested weight mappings for HuggingFace format
530    ///
531    /// Returns a list of (hf_name, kizzasi_name) pairs that can be used
532    /// to convert HuggingFace checkpoints to Kizzasi format
533    pub fn suggest_huggingface_mapping(&self) -> Vec<(String, String)> {
534        let mut mappings = Vec::new();
535        let tensors = self.loader.list_tensors();
536
537        // Check if this looks like a HuggingFace checkpoint
538        if tensors.iter().any(|t| t.contains("backbone.layers")) {
539            for tensor in &tensors {
540                if let Some(kizzasi_name) = self.hf_to_kizzasi_name(tensor) {
541                    mappings.push((tensor.clone(), kizzasi_name));
542                }
543            }
544        }
545
546        mappings
547    }
548
549    /// Convert HuggingFace weight name to Kizzasi format
550    ///
551    /// # HuggingFace → Kizzasi Mapping
552    ///
553    /// - `backbone.embeddings` → `input_proj`
554    /// - `backbone.layers.{i}.norm.weight` → `layers.{i}.norm.weight`
555    /// - `backbone.layers.{i}.mixer.in_proj` → `layers.{i}.in_proj`
556    /// - `backbone.layers.{i}.mixer.conv1d` → `layers.{i}.conv`
557    /// - `backbone.layers.{i}.mixer.A_log` → `layers.{i}.ssm.log_a`
558    /// - `backbone.layers.{i}.mixer.D` → `layers.{i}.ssm.d_skip`
559    /// - `backbone.layers.{i}.mixer.out_proj` → `layers.{i}.out_proj`
560    /// - `lm_head` → `output_proj`
561    ///
562    /// Note: HuggingFace uses `x_proj` + `dt_proj` for selective parameters,
563    /// while Kizzasi uses separate `delta_proj`, `b_proj`, `c_proj`.
564    /// This requires splitting/combining weights during conversion.
565    fn hf_to_kizzasi_name(&self, hf_name: &str) -> Option<String> {
566        // Simple prefix replacement
567        let name = hf_name
568            .replace("backbone.", "")
569            .replace(".mixer.", ".")
570            .replace("conv1d", "conv")
571            .replace("A_log", "ssm.log_a")
572            .replace(".D", ".ssm.d_skip");
573
574        if name.is_empty() {
575            None
576        } else {
577            Some(name)
578        }
579    }
580}
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585
586    #[test]
587    fn test_tensor_info() {
588        let info = TensorInfo {
589            name: "test".to_string(),
590            shape: vec![2, 3],
591            dtype: "F32".to_string(),
592        };
593        assert_eq!(info.name, "test");
594        assert_eq!(info.shape, vec![2, 3]);
595    }
596}