1use crate::model::ModelBase;
3use log::trace;
4use serde::{Deserialize, Serialize};
5mod named_tensors;
6mod quantile_loss;
7use border_core::record::{Record, RecordValue};
8pub use named_tensors::NamedTensors;
9use ndarray::ArrayD;
10use num_traits::cast::AsPrimitive;
11pub use quantile_loss::quantile_huber_loss;
12use std::convert::TryFrom;
13use tch::{nn::VarStore, Tensor};
14
15#[allow(clippy::upper_case_acronyms)]
17#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
18pub enum CriticLoss {
19 Mse,
21
22 SmoothL1,
24}
25
26pub fn track<M: ModelBase>(dest: &mut M, src: &mut M, tau: f64) {
32 let src = &mut src.get_var_store().variables();
33 let dest = &mut dest.get_var_store().variables();
34 debug_assert_eq!(src.len(), dest.len());
35
36 let names = src.keys();
37 tch::no_grad(|| {
38 for name in names {
39 let src = src.get(name).unwrap();
40 let dest = dest.get_mut(name).unwrap();
41 dest.copy_(&(tau * src + (1.0 - tau) * &*dest));
42 }
43 });
44 trace!("soft update");
45}
46
47pub fn concat_slices(s1: &[i64], s2: &[i64]) -> Vec<i64> {
49 let mut v = Vec::from(s1);
50 v.append(&mut Vec::from(s2));
51 v
52}
53
54pub trait OutDim {
56 fn get_out_dim(&self) -> i64;
58
59 fn set_out_dim(&mut self, v: i64);
61}
62
63pub fn param_stats(var_store: &VarStore) -> Record {
65 let mut record = Record::empty();
66
67 for (k, v) in var_store.variables() {
68 let m = f32::try_from(v.mean(tch::Kind::Float)).expect("Failed to convert Tensor to f32");
70 let k_mean = format!("{}_mean", &k);
71 record.insert(k_mean, RecordValue::Scalar(m));
72
73 let m = f32::try_from(v.std(false)).expect("Failed to convert Tensor to f32");
74 let k_std = format!("{}_std", k);
75 record.insert(k_std, RecordValue::Scalar(m));
76 }
77
78 record
79}
80
81pub fn vec_to_tensor<T1, T2>(v: Vec<T1>, add_batch_dim: bool) -> Tensor
82where
83 T1: AsPrimitive<T2>,
84 T2: Copy + 'static + tch::kind::Element,
85{
86 let v = v.iter().map(|e| e.as_()).collect::<Vec<_>>();
87 let t: Tensor = TryFrom::<Vec<T2>>::try_from(v).unwrap();
88
89 match add_batch_dim {
90 true => t.unsqueeze(0),
91 false => t,
92 }
93}
94
95pub fn arrayd_to_tensor<T1, T2>(a: ArrayD<T1>, add_batch_dim: bool) -> Tensor
97where
98 T1: AsPrimitive<T2>,
99 T2: Copy + 'static + tch::kind::Element,
100{
101 let v = a.iter().map(|e| e.as_()).collect::<Vec<_>>();
102 let t: Tensor = TryFrom::<Vec<T2>>::try_from(v).unwrap();
103
104 match add_batch_dim {
105 true => t.unsqueeze(0),
106 false => t,
107 }
108}
109
110pub fn tensor_to_arrayd<T>(t: Tensor, delete_batch_dim: bool) -> ArrayD<T>
112where
113 T: tch::kind::Element + Copy,
114{
115 let shape = match delete_batch_dim {
116 false => t.size()[..].iter().map(|x| *x as usize).collect::<Vec<_>>(),
117 true => t.size()[1..]
118 .iter()
119 .map(|x| *x as usize)
120 .collect::<Vec<_>>(),
121 };
122 let v = Vec::<T>::try_from(&t.flatten(0, -1)).expect("Failed to convert from Tensor to Vec");
123
124 ndarray::Array1::<T>::from(v)
125 .into_shape(ndarray::IxDyn(&shape))
126 .unwrap()
127}