Skip to main content

entrenar/transformer/weights/
convert.rs

1//! Tensor format conversion from SafeTensors to f32
2
3/// Convert SafeTensors tensor view to f32 Vec
4///
5/// Handles bf16, fp16, and fp32 formats.
6pub(crate) fn tensor_to_f32_vec(tensor: &safetensors::tensor::TensorView<'_>) -> Option<Vec<f32>> {
7    use safetensors::Dtype;
8
9    let shape = tensor.shape();
10    let numel: usize = shape.iter().product();
11
12    if numel == 0 {
13        return Some(Vec::new());
14    }
15
16    let data = tensor.data();
17
18    match tensor.dtype() {
19        Dtype::F32 => {
20            // Direct f32 conversion
21            let values: Vec<f32> = data
22                .chunks_exact(4)
23                .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
24                .collect();
25            Some(values)
26        }
27        Dtype::F16 => {
28            // fp16 conversion
29            let values: Vec<f32> = data
30                .chunks_exact(2)
31                .map(|chunk| {
32                    let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
33                    half::f16::from_bits(bits).to_f32()
34                })
35                .collect();
36            Some(values)
37        }
38        Dtype::BF16 => {
39            // bf16 conversion
40            let values: Vec<f32> = data
41                .chunks_exact(2)
42                .map(|chunk| {
43                    let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
44                    half::bf16::from_bits(bits).to_f32()
45                })
46                .collect();
47            Some(values)
48        }
49        Dtype::I32 => {
50            // Integer to float (rare for transformer weights)
51            let values: Vec<f32> = data
52                .chunks_exact(4)
53                .map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as f32)
54                .collect();
55            Some(values)
56        }
57        _ => {
58            // Unsupported dtype
59            eprintln!("Warning: Unsupported tensor dtype {:?}, skipping", tensor.dtype());
60            None
61        }
62    }
63}