1use anyhow::Result;
7
8pub trait WeightSource {
11 fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)>;
12
13 fn has(&self, key: &str) -> bool {
15 let _ = key;
16 false
17 }
18}
19
20impl<T: WeightSource + ?Sized> WeightSource for &mut T {
21 fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
22 (*self).take(key, transpose)
23 }
24}
25
26#[derive(Debug, Default, Clone)]
28pub struct MapWeights {
29 pub tensors: std::collections::HashMap<String, (Vec<f32>, Vec<usize>)>,
30}
31
32impl MapWeights {
33 pub fn insert(&mut self, key: impl Into<String>, data: Vec<f32>, shape: Vec<usize>) {
34 self.tensors.insert(key.into(), (data, shape));
35 }
36}
37
38impl WeightSource for MapWeights {
39 fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
40 let (data, shape) = self
41 .tensors
42 .remove(key)
43 .ok_or_else(|| anyhow::anyhow!("missing weight: {key}"))?;
44 if !transpose {
45 return Ok((data, shape));
46 }
47 if shape.len() != 2 {
48 return Err(anyhow::anyhow!("transpose requires rank-2 weight: {key}"));
49 }
50 let rows = shape[0];
51 let cols = shape[1];
52 let mut out = vec![0f32; rows * cols];
53 for r in 0..rows {
54 for c in 0..cols {
55 out[c * rows + r] = data[r * cols + c];
56 }
57 }
58 Ok((out, vec![cols, rows]))
59 }
60}