entrenar/transformer/weights/
convert.rs1pub(crate) fn tensor_to_f32_vec(tensor: &safetensors::tensor::TensorView<'_>) -> Option<Vec<f32>> {
7 use safetensors::Dtype;
8
9 let shape = tensor.shape();
10 let numel: usize = shape.iter().product();
11
12 if numel == 0 {
13 return Some(Vec::new());
14 }
15
16 let data = tensor.data();
17
18 match tensor.dtype() {
19 Dtype::F32 => {
20 let values: Vec<f32> = data
22 .chunks_exact(4)
23 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
24 .collect();
25 Some(values)
26 }
27 Dtype::F16 => {
28 let values: Vec<f32> = data
30 .chunks_exact(2)
31 .map(|chunk| {
32 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
33 half::f16::from_bits(bits).to_f32()
34 })
35 .collect();
36 Some(values)
37 }
38 Dtype::BF16 => {
39 let values: Vec<f32> = data
41 .chunks_exact(2)
42 .map(|chunk| {
43 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
44 half::bf16::from_bits(bits).to_f32()
45 })
46 .collect();
47 Some(values)
48 }
49 Dtype::I32 => {
50 let values: Vec<f32> = data
52 .chunks_exact(4)
53 .map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as f32)
54 .collect();
55 Some(values)
56 }
57 _ => {
58 eprintln!("Warning: Unsupported tensor dtype {:?}, skipping", tensor.dtype());
60 None
61 }
62 }
63}