kizzasi_core/
pytorch_compat.rs

1//! # PyTorch Compatibility Layer
2//!
3//! Utilities for loading and converting PyTorch checkpoints to Kizzasi format.
4//!
5//! ## Features
6//!
7//! - **Checkpoint Loading**: Load PyTorch `.pth` or `.pt` files
8//! - **Tensor Conversion**: Convert PyTorch tensors to ndarray/candle format
9//! - **Weight Mapping**: Map PyTorch layer names to Kizzasi layer names
10//! - **Architecture Detection**: Automatically detect model architecture
11//! - **Validation**: Verify checkpoint compatibility
12//!
13//! ## Supported Formats
14//!
15//! - PyTorch state_dict (via safetensors)
16//! - HuggingFace checkpoints
17//! - Custom Mamba/SSM checkpoints
18
19use crate::{CoreError, CoreResult};
20use candle_core::{DType, Device, Tensor};
21use safetensors::SafeTensors;
22use scirs2_core::ndarray::{Array1, Array2};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::path::Path;
26
27/// PyTorch checkpoint metadata
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PyTorchCheckpoint {
30    /// Model architecture name
31    pub architecture: String,
32    /// Number of layers
33    pub num_layers: Option<usize>,
34    /// Hidden dimension
35    pub hidden_dim: Option<usize>,
36    /// Model dimension
37    pub d_model: Option<usize>,
38    /// State dimension
39    pub d_state: Option<usize>,
40    /// Additional metadata
41    pub metadata: HashMap<String, String>,
42}
43
44/// Weight mapping configuration for different architectures
45#[derive(Debug, Clone)]
46pub struct WeightMapping {
47    /// Source layer name pattern (PyTorch)
48    pub source_pattern: String,
49    /// Target layer name (Kizzasi)
50    pub target_name: String,
51    /// Whether to transpose the weight matrix
52    pub transpose: bool,
53}
54
55/// PyTorch checkpoint converter
56pub struct PyTorchConverter {
57    /// Device for tensor operations
58    device: Device,
59    /// Weight mappings
60    mappings: Vec<WeightMapping>,
61}
62
63impl PyTorchConverter {
64    /// Create a new PyTorch converter
65    pub fn new(device: Device) -> Self {
66        Self {
67            device,
68            mappings: Vec::new(),
69        }
70    }
71
72    /// Create converter with CPU device
73    pub fn new_cpu() -> Self {
74        Self::new(Device::Cpu)
75    }
76
77    /// Add a weight mapping
78    pub fn add_mapping(&mut self, source: &str, target: &str, transpose: bool) {
79        self.mappings.push(WeightMapping {
80            source_pattern: source.to_string(),
81            target_name: target.to_string(),
82            transpose,
83        });
84    }
85
86    /// Load checkpoint from safetensors file
87    pub fn load_safetensors(&self, path: impl AsRef<Path>) -> CoreResult<HashMap<String, Tensor>> {
88        let data = std::fs::read(path.as_ref())
89            .map_err(|e| CoreError::WeightLoadError(format!("Failed to read file: {}", e)))?;
90
91        let tensors = SafeTensors::deserialize(&data).map_err(|e| {
92            CoreError::WeightLoadError(format!("Failed to deserialize safetensors: {}", e))
93        })?;
94
95        let mut weights = HashMap::new();
96
97        for (name, tensor_view) in tensors.tensors() {
98            let tensor = self.safetensor_to_candle(&tensor_view)?;
99            weights.insert(name.to_string(), tensor);
100        }
101
102        Ok(weights)
103    }
104
105    /// Convert safetensor view to candle tensor
106    fn safetensor_to_candle(&self, view: &safetensors::tensor::TensorView) -> CoreResult<Tensor> {
107        let shape: Vec<usize> = view.shape().to_vec();
108        let dtype = match view.dtype() {
109            safetensors::Dtype::F32 => DType::F32,
110            safetensors::Dtype::F16 => DType::F16,
111            safetensors::Dtype::BF16 => DType::BF16,
112            safetensors::Dtype::F64 => DType::F64,
113            safetensors::Dtype::I64 => DType::I64,
114            safetensors::Dtype::U8 => DType::U8,
115            _ => {
116                return Err(CoreError::WeightLoadError(format!(
117                    "Unsupported dtype: {:?}",
118                    view.dtype()
119                )))
120            }
121        };
122
123        let data = view.data();
124        Tensor::from_raw_buffer(data, dtype, &shape, &self.device)
125            .map_err(|e| CoreError::WeightLoadError(format!("Failed to create tensor: {}", e)))
126    }
127
128    /// Convert candle tensor to ndarray Array2
129    pub fn tensor_to_array2(&self, tensor: &Tensor) -> CoreResult<Array2<f32>> {
130        if tensor.rank() != 2 {
131            return Err(CoreError::WeightLoadError(format!(
132                "Expected 2D tensor, got rank {}",
133                tensor.rank()
134            )));
135        }
136
137        let shape = tensor.shape();
138        let rows = shape.dims()[0];
139        let cols = shape.dims()[1];
140
141        // Convert to f32 if needed
142        let tensor_f32 = if tensor.dtype() != DType::F32 {
143            tensor.to_dtype(DType::F32).map_err(|e| {
144                CoreError::WeightLoadError(format!("Failed to convert dtype: {}", e))
145            })?
146        } else {
147            tensor.clone()
148        };
149
150        // Get data as Vec<f32>
151        let data: Vec<f32> = tensor_f32
152            .to_vec2()
153            .map_err(|e| CoreError::WeightLoadError(format!("Failed to convert to vec: {}", e)))?
154            .into_iter()
155            .flatten()
156            .collect();
157
158        Array2::from_shape_vec((rows, cols), data).map_err(CoreError::ShapeError)
159    }
160
161    /// Convert candle tensor to ndarray Array1
162    pub fn tensor_to_array1(&self, tensor: &Tensor) -> CoreResult<Array1<f32>> {
163        if tensor.rank() != 1 {
164            return Err(CoreError::WeightLoadError(format!(
165                "Expected 1D tensor, got rank {}",
166                tensor.rank()
167            )));
168        }
169
170        // Convert to f32 if needed
171        let tensor_f32 = if tensor.dtype() != DType::F32 {
172            tensor.to_dtype(DType::F32).map_err(|e| {
173                CoreError::WeightLoadError(format!("Failed to convert dtype: {}", e))
174            })?
175        } else {
176            tensor.clone()
177        };
178
179        // Get data as Vec<f32>
180        let data: Vec<f32> = tensor_f32
181            .to_vec1()
182            .map_err(|e| CoreError::WeightLoadError(format!("Failed to convert to vec: {}", e)))?;
183
184        Ok(Array1::from_vec(data))
185    }
186
187    /// Apply weight mappings to convert PyTorch names to Kizzasi names
188    pub fn apply_mappings(
189        &self,
190        weights: HashMap<String, Tensor>,
191    ) -> CoreResult<HashMap<String, Tensor>> {
192        let mut mapped_weights = HashMap::new();
193
194        for (source_name, tensor) in weights {
195            // Try to find a matching mapping
196            let mut mapped = false;
197            for mapping in &self.mappings {
198                if source_name.contains(&mapping.source_pattern) {
199                    let mut target_tensor = tensor.clone();
200
201                    // Transpose if needed
202                    if mapping.transpose && target_tensor.rank() == 2 {
203                        target_tensor = target_tensor
204                            .t()
205                            .map_err(|e| {
206                                CoreError::WeightLoadError(format!("Failed to transpose: {}", e))
207                            })?
208                            .contiguous()
209                            .map_err(|e| {
210                                CoreError::WeightLoadError(format!(
211                                    "Failed to make contiguous: {}",
212                                    e
213                                ))
214                            })?;
215                    }
216
217                    mapped_weights.insert(mapping.target_name.clone(), target_tensor);
218                    mapped = true;
219                    break;
220                }
221            }
222
223            // If no mapping found, keep original name
224            if !mapped {
225                mapped_weights.insert(source_name, tensor);
226            }
227        }
228
229        Ok(mapped_weights)
230    }
231
232    /// Detect architecture from checkpoint
233    pub fn detect_architecture(
234        &self,
235        weights: &HashMap<String, Tensor>,
236    ) -> CoreResult<PyTorchCheckpoint> {
237        let mut metadata = HashMap::new();
238        let mut architecture = "unknown".to_string();
239        let mut num_layers = None;
240        let mut hidden_dim = None;
241        let mut d_model = None;
242        let mut d_state = None;
243
244        // Analyze weight names to detect architecture
245        for (name, tensor) in weights {
246            // Detect Mamba
247            if name.contains("mixer") || name.contains("ssm") {
248                architecture = "mamba".to_string();
249            }
250            // Detect Mamba-2
251            else if name.contains("ssd") || name.contains("mamba2") {
252                architecture = "mamba2".to_string();
253            }
254            // Detect S4/S4D
255            else if name.contains("s4") {
256                architecture = "s4d".to_string();
257            }
258            // Detect S5
259            else if name.contains("s5") || name.contains("block_diagonal") {
260                architecture = "s5".to_string();
261            }
262            // Detect RetNet
263            else if name.contains("retention") {
264                architecture = "retnet".to_string();
265            }
266
267            // Extract dimensions
268            if name.contains("layers.") {
269                // Count layers
270                if let Some(layer_str) = name.split("layers.").nth(1) {
271                    if let Some(layer_num_str) = layer_str.split('.').next() {
272                        if let Ok(layer_num) = layer_num_str.parse::<usize>() {
273                            num_layers = Some(num_layers.unwrap_or(0).max(layer_num + 1));
274                        }
275                    }
276                }
277            }
278
279            // Detect dimensions from tensor shapes
280            let shape = tensor.shape();
281            if (name.contains("in_proj") || name.contains("embedding")) && shape.rank() == 2 {
282                d_model = Some(shape.dims()[0]);
283            }
284            if (name.contains("dt_proj") || name.contains("ssm")) && shape.rank() == 2 {
285                hidden_dim = Some(shape.dims()[0]);
286            }
287            if (name.contains("a_log") || name.contains("lambda")) && shape.rank() >= 1 {
288                d_state = Some(shape.dims()[shape.rank() - 1]);
289            }
290        }
291
292        metadata.insert("num_weights".to_string(), weights.len().to_string());
293
294        Ok(PyTorchCheckpoint {
295            architecture,
296            num_layers,
297            hidden_dim,
298            d_model,
299            d_state,
300            metadata,
301        })
302    }
303
304    /// Create default mappings for Mamba architecture
305    pub fn create_mamba_mappings(&mut self) {
306        // Input embeddings
307        self.add_mapping("embedding.weight", "embedding_w", false);
308
309        // Layer mappings (example for layer 0)
310        for i in 0..32 {
311            // Adjust layer count as needed
312            let prefix = format!("layers.{}", i);
313            let target_prefix = format!("layer_{}", i);
314
315            self.add_mapping(
316                &format!("{}.mixer.in_proj", prefix),
317                &format!("{}.in_proj_w", target_prefix),
318                true,
319            );
320            self.add_mapping(
321                &format!("{}.mixer.out_proj", prefix),
322                &format!("{}.out_proj_w", target_prefix),
323                true,
324            );
325            self.add_mapping(
326                &format!("{}.mixer.conv1d.weight", prefix),
327                &format!("{}.conv1d_w", target_prefix),
328                false,
329            );
330            self.add_mapping(
331                &format!("{}.mixer.conv1d.bias", prefix),
332                &format!("{}.conv1d_b", target_prefix),
333                false,
334            );
335            self.add_mapping(
336                &format!("{}.mixer.dt_proj", prefix),
337                &format!("{}.dt_proj_w", target_prefix),
338                true,
339            );
340            self.add_mapping(
341                &format!("{}.mixer.A_log", prefix),
342                &format!("{}.a_log", target_prefix),
343                false,
344            );
345            self.add_mapping(
346                &format!("{}.mixer.D", prefix),
347                &format!("{}.d_param", target_prefix),
348                false,
349            );
350            self.add_mapping(
351                &format!("{}.norm.weight", prefix),
352                &format!("{}.norm_w", target_prefix),
353                false,
354            );
355            self.add_mapping(
356                &format!("{}.norm.bias", prefix),
357                &format!("{}.norm_b", target_prefix),
358                false,
359            );
360        }
361
362        // Output head
363        self.add_mapping("lm_head.weight", "output_w", true);
364    }
365
366    /// Create default mappings for S4D architecture
367    pub fn create_s4d_mappings(&mut self) {
368        for i in 0..32 {
369            let prefix = format!("layers.{}", i);
370            let target_prefix = format!("layer_{}", i);
371
372            self.add_mapping(
373                &format!("{}.input_proj", prefix),
374                &format!("{}.input_proj", target_prefix),
375                true,
376            );
377            self.add_mapping(
378                &format!("{}.output_proj", prefix),
379                &format!("{}.output_proj", target_prefix),
380                true,
381            );
382            self.add_mapping(
383                &format!("{}.lambda", prefix),
384                &format!("{}.lambda", target_prefix),
385                false,
386            );
387            self.add_mapping(
388                &format!("{}.B", prefix),
389                &format!("{}.b", target_prefix),
390                false,
391            );
392            self.add_mapping(
393                &format!("{}.C", prefix),
394                &format!("{}.c", target_prefix),
395                false,
396            );
397            self.add_mapping(
398                &format!("{}.D", prefix),
399                &format!("{}.d", target_prefix),
400                false,
401            );
402        }
403    }
404}
405
406/// Helper function to load PyTorch checkpoint from safetensors
407pub fn load_pytorch_checkpoint(path: impl AsRef<Path>) -> CoreResult<HashMap<String, Tensor>> {
408    let converter = PyTorchConverter::new_cpu();
409    converter.load_safetensors(path)
410}
411
412/// Helper function to detect architecture from checkpoint file
413pub fn detect_checkpoint_architecture(path: impl AsRef<Path>) -> CoreResult<PyTorchCheckpoint> {
414    let converter = PyTorchConverter::new_cpu();
415    let weights = converter.load_safetensors(path)?;
416    converter.detect_architecture(&weights)
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[test]
424    fn test_converter_creation() {
425        let converter = PyTorchConverter::new_cpu();
426        assert_eq!(converter.mappings.len(), 0);
427    }
428
429    #[test]
430    fn test_add_mapping() {
431        let mut converter = PyTorchConverter::new_cpu();
432        converter.add_mapping("layers.0.weight", "layer_0_w", true);
433        assert_eq!(converter.mappings.len(), 1);
434        assert_eq!(converter.mappings[0].source_pattern, "layers.0.weight");
435        assert_eq!(converter.mappings[0].target_name, "layer_0_w");
436        assert!(converter.mappings[0].transpose);
437    }
438
439    #[test]
440    fn test_mamba_mappings() {
441        let mut converter = PyTorchConverter::new_cpu();
442        converter.create_mamba_mappings();
443        assert!(!converter.mappings.is_empty());
444        // Should have mappings for multiple layers
445        let has_layer_0 = converter
446            .mappings
447            .iter()
448            .any(|m| m.target_name.contains("layer_0"));
449        assert!(has_layer_0);
450    }
451
452    #[test]
453    fn test_s4d_mappings() {
454        let mut converter = PyTorchConverter::new_cpu();
455        converter.create_s4d_mappings();
456        assert!(!converter.mappings.is_empty());
457    }
458
459    #[test]
460    fn test_tensor_conversion() {
461        let converter = PyTorchConverter::new_cpu();
462
463        // Create a test tensor
464        let data = vec![vec![1.0f32, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
465        let tensor = Tensor::new(data, &Device::Cpu).unwrap();
466
467        let array = converter.tensor_to_array2(&tensor).unwrap();
468        assert_eq!(array.shape(), &[2, 3]);
469        assert_eq!(array[[0, 0]], 1.0);
470        assert_eq!(array[[1, 2]], 6.0);
471    }
472
473    #[test]
474    fn test_tensor_1d_conversion() {
475        let converter = PyTorchConverter::new_cpu();
476
477        let data = vec![1.0f32, 2.0, 3.0, 4.0];
478        let tensor = Tensor::new(data, &Device::Cpu).unwrap();
479
480        let array = converter.tensor_to_array1(&tensor).unwrap();
481        assert_eq!(array.len(), 4);
482        assert_eq!(array[0], 1.0);
483        assert_eq!(array[3], 4.0);
484    }
485
486    #[test]
487    fn test_architecture_detection() {
488        let converter = PyTorchConverter::new_cpu();
489        let mut weights = HashMap::new();
490
491        // Create dummy tensors that look like Mamba weights
492        let tensor = Tensor::zeros((256, 128), DType::F32, &Device::Cpu).unwrap();
493        weights.insert("layers.0.mixer.in_proj.weight".to_string(), tensor.clone());
494        weights.insert("layers.0.mixer.A_log".to_string(), tensor.clone());
495
496        let checkpoint = converter.detect_architecture(&weights).unwrap();
497        assert_eq!(checkpoint.architecture, "mamba");
498        assert_eq!(checkpoint.num_layers, Some(1));
499    }
500
501    #[test]
502    fn test_checkpoint_metadata() {
503        let checkpoint = PyTorchCheckpoint {
504            architecture: "mamba2".to_string(),
505            num_layers: Some(24),
506            hidden_dim: Some(768),
507            d_model: Some(768),
508            d_state: Some(16),
509            metadata: HashMap::new(),
510        };
511
512        assert_eq!(checkpoint.architecture, "mamba2");
513        assert_eq!(checkpoint.num_layers, Some(24));
514    }
515}