use crate::model::ModelBase;
use log::trace;
use serde::{Deserialize, Serialize};
mod quantile_loss;
mod named_tensors;
pub use quantile_loss::quantile_huber_loss;
pub use named_tensors::NamedTensors;
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
pub enum CriticLoss {
MSE,
SmoothL1,
}
pub fn track<M: ModelBase>(dest: &mut M, src: &mut M, tau: f64) {
let src = &mut src.get_var_store().variables();
let dest = &mut dest.get_var_store().variables();
debug_assert_eq!(src.len(), dest.len());
let names = src.keys();
tch::no_grad(|| {
for name in names {
let src = src.get(name).unwrap();
let dest = dest.get_mut(name).unwrap();
dest.copy_(&(tau * src + (1.0 - tau) * &*dest));
}
});
trace!("soft update");
}
pub fn concat_slices(s1: &[i64], s2: &[i64]) -> Vec<i64> {
let mut v = Vec::from(s1);
v.append(&mut Vec::from(s2));
v
}
pub trait OutDim {
fn get_out_dim(&self) -> i64;
fn set_out_dim(&mut self, v: i64);
}