1use candle::{DType, Device, Result, Shape, Tensor, Var};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6
7#[derive(Clone)]
12pub struct VarMap {
13 data: Arc<Mutex<HashMap<String, Var>>>,
14}
15
16impl VarMap {
17 #[allow(clippy::new_without_default)]
19 pub fn new() -> Self {
20 let data = Arc::new(Mutex::new(HashMap::new()));
21 Self { data }
22 }
23
24 pub fn all_vars(&self) -> Vec<Var> {
26 let tensor_data = self.data.lock().unwrap();
27 #[allow(clippy::map_clone)]
28 tensor_data.values().map(|c| c.clone()).collect::<Vec<_>>()
29 }
30
31 pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
33 let tensor_data = self.data.lock().unwrap();
34 let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor()));
35 safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
36 Ok(())
37 }
38
39 pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
44 let path = path.as_ref();
45 let data = unsafe { candle::safetensors::MmapedSafetensors::new(path)? };
46 let mut tensor_data = self.data.lock().unwrap();
47 for (name, var) in tensor_data.iter_mut() {
48 let data = data.load(name, var.device())?;
49 if let Err(err) = var.set(&data) {
50 candle::bail!("error setting {name} using data from {path:?}: {err}",)
51 }
52 }
53 Ok(())
54 }
55
56 pub fn set_one<K: AsRef<str>, V: AsRef<Tensor>>(&mut self, name: K, value: V) -> Result<()> {
58 let tensor_data = self.data.lock().unwrap();
59 let name = name.as_ref();
60 match tensor_data.get(name) {
61 None => candle::bail!("cannot find {name} in VarMap"),
62 Some(var) => {
63 if let Err(err) = var.set(value.as_ref()) {
64 candle::bail!("error setting {name}: {err}",)
65 }
66 }
67 }
68 Ok(())
69 }
70
71 pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<str>, V: AsRef<Tensor>>(
76 &mut self,
77 iter: I,
78 ) -> Result<()> {
79 let tensor_data = self.data.lock().unwrap();
80 for (name, value) in iter {
81 let name = name.as_ref();
82 match tensor_data.get(name) {
83 None => candle::bail!("cannot find {name} in VarMap"),
84 Some(var) => {
85 if let Err(err) = var.set(value.as_ref()) {
86 candle::bail!("error setting {name}: {err}",)
87 }
88 }
89 }
90 }
91 Ok(())
92 }
93
94 pub fn get<S: Into<Shape>>(
96 &self,
97 shape: S,
98 path: &str,
99 init: crate::Init,
100 dtype: DType,
101 device: &Device,
102 ) -> Result<Tensor> {
103 let shape = shape.into();
104 let mut tensor_data = self.data.lock().unwrap();
105 if let Some(tensor) = tensor_data.get(path) {
106 let tensor_shape = tensor.shape();
107 if &shape != tensor_shape {
108 candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}")
109 }
110 return Ok(tensor.as_tensor().clone());
111 }
112 let var = init.var(shape, dtype, device)?;
113 let tensor = var.as_tensor().clone();
114 tensor_data.insert(path.to_string(), var);
115 Ok(tensor)
116 }
117
118 pub fn data(&self) -> &Mutex<HashMap<String, Var>> {
119 &self.data
120 }
121}