kizzasi_model/loader.rs
1//! Weight loading from safetensors format
2//!
3//! This module provides functionality to load pre-trained model weights
4//! from the safetensors format, which is safer and faster than PyTorch
5//! pickle files.
6//!
7//! # Safetensors Format
8//!
9//! Safetensors is a simple format for storing tensors safely (as opposed to pickle)
10//! and that is still fast (zero-copy). It's used by Hugging Face and other ML frameworks.
11//!
12//! # Weight Naming Conventions
13//!
14//! Kizzasi models expect specific weight naming patterns. Each model architecture
15//! has its own convention documented in the respective model module.
16//!
17//! ## Mamba Weight Format
18//!
19//! Mamba models expect the following weight structure:
20//!
21//! ```text
22//! input_proj [input_dim, hidden_dim]
23//! output_proj [hidden_dim, input_dim]
24//! layers.{i}.norm.weight [hidden_dim]
25//! layers.{i}.norm.bias [hidden_dim] (optional)
26//! layers.{i}.in_proj [hidden_dim, inner_dim*2]
27//! layers.{i}.conv.weight [out_channels, in_channels, kernel_size]
28//! layers.{i}.conv.bias [out_channels] (optional)
29//! layers.{i}.ssm.log_a [state_dim]
30//! layers.{i}.ssm.delta_proj [inner_dim, inner_dim]
31//! layers.{i}.ssm.delta_bias [inner_dim]
32//! layers.{i}.ssm.b_proj [inner_dim, state_dim]
33//! layers.{i}.ssm.c_proj [inner_dim, state_dim]
34//! layers.{i}.ssm.d_skip [inner_dim]
35//! layers.{i}.out_proj [inner_dim, hidden_dim]
36//! ```
37//!
38//! ## RWKV Weight Format
39//!
40//! RWKV v6 models expect:
41//!
42//! ```text
43//! input_proj [input_dim, hidden_dim]
44//! output_proj [hidden_dim, input_dim]
45//! layers.{i}.norm.weight [hidden_dim]
46//! layers.{i}.time_mix.w_r [num_heads, head_dim]
47//! layers.{i}.time_mix.w_k [num_heads, head_dim]
48//! layers.{i}.time_mix.w_v [num_heads, head_dim]
49//! layers.{i}.time_mix.w_g [num_heads, head_dim]
50//! layers.{i}.time_mix.w_a [num_heads, head_dim]
51//! layers.{i}.time_mix.w_b [num_heads, head_dim]
52//! layers.{i}.channel_mix.w_r [hidden_dim]
53//! layers.{i}.channel_mix.w_k [hidden_dim]
54//! layers.{i}.channel_mix.w_v [hidden_dim]
55//! ```
56//!
57//! ## HuggingFace Compatibility
58//!
59//! HuggingFace Mamba models use a different architecture and naming:
60//!
61//! ```text
62//! HuggingFace: Kizzasi:
63//! backbone.embeddings → input_proj
64//! backbone.layers.{i}.norm → layers.{i}.norm
65//! backbone.layers.{i}.mixer.in_proj → layers.{i}.in_proj
66//! backbone.layers.{i}.mixer.conv1d → layers.{i}.conv
67//! backbone.layers.{i}.mixer.x_proj → (needs splitting)
68//! backbone.layers.{i}.mixer.dt_proj → layers.{i}.ssm.delta_proj
69//! backbone.layers.{i}.mixer.A_log → layers.{i}.ssm.log_a
70//! backbone.layers.{i}.mixer.D → layers.{i}.ssm.d_skip
71//! backbone.layers.{i}.mixer.out_proj → layers.{i}.out_proj
72//! lm_head → output_proj
73//! ```
74//!
75//! **Important**: HuggingFace's `x_proj` combines time_step, B, and C projections
76//! into a single matrix that must be split during conversion:
77//!
78//! ```text
79//! x_proj [intermediate_size, time_step_rank + state_size*2]
80//! ↓ split ↓
81//! dt [time_step_rank], B [state_size], C [state_size]
82//! ```
83//!
84//! # Conversion Utilities
85//!
86//! Use `WeightLoader` for advanced loading with validation and name mapping:
87//!
88//! ```ignore
89//! use kizzasi_model::loader::{ModelLoader, WeightLoader};
90//! use kizzasi_model::ModelType;
91//!
92//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
93//! let loader = ModelLoader::new("mamba.safetensors")?;
94//! let weight_loader = WeightLoader::new(loader)
95//! .model_type(ModelType::Mamba)
96//! .strict(false); // Allow missing optional weights
97//!
98//! // Inspect checkpoint structure
99//! weight_loader.print_weights();
100//!
101//! // Get suggested mappings for HuggingFace format
102//! let mappings = weight_loader.suggest_huggingface_mapping();
103//! # Ok(())
104//! # }
105//! ```
106//!
107//! # Example
108//!
109//! ```ignore
110//! use kizzasi_model::loader::ModelLoader;
111//!
112//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
113//! let loader = ModelLoader::new("model.safetensors")?;
114//! let tensor_names = loader.list_tensors();
115//! // Load specific tensors as needed
116//! # Ok(())
117//! # }
118//! ```
119
120use crate::error::{ModelError, ModelResult};
121use crate::ModelType;
122use safetensors::tensor::SafeTensors;
123use scirs2_core::ndarray::{Array1, Array2, ArrayD};
124use std::collections::HashMap;
125use std::fs::File;
126use std::io::Read;
127use std::path::Path;
128
129/// Weight loader for safetensors format
130pub struct ModelLoader {
131 /// Loaded safetensors data
132 tensors: SafeTensors<'static>,
133 /// Raw file data (kept alive for tensors)
134 _data: Vec<u8>,
135}
136
137impl ModelLoader {
138 /// Load a safetensors file from disk
139 pub fn new<P: AsRef<Path>>(path: P) -> ModelResult<Self> {
140 let mut file = File::open(path.as_ref())
141 .map_err(|e| ModelError::simple_load_error(format!("Failed to open file: {}", e)))?;
142
143 let mut data = Vec::new();
144 file.read_to_end(&mut data)
145 .map_err(|e| ModelError::simple_load_error(format!("Failed to read file: {}", e)))?;
146
147 // Leak the data to get a 'static lifetime
148 // This is safe because we keep the Vec alive in the struct
149 let data_static = Box::leak(data.clone().into_boxed_slice());
150
151 let tensors = SafeTensors::deserialize(data_static).map_err(|e| {
152 ModelError::simple_load_error(format!("Failed to parse safetensors: {}", e))
153 })?;
154
155 Ok(Self {
156 tensors,
157 _data: data,
158 })
159 }
160
161 /// Load a safetensors from bytes
162 pub fn from_bytes(data: Vec<u8>) -> ModelResult<Self> {
163 let data_static = Box::leak(data.clone().into_boxed_slice());
164
165 let tensors = SafeTensors::deserialize(data_static).map_err(|e| {
166 ModelError::simple_load_error(format!("Failed to parse safetensors: {}", e))
167 })?;
168
169 Ok(Self {
170 tensors,
171 _data: data,
172 })
173 }
174
175 /// List all available tensor names in the file
176 pub fn list_tensors(&self) -> Vec<String> {
177 self.tensors.names().iter().map(|s| s.to_string()).collect()
178 }
179
180 /// Get metadata about a specific tensor
181 pub fn tensor_info(&self, name: &str) -> Option<TensorInfo> {
182 self.tensors.tensor(name).ok().map(|view| TensorInfo {
183 name: name.to_string(),
184 shape: view.shape().to_vec(),
185 dtype: format!("{:?}", view.dtype()),
186 })
187 }
188
189 /// Load a 1D tensor (Array1<f32>)
190 pub fn load_array1(&self, name: &str) -> ModelResult<Array1<f32>> {
191 let view = self.tensors.tensor(name).map_err(|e| {
192 ModelError::simple_load_error(format!("Tensor '{}' not found: {}", name, e))
193 })?;
194
195 let shape = view.shape();
196 if shape.len() != 1 {
197 return Err(ModelError::simple_load_error(format!(
198 "Expected 1D tensor for '{}', got shape {:?}",
199 name, shape
200 )));
201 }
202
203 let data = view.data();
204 let float_data = match view.dtype() {
205 safetensors::Dtype::F32 => {
206 // Convert bytes to f32
207 data.chunks_exact(4)
208 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
209 .collect::<Vec<_>>()
210 }
211 safetensors::Dtype::F64 => {
212 // Convert f64 to f32
213 data.chunks_exact(8)
214 .map(|chunk| {
215 let bytes = [
216 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
217 chunk[7],
218 ];
219 f64::from_le_bytes(bytes) as f32
220 })
221 .collect::<Vec<_>>()
222 }
223 dtype => {
224 return Err(ModelError::simple_load_error(format!(
225 "Unsupported dtype for '{}': {:?}",
226 name, dtype
227 )));
228 }
229 };
230
231 Ok(Array1::from_vec(float_data))
232 }
233
234 /// Load a 2D tensor (Array2<f32>)
235 pub fn load_array2(&self, name: &str) -> ModelResult<Array2<f32>> {
236 let view = self.tensors.tensor(name).map_err(|e| {
237 ModelError::simple_load_error(format!("Tensor '{}' not found: {}", name, e))
238 })?;
239
240 let shape = view.shape();
241 if shape.len() != 2 {
242 return Err(ModelError::simple_load_error(format!(
243 "Expected 2D tensor for '{}', got shape {:?}",
244 name, shape
245 )));
246 }
247
248 let data = view.data();
249 let float_data = match view.dtype() {
250 safetensors::Dtype::F32 => data
251 .chunks_exact(4)
252 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
253 .collect::<Vec<_>>(),
254 safetensors::Dtype::F64 => data
255 .chunks_exact(8)
256 .map(|chunk| {
257 let bytes = [
258 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
259 chunk[7],
260 ];
261 f64::from_le_bytes(bytes) as f32
262 })
263 .collect::<Vec<_>>(),
264 dtype => {
265 return Err(ModelError::simple_load_error(format!(
266 "Unsupported dtype for '{}': {:?}",
267 name, dtype
268 )));
269 }
270 };
271
272 Array2::from_shape_vec((shape[0], shape[1]), float_data)
273 .map_err(|e| ModelError::simple_load_error(format!("Failed to create Array2: {}", e)))
274 }
275
276 /// Load a tensor of arbitrary dimension
277 pub fn load_array(&self, name: &str) -> ModelResult<ArrayD<f32>> {
278 let view = self.tensors.tensor(name).map_err(|e| {
279 ModelError::simple_load_error(format!("Tensor '{}' not found: {}", name, e))
280 })?;
281
282 let shape = view.shape();
283 let data = view.data();
284
285 let float_data = match view.dtype() {
286 safetensors::Dtype::F32 => data
287 .chunks_exact(4)
288 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
289 .collect::<Vec<_>>(),
290 safetensors::Dtype::F64 => data
291 .chunks_exact(8)
292 .map(|chunk| {
293 let bytes = [
294 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
295 chunk[7],
296 ];
297 f64::from_le_bytes(bytes) as f32
298 })
299 .collect::<Vec<_>>(),
300 safetensors::Dtype::F16 => {
301 // For F16, we need to convert to f32
302 // Note: This is a simplified conversion
303 data.chunks_exact(2)
304 .map(|chunk| {
305 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
306 half::f16::from_bits(bits).to_f32()
307 })
308 .collect::<Vec<_>>()
309 }
310 dtype => {
311 return Err(ModelError::simple_load_error(format!(
312 "Unsupported dtype for '{}': {:?}",
313 name, dtype
314 )));
315 }
316 };
317
318 ArrayD::from_shape_vec(shape, float_data)
319 .map_err(|e| ModelError::simple_load_error(format!("Failed to create ArrayD: {}", e)))
320 }
321
322 /// Load a 3D tensor as Vec<Vec<Vec<f32>>>
323 ///
324 /// This is useful for convolution weights [out_channels, in_channels, kernel_size]
325 pub fn load_array3(&self, name: &str) -> ModelResult<Vec<Vec<Vec<f32>>>> {
326 let array_d = self.load_array(name)?;
327
328 if array_d.ndim() != 3 {
329 return Err(ModelError::simple_load_error(format!(
330 "Expected 3D tensor for '{}', got {}D tensor",
331 name,
332 array_d.ndim()
333 )));
334 }
335
336 let shape = array_d.shape();
337 let dim0 = shape[0];
338 let dim1 = shape[1];
339 let dim2 = shape[2];
340
341 // Convert ArrayD to nested Vec structure
342 let mut result = Vec::with_capacity(dim0);
343 for i in 0..dim0 {
344 let mut dim1_vec = Vec::with_capacity(dim1);
345 for j in 0..dim1 {
346 let mut dim2_vec = Vec::with_capacity(dim2);
347 for k in 0..dim2 {
348 dim2_vec.push(array_d[[i, j, k]]);
349 }
350 dim1_vec.push(dim2_vec);
351 }
352 result.push(dim1_vec);
353 }
354
355 Ok(result)
356 }
357
358 /// Check if a tensor exists
359 pub fn has_tensor(&self, name: &str) -> bool {
360 self.tensors.tensor(name).is_ok()
361 }
362
363 /// Load all tensors into a HashMap
364 pub fn load_all(&self) -> ModelResult<HashMap<String, ArrayD<f32>>> {
365 let mut result = HashMap::new();
366 for name in self.list_tensors() {
367 let array = self.load_array(&name)?;
368 result.insert(name, array);
369 }
370 Ok(result)
371 }
372
373 /// Print a summary of all tensors in the file
374 ///
375 /// This is useful for inspecting checkpoint files and understanding their structure
376 pub fn print_summary(&self) {
377 println!("SafeTensors Weight Summary");
378 println!("==========================");
379 println!("Total tensors: {}", self.list_tensors().len());
380 println!();
381
382 // Group by prefix
383 let mut prefixes: HashMap<String, Vec<String>> = HashMap::new();
384 for name in self.list_tensors() {
385 let parts: Vec<&str> = name.split('.').collect();
386 let prefix = if parts.len() > 1 {
387 parts[0..parts.len() - 1].join(".")
388 } else {
389 "root".to_string()
390 };
391 prefixes.entry(prefix).or_default().push(name);
392 }
393
394 for (prefix, tensors) in prefixes.iter() {
395 println!("\n[{}]", prefix);
396 for name in tensors {
397 if let Some(info) = self.tensor_info(name) {
398 println!(
399 " {} - shape: {:?}, dtype: {}",
400 name, info.shape, info.dtype
401 );
402 }
403 }
404 }
405 }
406
407 /// Get statistics about tensor sizes
408 pub fn get_size_stats(&self) -> HashMap<String, usize> {
409 let mut stats = HashMap::new();
410 let mut total_params = 0usize;
411
412 for name in self.list_tensors() {
413 if let Some(info) = self.tensor_info(&name) {
414 let size: usize = info.shape.iter().product();
415 stats.insert(name.clone(), size);
416 total_params += size;
417 }
418 }
419
420 stats.insert("__total_parameters".to_string(), total_params);
421 stats
422 }
423
424 /// Search for tensors matching a pattern
425 ///
426 /// # Example
427 /// ```ignore
428 /// // Find all conv weights
429 /// let conv_tensors = loader.search_tensors("conv.weight");
430 /// ```
431 pub fn search_tensors(&self, pattern: &str) -> Vec<String> {
432 self.list_tensors()
433 .into_iter()
434 .filter(|name| name.contains(pattern))
435 .collect()
436 }
437}
438
439/// Information about a tensor in the safetensors file
440#[derive(Debug, Clone)]
441pub struct TensorInfo {
442 /// Tensor name
443 pub name: String,
444 /// Shape of the tensor
445 pub shape: Vec<usize>,
446 /// Data type as string
447 pub dtype: String,
448}
449
450/// Builder for loading model weights with validation
451pub struct WeightLoader {
452 loader: ModelLoader,
453 model_type: Option<ModelType>,
454 strict: bool,
455}
456
457impl WeightLoader {
458 /// Create a new weight loader
459 pub fn new(loader: ModelLoader) -> Self {
460 Self {
461 loader,
462 model_type: None,
463 strict: true,
464 }
465 }
466
467 /// Set the expected model type
468 pub fn model_type(mut self, model_type: ModelType) -> Self {
469 self.model_type = Some(model_type);
470 self
471 }
472
473 /// Set whether to enforce strict loading (all weights must be present)
474 pub fn strict(mut self, strict: bool) -> Self {
475 self.strict = strict;
476 self
477 }
478
479 /// Validate that all required weights are present
480 pub fn validate_weights(&self, required: &[&str]) -> ModelResult<()> {
481 if !self.strict {
482 return Ok(());
483 }
484
485 let missing: Vec<_> = required
486 .iter()
487 .filter(|&&name| !self.loader.has_tensor(name))
488 .copied()
489 .collect();
490
491 if !missing.is_empty() {
492 return Err(ModelError::simple_load_error(format!(
493 "Missing required weights: {:?}",
494 missing
495 )));
496 }
497
498 Ok(())
499 }
500
501 /// Get the underlying loader
502 pub fn loader(&self) -> &ModelLoader {
503 &self.loader
504 }
505
506 /// Create a name mapping from source format to target format
507 ///
508 /// # Example
509 /// ```ignore
510 /// let mapping = HashMap::from([
511 /// ("backbone.layers.0.mixer.in_proj.weight", "layers.0.in_proj"),
512 /// ("backbone.layers.0.mixer.A_log", "layers.0.ssm.log_a"),
513 /// ]);
514 /// let mapped_loader = WeightLoader::new(loader).with_name_mapping(mapping);
515 /// ```
516 pub fn with_name_mapping(self, _mapping: HashMap<String, String>) -> Self {
517 // TODO: Implement name remapping
518 // This requires storing the mapping and using it during tensor lookups
519 self
520 }
521
522 /// Print available weights and their shapes
523 ///
524 /// This is useful for understanding what weights are available in the checkpoint
525 pub fn print_weights(&self) {
526 self.loader.print_summary();
527 }
528
529 /// Get suggested weight mappings for HuggingFace format
530 ///
531 /// Returns a list of (hf_name, kizzasi_name) pairs that can be used
532 /// to convert HuggingFace checkpoints to Kizzasi format
533 pub fn suggest_huggingface_mapping(&self) -> Vec<(String, String)> {
534 let mut mappings = Vec::new();
535 let tensors = self.loader.list_tensors();
536
537 // Check if this looks like a HuggingFace checkpoint
538 if tensors.iter().any(|t| t.contains("backbone.layers")) {
539 for tensor in &tensors {
540 if let Some(kizzasi_name) = self.hf_to_kizzasi_name(tensor) {
541 mappings.push((tensor.clone(), kizzasi_name));
542 }
543 }
544 }
545
546 mappings
547 }
548
549 /// Convert HuggingFace weight name to Kizzasi format
550 ///
551 /// # HuggingFace → Kizzasi Mapping
552 ///
553 /// - `backbone.embeddings` → `input_proj`
554 /// - `backbone.layers.{i}.norm.weight` → `layers.{i}.norm.weight`
555 /// - `backbone.layers.{i}.mixer.in_proj` → `layers.{i}.in_proj`
556 /// - `backbone.layers.{i}.mixer.conv1d` → `layers.{i}.conv`
557 /// - `backbone.layers.{i}.mixer.A_log` → `layers.{i}.ssm.log_a`
558 /// - `backbone.layers.{i}.mixer.D` → `layers.{i}.ssm.d_skip`
559 /// - `backbone.layers.{i}.mixer.out_proj` → `layers.{i}.out_proj`
560 /// - `lm_head` → `output_proj`
561 ///
562 /// Note: HuggingFace uses `x_proj` + `dt_proj` for selective parameters,
563 /// while Kizzasi uses separate `delta_proj`, `b_proj`, `c_proj`.
564 /// This requires splitting/combining weights during conversion.
565 fn hf_to_kizzasi_name(&self, hf_name: &str) -> Option<String> {
566 // Simple prefix replacement
567 let name = hf_name
568 .replace("backbone.", "")
569 .replace(".mixer.", ".")
570 .replace("conv1d", "conv")
571 .replace("A_log", "ssm.log_a")
572 .replace(".D", ".ssm.d_skip");
573
574 if name.is_empty() {
575 None
576 } else {
577 Some(name)
578 }
579 }
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585
586 #[test]
587 fn test_tensor_info() {
588 let info = TensorInfo {
589 name: "test".to_string(),
590 shape: vec![2, 3],
591 dtype: "F32".to_string(),
592 };
593 assert_eq!(info.name, "test");
594 assert_eq!(info.shape, vec![2, 3]);
595 }
596}