border_candle_agent/
util.rs

1//! Utilities.
2use anyhow::Result;
3use candle_core::{DType, Device, Tensor, WithDType, D};
4use candle_nn::VarMap;
5use log::trace;
6use serde::{Deserialize, Serialize};
7mod named_tensors;
8mod quantile_loss;
9use border_core::record::{Record, RecordValue};
10pub use named_tensors::NamedTensors;
11use ndarray::ArrayD;
12use num_traits::AsPrimitive;
13pub use quantile_loss::quantile_huber_loss;
14use std::convert::TryFrom;
15pub mod actor;
16pub mod critic;
17
18/// Critic loss type.
19#[allow(clippy::upper_case_acronyms)]
20#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
21pub enum CriticLoss {
22    /// Mean squared error.
23    Mse,
24
25    /// Smooth L1 loss.
26    SmoothL1,
27}
28
29/// Apply soft update on variables.
30///
31/// Variables are identified by their names.
32///
33/// dest = tau * src + (1.0 - tau) * dest
34pub fn track(dest: &VarMap, src: &VarMap, tau: f64) -> Result<()> {
35    trace!("dest");
36    let dest = dest.data().lock().unwrap();
37    trace!("src");
38    let src = src.data().lock().unwrap();
39
40    dest.iter().for_each(|(k_dest, v_dest)| {
41        let v_src = src.get(k_dest).unwrap();
42        let t_src = v_src.as_tensor();
43        let t_dest = v_dest.as_tensor();
44        let t_dest = ((tau * t_src).unwrap() + (1.0 - tau) * t_dest).unwrap();
45        v_dest.set(&t_dest).unwrap();
46    });
47
48    Ok(())
49}
50
51pub fn track_with_replace_substring(
52    dest: &VarMap,
53    src: &VarMap,
54    tau: f64,
55    (ss_src, ss_dest): (&str, &str),
56) -> Result<()> {
57    trace!("dest");
58    let dest = dest.data().lock().unwrap();
59    trace!("src");
60    let src = src.data().lock().unwrap();
61
62    dest.iter().for_each(|(k_dest, v_dest)| {
63        let k_src = k_dest.replace(ss_dest, ss_src);
64        let v_src = src.get(&k_src).unwrap();
65        let t_src = v_src.as_tensor();
66        let t_dest = v_dest.as_tensor();
67        let t_dest = ((tau * t_src).unwrap() + (1.0 - tau) * t_dest).unwrap();
68        v_dest.set(&t_dest).unwrap();
69    });
70
71    Ok(())
72}
73
74// /// Concatenates slices.
75// pub fn concat_slices(s1: &[i64], s2: &[i64]) -> Vec<i64> {
76//     let mut v = Vec::from(s1);
77//     v.append(&mut Vec::from(s2));
78//     v
79// }
80
81/// Interface for handling output dimensions.
82pub trait OutDim {
83    /// Returns the output dimension.
84    fn get_out_dim(&self) -> i64;
85
86    /// Sets the  output dimension.
87    fn set_out_dim(&mut self, v: i64);
88}
89
90#[test]
91fn test_track() -> Result<()> {
92    use candle_core::{DType, Device, Tensor};
93    use candle_nn::Init;
94
95    let tau = 0.7;
96    let t_src = Tensor::from_slice(&[1.0f32, 2.0, 3.0], (3,), &Device::Cpu)?;
97    let t_dest = Tensor::from_slice(&[4.0f32, 5.0, 6.0], (3,), &Device::Cpu)?;
98    let t = ((tau * &t_src).unwrap() + (1.0 - tau) * &t_dest).unwrap();
99
100    let vm_src = {
101        let vm = VarMap::new();
102        let init = Init::Randn {
103            mean: 0.0,
104            stdev: 1.0,
105        };
106        vm.get((3,), "var1", init, DType::F32, &Device::Cpu)?;
107        vm.data().lock().unwrap().get("var1").unwrap().set(&t_src)?;
108        vm
109    };
110    let vm_dest = {
111        let vm = VarMap::new();
112        let init = Init::Randn {
113            mean: 0.0,
114            stdev: 1.0,
115        };
116        vm.get((3,), "var1", init, DType::F32, &Device::Cpu)?;
117        vm.data()
118            .lock()
119            .unwrap()
120            .get("var1")
121            .unwrap()
122            .set(&t_dest)?;
123        vm
124    };
125    track(&vm_dest, &vm_src, tau)?;
126
127    let t_ = vm_dest
128        .data()
129        .lock()
130        .unwrap()
131        .get("var1")
132        .unwrap()
133        .as_tensor()
134        .clone();
135
136    println!("{:?}", t);
137    println!("{:?}", t_);
138    assert!((t - t_)?.abs()?.sum(0)?.to_scalar::<f32>()? < 1e-32);
139
140    Ok(())
141}
142
143/// See <https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html>.
144pub fn smooth_l1_loss(x: &Tensor, y: &Tensor) -> Result<Tensor, candle_core::Error> {
145    let device = x.device();
146    let d = (x - y)?.abs()?;
147    let m1 = d.lt(1.0)?.to_dtype(DType::F32)?.to_device(&device)?;
148    let m2 = Tensor::try_from(1f32)?
149        .to_device(&device)?
150        .broadcast_sub(&m1)?;
151    (((0.5 * m1)? * d.powf(2.0))? + m2 * (d - 0.5))?.mean_all()
152}
153
154/// Returns the standard deviation of a tensor.
155pub fn std(t: &Tensor) -> f32 {
156    t.broadcast_sub(&t.mean_all().unwrap())
157        .unwrap()
158        .powf(2f64)
159        .unwrap()
160        .mean_all()
161        .unwrap()
162        .sqrt()
163        .unwrap()
164        .to_vec0::<f32>()
165        .unwrap()
166}
167
168/// Returns the mean and standard deviation of the parameters.
169pub fn param_stats(varmap: &VarMap) -> Record {
170    let mut record = Record::empty();
171
172    for (k, v) in varmap.data().lock().unwrap().iter() {
173        let m: f32 = v.mean_all().unwrap().to_vec0().unwrap();
174        let k_mean = format!("{}_mean", &k);
175        record.insert(k_mean, RecordValue::Scalar(m));
176
177        let m: f32 = std(v.as_tensor());
178        let k_std = format!("{}_std", &k);
179        record.insert(k_std, RecordValue::Scalar(m));
180    }
181
182    record
183}
184
185pub fn vec_to_tensor<T1, T2>(v: Vec<T1>, add_batch_dim: bool) -> Result<Tensor>
186where
187    T1: AsPrimitive<T2>,
188    T2: WithDType,
189{
190    let v = v.iter().map(|e| e.as_()).collect::<Vec<_>>();
191    let t: Tensor = TryFrom::<Vec<T2>>::try_from(v).unwrap();
192
193    match add_batch_dim {
194        true => Ok(t.unsqueeze(0)?),
195        false => Ok(t),
196    }
197}
198
199pub fn arrayd_to_tensor<T1, T2>(a: ArrayD<T1>, add_batch_dim: bool) -> Result<Tensor>
200where
201    T1: AsPrimitive<T2>,
202    T2: WithDType,
203{
204    let shape = a.shape();
205    let v = a.iter().map(|e| e.as_()).collect::<Vec<_>>();
206    let t: Tensor = TryFrom::<Vec<T2>>::try_from(v)?;
207    let t = t.reshape(shape)?;
208
209    match add_batch_dim {
210        true => Ok(t.unsqueeze(0)?),
211        false => Ok(t),
212    }
213}
214
215pub fn tensor_to_arrayd<T>(t: Tensor, delete_batch_dim: bool) -> Result<ArrayD<T>>
216where
217    T: WithDType, //tch::kind::Element,
218{
219    let shape = match delete_batch_dim {
220        false => t.dims()[..].iter().map(|x| *x as usize).collect::<Vec<_>>(),
221        true => t.dims()[1..]
222            .iter()
223            .map(|x| *x as usize)
224            .collect::<Vec<_>>(),
225    };
226    let v: Vec<T> = t.flatten_all()?.to_vec1()?;
227
228    Ok(ndarray::Array1::<T>::from(v).into_shape(ndarray::IxDyn(&shape))?)
229}
230
231/// Returns gamma values multipied by done flag values.
232///
233/// When `is_truncated` is given, done flag is set to 1 if either of
234/// `is_terminated` and `is_truncated` is true.
235pub fn gamma_not_done(
236    gamma: f32,
237    is_terminated: Vec<i8>,
238    is_truncated: Option<Vec<i8>>,
239    device: &Device,
240) -> Result<Tensor> {
241    let batch_size = is_terminated.len();
242    let not_done = if let Some(is_truncated) = is_truncated.as_ref() {
243        is_terminated
244            .iter()
245            .zip(is_truncated.iter())
246            .map(|(e1, e2)| (1f32 - (*e1 | *e2) as f32) * gamma)
247            .collect::<Vec<_>>()
248    } else {
249        is_terminated
250            .iter()
251            .map(|e1| (1f32 - *e1 as f32) * gamma)
252            .collect::<Vec<_>>()
253    };
254    Ok(Tensor::from_slice(&not_done[..], (batch_size,), device)?)
255}
256
257pub fn reward(reward: Vec<f32>, device: &Device) -> Result<Tensor> {
258    let batch_size = reward.len();
259    Ok(Tensor::from_slice(&reward[..], (batch_size,), device)?)
260}
261
262pub fn asymmetric_l2_loss(u: &Tensor, tau: f64) -> Result<Tensor> {
263    // println!("u.dtype()   = {:?}", u.dtype());
264    // println!("tau.dtype() = {:?}", u.lt(0f32)?.dtype());
265    Ok(((tau - u.lt(0f32)?.to_dtype(DType::F32)?)?.abs()? * u.powf(2.0)?)?.mean_all()?)
266}
267
268pub fn atanh(t: &Tensor) -> Result<Tensor> {
269    let t = t.clamp(-0.999999, 0.999999)?;
270    Ok((0.5 * (((1. + &t)? / (1. - &t)?)?).log()?)?)
271}
272
273/// Density transformation for tanh function
274pub fn log_jacobian_tanh(a: &Tensor) -> Result<Tensor> {
275    // let eps = Tensor::new(&[eps as f32], device)?.broadcast_as(a.shape())?;
276    // Ok((-1f64 * (1f64 - a.powf(2.0)? + eps)?.log()?)?.sum(D::Minus1)?)
277    let a = a.clamp(-0.999999, 0.999999)?;
278    Ok((-1f64 * (1f64 - a.powf(2.0)?)?.log()?)?.sum(D::Minus1)?)
279}