use crate::{
inference::Parameters,
parameter_initialization::INIT_S_MAX,
simulation::{D_MAX, D_MIN, S_MIN},
};
use burn::{
module::Param,
tensor::{Tensor, TensorData, backend::Backend},
};
pub(crate) fn parameter_clipper<B: Backend>(
parameters: Param<Tensor<B, 1>>,
num_relearning_steps: usize,
enable_short_term: bool,
) -> Param<Tensor<B, 1>> {
let (id, val) = parameters.consume();
let clipped = clip_parameters(
&val.to_data().to_vec().unwrap(),
num_relearning_steps,
enable_short_term,
);
Param::initialized(
id,
Tensor::from_data(TensorData::new(clipped, val.shape()), &B::Device::default())
.require_grad(),
)
}
pub(crate) fn clip_parameters(
parameters: &Parameters,
num_relearning_steps: usize,
enable_short_term: bool,
) -> Vec<f32> {
let mut parameters = parameters.to_vec();
let w17_w18_ceiling = if num_relearning_steps > 1 {
(-(parameters[11].ln() + (2.0f32.powf(parameters[13]) - 1.0).ln() + parameters[14] * 0.3)
/ num_relearning_steps as f32)
.max(0.01)
.sqrt()
.min(2.0)
} else {
2.0
};
let w19_floor = if enable_short_term { 0.01 } else { 0.0 };
let clamps: [(f32, f32); 21] = [
(S_MIN, INIT_S_MAX),
(S_MIN, INIT_S_MAX),
(S_MIN, INIT_S_MAX),
(S_MIN, INIT_S_MAX),
(D_MIN, D_MAX),
(0.001, 4.0),
(0.001, 4.0),
(0.001, 0.75),
(0.0, 4.5),
(0.0, 0.8),
(0.001, 3.5),
(0.001, 5.0),
(0.001, 0.25),
(0.001, 0.9),
(0.0, 4.0),
(0.0, 1.0),
(1.0, 6.0),
(0.0, w17_w18_ceiling),
(0.0, w17_w18_ceiling),
(w19_floor, 0.8),
(0.1, 0.8),
];
parameters
.iter_mut()
.zip(clamps)
.for_each(|(w, (low, high))| *w = w.clamp(low, high));
parameters
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{DEFAULT_PARAMETERS, test_helpers::Tensor};
use burn::backend::ndarray::NdArrayDevice;
#[test]
fn parameter_clipper_works() {
let device = NdArrayDevice::Cpu;
let tensor = Tensor::from_floats(
[0.0, -1000.0, 1000.0, 0.0, 1000.0, -1000.0, 1.0, 0.25, -0.1],
&device,
);
let param = parameter_clipper(Param::from_tensor(tensor), 1, true);
let values = ¶m.to_data().to_vec::<f32>().unwrap();
assert_eq!(
values,
&[0.001, 0.001, 100.0, 0.001, 10.0, 0.001, 1.0, 0.25, 0.0]
);
}
#[test]
fn parameter_clipper_works_with_num_relearning_steps() {
use crate::test_helpers::TestHelper;
let device = NdArrayDevice::Cpu;
let tensor = Tensor::from_floats(DEFAULT_PARAMETERS, &device);
let param = parameter_clipper(Param::from_tensor(tensor), 2, true);
let values = ¶m.to_data().to_vec::<f32>().unwrap();
values[17..=19].assert_approx_eq([0.5425, 0.0912, 0.0658]);
}
}