border_tch_agent/util/
named_tensors.rs

1use std::{collections::HashMap, iter::FromIterator};
2use tch::{nn::VarStore, Device::Cpu, Tensor};
3
4/// Named tensors to send model parameters using a channel.
5pub struct NamedTensors {
6    pub named_tensors: HashMap<String, Tensor>,
7}
8
9impl NamedTensors {
10    /// Copy data of VarStore to CPU.
11    pub fn copy_from(vs: &VarStore) -> Self {
12        let src = vs.variables();
13
14        tch::no_grad(|| NamedTensors {
15            named_tensors: HashMap::from_iter(src.iter().map(|(k, v)| {
16                let v = v.detach().to(Cpu).data();
17                (k.clone(), v)
18            })),
19        })
20    }
21
22    /// Copy named tensors to [VarStore].
23    pub fn copy_to(&self, vs: &mut VarStore) {
24        let src = &self.named_tensors;
25        let dest = &mut vs.variables();
26        // let device = vs.device();
27        debug_assert_eq!(src.len(), dest.len());
28
29        tch::no_grad(|| {
30            for (name, src) in src.iter() {
31                let dest = dest.get_mut(name).unwrap();
32                dest.copy_(src);
33            }
34        });
35    }
36}
37
38impl Clone for NamedTensors {
39    fn clone(&self) -> Self {
40        let src = &self.named_tensors;
41
42        tch::no_grad(|| NamedTensors {
43            named_tensors: HashMap::from_iter(src.iter().map(|(k, v)| {
44                let v = v.detach().to(Cpu).data();
45                (k.clone(), v)
46            })),
47        })
48    }
49}
50
51#[cfg(test)]
52mod test {
53    use super::NamedTensors;
54    use std::convert::{TryFrom, TryInto};
55    use tch::{
56        nn::{self, Module},
57        Device::Cpu,
58        Tensor,
59    };
60
61    #[test]
62    fn test_named_tensors() {
63        tch::manual_seed(42);
64
65        let tensor1 = Tensor::try_from(vec![1., 2., 3.])
66            .unwrap()
67            .internal_cast_float(false);
68
69        let vs1 = nn::VarStore::new(Cpu);
70        let model1 = nn::seq()
71            .add(nn::linear(&vs1.root() / "layer1", 3, 8, Default::default()))
72            .add(nn::linear(&vs1.root() / "layer2", 8, 2, Default::default()));
73
74        let mut vs2 = nn::VarStore::new(tch::Device::cuda_if_available());
75        let model2 = nn::seq()
76            .add(nn::linear(&vs2.root() / "layer1", 3, 8, Default::default()))
77            .add(nn::linear(&vs2.root() / "layer2", 8, 2, Default::default()));
78        let device = vs2.device();
79
80        let t1: Vec<f64> = model1.forward(&tensor1).try_into().unwrap();
81        let t2: Vec<f64> = model2.forward(&tensor1.to(device)).try_into().unwrap();
82
83        let nt = NamedTensors::copy_from(&vs1);
84        nt.copy_to(&mut vs2);
85
86        let t3: Vec<f64> = model2.forward(&tensor1.to(device)).try_into().unwrap();
87
88        for i in 0..2 {
89            assert!((t1[i] - t2[i]).abs() >= t1[i].abs() * 0.001);
90            assert!((t1[i] - t3[i]).abs() < t1[i].abs() * 0.001);
91        }
92        // println!("{:?}", t1);
93        // println!("{:?}", t2);
94        // println!("{:?}", t3);
95    }
96}