candle_nn/
var_map.rs

1//! A `VarMap` is a store that holds named variables.
2//!
3use candle::{DType, Device, Result, Shape, Tensor, Var};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6
7/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores
8/// and new variables can be added by providing some initialization config in case they are
9/// missing.
10/// `VarMap` structures can be serialized in the safetensors format.
11#[derive(Clone)]
12pub struct VarMap {
13    data: Arc<Mutex<HashMap<String, Var>>>,
14}
15
16impl VarMap {
17    /// Create a new empty `VarMap`.
18    #[allow(clippy::new_without_default)]
19    pub fn new() -> Self {
20        let data = Arc::new(Mutex::new(HashMap::new()));
21        Self { data }
22    }
23
24    /// Retrieve all the variables currently stored in the map.
25    pub fn all_vars(&self) -> Vec<Var> {
26        let tensor_data = self.data.lock().unwrap();
27        #[allow(clippy::map_clone)]
28        tensor_data.values().map(|c| c.clone()).collect::<Vec<_>>()
29    }
30
31    /// Save the map in the safetensors format.
32    pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
33        let tensor_data = self.data.lock().unwrap();
34        let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor()));
35        safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
36        Ok(())
37    }
38
39    /// Load some values from a safetensors file and modify the existing variables to have these
40    /// values.
41    ///
42    /// Note that values for variables that are currently not in the map are not kept.
43    pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
44        let path = path.as_ref();
45        let data = unsafe { candle::safetensors::MmapedSafetensors::new(path)? };
46        let mut tensor_data = self.data.lock().unwrap();
47        for (name, var) in tensor_data.iter_mut() {
48            let data = data.load(name, var.device())?;
49            if let Err(err) = var.set(&data) {
50                candle::bail!("error setting {name} using data from {path:?}: {err}",)
51            }
52        }
53        Ok(())
54    }
55
56    /// Set a named variable to some value.
57    pub fn set_one<K: AsRef<str>, V: AsRef<Tensor>>(&mut self, name: K, value: V) -> Result<()> {
58        let tensor_data = self.data.lock().unwrap();
59        let name = name.as_ref();
60        match tensor_data.get(name) {
61            None => candle::bail!("cannot find {name} in VarMap"),
62            Some(var) => {
63                if let Err(err) = var.set(value.as_ref()) {
64                    candle::bail!("error setting {name}: {err}",)
65                }
66            }
67        }
68        Ok(())
69    }
70
71    /// Set some named variables to some values.
72    ///
73    /// If an error is returned, some of the variables might have already been set to their new
74    /// values.
75    pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<str>, V: AsRef<Tensor>>(
76        &mut self,
77        iter: I,
78    ) -> Result<()> {
79        let tensor_data = self.data.lock().unwrap();
80        for (name, value) in iter {
81            let name = name.as_ref();
82            match tensor_data.get(name) {
83                None => candle::bail!("cannot find {name} in VarMap"),
84                Some(var) => {
85                    if let Err(err) = var.set(value.as_ref()) {
86                        candle::bail!("error setting {name}: {err}",)
87                    }
88                }
89            }
90        }
91        Ok(())
92    }
93
94    /// Retrieve or add a new variable.
95    pub fn get<S: Into<Shape>>(
96        &self,
97        shape: S,
98        path: &str,
99        init: crate::Init,
100        dtype: DType,
101        device: &Device,
102    ) -> Result<Tensor> {
103        let shape = shape.into();
104        let mut tensor_data = self.data.lock().unwrap();
105        if let Some(tensor) = tensor_data.get(path) {
106            let tensor_shape = tensor.shape();
107            if &shape != tensor_shape {
108                candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}")
109            }
110            return Ok(tensor.as_tensor().clone());
111        }
112        let var = init.var(shape, dtype, device)?;
113        let tensor = var.as_tensor().clone();
114        tensor_data.insert(path.to_string(), var);
115        Ok(tensor)
116    }
117
118    pub fn data(&self) -> &Mutex<HashMap<String, Var>> {
119        &self.data
120    }
121}