1use anyhow::Result;
3use candle_core::{DType, Device, Tensor, WithDType, D};
4use candle_nn::VarMap;
5use log::trace;
6use serde::{Deserialize, Serialize};
7mod named_tensors;
8mod quantile_loss;
9use border_core::record::{Record, RecordValue};
10pub use named_tensors::NamedTensors;
11use ndarray::ArrayD;
12use num_traits::AsPrimitive;
13pub use quantile_loss::quantile_huber_loss;
14use std::convert::TryFrom;
15pub mod actor;
16pub mod critic;
17
18#[allow(clippy::upper_case_acronyms)]
20#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
21pub enum CriticLoss {
22 Mse,
24
25 SmoothL1,
27}
28
29pub fn track(dest: &VarMap, src: &VarMap, tau: f64) -> Result<()> {
35 trace!("dest");
36 let dest = dest.data().lock().unwrap();
37 trace!("src");
38 let src = src.data().lock().unwrap();
39
40 dest.iter().for_each(|(k_dest, v_dest)| {
41 let v_src = src.get(k_dest).unwrap();
42 let t_src = v_src.as_tensor();
43 let t_dest = v_dest.as_tensor();
44 let t_dest = ((tau * t_src).unwrap() + (1.0 - tau) * t_dest).unwrap();
45 v_dest.set(&t_dest).unwrap();
46 });
47
48 Ok(())
49}
50
51pub fn track_with_replace_substring(
52 dest: &VarMap,
53 src: &VarMap,
54 tau: f64,
55 (ss_src, ss_dest): (&str, &str),
56) -> Result<()> {
57 trace!("dest");
58 let dest = dest.data().lock().unwrap();
59 trace!("src");
60 let src = src.data().lock().unwrap();
61
62 dest.iter().for_each(|(k_dest, v_dest)| {
63 let k_src = k_dest.replace(ss_dest, ss_src);
64 let v_src = src.get(&k_src).unwrap();
65 let t_src = v_src.as_tensor();
66 let t_dest = v_dest.as_tensor();
67 let t_dest = ((tau * t_src).unwrap() + (1.0 - tau) * t_dest).unwrap();
68 v_dest.set(&t_dest).unwrap();
69 });
70
71 Ok(())
72}
73
74pub trait OutDim {
83 fn get_out_dim(&self) -> i64;
85
86 fn set_out_dim(&mut self, v: i64);
88}
89
90#[test]
91fn test_track() -> Result<()> {
92 use candle_core::{DType, Device, Tensor};
93 use candle_nn::Init;
94
95 let tau = 0.7;
96 let t_src = Tensor::from_slice(&[1.0f32, 2.0, 3.0], (3,), &Device::Cpu)?;
97 let t_dest = Tensor::from_slice(&[4.0f32, 5.0, 6.0], (3,), &Device::Cpu)?;
98 let t = ((tau * &t_src).unwrap() + (1.0 - tau) * &t_dest).unwrap();
99
100 let vm_src = {
101 let vm = VarMap::new();
102 let init = Init::Randn {
103 mean: 0.0,
104 stdev: 1.0,
105 };
106 vm.get((3,), "var1", init, DType::F32, &Device::Cpu)?;
107 vm.data().lock().unwrap().get("var1").unwrap().set(&t_src)?;
108 vm
109 };
110 let vm_dest = {
111 let vm = VarMap::new();
112 let init = Init::Randn {
113 mean: 0.0,
114 stdev: 1.0,
115 };
116 vm.get((3,), "var1", init, DType::F32, &Device::Cpu)?;
117 vm.data()
118 .lock()
119 .unwrap()
120 .get("var1")
121 .unwrap()
122 .set(&t_dest)?;
123 vm
124 };
125 track(&vm_dest, &vm_src, tau)?;
126
127 let t_ = vm_dest
128 .data()
129 .lock()
130 .unwrap()
131 .get("var1")
132 .unwrap()
133 .as_tensor()
134 .clone();
135
136 println!("{:?}", t);
137 println!("{:?}", t_);
138 assert!((t - t_)?.abs()?.sum(0)?.to_scalar::<f32>()? < 1e-32);
139
140 Ok(())
141}
142
143pub fn smooth_l1_loss(x: &Tensor, y: &Tensor) -> Result<Tensor, candle_core::Error> {
145 let device = x.device();
146 let d = (x - y)?.abs()?;
147 let m1 = d.lt(1.0)?.to_dtype(DType::F32)?.to_device(&device)?;
148 let m2 = Tensor::try_from(1f32)?
149 .to_device(&device)?
150 .broadcast_sub(&m1)?;
151 (((0.5 * m1)? * d.powf(2.0))? + m2 * (d - 0.5))?.mean_all()
152}
153
154pub fn std(t: &Tensor) -> f32 {
156 t.broadcast_sub(&t.mean_all().unwrap())
157 .unwrap()
158 .powf(2f64)
159 .unwrap()
160 .mean_all()
161 .unwrap()
162 .sqrt()
163 .unwrap()
164 .to_vec0::<f32>()
165 .unwrap()
166}
167
168pub fn param_stats(varmap: &VarMap) -> Record {
170 let mut record = Record::empty();
171
172 for (k, v) in varmap.data().lock().unwrap().iter() {
173 let m: f32 = v.mean_all().unwrap().to_vec0().unwrap();
174 let k_mean = format!("{}_mean", &k);
175 record.insert(k_mean, RecordValue::Scalar(m));
176
177 let m: f32 = std(v.as_tensor());
178 let k_std = format!("{}_std", &k);
179 record.insert(k_std, RecordValue::Scalar(m));
180 }
181
182 record
183}
184
185pub fn vec_to_tensor<T1, T2>(v: Vec<T1>, add_batch_dim: bool) -> Result<Tensor>
186where
187 T1: AsPrimitive<T2>,
188 T2: WithDType,
189{
190 let v = v.iter().map(|e| e.as_()).collect::<Vec<_>>();
191 let t: Tensor = TryFrom::<Vec<T2>>::try_from(v).unwrap();
192
193 match add_batch_dim {
194 true => Ok(t.unsqueeze(0)?),
195 false => Ok(t),
196 }
197}
198
199pub fn arrayd_to_tensor<T1, T2>(a: ArrayD<T1>, add_batch_dim: bool) -> Result<Tensor>
200where
201 T1: AsPrimitive<T2>,
202 T2: WithDType,
203{
204 let shape = a.shape();
205 let v = a.iter().map(|e| e.as_()).collect::<Vec<_>>();
206 let t: Tensor = TryFrom::<Vec<T2>>::try_from(v)?;
207 let t = t.reshape(shape)?;
208
209 match add_batch_dim {
210 true => Ok(t.unsqueeze(0)?),
211 false => Ok(t),
212 }
213}
214
215pub fn tensor_to_arrayd<T>(t: Tensor, delete_batch_dim: bool) -> Result<ArrayD<T>>
216where
217 T: WithDType, {
219 let shape = match delete_batch_dim {
220 false => t.dims()[..].iter().map(|x| *x as usize).collect::<Vec<_>>(),
221 true => t.dims()[1..]
222 .iter()
223 .map(|x| *x as usize)
224 .collect::<Vec<_>>(),
225 };
226 let v: Vec<T> = t.flatten_all()?.to_vec1()?;
227
228 Ok(ndarray::Array1::<T>::from(v).into_shape(ndarray::IxDyn(&shape))?)
229}
230
231pub fn gamma_not_done(
236 gamma: f32,
237 is_terminated: Vec<i8>,
238 is_truncated: Option<Vec<i8>>,
239 device: &Device,
240) -> Result<Tensor> {
241 let batch_size = is_terminated.len();
242 let not_done = if let Some(is_truncated) = is_truncated.as_ref() {
243 is_terminated
244 .iter()
245 .zip(is_truncated.iter())
246 .map(|(e1, e2)| (1f32 - (*e1 | *e2) as f32) * gamma)
247 .collect::<Vec<_>>()
248 } else {
249 is_terminated
250 .iter()
251 .map(|e1| (1f32 - *e1 as f32) * gamma)
252 .collect::<Vec<_>>()
253 };
254 Ok(Tensor::from_slice(¬_done[..], (batch_size,), device)?)
255}
256
257pub fn reward(reward: Vec<f32>, device: &Device) -> Result<Tensor> {
258 let batch_size = reward.len();
259 Ok(Tensor::from_slice(&reward[..], (batch_size,), device)?)
260}
261
262pub fn asymmetric_l2_loss(u: &Tensor, tau: f64) -> Result<Tensor> {
263 Ok(((tau - u.lt(0f32)?.to_dtype(DType::F32)?)?.abs()? * u.powf(2.0)?)?.mean_all()?)
266}
267
268pub fn atanh(t: &Tensor) -> Result<Tensor> {
269 let t = t.clamp(-0.999999, 0.999999)?;
270 Ok((0.5 * (((1. + &t)? / (1. - &t)?)?).log()?)?)
271}
272
273pub fn log_jacobian_tanh(a: &Tensor) -> Result<Tensor> {
275 let a = a.clamp(-0.999999, 0.999999)?;
278 Ok((-1f64 * (1f64 - a.powf(2.0)?)?.log()?)?.sum(D::Minus1)?)
279}