use super::{
BuildPolicy, HistoryFeatures, PackedTensor, ParameterizedDistributionSpace, Policy,
SeqIterative, SeqPacked, StatsLogger,
};
use crate::torch::backends::WithCudnnEnabled;
use crate::torch::modules::{AsModule, BuildModule, Module};
use crate::torch::optimizers::{
BuildOptimizer, ConjugateGradientOptimizer, ConjugateGradientOptimizerConfig,
OptimizerStepError, TrustRegionOptimizer,
};
use crate::utils::distributions::ArrayDistribution;
use log::warn;
use serde::{Deserialize, Serialize};
use tch::{Device, Kind, Tensor};
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct TrpoConfig<MB, OC = ConjugateGradientOptimizerConfig> {
pub policy_fn_config: MB,
pub optimizer_config: OC,
pub max_policy_step_kl: f64,
}
impl<MB, OC> Default for TrpoConfig<MB, OC>
where
MB: Default,
OC: Default,
{
fn default() -> Self {
Self {
policy_fn_config: MB::default(),
optimizer_config: OC::default(),
max_policy_step_kl: 0.01,
}
}
}
impl<MB, OC> BuildPolicy for TrpoConfig<MB, OC>
where
MB: BuildModule,
MB::Module: SeqPacked + SeqIterative,
OC: BuildOptimizer,
OC::Optimizer: TrustRegionOptimizer,
{
type Policy = Trpo<MB::Module, OC::Optimizer>;
fn build_policy(&self, in_dim: usize, out_dim: usize, device: Device) -> Self::Policy {
let policy_fn = self.policy_fn_config.build_module(in_dim, out_dim, device);
let optimizer = self
.optimizer_config
.build_optimizer(policy_fn.trainable_variables())
.unwrap();
Trpo {
policy_fn,
optimizer,
max_policy_step_kl: self.max_policy_step_kl,
}
}
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct Trpo<M, O = ConjugateGradientOptimizer> {
policy_fn: M,
optimizer: O,
max_policy_step_kl: f64,
}
impl<M: Module, O> AsModule for Trpo<M, O> {
type Module = M;
fn as_module(&self) -> &Self::Module {
&self.policy_fn
}
fn as_module_mut(&mut self) -> &mut Self::Module {
&mut self.policy_fn
}
}
impl<M, O> Policy for Trpo<M, O>
where
M: Module + SeqPacked + SeqIterative,
O: TrustRegionOptimizer,
{
type PolicyModule = M;
fn update<AS: ParameterizedDistributionSpace<Tensor> + ?Sized>(
&mut self,
features: &dyn HistoryFeatures,
advantages: PackedTensor,
action_space: &AS,
logger: &mut dyn StatsLogger,
) {
let _cudnn_disable_guard = if self.policy_fn.has_cudnn_second_derivatives() {
None
} else {
Some(WithCudnnEnabled::new(false))
};
let observation_features = features.observation_features();
let actions = features.actions().tensor();
let (initial_distribution, initial_log_probs) = {
let _no_grad = tch::no_grad_guard();
let policy_output = self.policy_fn.seq_packed(observation_features);
let distribution = action_space.distribution(policy_output.tensor());
let log_probs = distribution.log_probs(actions);
let entropy = distribution.entropy().mean(Kind::Float);
logger.log_scalar("entropy", entropy.into());
(distribution, log_probs)
};
let mut policy_loss_distance_fn = || {
let policy_output = self.policy_fn.seq_packed(observation_features);
let distribution = action_space.distribution(policy_output.tensor());
let log_probs = distribution.log_probs(actions);
let likelihood_ratio = (log_probs - &initial_log_probs).exp();
let loss = -(likelihood_ratio * advantages.tensor()).mean(Kind::Float);
let distance = initial_distribution
.kl_divergence_from(&distribution)
.mean(Kind::Float);
(loss, distance)
};
let result = self.optimizer.trust_region_backward_step(
&mut policy_loss_distance_fn,
self.max_policy_step_kl,
logger,
);
if let Err(error) = result {
match error {
OptimizerStepError::NaNLoss => panic!("NaN loss in policy optimization"),
OptimizerStepError::NaNConstraint => {
panic!("NaN constraint in policy optimization")
}
err => warn!("error in policy step: {}", err),
};
}
}
}