1use crate::{CoreError, CoreResult};
20use candle_core::{DType, Device, Tensor};
21use safetensors::SafeTensors;
22use scirs2_core::ndarray::{Array1, Array2};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::path::Path;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PyTorchCheckpoint {
30 pub architecture: String,
32 pub num_layers: Option<usize>,
34 pub hidden_dim: Option<usize>,
36 pub d_model: Option<usize>,
38 pub d_state: Option<usize>,
40 pub metadata: HashMap<String, String>,
42}
43
44#[derive(Debug, Clone)]
46pub struct WeightMapping {
47 pub source_pattern: String,
49 pub target_name: String,
51 pub transpose: bool,
53}
54
55pub struct PyTorchConverter {
57 device: Device,
59 mappings: Vec<WeightMapping>,
61}
62
63impl PyTorchConverter {
64 pub fn new(device: Device) -> Self {
66 Self {
67 device,
68 mappings: Vec::new(),
69 }
70 }
71
72 pub fn new_cpu() -> Self {
74 Self::new(Device::Cpu)
75 }
76
77 pub fn add_mapping(&mut self, source: &str, target: &str, transpose: bool) {
79 self.mappings.push(WeightMapping {
80 source_pattern: source.to_string(),
81 target_name: target.to_string(),
82 transpose,
83 });
84 }
85
86 pub fn load_safetensors(&self, path: impl AsRef<Path>) -> CoreResult<HashMap<String, Tensor>> {
88 let data = std::fs::read(path.as_ref())
89 .map_err(|e| CoreError::WeightLoadError(format!("Failed to read file: {}", e)))?;
90
91 let tensors = SafeTensors::deserialize(&data).map_err(|e| {
92 CoreError::WeightLoadError(format!("Failed to deserialize safetensors: {}", e))
93 })?;
94
95 let mut weights = HashMap::new();
96
97 for (name, tensor_view) in tensors.tensors() {
98 let tensor = self.safetensor_to_candle(&tensor_view)?;
99 weights.insert(name.to_string(), tensor);
100 }
101
102 Ok(weights)
103 }
104
105 fn safetensor_to_candle(&self, view: &safetensors::tensor::TensorView) -> CoreResult<Tensor> {
107 let shape: Vec<usize> = view.shape().to_vec();
108 let dtype = match view.dtype() {
109 safetensors::Dtype::F32 => DType::F32,
110 safetensors::Dtype::F16 => DType::F16,
111 safetensors::Dtype::BF16 => DType::BF16,
112 safetensors::Dtype::F64 => DType::F64,
113 safetensors::Dtype::I64 => DType::I64,
114 safetensors::Dtype::U8 => DType::U8,
115 _ => {
116 return Err(CoreError::WeightLoadError(format!(
117 "Unsupported dtype: {:?}",
118 view.dtype()
119 )))
120 }
121 };
122
123 let data = view.data();
124 Tensor::from_raw_buffer(data, dtype, &shape, &self.device)
125 .map_err(|e| CoreError::WeightLoadError(format!("Failed to create tensor: {}", e)))
126 }
127
128 pub fn tensor_to_array2(&self, tensor: &Tensor) -> CoreResult<Array2<f32>> {
130 if tensor.rank() != 2 {
131 return Err(CoreError::WeightLoadError(format!(
132 "Expected 2D tensor, got rank {}",
133 tensor.rank()
134 )));
135 }
136
137 let shape = tensor.shape();
138 let rows = shape.dims()[0];
139 let cols = shape.dims()[1];
140
141 let tensor_f32 = if tensor.dtype() != DType::F32 {
143 tensor.to_dtype(DType::F32).map_err(|e| {
144 CoreError::WeightLoadError(format!("Failed to convert dtype: {}", e))
145 })?
146 } else {
147 tensor.clone()
148 };
149
150 let data: Vec<f32> = tensor_f32
152 .to_vec2()
153 .map_err(|e| CoreError::WeightLoadError(format!("Failed to convert to vec: {}", e)))?
154 .into_iter()
155 .flatten()
156 .collect();
157
158 Array2::from_shape_vec((rows, cols), data).map_err(CoreError::ShapeError)
159 }
160
161 pub fn tensor_to_array1(&self, tensor: &Tensor) -> CoreResult<Array1<f32>> {
163 if tensor.rank() != 1 {
164 return Err(CoreError::WeightLoadError(format!(
165 "Expected 1D tensor, got rank {}",
166 tensor.rank()
167 )));
168 }
169
170 let tensor_f32 = if tensor.dtype() != DType::F32 {
172 tensor.to_dtype(DType::F32).map_err(|e| {
173 CoreError::WeightLoadError(format!("Failed to convert dtype: {}", e))
174 })?
175 } else {
176 tensor.clone()
177 };
178
179 let data: Vec<f32> = tensor_f32
181 .to_vec1()
182 .map_err(|e| CoreError::WeightLoadError(format!("Failed to convert to vec: {}", e)))?;
183
184 Ok(Array1::from_vec(data))
185 }
186
187 pub fn apply_mappings(
189 &self,
190 weights: HashMap<String, Tensor>,
191 ) -> CoreResult<HashMap<String, Tensor>> {
192 let mut mapped_weights = HashMap::new();
193
194 for (source_name, tensor) in weights {
195 let mut mapped = false;
197 for mapping in &self.mappings {
198 if source_name.contains(&mapping.source_pattern) {
199 let mut target_tensor = tensor.clone();
200
201 if mapping.transpose && target_tensor.rank() == 2 {
203 target_tensor = target_tensor
204 .t()
205 .map_err(|e| {
206 CoreError::WeightLoadError(format!("Failed to transpose: {}", e))
207 })?
208 .contiguous()
209 .map_err(|e| {
210 CoreError::WeightLoadError(format!(
211 "Failed to make contiguous: {}",
212 e
213 ))
214 })?;
215 }
216
217 mapped_weights.insert(mapping.target_name.clone(), target_tensor);
218 mapped = true;
219 break;
220 }
221 }
222
223 if !mapped {
225 mapped_weights.insert(source_name, tensor);
226 }
227 }
228
229 Ok(mapped_weights)
230 }
231
232 pub fn detect_architecture(
234 &self,
235 weights: &HashMap<String, Tensor>,
236 ) -> CoreResult<PyTorchCheckpoint> {
237 let mut metadata = HashMap::new();
238 let mut architecture = "unknown".to_string();
239 let mut num_layers = None;
240 let mut hidden_dim = None;
241 let mut d_model = None;
242 let mut d_state = None;
243
244 for (name, tensor) in weights {
246 if name.contains("mixer") || name.contains("ssm") {
248 architecture = "mamba".to_string();
249 }
250 else if name.contains("ssd") || name.contains("mamba2") {
252 architecture = "mamba2".to_string();
253 }
254 else if name.contains("s4") {
256 architecture = "s4d".to_string();
257 }
258 else if name.contains("s5") || name.contains("block_diagonal") {
260 architecture = "s5".to_string();
261 }
262 else if name.contains("retention") {
264 architecture = "retnet".to_string();
265 }
266
267 if name.contains("layers.") {
269 if let Some(layer_str) = name.split("layers.").nth(1) {
271 if let Some(layer_num_str) = layer_str.split('.').next() {
272 if let Ok(layer_num) = layer_num_str.parse::<usize>() {
273 num_layers = Some(num_layers.unwrap_or(0).max(layer_num + 1));
274 }
275 }
276 }
277 }
278
279 let shape = tensor.shape();
281 if (name.contains("in_proj") || name.contains("embedding")) && shape.rank() == 2 {
282 d_model = Some(shape.dims()[0]);
283 }
284 if (name.contains("dt_proj") || name.contains("ssm")) && shape.rank() == 2 {
285 hidden_dim = Some(shape.dims()[0]);
286 }
287 if (name.contains("a_log") || name.contains("lambda")) && shape.rank() >= 1 {
288 d_state = Some(shape.dims()[shape.rank() - 1]);
289 }
290 }
291
292 metadata.insert("num_weights".to_string(), weights.len().to_string());
293
294 Ok(PyTorchCheckpoint {
295 architecture,
296 num_layers,
297 hidden_dim,
298 d_model,
299 d_state,
300 metadata,
301 })
302 }
303
304 pub fn create_mamba_mappings(&mut self) {
306 self.add_mapping("embedding.weight", "embedding_w", false);
308
309 for i in 0..32 {
311 let prefix = format!("layers.{}", i);
313 let target_prefix = format!("layer_{}", i);
314
315 self.add_mapping(
316 &format!("{}.mixer.in_proj", prefix),
317 &format!("{}.in_proj_w", target_prefix),
318 true,
319 );
320 self.add_mapping(
321 &format!("{}.mixer.out_proj", prefix),
322 &format!("{}.out_proj_w", target_prefix),
323 true,
324 );
325 self.add_mapping(
326 &format!("{}.mixer.conv1d.weight", prefix),
327 &format!("{}.conv1d_w", target_prefix),
328 false,
329 );
330 self.add_mapping(
331 &format!("{}.mixer.conv1d.bias", prefix),
332 &format!("{}.conv1d_b", target_prefix),
333 false,
334 );
335 self.add_mapping(
336 &format!("{}.mixer.dt_proj", prefix),
337 &format!("{}.dt_proj_w", target_prefix),
338 true,
339 );
340 self.add_mapping(
341 &format!("{}.mixer.A_log", prefix),
342 &format!("{}.a_log", target_prefix),
343 false,
344 );
345 self.add_mapping(
346 &format!("{}.mixer.D", prefix),
347 &format!("{}.d_param", target_prefix),
348 false,
349 );
350 self.add_mapping(
351 &format!("{}.norm.weight", prefix),
352 &format!("{}.norm_w", target_prefix),
353 false,
354 );
355 self.add_mapping(
356 &format!("{}.norm.bias", prefix),
357 &format!("{}.norm_b", target_prefix),
358 false,
359 );
360 }
361
362 self.add_mapping("lm_head.weight", "output_w", true);
364 }
365
366 pub fn create_s4d_mappings(&mut self) {
368 for i in 0..32 {
369 let prefix = format!("layers.{}", i);
370 let target_prefix = format!("layer_{}", i);
371
372 self.add_mapping(
373 &format!("{}.input_proj", prefix),
374 &format!("{}.input_proj", target_prefix),
375 true,
376 );
377 self.add_mapping(
378 &format!("{}.output_proj", prefix),
379 &format!("{}.output_proj", target_prefix),
380 true,
381 );
382 self.add_mapping(
383 &format!("{}.lambda", prefix),
384 &format!("{}.lambda", target_prefix),
385 false,
386 );
387 self.add_mapping(
388 &format!("{}.B", prefix),
389 &format!("{}.b", target_prefix),
390 false,
391 );
392 self.add_mapping(
393 &format!("{}.C", prefix),
394 &format!("{}.c", target_prefix),
395 false,
396 );
397 self.add_mapping(
398 &format!("{}.D", prefix),
399 &format!("{}.d", target_prefix),
400 false,
401 );
402 }
403 }
404}
405
406pub fn load_pytorch_checkpoint(path: impl AsRef<Path>) -> CoreResult<HashMap<String, Tensor>> {
408 let converter = PyTorchConverter::new_cpu();
409 converter.load_safetensors(path)
410}
411
412pub fn detect_checkpoint_architecture(path: impl AsRef<Path>) -> CoreResult<PyTorchCheckpoint> {
414 let converter = PyTorchConverter::new_cpu();
415 let weights = converter.load_safetensors(path)?;
416 converter.detect_architecture(&weights)
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn test_converter_creation() {
425 let converter = PyTorchConverter::new_cpu();
426 assert_eq!(converter.mappings.len(), 0);
427 }
428
429 #[test]
430 fn test_add_mapping() {
431 let mut converter = PyTorchConverter::new_cpu();
432 converter.add_mapping("layers.0.weight", "layer_0_w", true);
433 assert_eq!(converter.mappings.len(), 1);
434 assert_eq!(converter.mappings[0].source_pattern, "layers.0.weight");
435 assert_eq!(converter.mappings[0].target_name, "layer_0_w");
436 assert!(converter.mappings[0].transpose);
437 }
438
439 #[test]
440 fn test_mamba_mappings() {
441 let mut converter = PyTorchConverter::new_cpu();
442 converter.create_mamba_mappings();
443 assert!(!converter.mappings.is_empty());
444 let has_layer_0 = converter
446 .mappings
447 .iter()
448 .any(|m| m.target_name.contains("layer_0"));
449 assert!(has_layer_0);
450 }
451
452 #[test]
453 fn test_s4d_mappings() {
454 let mut converter = PyTorchConverter::new_cpu();
455 converter.create_s4d_mappings();
456 assert!(!converter.mappings.is_empty());
457 }
458
459 #[test]
460 fn test_tensor_conversion() {
461 let converter = PyTorchConverter::new_cpu();
462
463 let data = vec![vec![1.0f32, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
465 let tensor = Tensor::new(data, &Device::Cpu).unwrap();
466
467 let array = converter.tensor_to_array2(&tensor).unwrap();
468 assert_eq!(array.shape(), &[2, 3]);
469 assert_eq!(array[[0, 0]], 1.0);
470 assert_eq!(array[[1, 2]], 6.0);
471 }
472
473 #[test]
474 fn test_tensor_1d_conversion() {
475 let converter = PyTorchConverter::new_cpu();
476
477 let data = vec![1.0f32, 2.0, 3.0, 4.0];
478 let tensor = Tensor::new(data, &Device::Cpu).unwrap();
479
480 let array = converter.tensor_to_array1(&tensor).unwrap();
481 assert_eq!(array.len(), 4);
482 assert_eq!(array[0], 1.0);
483 assert_eq!(array[3], 4.0);
484 }
485
486 #[test]
487 fn test_architecture_detection() {
488 let converter = PyTorchConverter::new_cpu();
489 let mut weights = HashMap::new();
490
491 let tensor = Tensor::zeros((256, 128), DType::F32, &Device::Cpu).unwrap();
493 weights.insert("layers.0.mixer.in_proj.weight".to_string(), tensor.clone());
494 weights.insert("layers.0.mixer.A_log".to_string(), tensor.clone());
495
496 let checkpoint = converter.detect_architecture(&weights).unwrap();
497 assert_eq!(checkpoint.architecture, "mamba");
498 assert_eq!(checkpoint.num_layers, Some(1));
499 }
500
501 #[test]
502 fn test_checkpoint_metadata() {
503 let checkpoint = PyTorchCheckpoint {
504 architecture: "mamba2".to_string(),
505 num_layers: Some(24),
506 hidden_dim: Some(768),
507 d_model: Some(768),
508 d_state: Some(16),
509 metadata: HashMap::new(),
510 };
511
512 assert_eq!(checkpoint.architecture, "mamba2");
513 assert_eq!(checkpoint.num_layers, Some(24));
514 }
515}