from _common._algorithms.BaseReplayBuffer import combined_shape, discount_cumsum, statistics_scalar, ReplayBufferAbstract
import numpy as np
import torch
class ReplayBuffer(ReplayBufferAbstract):
def __init__(self, obs_dim, mask_dim, size, gamma=0.99, lam=0.95, with_vf_baseline=True):
super().__init__()
self.obs_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
self.act_buf = np.zeros(combined_shape(size), dtype=np.float32)
self.mask_buf = np.zeros(combined_shape(size, mask_dim), dtype=np.float32)
self.adv_buf = np.zeros(size, dtype=np.float32)
self.rew_buf = np.zeros(size, dtype=np.float32)
self.ret_buf = np.zeros(size, dtype=np.float32)
self.logp_buf = np.zeros(size, dtype=np.float32)
if with_vf_baseline:
self.val_buf = np.zeros(size, dtype=np.float32)
self.gamma, self.lam = gamma, lam
self.ptr, self.path_start_idx, self.max_size = 0, 0, size
self.capacity = size
self.with_vf_baseline = with_vf_baseline
def store(self, obs, act, mask, rew, val, logp):
assert self.ptr < self.max_size self.obs_buf[self.ptr] = obs
self.act_buf[self.ptr] = act
self.mask_buf[self.ptr] = mask
self.rew_buf[self.ptr] = rew
self.logp_buf[self.ptr] = logp
if self.with_vf_baseline:
self.val_buf[self.ptr] = val
self.ptr += 1
def finish_path(self, last_val=0):
path_slice = slice(self.path_start_idx, self.ptr)
if self.with_vf_baseline:
rews = np.append(self.rew_buf[path_slice], last_val)
vals = np.append(self.val_buf[path_slice], last_val)
deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]
self.adv_buf[path_slice] = discount_cumsum(deltas, self.gamma * self.lam)
self.ret_buf[path_slice] = discount_cumsum(rews, self.gamma)[:-1]
else:
rews = self.rew_buf[path_slice]
self.adv_buf[path_slice] = discount_cumsum(rews, self.gamma * self.lam)
self.ret_buf[path_slice] = discount_cumsum(rews, self.gamma)
self.path_start_idx = self.ptr
def get(self):
assert self.ptr < self.max_size
actual_size = self.ptr
self.ptr, self.path_start_idx = 0, 0
actual_adv_buf = np.array(self.adv_buf, dtype=np.float32)
actual_adv_buf = actual_adv_buf[:actual_size]
adv_mean, adv_std = statistics_scalar(actual_adv_buf)
actual_adv_buf = (actual_adv_buf - adv_mean) / adv_std
data = dict(obs=self.obs_buf[:actual_size],
act=self.act_buf[:actual_size], mask=self.mask_buf[:actual_size],
ret=self.ret_buf[:actual_size], adv=actual_adv_buf, logp=self.logp_buf[:actual_size])
return {k: torch.as_tensor(v, dtype=torch.float32) for k, v in data.items()}