1use anyhow::Result;
19
20pub trait WeightSource {
23 fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)>;
24
25 fn has(&self, key: &str) -> bool {
27 let _ = key;
28 false
29 }
30}
31
32impl<T: WeightSource + ?Sized> WeightSource for &mut T {
33 fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
34 (*self).take(key, transpose)
35 }
36}
37
38#[derive(Debug, Default, Clone)]
40pub struct MapWeights {
41 pub tensors: std::collections::HashMap<String, (Vec<f32>, Vec<usize>)>,
42}
43
44impl MapWeights {
45 pub fn insert(&mut self, key: impl Into<String>, data: Vec<f32>, shape: Vec<usize>) {
46 self.tensors.insert(key.into(), (data, shape));
47 }
48}
49
50impl WeightSource for MapWeights {
51 fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
52 let (data, shape) = self
53 .tensors
54 .remove(key)
55 .ok_or_else(|| anyhow::anyhow!("missing weight: {key}"))?;
56 if !transpose {
57 return Ok((data, shape));
58 }
59 if shape.len() != 2 {
60 return Err(anyhow::anyhow!("transpose requires rank-2 weight: {key}"));
61 }
62 let rows = shape[0];
63 let cols = shape[1];
64 let mut out = vec![0f32; rows * cols];
65 for r in 0..rows {
66 for c in 0..cols {
67 out[c * rows + r] = data[r * cols + c];
68 }
69 }
70 Ok((out, vec![cols, rows]))
71 }
72}