use super::super::{n_backward_steps, ToLog};
use super::{
BuildPolicy, HistoryFeatures, PackedTensor, ParameterizedDistributionSpace, Policy,
SeqIterative, SeqPacked, StatsLogger,
};
use crate::torch::modules::{AsModule, BuildModule, Module};
use crate::torch::optimizers::{AdamConfig, BuildOptimizer, Optimizer};
use crate::utils::distributions::ArrayDistribution;
use serde::{Deserialize, Serialize};
use tch::{COptimizer, Device, Kind, Tensor};
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct PpoConfig<MB, OC = AdamConfig> {
pub policy_fn_config: MB,
pub optimizer_config: OC,
pub opt_steps_per_update: u64,
pub clip_distance: f64,
}
impl<MB, OC> Default for PpoConfig<MB, OC>
where
MB: Default,
OC: Default,
{
fn default() -> Self {
Self {
policy_fn_config: MB::default(),
optimizer_config: OC::default(),
opt_steps_per_update: 10,
clip_distance: 0.2,
}
}
}
impl<MB, OC> BuildPolicy for PpoConfig<MB, OC>
where
MB: BuildModule,
MB::Module: SeqPacked + SeqIterative,
OC: BuildOptimizer,
OC::Optimizer: Optimizer,
{
type Policy = Ppo<MB::Module, OC::Optimizer>;
fn build_policy(&self, in_dim: usize, out_dim: usize, device: Device) -> Self::Policy {
let policy_fn = self.policy_fn_config.build_module(in_dim, out_dim, device);
let optimizer = self
.optimizer_config
.build_optimizer(policy_fn.trainable_variables())
.unwrap();
Ppo {
policy_fn,
optimizer,
opt_steps_per_update: self.opt_steps_per_update,
clip_distance: self.clip_distance,
}
}
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct Ppo<M, O = COptimizer> {
policy_fn: M,
optimizer: O,
opt_steps_per_update: u64,
clip_distance: f64,
}
impl<M: Module, O> AsModule for Ppo<M, O> {
type Module = M;
fn as_module(&self) -> &Self::Module {
&self.policy_fn
}
fn as_module_mut(&mut self) -> &mut Self::Module {
&mut self.policy_fn
}
}
impl<M, O> Policy for Ppo<M, O>
where
M: Module + SeqPacked + SeqIterative,
O: Optimizer,
{
type PolicyModule = M;
fn update<AS: ParameterizedDistributionSpace<Tensor> + ?Sized>(
&mut self,
features: &dyn HistoryFeatures,
advantages: PackedTensor,
action_space: &AS,
logger: &mut dyn StatsLogger,
) {
let observation_features = features.observation_features();
let actions = features.actions().tensor();
let initial_log_probs = {
let _no_grad = tch::no_grad_guard();
let policy_output = self.policy_fn.seq_packed(observation_features);
let distribution = action_space.distribution(policy_output.tensor());
let log_probs = distribution.log_probs(actions);
let entropy = distribution.entropy().mean(Kind::Float);
logger.log_scalar("entropy", entropy.into());
log_probs
};
let sample_minibatch = || {};
let policy_surrogate_loss_fn = |_| {
let policy_output = self.policy_fn.seq_packed(observation_features);
let distribution = action_space.distribution(policy_output.tensor());
let log_probs = distribution.log_probs(actions);
let likelihood_ratio = (log_probs - &initial_log_probs).exp();
let clipped_likelihood_ratio =
likelihood_ratio.clip(1.0 - self.clip_distance, 1.0 + self.clip_distance);
(likelihood_ratio * advantages.tensor())
.min_other(&(clipped_likelihood_ratio * advantages.tensor()))
.mean(Kind::Float)
.neg()
};
n_backward_steps(
&mut self.optimizer,
sample_minibatch,
policy_surrogate_loss_fn,
self.opt_steps_per_update,
logger,
ToLog::NoAbsLoss, "policy update error",
);
}
}