border_tch_agent/util/
named_tensors.rs1use std::{collections::HashMap, iter::FromIterator};
2use tch::{nn::VarStore, Device::Cpu, Tensor};
3
4pub struct NamedTensors {
6 pub named_tensors: HashMap<String, Tensor>,
7}
8
9impl NamedTensors {
10 pub fn copy_from(vs: &VarStore) -> Self {
12 let src = vs.variables();
13
14 tch::no_grad(|| NamedTensors {
15 named_tensors: HashMap::from_iter(src.iter().map(|(k, v)| {
16 let v = v.detach().to(Cpu).data();
17 (k.clone(), v)
18 })),
19 })
20 }
21
22 pub fn copy_to(&self, vs: &mut VarStore) {
24 let src = &self.named_tensors;
25 let dest = &mut vs.variables();
26 debug_assert_eq!(src.len(), dest.len());
28
29 tch::no_grad(|| {
30 for (name, src) in src.iter() {
31 let dest = dest.get_mut(name).unwrap();
32 dest.copy_(src);
33 }
34 });
35 }
36}
37
38impl Clone for NamedTensors {
39 fn clone(&self) -> Self {
40 let src = &self.named_tensors;
41
42 tch::no_grad(|| NamedTensors {
43 named_tensors: HashMap::from_iter(src.iter().map(|(k, v)| {
44 let v = v.detach().to(Cpu).data();
45 (k.clone(), v)
46 })),
47 })
48 }
49}
50
51#[cfg(test)]
52mod test {
53 use super::NamedTensors;
54 use std::convert::{TryFrom, TryInto};
55 use tch::{
56 nn::{self, Module},
57 Device::Cpu,
58 Tensor,
59 };
60
61 #[test]
62 fn test_named_tensors() {
63 tch::manual_seed(42);
64
65 let tensor1 = Tensor::try_from(vec![1., 2., 3.])
66 .unwrap()
67 .internal_cast_float(false);
68
69 let vs1 = nn::VarStore::new(Cpu);
70 let model1 = nn::seq()
71 .add(nn::linear(&vs1.root() / "layer1", 3, 8, Default::default()))
72 .add(nn::linear(&vs1.root() / "layer2", 8, 2, Default::default()));
73
74 let mut vs2 = nn::VarStore::new(tch::Device::cuda_if_available());
75 let model2 = nn::seq()
76 .add(nn::linear(&vs2.root() / "layer1", 3, 8, Default::default()))
77 .add(nn::linear(&vs2.root() / "layer2", 8, 2, Default::default()));
78 let device = vs2.device();
79
80 let t1: Vec<f64> = model1.forward(&tensor1).try_into().unwrap();
81 let t2: Vec<f64> = model2.forward(&tensor1.to(device)).try_into().unwrap();
82
83 let nt = NamedTensors::copy_from(&vs1);
84 nt.copy_to(&mut vs2);
85
86 let t3: Vec<f64> = model2.forward(&tensor1.to(device)).try_into().unwrap();
87
88 for i in 0..2 {
89 assert!((t1[i] - t2[i]).abs() >= t1[i].abs() * 0.001);
90 assert!((t1[i] - t3[i]).abs() < t1[i].abs() * 0.001);
91 }
92 }
96}