pub trait WeightLoader {
fn tensor_bytes(&self, name: &str) -> Option<&[u8]>;
fn names(&self) -> Vec<String>;
}
pub struct BytesWeightLoader {
entries: Vec<(String, usize, usize)>,
data: Vec<u8>,
}
impl BytesWeightLoader {
pub fn from_pairs(pairs: Vec<(String, Vec<u8>)>) -> Self {
let total: usize = pairs.iter().map(|(_, b)| b.len()).sum();
let mut data = Vec::with_capacity(total);
let mut entries = Vec::with_capacity(pairs.len());
for (name, bytes) in pairs {
let start = data.len();
let len = bytes.len();
data.extend_from_slice(&bytes);
entries.push((name, start, len));
}
Self { entries, data }
}
}
impl WeightLoader for BytesWeightLoader {
fn tensor_bytes(&self, name: &str) -> Option<&[u8]> {
self.entries
.iter()
.find(|(n, _, _)| n == name)
.map(|(_, off, len)| &self.data[*off..*off + *len])
}
fn names(&self) -> Vec<String> {
self.entries.iter().map(|(n, _, _)| n.clone()).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip() {
let loader = BytesWeightLoader::from_pairs(vec![
("w".into(), vec![1, 2, 3, 4]),
("b".into(), vec![5, 6]),
]);
assert_eq!(loader.tensor_bytes("w"), Some(&[1u8, 2, 3, 4][..]));
assert_eq!(loader.tensor_bytes("b"), Some(&[5u8, 6][..]));
assert_eq!(loader.tensor_bytes("missing"), None);
assert_eq!(loader.names(), vec!["w".to_string(), "b".to_string()]);
}
}