use crate::envs::Successor;
use crate::simulation::PartialStep;
use crate::spaces::{FeatureSpace, ReprSpace, Space};
use crate::torch::packed::{PackedSeqIter, PackedStructure, PackedTensor};
use crate::torch::ExclusiveTensor;
use crate::utils::sequence::Sequence;
use ndarray::Axis;
use once_cell::unsync::OnceCell;
use std::cmp::Reverse;
use tch::{Device, Tensor};
pub trait HistoryFeatures {
fn observation_features(&self) -> &PackedTensor;
fn extended_observation_features(&self) -> (&PackedTensor, &PackedTensor);
fn actions(&self) -> &PackedTensor;
fn rewards(&self) -> &PackedTensor;
fn device(&self) -> Device;
}
#[derive(Debug)]
pub struct LazyHistoryFeatures<'a, OS: Space + ?Sized, AS: Space + ?Sized, E> {
episodes: Vec<E>,
observation_space: &'a OS,
action_space: &'a AS,
device: Device,
extended_structure: PackedStructure,
cached_observation_features: OnceCell<PackedTensor>,
cached_extended_observation_features: OnceCell<(PackedTensor, PackedTensor)>,
cached_actions: OnceCell<PackedTensor>,
cached_rewards: OnceCell<PackedTensor>,
}
impl<'a, OS, AS, E> LazyHistoryFeatures<'a, OS, AS, E>
where
OS: Space + ?Sized,
AS: Space + ?Sized,
E: Sequence,
{
pub fn new<I>(
episodes: I,
observation_space: &'a OS,
action_space: &'a AS,
device: Device,
) -> Self
where
I: IntoIterator<Item = E>,
{
let mut episodes: Vec<_> = episodes.into_iter().collect();
episodes.sort_unstable_by_key(|ep| Reverse(ep.len()));
let extended_structure =
PackedStructure::from_sorted_sequence_lengths(episodes.iter().map(|ep| ep.len() + 1))
.unwrap();
Self {
episodes,
observation_space,
action_space,
device,
extended_structure,
cached_observation_features: OnceCell::new(),
cached_extended_observation_features: OnceCell::new(),
cached_actions: OnceCell::new(),
cached_rewards: OnceCell::new(),
}
}
pub fn num_steps(&self) -> usize {
self.extended_structure.len() - self.episodes.len()
}
pub fn num_episodes(&self) -> usize {
self.episodes.len()
}
pub fn is_empty(&self) -> bool {
self.episodes.is_empty()
}
fn structure(&self) -> PackedStructure {
self.extended_structure.clone().trim(1)
}
}
impl<'a, OS, AS, E> HistoryFeatures for LazyHistoryFeatures<'a, OS, AS, E>
where
OS: FeatureSpace + ?Sized,
AS: ReprSpace<Tensor> + ?Sized,
E: Sequence<Item = &'a PartialStep<OS::Element, AS::Element>>
+ IntoIterator<Item = &'a PartialStep<OS::Element, AS::Element>>
+ Copy,
E::IntoIter: DoubleEndedIterator,
{
fn observation_features(&self) -> &PackedTensor {
self.cached_observation_features.get_or_init(|| {
let tensor = self
.observation_space
.batch_features::<_, Tensor>(
PackedSeqIter::from_sorted(&self.episodes).map(|step| &step.observation),
)
.to(self.device);
PackedTensor::from_parts(tensor, self.structure())
})
}
fn extended_observation_features(&self) -> (&PackedTensor, &PackedTensor) {
let (extended_observations, is_invalid) =
self.cached_extended_observation_features.get_or_init(|| {
let observations = PackedSeqIter::from_sorted(
self.episodes
.iter()
.copied()
.map(ExtendedEpisodeObservations::from),
);
let num_observations = observations.len();
let num_features = self.observation_space.num_features();
let mut extended_observations =
ExclusiveTensor::<f32, _>::zeros((num_observations, num_features));
let mut is_invalid = ExclusiveTensor::<bool, _>::zeros(num_observations);
{
let mut extended_observations = extended_observations.array_view_mut();
let mut is_invalid = is_invalid.array_view_mut();
for (i, obs) in observations.enumerate() {
if let Some(obs) = obs {
self.observation_space.features_out(
obs,
extended_observations
.index_axis_mut(Axis(0), i)
.as_slice_mut()
.unwrap(),
true,
);
} else {
is_invalid[i] = true;
}
}
}
let packed_extended_observations = PackedTensor::from_parts(
extended_observations.into_tensor().to(self.device),
self.extended_structure.clone(),
);
let packed_is_invalid = PackedTensor::from_parts(
is_invalid.into_tensor().to(self.device),
self.extended_structure.clone(),
);
(packed_extended_observations, packed_is_invalid)
});
(extended_observations, is_invalid)
}
fn actions(&self) -> &PackedTensor {
self.cached_actions.get_or_init(|| {
let tensor = self
.action_space
.batch_repr(PackedSeqIter::from_sorted(&self.episodes).map(|step| &step.action))
.to(self.device);
PackedTensor::from_parts(tensor, self.structure())
})
}
#[allow(clippy::cast_possible_truncation)]
fn rewards(&self) -> &PackedTensor {
self.cached_rewards.get_or_init(|| {
let tensor = Tensor::of_slice(
&PackedSeqIter::from_sorted(&self.episodes)
.map(|step| f64::from(step.feedback) as f32)
.collect::<Vec<_>>(),
)
.to(self.device);
PackedTensor::from_parts(tensor, self.structure())
})
}
fn device(&self) -> Device {
self.device
}
}
struct ExtendedEpisodeObservations<E> {
episode: E,
}
impl<E> From<E> for ExtendedEpisodeObservations<E> {
fn from(episode: E) -> Self {
Self { episode }
}
}
impl<'a, E, O, A> Sequence for ExtendedEpisodeObservations<E>
where
E: Sequence<Item = &'a PartialStep<O, A>>,
O: 'a,
A: 'a,
{
type Item = Option<&'a O>;
fn len(&self) -> usize {
self.episode.len() + 1
}
fn is_empty(&self) -> bool {
false
}
fn get(&self, idx: usize) -> Option<Self::Item> {
match self.episode.get(idx) {
Some(step) => Some(Some(&step.observation)),
None if idx == 0 => Some(None),
None if idx == self.episode.len() => {
match &self.episode.get(idx - 1).unwrap().next {
Successor::Interrupt(obs) => Some(Some(obs)),
_ => Some(None),
}
}
_ => None,
}
}
}
#[cfg(test)]
#[allow(clippy::needless_pass_by_value)]
pub(crate) mod tests {
use super::*;
use crate::envs::Successor::{Continue, Interrupt, Terminate};
use crate::feedback::Reward;
use crate::spaces::{BooleanSpace, IndexSpace};
use rstest::{fixture, rstest};
pub struct StoredHistory<OS: Space, AS: Space> {
episodes: Vec<Vec<PartialStep<OS::Element, AS::Element>>>,
observation_space: OS,
action_space: AS,
device: Device,
}
impl<OS: Space, AS: Space> StoredHistory<OS, AS> {
#[allow(clippy::type_complexity)]
pub fn features(
&self,
) -> LazyHistoryFeatures<OS, AS, &[PartialStep<OS::Element, AS::Element>]> {
LazyHistoryFeatures::new(
self.episodes.iter().map(AsRef::as_ref),
&self.observation_space,
&self.action_space,
self.device,
)
}
}
#[fixture]
pub fn history() -> StoredHistory<BooleanSpace, IndexSpace> {
let episodes = vec![
vec![
PartialStep::new(true, 0, Reward(1.0), Continue(())),
PartialStep::new(true, 1, Reward(1.0), Continue(())),
PartialStep::new(true, 2, Reward(1.0), Continue(())),
PartialStep::new(true, 3, Reward(1.0), Continue(())),
],
vec![
PartialStep::new(false, 10, Reward(-1.0), Continue(())),
PartialStep::new(false, 11, Reward(-1.0), Continue(())),
PartialStep::new(false, 12, Reward(0.0), Continue(())),
PartialStep::new(false, 13, Reward(0.0), Continue(())),
PartialStep::new(false, 14, Reward(1.0), Continue(())),
PartialStep::new(false, 15, Reward(1.0), Terminate),
],
vec![
PartialStep::new(false, 20, Reward(2.0), Continue(())),
PartialStep::new(true, 21, Reward(2.0), Continue(())),
PartialStep::new(false, 22, Reward(2.0), Interrupt(true)),
],
vec![PartialStep::new(true, 30, Reward(3.0), Terminate)],
];
StoredHistory {
episodes,
observation_space: BooleanSpace::new(),
action_space: IndexSpace::new(31),
device: Device::Cpu,
}
}
#[rstest]
fn num_steps(history: StoredHistory<BooleanSpace, IndexSpace>) {
assert_eq!(history.features().num_steps(), 14);
}
#[rstest]
fn num_episodes(history: StoredHistory<BooleanSpace, IndexSpace>) {
assert_eq!(history.features().num_episodes(), 4);
}
#[rstest]
fn is_empty(history: StoredHistory<BooleanSpace, IndexSpace>) {
assert!(!history.features().is_empty());
}
#[rstest]
fn device(history: StoredHistory<BooleanSpace, IndexSpace>) {
assert_eq!(history.features().device(), Device::Cpu);
}
#[rstest]
fn observation_features(history: StoredHistory<BooleanSpace, IndexSpace>) {
let features = history.features();
let actual = features.observation_features();
let expected = &Tensor::of_slice(&[
0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0f32,
])
.unsqueeze(-1);
assert_eq!(actual.tensor(), expected);
}
#[rstest]
fn actions(history: StoredHistory<BooleanSpace, IndexSpace>) {
let features = history.features();
let actual = features.actions();
let expected = &Tensor::of_slice(&[
10, 0, 20, 30, 11, 1, 21, 12, 2, 22, 13, 3, 14, 15i64,
]);
assert_eq!(actual.tensor(), expected);
}
#[rstest]
fn actions_batch_sizes_tensor(history: StoredHistory<BooleanSpace, IndexSpace>) {
assert_eq!(
history.features().actions().batch_sizes_tensor(),
Tensor::of_slice(&[4, 3, 3, 2, 1, 1])
);
}
#[rstest]
fn rewards(history: StoredHistory<BooleanSpace, IndexSpace>) {
let features = history.features();
let actual = features.rewards();
let expected = &Tensor::of_slice(&[
-1.0, 1.0, 2.0, 3.0, -1.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 1.0, 1.0f32,
]);
assert_eq!(actual.tensor(), expected);
}
}