1use crate::model::ModelBase;
3use log::trace;
4use serde::{Deserialize, Serialize};
5mod named_tensors;
6mod quantile_loss;
7use border_core::record::{Record, RecordValue};
8pub use named_tensors::NamedTensors;
9pub use quantile_loss::quantile_huber_loss;
10use std::convert::TryFrom;
11use tch::nn::VarStore;
12
13#[allow(clippy::upper_case_acronyms)]
15#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
16pub enum CriticLoss {
17 Mse,
19
20 SmoothL1,
22}
23
24pub fn track<M: ModelBase>(dest: &mut M, src: &mut M, tau: f64) {
30 let src = &mut src.get_var_store().variables();
31 let dest = &mut dest.get_var_store().variables();
32 debug_assert_eq!(src.len(), dest.len());
33
34 let names = src.keys();
35 tch::no_grad(|| {
36 for name in names {
37 let src = src.get(name).unwrap();
38 let dest = dest.get_mut(name).unwrap();
39 dest.copy_(&(tau * src + (1.0 - tau) * &*dest));
40 }
41 });
42 trace!("soft update");
43}
44
45pub fn concat_slices(s1: &[i64], s2: &[i64]) -> Vec<i64> {
47 let mut v = Vec::from(s1);
48 v.append(&mut Vec::from(s2));
49 v
50}
51
52pub trait OutDim {
54 fn get_out_dim(&self) -> i64;
56
57 fn set_out_dim(&mut self, v: i64);
59}
60
61pub fn param_stats(var_store: &VarStore) -> Record {
63 let mut record = Record::empty();
64
65 for (k, v) in var_store.variables() {
66 let m = f32::try_from(v.mean(tch::Kind::Float)).expect("Failed to convert Tensor to f32");
68 let k_mean = format!("{}_mean", &k);
69 record.insert(k_mean, RecordValue::Scalar(m));
70
71 let m = f32::try_from(v.std(false)).expect("Failed to convert Tensor to f32");
72 let k_std = format!("{}_std", k);
73 record.insert(k_std, RecordValue::Scalar(m));
74 }
75
76 record
77}