use super::base_agent::BaseAgent;
use crate::memory::{Experience, OnPolicyBuffer};
use crate::misc::batch_states::batch_states;
use crate::misc::cumsum::cumsum_rev;
use crate::models::BasePolicy;
use rand::seq::SliceRandom;
use rand::thread_rng;
use std::sync::Arc;
use tch::{nn, no_grad, Device, Kind, Tensor};
use ulid::Ulid;
const POLICY_LOG_PROB_RATIO_CLAMP_RANGE: f64 = 8.0;
pub struct PPO {
agent_id: Ulid,
model: Box<dyn BasePolicy>,
optimizer: nn::Optimizer,
buffer: OnPolicyBuffer,
gamma: f64,
lambda: f64,
update_interval: usize,
epoch: usize,
minibatch_size: usize,
policy_clip_epsilon: f64,
value_clip_range: f64,
value_coef: f64,
entropy_coef: f64,
gae_std: bool,
t: usize,
current_episode_id: Ulid,
}
unsafe impl Send for PPO {}
impl PPO {
pub fn new(
model: Box<dyn BasePolicy>,
optimizer: nn::Optimizer,
gamma: f64,
lambda: f64,
update_interval: usize,
epoch: usize,
minibatch_size: usize,
policy_clip_epsilon: f64,
value_clip_range: f64,
value_coef: f64,
entropy_coef: f64,
gae_std: bool,
) -> Self {
assert!(minibatch_size <= update_interval);
PPO {
agent_id: Ulid::new(),
model,
optimizer,
buffer: OnPolicyBuffer::new(),
gamma,
lambda,
update_interval,
epoch,
minibatch_size,
policy_clip_epsilon,
value_clip_range,
value_coef,
entropy_coef,
gae_std,
t: 0,
current_episode_id: Ulid::new(),
}
}
fn _update(&mut self) {
let experiences_per_episode: Vec<Vec<Arc<Experience>>> = self.buffer.flush();
let total_transitions = experiences_per_episode
.iter()
.map(|v| v.len())
.sum::<usize>()
- experiences_per_episode.len();
let n_iter = total_transitions.div_ceil(self.minibatch_size);
let n_data_per_epoch = n_iter * self.minibatch_size;
let n_data = n_data_per_epoch * self.epoch;
let mut rng = thread_rng();
let mut batch_indice = (0..total_transitions).collect::<Vec<usize>>();
let mut all_indice =
Vec::with_capacity(total_transitions * n_data.div_ceil(total_transitions));
for _ in 0..n_data.div_ceil(total_transitions) {
batch_indice.shuffle(&mut rng);
all_indice.extend(batch_indice.iter().cloned());
}
let all_indice = all_indice
.into_iter()
.map(|x| x as i64)
.collect::<Vec<i64>>();
let _skip_first = experiences_per_episode
.iter()
.flat_map(|v| v.iter().skip(1))
.cloned()
.collect::<Vec<Arc<Experience>>>();
let _skip_last = experiences_per_episode
.iter()
.flat_map(|v| v.iter().take(v.len().saturating_sub(1)))
.cloned()
.collect::<Vec<Arc<Experience>>>();
let state = batch_states(
&_skip_last
.iter()
.map(|e| e.state.shallow_clone())
.collect::<Vec<Tensor>>(),
self.model.device(),
);
let next_state = batch_states(
&_skip_first
.iter()
.map(|e| e.state.shallow_clone())
.collect::<Vec<Tensor>>(),
self.model.device(),
);
let _action = batch_states(
&_skip_last
.iter()
.map(|e| e.action.as_ref().unwrap().shallow_clone())
.collect::<Vec<Tensor>>(),
self.model.device(),
);
let action = _action.view([total_transitions as i64, *_action.size().last().unwrap()]);
let reward =
Tensor::from_slice(&_skip_first.iter().map(|e| e.reward).collect::<Vec<f64>>())
.to_device(self.model.device());
let (old_action_distrib, old_value) = self.model.forward(&state);
let old_value = old_value.unwrap().flatten(0, 1).detach();
let old_log_prob = old_action_distrib.log_prob(&action.detach()).detach();
let (_, old_next_value) = self.model.forward(&next_state);
let old_next_value = old_next_value.unwrap().flatten(0, 1).detach();
let non_terminal: Tensor = 1.0
- Tensor::from_slice(
&_skip_first
.iter()
.map(|e| if e.is_episode_terminal { 1.0 } else { 0.0 })
.collect::<Vec<f64>>(),
)
.to_device(self.model.device());
let old_next_value = (old_next_value * non_terminal).detach();
let td_error = (reward + self.gamma * &old_next_value - &old_value).to_device(Device::Cpu);
let _gae = Tensor::from_slice(&cumsum_rev(
&(0..td_error.size()[0])
.map(|i| td_error.double_value(&[i]))
.collect::<Vec<f64>>(),
&_skip_first
.iter()
.map(|e| {
if e.is_episode_terminal {
0.0 } else {
self.gamma * self.lambda
}
})
.collect::<Vec<f64>>(),
))
.to_device(self.model.device())
.detach();
let gae = if self.gae_std {
(&_gae - (&_gae).mean(Kind::Float)) / ((&_gae).std(false) + 1e-8)
} else {
_gae
};
let value_target = &gae + &old_value;
for i in 0..self.epoch {
for j in 0..n_iter {
let minibatch_indice = Tensor::from_slice(
&all_indice[i * n_data_per_epoch + j * self.minibatch_size
..i * n_data_per_epoch + (j + 1) * self.minibatch_size],
)
.to_device(self.model.device());
let (action_distrib, value) = self.model.forward(&state);
let value = value
.unwrap()
.flatten(0, 1)
.index_select(0, &minibatch_indice);
let log_prob = action_distrib.log_prob(&action.detach());
let policy_ratio = (log_prob - &old_log_prob)
.index_select(0, &minibatch_indice)
.clamp(
-POLICY_LOG_PROB_RATIO_CLAMP_RANGE,
POLICY_LOG_PROB_RATIO_CLAMP_RANGE,
)
.exp();
let clipped_policy_ratio = policy_ratio.clamp(
1.0 - self.policy_clip_epsilon,
1.0 + self.policy_clip_epsilon,
);
let _old_value = old_value.index_select(0, &minibatch_indice).detach();
let clipped_value = &_old_value
+ (&value - &_old_value).clamp(-self.value_clip_range, self.value_clip_range);
let minibatch_gae = gae.index_select(0, &minibatch_indice).detach();
let _value_target = (&value_target).index_select(0, &minibatch_indice).detach();
let policy_loss: Tensor = -1.0
* (policy_ratio * &minibatch_gae)
.minimum(&(clipped_policy_ratio * &minibatch_gae))
.mean(Kind::Float);
let value_loss = (&_value_target - value)
.square()
.maximum(&(&_value_target - clipped_value).square())
.mean(Kind::Float);
let entropy_regularized = action_distrib
.entropy()
.index_select(0, &minibatch_indice)
.mean(Kind::Float);
assert!(policy_loss.isnan().any().int64_value(&[]) == 0);
assert!(value_loss.isnan().any().int64_value(&[]) == 0);
assert!(entropy_regularized.isnan().any().int64_value(&[]) == 0);
let loss: Tensor = policy_loss + self.value_coef * value_loss
- self.entropy_coef * entropy_regularized;
self.optimizer.zero_grad();
loss.backward();
self.optimizer.step();
}
}
}
}
impl BaseAgent for PPO {
fn act(&self, obs: &Tensor) -> Tensor {
no_grad(|| {
let state = batch_states(&vec![obs.shallow_clone()], self.model.device());
let (action_distrib, _) = self.model.forward(&state);
let action = action_distrib.most_probable().to_device(Device::Cpu);
action
})
}
fn act_and_train(&mut self, obs: &Tensor, reward: f64) -> Tensor {
self.t += 1;
let state = batch_states(&vec![obs.shallow_clone()], self.model.device());
let action_distrib = no_grad(|| {
let (action_distrib, _) = self.model.forward(&state);
action_distrib
});
let action = action_distrib.sample().detach().to_device(Device::Cpu);
self.buffer.append(
self.agent_id,
self.current_episode_id,
state,
Some(action.shallow_clone()),
Some(action_distrib),
reward,
false,
);
if self.t % self.update_interval == 0 {
self._update();
}
action
}
fn stop_episode_and_train(&mut self, obs: &Tensor, reward: f64) {
let state = batch_states(&vec![obs.shallow_clone()], self.model.device());
self.buffer.append(
self.agent_id,
self.current_episode_id,
state,
None,
None,
reward,
true,
);
self.current_episode_id = Ulid::new();
}
fn get_statistics(&self) -> Vec<(String, f64)> {
vec![]
}
fn get_agent_id(&self) -> &Ulid {
&self.agent_id
}
fn save(&self, dirname: &str, ancestors: std::collections::HashSet<String>) {}
fn load(&self, dirname: &str, ancestors: std::collections::HashSet<String>) {}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::FCSoftmaxPolicyWithValue;
use tch::{nn, nn::OptimizerConfig, Device, Kind, Tensor};
#[test]
fn test_ppo_new() {
let vs = nn::VarStore::new(Device::Cpu);
let optimizer = nn::Adam::default().build(&vs, 1e-3).unwrap();
let model = FCSoftmaxPolicyWithValue::new(vs, 4, 2, 2, 64, 0.0);
let ppo = PPO::new(
Box::new(model),
optimizer,
0.99,
0.99,
100,
8,
16,
0.1,
0.2,
1.0,
1.0,
false,
);
assert_eq!(ppo.update_interval, 100);
assert_eq!(ppo.epoch, 8);
assert_eq!(ppo.gamma, 0.99);
assert_eq!(ppo.t, 0);
}
#[test]
fn test_ppo_act_and_train() {
let vs = nn::VarStore::new(Device::Cpu);
let optimizer = nn::Adam::default().build(&vs, 1e-3).unwrap();
let model = FCSoftmaxPolicyWithValue::new(vs, 4, 4, 2, 64, 0.0);
let mut ppo = PPO::new(
Box::new(model),
optimizer,
0.5,
0.99,
100,
3,
32,
0.1,
0.2,
1.0,
1.0,
false,
);
let mut reward = 0.0;
for i in 0..2000 {
let obs = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0]).to_kind(Kind::Float);
let action = ppo.act_and_train(&obs, reward);
let action_value = i64::from(action.int64_value(&[]));
if action_value == 2 {
reward = 100.0;
} else {
reward = 0.0
}
assert!([0, 1, 2, 3].contains(&action_value));
assert_eq!(ppo.t, i + 1);
}
let obs = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0]).to_kind(Kind::Float);
ppo.stop_episode_and_train(&obs, 1.0);
for _ in 0..1000 {
let action = ppo.act(&obs);
let action_value = i64::from(action.int64_value(&[]));
assert_eq!(action_value, 2);
}
}
}