border_tch_agent/
util.rs

1//! Utilities.
2use crate::model::ModelBase;
3use log::trace;
4use serde::{Deserialize, Serialize};
5mod named_tensors;
6mod quantile_loss;
7use border_core::record::{Record, RecordValue};
8pub use named_tensors::NamedTensors;
9use ndarray::ArrayD;
10use num_traits::cast::AsPrimitive;
11pub use quantile_loss::quantile_huber_loss;
12use std::convert::TryFrom;
13use tch::{nn::VarStore, Tensor};
14
15/// Critic loss type.
16#[allow(clippy::upper_case_acronyms)]
17#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
18pub enum CriticLoss {
19    /// Mean squared error.
20    Mse,
21
22    /// Smooth L1 loss.
23    SmoothL1,
24}
25
26/// Apply soft update on variables.
27///
28/// Variables are identified by their names.
29///
30/// dest = tau * src + (1.0 - tau) * dest
31pub fn track<M: ModelBase>(dest: &mut M, src: &mut M, tau: f64) {
32    let src = &mut src.get_var_store().variables();
33    let dest = &mut dest.get_var_store().variables();
34    debug_assert_eq!(src.len(), dest.len());
35
36    let names = src.keys();
37    tch::no_grad(|| {
38        for name in names {
39            let src = src.get(name).unwrap();
40            let dest = dest.get_mut(name).unwrap();
41            dest.copy_(&(tau * src + (1.0 - tau) * &*dest));
42        }
43    });
44    trace!("soft update");
45}
46
47/// Concatenates slices.
48pub fn concat_slices(s1: &[i64], s2: &[i64]) -> Vec<i64> {
49    let mut v = Vec::from(s1);
50    v.append(&mut Vec::from(s2));
51    v
52}
53
54/// Interface for handling output dimensions.
55pub trait OutDim {
56    /// Returns the output dimension.
57    fn get_out_dim(&self) -> i64;
58
59    /// Sets the  output dimension.
60    fn set_out_dim(&mut self, v: i64);
61}
62
63/// Returns the mean and standard deviation of the parameters.
64pub fn param_stats(var_store: &VarStore) -> Record {
65    let mut record = Record::empty();
66
67    for (k, v) in var_store.variables() {
68        // let m: f32 = v.mean(tch::Kind::Float).into();
69        let m = f32::try_from(v.mean(tch::Kind::Float)).expect("Failed to convert Tensor to f32");
70        let k_mean = format!("{}_mean", &k);
71        record.insert(k_mean, RecordValue::Scalar(m));
72
73        let m = f32::try_from(v.std(false)).expect("Failed to convert Tensor to f32");
74        let k_std = format!("{}_std", k);
75        record.insert(k_std, RecordValue::Scalar(m));
76    }
77
78    record
79}
80
81pub fn vec_to_tensor<T1, T2>(v: Vec<T1>, add_batch_dim: bool) -> Tensor
82where
83    T1: AsPrimitive<T2>,
84    T2: Copy + 'static + tch::kind::Element,
85{
86    let v = v.iter().map(|e| e.as_()).collect::<Vec<_>>();
87    let t: Tensor = TryFrom::<Vec<T2>>::try_from(v).unwrap();
88
89    match add_batch_dim {
90        true => t.unsqueeze(0),
91        false => t,
92    }
93}
94
95/// Converts [`ndarray::ArrayD`] to [`Tensor`].
96pub fn arrayd_to_tensor<T1, T2>(a: ArrayD<T1>, add_batch_dim: bool) -> Tensor
97where
98    T1: AsPrimitive<T2>,
99    T2: Copy + 'static + tch::kind::Element,
100{
101    let v = a.iter().map(|e| e.as_()).collect::<Vec<_>>();
102    let t: Tensor = TryFrom::<Vec<T2>>::try_from(v).unwrap();
103
104    match add_batch_dim {
105        true => t.unsqueeze(0),
106        false => t,
107    }
108}
109
110/// Converts [`Tensor`] to [`ndarray::ArrayD`].
111pub fn tensor_to_arrayd<T>(t: Tensor, delete_batch_dim: bool) -> ArrayD<T>
112where
113    T: tch::kind::Element + Copy,
114{
115    let shape = match delete_batch_dim {
116        false => t.size()[..].iter().map(|x| *x as usize).collect::<Vec<_>>(),
117        true => t.size()[1..]
118            .iter()
119            .map(|x| *x as usize)
120            .collect::<Vec<_>>(),
121    };
122    let v = Vec::<T>::try_from(&t.flatten(0, -1)).expect("Failed to convert from Tensor to Vec");
123
124    ndarray::Array1::<T>::from(v)
125        .into_shape(ndarray::IxDyn(&shape))
126        .unwrap()
127}