use super::{
FeatureSpace, FiniteSpace, LogElementSpace, NonEmptySpace, ParameterizedDistributionSpace,
ReprSpace, Space, SubsetOrd,
};
use crate::logging::{LogError, LogValue, StatsLogger};
use crate::torch::distributions::Categorical;
use crate::utils::distributions::ArrayDistribution;
use ndarray::{s, ArrayBase, DataMut, Ix2};
use num_traits::{Float, One, Zero};
use rand::distributions::Distribution;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
use std::{any, fmt};
use tch::{Device, Kind, Tensor};
pub trait Indexed {
const SIZE: usize;
fn index(&self) -> usize;
fn from_index(index: usize) -> Option<Self>
where
Self: Sized;
}
#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
pub struct IndexedTypeSpace<T> {
#[serde(skip)]
element_type: PhantomData<fn() -> T>,
}
impl<T> IndexedTypeSpace<T> {
#[must_use]
#[inline]
pub fn new() -> Self {
Self {
element_type: PhantomData,
}
}
}
impl<T> Default for IndexedTypeSpace<T> {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<T> fmt::Debug for IndexedTypeSpace<T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "IndexedTypeSpace<{}>", any::type_name::<T>())
}
}
impl<T> fmt::Display for IndexedTypeSpace<T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "IndexedTypeSpace<{}>", any::type_name::<T>())
}
}
impl<T> Clone for IndexedTypeSpace<T> {
#[inline]
fn clone(&self) -> Self {
Self::new()
}
}
impl<T> Copy for IndexedTypeSpace<T> {}
impl<T: Clone + Send> Space for IndexedTypeSpace<T> {
type Element = T;
#[inline]
fn contains(&self, _element: &Self::Element) -> bool {
true
}
}
impl<T> PartialEq for IndexedTypeSpace<T> {
#[inline]
fn eq(&self, _other: &Self) -> bool {
true }
}
impl<T> Eq for IndexedTypeSpace<T> {}
impl<T> Hash for IndexedTypeSpace<T> {
#[inline]
fn hash<H: Hasher>(&self, _state: &mut H) {}
}
impl<T> SubsetOrd for IndexedTypeSpace<T> {
#[inline]
fn subset_cmp(&self, _other: &Self) -> Option<Ordering> {
Some(Ordering::Equal)
}
}
impl<T: Indexed + Clone + Send> NonEmptySpace for IndexedTypeSpace<T> {
#[inline]
fn some_element(&self) -> Self::Element {
T::from_index(0).expect("space is empty")
}
}
impl<T: Indexed> Distribution<T> for IndexedTypeSpace<T> {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
T::from_index(rng.gen_range(0..T::SIZE)).unwrap()
}
}
impl<T: Indexed + Clone + Send> FiniteSpace for IndexedTypeSpace<T> {
#[inline]
fn size(&self) -> usize {
T::SIZE
}
#[inline]
fn to_index(&self, element: &Self::Element) -> usize {
T::index(element)
}
#[inline]
fn from_index(&self, index: usize) -> Option<Self::Element> {
T::from_index(index)
}
}
impl<T: Indexed + Clone + Send> FeatureSpace for IndexedTypeSpace<T> {
#[inline]
fn num_features(&self) -> usize {
T::SIZE
}
#[inline]
fn features_out<'a, F: Float>(
&self,
element: &Self::Element,
out: &'a mut [F],
zeroed: bool,
) -> &'a mut [F] {
let (out, rest) = out.split_at_mut(T::SIZE);
if !zeroed {
out.fill(F::zero());
}
out[self.to_index(element)] = F::one();
rest
}
#[inline]
fn batch_features_out<'a, I, A>(&self, elements: I, out: &mut ArrayBase<A, Ix2>, zeroed: bool)
where
I: IntoIterator<Item = &'a Self::Element>,
Self::Element: 'a,
A: DataMut,
A::Elem: Float,
{
if !zeroed {
out.slice_mut(s![.., 0..self.num_features()])
.fill(Zero::zero());
}
let mut rows = out.rows_mut().into_iter();
for element in elements {
let mut row = rows.next().expect("fewer rows than elements");
row[self.to_index(element)] = One::one();
}
}
}
impl<T: Indexed + Clone + Send> ReprSpace<Tensor> for IndexedTypeSpace<T> {
#[inline]
fn repr(&self, element: &Self::Element) -> Tensor {
Tensor::scalar_tensor(self.to_index(element) as i64, (Kind::Int64, Device::Cpu))
}
#[inline]
fn batch_repr<'a, I>(&self, elements: I) -> Tensor
where
I: IntoIterator<Item = &'a Self::Element>,
Self::Element: 'a,
{
let indices: Vec<_> = elements
.into_iter()
.map(|elem| self.to_index(elem) as i64)
.collect();
Tensor::of_slice(&indices)
}
}
impl<T: Indexed + Clone + Send> ParameterizedDistributionSpace<Tensor> for IndexedTypeSpace<T> {
type Distribution = Categorical;
#[inline]
fn num_distribution_params(&self) -> usize {
T::SIZE
}
#[inline]
fn sample_element(&self, params: &Tensor) -> Self::Element {
self.from_index(
self.distribution(params)
.sample()
.int64_value(&[])
.try_into()
.unwrap(),
)
.unwrap()
}
#[inline]
fn distribution(&self, params: &Tensor) -> Self::Distribution {
Self::Distribution::new(params)
}
}
impl<T: Indexed + Clone + Send> LogElementSpace for IndexedTypeSpace<T> {
#[inline]
fn log_element<L: StatsLogger + ?Sized>(
&self,
name: &'static str,
element: &Self::Element,
logger: &mut L,
) -> Result<(), LogError> {
let log_value = LogValue::Index {
value: self.to_index(element),
size: T::SIZE,
};
logger.log(name.into(), log_value)
}
}
impl Indexed for bool {
const SIZE: usize = 2;
#[inline]
fn index(&self) -> usize {
(*self).into()
}
#[inline]
fn from_index(index: usize) -> Option<Self> {
match index {
0 => Some(false),
1 => Some(true),
_ => None,
}
}
}
#[cfg(test)]
mod trit {
use relearn_derive::Indexed;
#[derive(Debug, Copy, Clone, Indexed, PartialEq, Eq)]
pub enum Trit {
Zero,
One,
Two,
}
}
#[cfg(test)]
mod space {
use super::super::testing;
use super::trit::Trit;
use super::*;
fn check_contains_samples<T: Indexed + Clone + Send>() {
let space = IndexedTypeSpace::<T>::new();
testing::check_contains_samples(&space, 100);
}
#[test]
fn contains_samples_bool() {
check_contains_samples::<bool>();
}
#[test]
fn contains_samples_enum() {
check_contains_samples::<Trit>();
}
fn check_from_to_index_iter_size<T: Indexed + Clone + Send>() {
let space = IndexedTypeSpace::<T>::new();
testing::check_from_to_index_iter_size(&space);
}
#[test]
fn from_to_index_iter_size_bool() {
check_from_to_index_iter_size::<bool>();
}
#[test]
fn from_to_index_iter_size_enum() {
check_from_to_index_iter_size::<Trit>();
}
fn check_from_index_sampled<T: Indexed + Clone + Send>() {
let space = IndexedTypeSpace::<T>::new();
testing::check_from_index_sampled(&space, 20);
}
#[test]
fn from_index_sampled_bool() {
check_from_index_sampled::<bool>();
}
#[test]
fn from_index_sampled_enum() {
check_from_index_sampled::<Trit>();
}
fn check_from_index_invalid<T: Indexed + Clone + Send>() {
let space = IndexedTypeSpace::<T>::new();
testing::check_from_index_invalid(&space);
}
#[test]
fn from_index_invalid_bool() {
check_from_index_invalid::<bool>();
}
#[test]
fn from_index_invalid_enum() {
check_from_index_invalid::<Trit>();
}
}
#[cfg(test)]
mod subset_ord {
use super::super::SubsetOrd;
use super::trit::Trit;
use super::*;
use std::cmp::Ordering;
#[test]
fn eq() {
assert_eq!(
IndexedTypeSpace::<Trit>::new(),
IndexedTypeSpace::<Trit>::new()
);
}
#[test]
fn cmp_equal() {
assert_eq!(
IndexedTypeSpace::<Trit>::new().subset_cmp(&IndexedTypeSpace::<Trit>::new()),
Some(Ordering::Equal)
);
}
#[test]
fn not_strict_subset() {
assert!(
!(IndexedTypeSpace::<Trit>::new().strict_subset_of(&IndexedTypeSpace::<Trit>::new()))
);
}
}
#[cfg(test)]
mod serialize {
use super::trit::Trit;
use super::*;
use serde_test::{assert_tokens, Token};
#[test]
fn ser_de_tokens() {
let space = IndexedTypeSpace::<Trit>::new();
assert_tokens(
&space,
&[
Token::Struct {
name: "IndexedTypeSpace",
len: 0,
},
Token::StructEnd,
],
);
}
}
#[cfg(test)]
mod derive_indexed_macro {
use super::*;
use relearn_derive::Indexed;
#[derive(Debug, Indexed)]
enum EmptyEnum {}
#[derive(Debug, Indexed)]
enum NonEmptyEnum {
A,
B,
}
#[test]
fn empty_enum_len() {
assert_eq!(EmptyEnum::SIZE, 0);
}
#[test]
fn empty_enum_from_index_invalid_0() {
let result = EmptyEnum::from_index(0);
assert!(result.is_none(), "Expected `None`, got {:?}", result);
}
#[test]
fn empty_enum_from_index_invalid_1() {
let result = EmptyEnum::from_index(1);
assert!(result.is_none(), "Expected `None`, got {:?}", result);
}
#[test]
fn non_empty_enum_len() {
assert_eq!(NonEmptyEnum::SIZE, 2);
}
#[test]
fn non_empty_enum_to_index() {
assert_eq!(NonEmptyEnum::A.index(), 0);
assert_eq!(NonEmptyEnum::B.index(), 1);
}
#[test]
fn non_empty_enum_from_index_valid_0() {
let result = NonEmptyEnum::from_index(0);
if let Some(NonEmptyEnum::A) = result {
} else {
panic!("Expected `Some(NonEmptyEnum::A)`, got {:?}", result);
}
}
#[test]
fn non_empty_enum_from_index_valid_1() {
let result = NonEmptyEnum::from_index(1);
if let Some(NonEmptyEnum::B) = result {
} else {
panic!("Expected `Some(NonEmptyEnum::B)`, got {:?}", result);
}
}
#[test]
fn non_empty_enum_from_index_invalid_2() {
let result = NonEmptyEnum::from_index(2);
assert!(result.is_none(), "Expected `None`, got {:?}", result);
}
}
#[cfg(test)]
mod feature_space {
use super::trit::Trit;
use super::*;
fn space() -> IndexedTypeSpace<Trit> {
IndexedTypeSpace::new()
}
#[test]
fn num_features() {
let space = space();
assert_eq!(3, space.num_features());
}
features_tests!(trit_zero, space(), Trit::Zero, [1.0, 0.0, 0.0]);
features_tests!(trit_one, space(), Trit::One, [0.0, 1.0, 0.0]);
features_tests!(trit_two, space(), Trit::Two, [0.0, 0.0, 1.0]);
batch_features_tests!(
trit_batch,
space(),
[Trit::Two, Trit::Zero, Trit::One, Trit::Zero],
[
[0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0] ]
);
}
#[cfg(test)]
mod repr_space_tensor {
use super::trit::Trit;
use super::*;
#[test]
fn repr() {
let space: IndexedTypeSpace<Trit> = IndexedTypeSpace::new();
assert_eq!(
space.repr(&Trit::Zero),
Tensor::scalar_tensor(0, (Kind::Int64, Device::Cpu))
);
assert_eq!(
space.repr(&Trit::One),
Tensor::scalar_tensor(1, (Kind::Int64, Device::Cpu))
);
assert_eq!(
space.repr(&Trit::Two),
Tensor::scalar_tensor(2, (Kind::Int64, Device::Cpu))
);
}
#[test]
fn batch_repr() {
let space: IndexedTypeSpace<Trit> = IndexedTypeSpace::new();
let elements = [Trit::Zero, Trit::One, Trit::Two, Trit::One];
let actual = space.batch_repr(&elements);
let expected = Tensor::of_slice(&[0_i64, 1, 2, 1]);
assert_eq!(actual, expected);
}
}
#[cfg(test)]
mod parameterized_sample_space_tensor {
use super::super::IndexedTypeSpace;
use super::trit::Trit;
use super::*;
#[test]
fn num_sample_params() {
let space: IndexedTypeSpace<Trit> = IndexedTypeSpace::new();
assert_eq!(3, space.num_distribution_params());
}
#[test]
fn sample_element_deterministic() {
let space: IndexedTypeSpace<Trit> = IndexedTypeSpace::new();
let params = Tensor::of_slice(&[f32::NEG_INFINITY, 0.0, f32::NEG_INFINITY]);
for _ in 0..10 {
assert_eq!(Trit::One, space.sample_element(¶ms));
}
}
#[test]
fn sample_element_two_of_three() {
let space: IndexedTypeSpace<Trit> = IndexedTypeSpace::new();
let params = Tensor::of_slice(&[f32::NEG_INFINITY, 0.0, 0.0]);
for _ in 0..10 {
assert!(Trit::Zero != space.sample_element(¶ms));
}
}
#[test]
fn sample_element_check_distribution() {
let space: IndexedTypeSpace<Trit> = IndexedTypeSpace::new();
let params = Tensor::of_slice(&[-1.0, 0.0, 1.0]);
let mut one_count = 0;
let mut two_count = 0;
let mut three_count = 0;
for _ in 0..1000 {
match space.sample_element(¶ms) {
Trit::Zero => one_count += 1,
Trit::One => two_count += 1,
Trit::Two => three_count += 1,
}
}
assert!((58..=121).contains(&one_count));
assert!((197..=292).contains(&two_count));
assert!((613..=717).contains(&three_count));
}
}