1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use super::super::{n_backward_steps, ToLog};
use super::{
    BuildPolicy, HistoryFeatures, PackedTensor, ParameterizedDistributionSpace, Policy,
    SeqIterative, SeqPacked, StatsLogger,
};
use crate::torch::modules::{AsModule, BuildModule, Module};
use crate::torch::optimizers::{AdamConfig, BuildOptimizer, Optimizer};
use crate::utils::distributions::ArrayDistribution;
use serde::{Deserialize, Serialize};
use tch::{COptimizer, Device, Kind, Tensor};

/// Configuration for [`Ppo`]
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct PpoConfig<MB, OC = AdamConfig> {
    pub policy_fn_config: MB,
    pub optimizer_config: OC,
    /// Number of optimization steps per update.
    pub opt_steps_per_update: u64,
    // TODO: Support minibatches
    // pub minibatch_size: usize,
    /// Clip the surrogate objective to `1 ± clip_distance`.
    ///
    /// This is ε (epsilon) in the paper.
    pub clip_distance: f64,
}

impl<MB, OC> Default for PpoConfig<MB, OC>
where
    MB: Default,
    OC: Default,
{
    fn default() -> Self {
        Self {
            policy_fn_config: MB::default(),
            optimizer_config: OC::default(),
            opt_steps_per_update: 10,
            clip_distance: 0.2,
        }
    }
}

impl<MB, OC> BuildPolicy for PpoConfig<MB, OC>
where
    MB: BuildModule,
    MB::Module: SeqPacked + SeqIterative,
    OC: BuildOptimizer,
    OC::Optimizer: Optimizer,
{
    type Policy = Ppo<MB::Module, OC::Optimizer>;

    fn build_policy(&self, in_dim: usize, out_dim: usize, device: Device) -> Self::Policy {
        let policy_fn = self.policy_fn_config.build_module(in_dim, out_dim, device);
        let optimizer = self
            .optimizer_config
            .build_optimizer(policy_fn.trainable_variables())
            .unwrap();
        Ppo {
            policy_fn,
            optimizer,
            opt_steps_per_update: self.opt_steps_per_update,
            clip_distance: self.clip_distance,
        }
    }
}

/// Proximal Policy Optimization (PPO) with a clipped objective.
///
/// # Reference
/// "[Proximal Policy Optimization Algorithms][ppo]" by Schulman et al.
///
/// [ppo]: https://arxiv.org/abs/1707.06347
#[derive(Debug, Default, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct Ppo<M, O = COptimizer> {
    policy_fn: M,
    optimizer: O,
    opt_steps_per_update: u64,
    clip_distance: f64,
}

impl<M: Module, O> AsModule for Ppo<M, O> {
    type Module = M;
    fn as_module(&self) -> &Self::Module {
        &self.policy_fn
    }
    fn as_module_mut(&mut self) -> &mut Self::Module {
        &mut self.policy_fn
    }
}

impl<M, O> Policy for Ppo<M, O>
where
    M: Module + SeqPacked + SeqIterative,
    O: Optimizer,
{
    type PolicyModule = M;

    fn update<AS: ParameterizedDistributionSpace<Tensor> + ?Sized>(
        &mut self,
        features: &dyn HistoryFeatures,
        advantages: PackedTensor,
        action_space: &AS,
        logger: &mut dyn StatsLogger,
    ) {
        let observation_features = features.observation_features();
        let actions = features.actions().tensor();

        let initial_log_probs = {
            let _no_grad = tch::no_grad_guard();

            let policy_output = self.policy_fn.seq_packed(observation_features);
            let distribution = action_space.distribution(policy_output.tensor());
            let log_probs = distribution.log_probs(actions);
            let entropy = distribution.entropy().mean(Kind::Float);
            logger.log_scalar("entropy", entropy.into());

            log_probs
        };

        // TODO Sample a minibatch on each update.
        let sample_minibatch = || {};

        let policy_surrogate_loss_fn = |_| {
            let policy_output = self.policy_fn.seq_packed(observation_features);
            let distribution = action_space.distribution(policy_output.tensor());
            let log_probs = distribution.log_probs(actions);

            let likelihood_ratio = (log_probs - &initial_log_probs).exp();
            let clipped_likelihood_ratio =
                likelihood_ratio.clip(1.0 - self.clip_distance, 1.0 + self.clip_distance);

            (likelihood_ratio * advantages.tensor())
                .min_other(&(clipped_likelihood_ratio * advantages.tensor()))
                .mean(Kind::Float)
                .neg()
        };

        n_backward_steps(
            &mut self.optimizer,
            sample_minibatch,
            policy_surrogate_loss_fn,
            self.opt_steps_per_update,
            logger,
            ToLog::NoAbsLoss, // loss value is offset by a meaningless constant
            "policy update error",
        );
    }
}