border_tch_agent/
util.rs

1//! Utilities.
2use crate::model::ModelBase;
3use log::trace;
4use serde::{Deserialize, Serialize};
5mod named_tensors;
6mod quantile_loss;
7use border_core::record::{Record, RecordValue};
8pub use named_tensors::NamedTensors;
9pub use quantile_loss::quantile_huber_loss;
10use std::convert::TryFrom;
11use tch::nn::VarStore;
12
13/// Critic loss type.
14#[allow(clippy::upper_case_acronyms)]
15#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
16pub enum CriticLoss {
17    /// Mean squared error.
18    Mse,
19
20    /// Smooth L1 loss.
21    SmoothL1,
22}
23
24/// Apply soft update on variables.
25///
26/// Variables are identified by their names.
27/// 
28/// dest = tau * src + (1.0 - tau) * dest
29pub fn track<M: ModelBase>(dest: &mut M, src: &mut M, tau: f64) {
30    let src = &mut src.get_var_store().variables();
31    let dest = &mut dest.get_var_store().variables();
32    debug_assert_eq!(src.len(), dest.len());
33
34    let names = src.keys();
35    tch::no_grad(|| {
36        for name in names {
37            let src = src.get(name).unwrap();
38            let dest = dest.get_mut(name).unwrap();
39            dest.copy_(&(tau * src + (1.0 - tau) * &*dest));
40        }
41    });
42    trace!("soft update");
43}
44
45/// Concatenates slices.
46pub fn concat_slices(s1: &[i64], s2: &[i64]) -> Vec<i64> {
47    let mut v = Vec::from(s1);
48    v.append(&mut Vec::from(s2));
49    v
50}
51
52/// Interface for handling output dimensions.
53pub trait OutDim {
54    /// Returns the output dimension.
55    fn get_out_dim(&self) -> i64;
56
57    /// Sets the  output dimension.
58    fn set_out_dim(&mut self, v: i64);
59}
60
61/// Returns the mean and standard deviation of the parameters.
62pub fn param_stats(var_store: &VarStore) -> Record {
63    let mut record = Record::empty();
64
65    for (k, v) in var_store.variables() {
66        // let m: f32 = v.mean(tch::Kind::Float).into();
67        let m = f32::try_from(v.mean(tch::Kind::Float)).expect("Failed to convert Tensor to f32");
68        let k_mean = format!("{}_mean", &k);
69        record.insert(k_mean, RecordValue::Scalar(m));
70
71        let m = f32::try_from(v.std(false)).expect("Failed to convert Tensor to f32");
72        let k_std = format!("{}_std", k);
73        record.insert(k_std, RecordValue::Scalar(m));
74    }
75
76    record
77}