use super::super::{n_backward_steps, ToLog};
use super::{
AdvantageFn, BuildCritic, Critic, Device, HistoryFeatures, PackedTensor, SeqPacked,
StatsLogger, StepValueTarget,
};
use crate::torch::modules::{BuildModule, Module};
use crate::torch::optimizers::{AdamConfig, BuildOptimizer, Optimizer};
use serde::{Deserialize, Serialize};
use tch::{COptimizer, Reduction};
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct ValuesOptConfig<MB, OC = AdamConfig> {
pub state_value_fn_config: MB,
pub optimizer_config: OC,
pub advantage_fn: AdvantageFn,
pub target: StepValueTarget,
pub opt_steps_per_update: u64,
pub max_discount_factor: f64,
}
impl<MB: Default, OC: Default> Default for ValuesOptConfig<MB, OC> {
fn default() -> Self {
Self {
state_value_fn_config: MB::default(),
optimizer_config: OC::default(),
advantage_fn: AdvantageFn::default(),
target: StepValueTarget::default(),
opt_steps_per_update: 80,
max_discount_factor: 0.99,
}
}
}
impl<MB, OC> BuildCritic for ValuesOptConfig<MB, OC>
where
MB: BuildModule,
MB::Module: SeqPacked,
OC: BuildOptimizer,
OC::Optimizer: Optimizer,
{
type Critic = ValuesOpt<MB::Module, OC::Optimizer>;
#[allow(clippy::cast_possible_truncation)]
fn build_critic(&self, in_dim: usize, discount_factor: f64, device: Device) -> Self::Critic {
let state_value_fn = self.state_value_fn_config.build_module(in_dim, 1, device);
let optimizer = self
.optimizer_config
.build_optimizer(state_value_fn.trainable_variables())
.unwrap();
ValuesOpt {
state_value_fn,
optimizer,
advantage_fn: self.advantage_fn,
target: self.target,
discount_factor: self.max_discount_factor.min(discount_factor) as f32,
opt_steps_per_update: self.opt_steps_per_update,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct ValuesOpt<M, O = COptimizer> {
state_value_fn: M,
optimizer: O,
advantage_fn: AdvantageFn,
target: StepValueTarget,
discount_factor: f32,
opt_steps_per_update: u64,
}
impl<M, O> Critic for ValuesOpt<M, O>
where
M: SeqPacked,
O: Optimizer,
{
fn advantages(&self, features: &dyn HistoryFeatures) -> PackedTensor {
self.advantage_fn
.advantages(&self.state_value_fn, self.discount_factor, features)
}
fn update(&mut self, features: &dyn HistoryFeatures, logger: &mut dyn StatsLogger) {
let targets = tch::no_grad(|| {
self.target
.targets(&self.state_value_fn, self.discount_factor, features)
});
let observations = features.observation_features();
let sample_minibatch = || {};
let loss_fn = |_| {
self.state_value_fn
.seq_packed(observations)
.tensor()
.squeeze_dim(-1)
.mse_loss(targets.tensor(), Reduction::Mean)
};
n_backward_steps(
&mut self.optimizer,
sample_minibatch,
loss_fn,
self.opt_steps_per_update,
logger,
ToLog::All,
"critic update error",
);
}
}