1use crate::device::DeviceConfig;
10use crate::error::{CoreError, CoreResult};
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarMap;
13use std::collections::HashMap;
14use std::path::Path;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum WeightFormat {
19 SafeTensors,
21 PyTorch,
23 QuantizedInt8,
25}
26
27#[derive(Debug, Clone)]
29pub struct WeightLoadConfig {
30 pub device_config: DeviceConfig,
32 pub quantize: bool,
34 pub strict: bool,
36}
37
38impl Default for WeightLoadConfig {
39 fn default() -> Self {
40 Self {
41 device_config: DeviceConfig::default(),
42 quantize: false,
43 strict: true,
44 }
45 }
46}
47
48impl WeightLoadConfig {
49 pub fn create_device(&self) -> CoreResult<Device> {
51 self.device_config.create_device()
52 }
53
54 pub fn get_dtype(&self) -> DType {
56 if self.device_config.use_fp16 {
57 DType::F16
58 } else {
59 DType::F32
60 }
61 }
62}
63
64pub struct WeightLoader {
66 #[allow(dead_code)]
67 config: WeightLoadConfig,
68}
69
70impl WeightLoader {
71 pub fn new(config: WeightLoadConfig) -> Self {
73 Self { config }
74 }
75
76 pub fn load_safetensors<P: AsRef<Path>>(&self, path: P, varmap: &mut VarMap) -> CoreResult<()> {
80 let path = path.as_ref();
81
82 varmap.load(path).map_err(|e| {
84 CoreError::WeightLoadError(format!("Failed to load safetensors: {}", e))
85 })?;
86
87 Ok(())
88 }
89
90 pub fn save_safetensors<P: AsRef<Path>>(&self, path: P, varmap: &VarMap) -> CoreResult<()> {
94 let path = path.as_ref();
95
96 varmap.save(path).map_err(|e| {
98 CoreError::WeightLoadError(format!("Failed to save safetensors: {}", e))
99 })?;
100
101 Ok(())
102 }
103
104 #[allow(dead_code)]
106 fn safetensors_to_candle(&self, view: safetensors::tensor::TensorView) -> CoreResult<Tensor> {
107 let shape = view.shape().to_vec();
108 let dtype = match view.dtype() {
109 safetensors::Dtype::F32 => DType::F32,
110 safetensors::Dtype::F16 => DType::F16,
111 safetensors::Dtype::BF16 => DType::BF16,
112 safetensors::Dtype::I64 => DType::I64,
113 safetensors::Dtype::U8 => DType::U8,
114 _ => {
115 return Err(CoreError::WeightLoadError(format!(
116 "Unsupported dtype: {:?}",
117 view.dtype()
118 )))
119 }
120 };
121
122 let data = view.data();
124
125 let tensor = match dtype {
127 DType::F32 => {
128 let values: Vec<f32> = data
129 .chunks_exact(4)
130 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
131 .collect();
132 Tensor::from_vec(values, &shape[..], &Device::Cpu).map_err(|e| {
133 CoreError::WeightLoadError(format!("Failed to create tensor: {}", e))
134 })?
135 }
136 DType::F16 | DType::BF16 => {
137 let values: Vec<u16> = data
139 .chunks_exact(2)
140 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
141 .collect();
142
143 let f32_values: Vec<f32> = values
144 .iter()
145 .map(|&v| half::f16::from_bits(v).to_f32())
146 .collect();
147
148 Tensor::from_vec(f32_values, &shape[..], &Device::Cpu)
149 .map_err(|e| {
150 CoreError::WeightLoadError(format!("Failed to create tensor: {}", e))
151 })?
152 .to_dtype(dtype)
153 .map_err(|e| {
154 CoreError::WeightLoadError(format!("Failed to convert dtype: {}", e))
155 })?
156 }
157 _ => {
158 return Err(CoreError::WeightLoadError(format!(
159 "Unsupported dtype for conversion: {:?}",
160 dtype
161 )))
162 }
163 };
164
165 Ok(tensor)
166 }
167
168 #[allow(dead_code)]
170 fn candle_to_safetensors(&self, tensors: HashMap<String, Tensor>) -> CoreResult<Vec<u8>> {
171 use safetensors::tensor::Dtype as SafeDtype;
172
173 let mut tensor_data: HashMap<String, (SafeDtype, Vec<usize>, Vec<u8>)> = HashMap::new();
175
176 for (name, tensor) in tensors.iter() {
177 let shape: Vec<usize> = tensor.dims().to_vec();
178
179 let dtype = match tensor.dtype() {
180 DType::F32 => SafeDtype::F32,
181 DType::F16 => SafeDtype::F16,
182 DType::BF16 => SafeDtype::BF16,
183 DType::I64 => SafeDtype::I64,
184 DType::U8 => SafeDtype::U8,
185 _ => {
186 return Err(CoreError::WeightLoadError(format!(
187 "Unsupported dtype for safetensors: {:?}",
188 tensor.dtype()
189 )))
190 }
191 };
192
193 let data = self.tensor_to_bytes(tensor)?;
195
196 tensor_data.insert(name.clone(), (dtype, shape, data));
197 }
198
199 Ok(Vec::new())
206 }
207
208 #[allow(dead_code)]
210 fn tensor_to_bytes(&self, tensor: &Tensor) -> CoreResult<Vec<u8>> {
211 match tensor.dtype() {
212 DType::F32 => {
213 let values = tensor
214 .flatten_all()
215 .map_err(|e| {
216 CoreError::WeightLoadError(format!("Failed to flatten tensor: {}", e))
217 })?
218 .to_vec1::<f32>()
219 .map_err(|e| {
220 CoreError::WeightLoadError(format!("Failed to convert to vec: {}", e))
221 })?;
222
223 let mut bytes = Vec::with_capacity(values.len() * 4);
224 for v in values {
225 bytes.extend_from_slice(&v.to_le_bytes());
226 }
227 Ok(bytes)
228 }
229 DType::F16 => {
230 let values = tensor
231 .flatten_all()
232 .map_err(|e| {
233 CoreError::WeightLoadError(format!("Failed to flatten tensor: {}", e))
234 })?
235 .to_vec1::<half::f16>()
236 .map_err(|e| {
237 CoreError::WeightLoadError(format!("Failed to convert to vec: {}", e))
238 })?;
239
240 let mut bytes = Vec::with_capacity(values.len() * 2);
241 for v in values {
242 bytes.extend_from_slice(&v.to_bits().to_le_bytes());
243 }
244 Ok(bytes)
245 }
246 _ => Err(CoreError::WeightLoadError(format!(
247 "Unsupported dtype for bytes conversion: {:?}",
248 tensor.dtype()
249 ))),
250 }
251 }
252
253 #[allow(dead_code)]
255 fn quantize_tensor(&self, tensor: &Tensor) -> CoreResult<Tensor> {
256 let min_val = tensor
260 .min(candle_core::D::Minus1)
261 .map_err(|e| CoreError::WeightLoadError(format!("Failed to compute min: {}", e)))?;
262 let max_val = tensor
263 .max(candle_core::D::Minus1)
264 .map_err(|e| CoreError::WeightLoadError(format!("Failed to compute max: {}", e)))?;
265
266 let range = max_val
267 .sub(&min_val)
268 .map_err(|e| CoreError::WeightLoadError(format!("Failed to compute range: {}", e)))?;
269
270 let scaled = tensor
272 .broadcast_sub(&min_val)
273 .map_err(|e| CoreError::WeightLoadError(format!("Failed to subtract min: {}", e)))?
274 .broadcast_div(&range)
275 .map_err(|e| CoreError::WeightLoadError(format!("Failed to divide by range: {}", e)))?
276 .affine(255.0, 0.0)
277 .map_err(|e| CoreError::WeightLoadError(format!("Failed to scale: {}", e)))?;
278
279 let quantized = scaled
281 .to_dtype(DType::U8)
282 .map_err(|e| CoreError::WeightLoadError(format!("Failed to convert to U8: {}", e)))?;
283
284 Ok(quantized)
285 }
286
287 pub fn load_pytorch_checkpoint<P: AsRef<Path>>(
292 &self,
293 _path: P,
294 _varmap: &VarMap,
295 ) -> CoreResult<()> {
296 Err(CoreError::WeightLoadError(
303 "PyTorch checkpoint loading not yet implemented".to_string(),
304 ))
305 }
306}
307
308pub struct WeightPruner;
310
311impl WeightPruner {
312 pub fn prune_by_magnitude(tensor: &Tensor, threshold: f32) -> CoreResult<Tensor> {
316 let abs_tensor = tensor
317 .abs()
318 .map_err(|e| CoreError::Generic(format!("Failed to compute abs: {}", e)))?;
319
320 let mask = abs_tensor
321 .ge(threshold as f64)
322 .map_err(|e| CoreError::Generic(format!("Failed to create mask: {}", e)))?
323 .to_dtype(tensor.dtype())
324 .map_err(|e| CoreError::Generic(format!("Failed to convert mask dtype: {}", e)))?;
325
326 tensor
327 .mul(&mask)
328 .map_err(|e| CoreError::Generic(format!("Failed to apply mask: {}", e)))
329 }
330
331 pub fn prune_by_percentage(tensor: &Tensor, percentage: f32) -> CoreResult<Tensor> {
335 if percentage <= 0.0 || percentage >= 1.0 {
336 return Err(CoreError::InvalidConfig(
337 "Percentage must be between 0 and 1".to_string(),
338 ));
339 }
340
341 let flat = tensor
343 .flatten_all()
344 .map_err(|e| CoreError::Generic(format!("Failed to flatten: {}", e)))?;
345
346 let abs_flat = flat
347 .abs()
348 .map_err(|e| CoreError::Generic(format!("Failed to compute abs: {}", e)))?;
349
350 let values = abs_flat
352 .to_vec1::<f32>()
353 .map_err(|e| CoreError::Generic(format!("Failed to convert to vec: {}", e)))?;
354
355 let mut sorted_values = values.clone();
357 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
358
359 let threshold_idx = (sorted_values.len() as f32 * percentage) as usize;
360 let threshold = sorted_values[threshold_idx];
361
362 Self::prune_by_magnitude(tensor, threshold)
363 }
364
365 pub fn compute_sparsity(tensor: &Tensor) -> CoreResult<f32> {
367 let total_elements = tensor.elem_count();
368
369 let zeros = tensor
370 .eq(0.0)
371 .map_err(|e| CoreError::Generic(format!("Failed to compare with zero: {}", e)))?
372 .to_dtype(DType::F32)
373 .map_err(|e| CoreError::Generic(format!("Failed to convert dtype: {}", e)))?
374 .sum_all()
375 .map_err(|e| CoreError::Generic(format!("Failed to sum: {}", e)))?
376 .to_vec0::<f32>()
377 .map_err(|e| CoreError::Generic(format!("Failed to extract value: {}", e)))?;
378
379 Ok(zeros / total_elements as f32)
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use candle_nn::VarBuilder;
387
388 #[test]
389 fn test_weight_loader_creation() {
390 let config = WeightLoadConfig::default();
391 let _loader = WeightLoader::new(config);
392 }
393
394 #[test]
395 fn test_prune_by_magnitude() {
396 let device = Device::Cpu;
397 let tensor = Tensor::new(&[1.0f32, 0.1, 2.0, 0.05, 3.0], &device).unwrap();
398
399 let pruned = WeightPruner::prune_by_magnitude(&tensor, 0.5).unwrap();
400 let values = pruned.to_vec1::<f32>().unwrap();
401
402 assert_eq!(values, vec![1.0, 0.0, 2.0, 0.0, 3.0]);
403 }
404
405 #[test]
406 fn test_compute_sparsity() {
407 let device = Device::Cpu;
408 let tensor = Tensor::new(&[1.0f32, 0.0, 2.0, 0.0, 3.0], &device).unwrap();
409
410 let sparsity = WeightPruner::compute_sparsity(&tensor).unwrap();
411 assert!((sparsity - 0.4).abs() < 1e-5);
412 }
413
414 #[test]
415 fn test_safetensors_roundtrip() {
416 use std::env;
417
418 let device = Device::Cpu;
419 let varmap = VarMap::new();
420 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
421
422 let _w1 = vb
424 .get_with_hints((3, 4), "weight1", candle_nn::init::Init::Const(1.0))
425 .unwrap();
426 let _w2 = vb
427 .get_with_hints((5, 6), "weight2", candle_nn::init::Init::Const(2.0))
428 .unwrap();
429
430 let config = WeightLoadConfig::default();
431 let loader = WeightLoader::new(config);
432
433 let temp_dir = env::temp_dir();
435 let save_path = temp_dir.join("test_weights.safetensors");
436
437 let result = loader.save_safetensors(&save_path, &varmap);
438 assert!(result.is_ok());
439
440 if save_path.exists() {
442 std::fs::remove_file(save_path).ok();
443 }
444 }
445}