kizzasi_core/
weights.rs

1//! Weight management for SSM models
2//!
3//! Provides functionality to load/save model weights in various formats:
4//! - Safetensors (preferred format)
5//! - PyTorch checkpoints (via conversion)
6//! - Quantized weights (INT8)
7//! - LoRA adapters
8
9use crate::device::DeviceConfig;
10use crate::error::{CoreError, CoreResult};
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarMap;
13use std::collections::HashMap;
14use std::path::Path;
15
16/// Weight format options
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum WeightFormat {
19    /// Safetensors format (recommended)
20    SafeTensors,
21    /// PyTorch checkpoint
22    PyTorch,
23    /// Quantized INT8
24    QuantizedInt8,
25}
26
27/// Weight loading configuration
28#[derive(Debug, Clone)]
29pub struct WeightLoadConfig {
30    /// Device configuration (CPU/CUDA/Metal)
31    pub device_config: DeviceConfig,
32    /// Whether to quantize weights on load
33    pub quantize: bool,
34    /// Strict mode (fail if any keys are missing)
35    pub strict: bool,
36}
37
38impl Default for WeightLoadConfig {
39    fn default() -> Self {
40        Self {
41            device_config: DeviceConfig::default(),
42            quantize: false,
43            strict: true,
44        }
45    }
46}
47
48impl WeightLoadConfig {
49    /// Create device from configuration
50    pub fn create_device(&self) -> CoreResult<Device> {
51        self.device_config.create_device()
52    }
53
54    /// Get data type from configuration
55    pub fn get_dtype(&self) -> DType {
56        if self.device_config.use_fp16 {
57            DType::F16
58        } else {
59            DType::F32
60        }
61    }
62}
63
64/// Weight loader for SSM models
65pub struct WeightLoader {
66    #[allow(dead_code)]
67    config: WeightLoadConfig,
68}
69
70impl WeightLoader {
71    /// Create a new weight loader
72    pub fn new(config: WeightLoadConfig) -> Self {
73        Self { config }
74    }
75
76    /// Load weights from a safetensors file
77    ///
78    /// Note: This function uses varmap.load() which handles loading from safetensors format
79    pub fn load_safetensors<P: AsRef<Path>>(&self, path: P, varmap: &mut VarMap) -> CoreResult<()> {
80        let path = path.as_ref();
81
82        // Use VarMap's built-in safetensors loading
83        varmap.load(path).map_err(|e| {
84            CoreError::WeightLoadError(format!("Failed to load safetensors: {}", e))
85        })?;
86
87        Ok(())
88    }
89
90    /// Save weights to a safetensors file
91    ///
92    /// Note: This function uses varmap.save() which handles saving to safetensors format
93    pub fn save_safetensors<P: AsRef<Path>>(&self, path: P, varmap: &VarMap) -> CoreResult<()> {
94        let path = path.as_ref();
95
96        // Use VarMap's built-in safetensors saving
97        varmap.save(path).map_err(|e| {
98            CoreError::WeightLoadError(format!("Failed to save safetensors: {}", e))
99        })?;
100
101        Ok(())
102    }
103
104    /// Convert safetensors tensor view to candle Tensor
105    #[allow(dead_code)]
106    fn safetensors_to_candle(&self, view: safetensors::tensor::TensorView) -> CoreResult<Tensor> {
107        let shape = view.shape().to_vec();
108        let dtype = match view.dtype() {
109            safetensors::Dtype::F32 => DType::F32,
110            safetensors::Dtype::F16 => DType::F16,
111            safetensors::Dtype::BF16 => DType::BF16,
112            safetensors::Dtype::I64 => DType::I64,
113            safetensors::Dtype::U8 => DType::U8,
114            _ => {
115                return Err(CoreError::WeightLoadError(format!(
116                    "Unsupported dtype: {:?}",
117                    view.dtype()
118                )))
119            }
120        };
121
122        // Get raw data
123        let data = view.data();
124
125        // Create tensor from raw data
126        let tensor = match dtype {
127            DType::F32 => {
128                let values: Vec<f32> = data
129                    .chunks_exact(4)
130                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
131                    .collect();
132                Tensor::from_vec(values, &shape[..], &Device::Cpu).map_err(|e| {
133                    CoreError::WeightLoadError(format!("Failed to create tensor: {}", e))
134                })?
135            }
136            DType::F16 | DType::BF16 => {
137                // For F16/BF16, we need to convert to F32 first
138                let values: Vec<u16> = data
139                    .chunks_exact(2)
140                    .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
141                    .collect();
142
143                let f32_values: Vec<f32> = values
144                    .iter()
145                    .map(|&v| half::f16::from_bits(v).to_f32())
146                    .collect();
147
148                Tensor::from_vec(f32_values, &shape[..], &Device::Cpu)
149                    .map_err(|e| {
150                        CoreError::WeightLoadError(format!("Failed to create tensor: {}", e))
151                    })?
152                    .to_dtype(dtype)
153                    .map_err(|e| {
154                        CoreError::WeightLoadError(format!("Failed to convert dtype: {}", e))
155                    })?
156            }
157            _ => {
158                return Err(CoreError::WeightLoadError(format!(
159                    "Unsupported dtype for conversion: {:?}",
160                    dtype
161                )))
162            }
163        };
164
165        Ok(tensor)
166    }
167
168    /// Convert candle tensors to safetensors format
169    #[allow(dead_code)]
170    fn candle_to_safetensors(&self, tensors: HashMap<String, Tensor>) -> CoreResult<Vec<u8>> {
171        use safetensors::tensor::Dtype as SafeDtype;
172
173        // Prepare tensor data
174        let mut tensor_data: HashMap<String, (SafeDtype, Vec<usize>, Vec<u8>)> = HashMap::new();
175
176        for (name, tensor) in tensors.iter() {
177            let shape: Vec<usize> = tensor.dims().to_vec();
178
179            let dtype = match tensor.dtype() {
180                DType::F32 => SafeDtype::F32,
181                DType::F16 => SafeDtype::F16,
182                DType::BF16 => SafeDtype::BF16,
183                DType::I64 => SafeDtype::I64,
184                DType::U8 => SafeDtype::U8,
185                _ => {
186                    return Err(CoreError::WeightLoadError(format!(
187                        "Unsupported dtype for safetensors: {:?}",
188                        tensor.dtype()
189                    )))
190                }
191            };
192
193            // Get tensor data as bytes
194            let data = self.tensor_to_bytes(tensor)?;
195
196            tensor_data.insert(name.clone(), (dtype, shape, data));
197        }
198
199        // Build the safetensors file manually using safetensors::tensor module
200        // For now, we'll use a simpler approach with serialize_to_file
201        // This is a placeholder - full implementation would use proper serialization
202
203        // Temporary workaround: Return empty vec for now
204        // TODO: Implement proper safetensors serialization
205        Ok(Vec::new())
206    }
207
208    /// Convert tensor to bytes
209    #[allow(dead_code)]
210    fn tensor_to_bytes(&self, tensor: &Tensor) -> CoreResult<Vec<u8>> {
211        match tensor.dtype() {
212            DType::F32 => {
213                let values = tensor
214                    .flatten_all()
215                    .map_err(|e| {
216                        CoreError::WeightLoadError(format!("Failed to flatten tensor: {}", e))
217                    })?
218                    .to_vec1::<f32>()
219                    .map_err(|e| {
220                        CoreError::WeightLoadError(format!("Failed to convert to vec: {}", e))
221                    })?;
222
223                let mut bytes = Vec::with_capacity(values.len() * 4);
224                for v in values {
225                    bytes.extend_from_slice(&v.to_le_bytes());
226                }
227                Ok(bytes)
228            }
229            DType::F16 => {
230                let values = tensor
231                    .flatten_all()
232                    .map_err(|e| {
233                        CoreError::WeightLoadError(format!("Failed to flatten tensor: {}", e))
234                    })?
235                    .to_vec1::<half::f16>()
236                    .map_err(|e| {
237                        CoreError::WeightLoadError(format!("Failed to convert to vec: {}", e))
238                    })?;
239
240                let mut bytes = Vec::with_capacity(values.len() * 2);
241                for v in values {
242                    bytes.extend_from_slice(&v.to_bits().to_le_bytes());
243                }
244                Ok(bytes)
245            }
246            _ => Err(CoreError::WeightLoadError(format!(
247                "Unsupported dtype for bytes conversion: {:?}",
248                tensor.dtype()
249            ))),
250        }
251    }
252
253    /// Quantize a tensor to INT8
254    #[allow(dead_code)]
255    fn quantize_tensor(&self, tensor: &Tensor) -> CoreResult<Tensor> {
256        // Simple quantization: scale to [-128, 127] range
257        // TODO: Implement proper quantization with scale and zero-point tracking
258
259        let min_val = tensor
260            .min(candle_core::D::Minus1)
261            .map_err(|e| CoreError::WeightLoadError(format!("Failed to compute min: {}", e)))?;
262        let max_val = tensor
263            .max(candle_core::D::Minus1)
264            .map_err(|e| CoreError::WeightLoadError(format!("Failed to compute max: {}", e)))?;
265
266        let range = max_val
267            .sub(&min_val)
268            .map_err(|e| CoreError::WeightLoadError(format!("Failed to compute range: {}", e)))?;
269
270        // Scale to [0, 255]
271        let scaled = tensor
272            .broadcast_sub(&min_val)
273            .map_err(|e| CoreError::WeightLoadError(format!("Failed to subtract min: {}", e)))?
274            .broadcast_div(&range)
275            .map_err(|e| CoreError::WeightLoadError(format!("Failed to divide by range: {}", e)))?
276            .affine(255.0, 0.0)
277            .map_err(|e| CoreError::WeightLoadError(format!("Failed to scale: {}", e)))?;
278
279        // Convert to U8
280        let quantized = scaled
281            .to_dtype(DType::U8)
282            .map_err(|e| CoreError::WeightLoadError(format!("Failed to convert to U8: {}", e)))?;
283
284        Ok(quantized)
285    }
286
287    /// Load weights from PyTorch checkpoint
288    ///
289    /// This is a placeholder for PyTorch checkpoint loading.
290    /// Full implementation would require parsing PyTorch's pickle format.
291    pub fn load_pytorch_checkpoint<P: AsRef<Path>>(
292        &self,
293        _path: P,
294        _varmap: &VarMap,
295    ) -> CoreResult<()> {
296        // TODO: Implement PyTorch checkpoint loading
297        // This would require:
298        // 1. Parsing PyTorch's pickle format
299        // 2. Converting PyTorch tensor names to our naming convention
300        // 3. Handling state_dict structure
301
302        Err(CoreError::WeightLoadError(
303            "PyTorch checkpoint loading not yet implemented".to_string(),
304        ))
305    }
306}
307
308/// Weight pruning utilities
309pub struct WeightPruner;
310
311impl WeightPruner {
312    /// Prune weights by magnitude
313    ///
314    /// Sets weights with absolute value below threshold to zero.
315    pub fn prune_by_magnitude(tensor: &Tensor, threshold: f32) -> CoreResult<Tensor> {
316        let abs_tensor = tensor
317            .abs()
318            .map_err(|e| CoreError::Generic(format!("Failed to compute abs: {}", e)))?;
319
320        let mask = abs_tensor
321            .ge(threshold as f64)
322            .map_err(|e| CoreError::Generic(format!("Failed to create mask: {}", e)))?
323            .to_dtype(tensor.dtype())
324            .map_err(|e| CoreError::Generic(format!("Failed to convert mask dtype: {}", e)))?;
325
326        tensor
327            .mul(&mask)
328            .map_err(|e| CoreError::Generic(format!("Failed to apply mask: {}", e)))
329    }
330
331    /// Prune weights by percentage
332    ///
333    /// Keeps only the top (1 - percentage) weights by magnitude.
334    pub fn prune_by_percentage(tensor: &Tensor, percentage: f32) -> CoreResult<Tensor> {
335        if percentage <= 0.0 || percentage >= 1.0 {
336            return Err(CoreError::InvalidConfig(
337                "Percentage must be between 0 and 1".to_string(),
338            ));
339        }
340
341        // Flatten tensor to 1D for sorting
342        let flat = tensor
343            .flatten_all()
344            .map_err(|e| CoreError::Generic(format!("Failed to flatten: {}", e)))?;
345
346        let abs_flat = flat
347            .abs()
348            .map_err(|e| CoreError::Generic(format!("Failed to compute abs: {}", e)))?;
349
350        // Get values as vec
351        let values = abs_flat
352            .to_vec1::<f32>()
353            .map_err(|e| CoreError::Generic(format!("Failed to convert to vec: {}", e)))?;
354
355        // Find threshold at the given percentage
356        let mut sorted_values = values.clone();
357        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
358
359        let threshold_idx = (sorted_values.len() as f32 * percentage) as usize;
360        let threshold = sorted_values[threshold_idx];
361
362        Self::prune_by_magnitude(tensor, threshold)
363    }
364
365    /// Compute sparsity of a tensor
366    pub fn compute_sparsity(tensor: &Tensor) -> CoreResult<f32> {
367        let total_elements = tensor.elem_count();
368
369        let zeros = tensor
370            .eq(0.0)
371            .map_err(|e| CoreError::Generic(format!("Failed to compare with zero: {}", e)))?
372            .to_dtype(DType::F32)
373            .map_err(|e| CoreError::Generic(format!("Failed to convert dtype: {}", e)))?
374            .sum_all()
375            .map_err(|e| CoreError::Generic(format!("Failed to sum: {}", e)))?
376            .to_vec0::<f32>()
377            .map_err(|e| CoreError::Generic(format!("Failed to extract value: {}", e)))?;
378
379        Ok(zeros / total_elements as f32)
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use candle_nn::VarBuilder;
387
388    #[test]
389    fn test_weight_loader_creation() {
390        let config = WeightLoadConfig::default();
391        let _loader = WeightLoader::new(config);
392    }
393
394    #[test]
395    fn test_prune_by_magnitude() {
396        let device = Device::Cpu;
397        let tensor = Tensor::new(&[1.0f32, 0.1, 2.0, 0.05, 3.0], &device).unwrap();
398
399        let pruned = WeightPruner::prune_by_magnitude(&tensor, 0.5).unwrap();
400        let values = pruned.to_vec1::<f32>().unwrap();
401
402        assert_eq!(values, vec![1.0, 0.0, 2.0, 0.0, 3.0]);
403    }
404
405    #[test]
406    fn test_compute_sparsity() {
407        let device = Device::Cpu;
408        let tensor = Tensor::new(&[1.0f32, 0.0, 2.0, 0.0, 3.0], &device).unwrap();
409
410        let sparsity = WeightPruner::compute_sparsity(&tensor).unwrap();
411        assert!((sparsity - 0.4).abs() < 1e-5);
412    }
413
414    #[test]
415    fn test_safetensors_roundtrip() {
416        use std::env;
417
418        let device = Device::Cpu;
419        let varmap = VarMap::new();
420        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
421
422        // Create some test variables
423        let _w1 = vb
424            .get_with_hints((3, 4), "weight1", candle_nn::init::Init::Const(1.0))
425            .unwrap();
426        let _w2 = vb
427            .get_with_hints((5, 6), "weight2", candle_nn::init::Init::Const(2.0))
428            .unwrap();
429
430        let config = WeightLoadConfig::default();
431        let loader = WeightLoader::new(config);
432
433        // Save
434        let temp_dir = env::temp_dir();
435        let save_path = temp_dir.join("test_weights.safetensors");
436
437        let result = loader.save_safetensors(&save_path, &varmap);
438        assert!(result.is_ok());
439
440        // Clean up
441        if save_path.exists() {
442            std::fs::remove_file(save_path).ok();
443        }
444    }
445}