use std::fmt::Debug;
use rand::Rng;
use serde::Serialize;
use super::{sample_space::SampleSpace, space::Space};
pub trait Bounded: Sized + Clone + Debug {
fn in_bounds(value: &Self, low: &Self, high: &Self) -> bool;
fn sample_uniform<R: Rng>(rng: &mut R, low: &Self, high: &Self) -> Self;
fn clamp(value: Self, low: &Self, high: &Self) -> Self;
}
#[derive(Debug, Clone, Serialize)]
#[serde(bound = "B: Serialize")]
pub struct BoxSpace<B: Bounded> {
pub low: B,
pub high: B,
}
impl<B: Bounded> BoxSpace<B> {
pub fn new(low: B, high: B) -> Self {
Self { low, high }
}
}
impl<B: Bounded> AsRef<BoxSpace<B>> for BoxSpace<B> {
fn as_ref(&self) -> &BoxSpace<B> {
self
}
}
impl<B: Bounded> Space for BoxSpace<B> {
type Element = B;
fn contains(&self, value: &B) -> bool {
B::in_bounds(value, &self.low, &self.high)
}
}
impl<B: Bounded> SampleSpace for BoxSpace<B> {
type Mask = ();
fn sample<R: Rng>(&self, rng: &mut R, _mask: Option<&Self::Mask>) -> B {
B::sample_uniform(rng, &self.low, &self.high)
}
}
impl Bounded for Vec<f64> {
fn in_bounds(value: &Self, low: &Self, high: &Self) -> bool {
value.len() == low.len()
&& value
.iter()
.zip(low.iter().zip(high.iter()))
.all(|(v, (lo, hi))| v >= lo && v <= hi)
}
fn sample_uniform<R: Rng>(rng: &mut R, low: &Self, high: &Self) -> Self {
assert_eq!(low.len(), high.len(), "Vec<f64> bounds length mismatch");
low.iter()
.zip(high.iter())
.map(|(lo, hi)| rng.gen_range(*lo..=*hi))
.collect()
}
fn clamp(value: Self, low: &Self, high: &Self) -> Self {
value
.into_iter()
.zip(low.iter().zip(high.iter()))
.map(|(v, (lo, hi))| v.clamp(*lo, *hi))
.collect()
}
}
impl Bounded for f64 {
fn in_bounds(value: &Self, low: &Self, high: &Self) -> bool {
value >= low && value <= high
}
fn sample_uniform<R: Rng>(rng: &mut R, low: &Self, high: &Self) -> Self {
rng.gen_range(*low..=*high)
}
fn clamp(value: Self, low: &Self, high: &Self) -> Self {
value.clamp(*low, *high)
}
}
impl Bounded for f32 {
fn in_bounds(value: &Self, low: &Self, high: &Self) -> bool {
value >= low && value <= high
}
fn sample_uniform<R: Rng>(rng: &mut R, low: &Self, high: &Self) -> Self {
rng.gen_range(*low..=*high)
}
fn clamp(value: Self, low: &Self, high: &Self) -> Self {
value.clamp(*low, *high)
}
}
#[derive(Clone, Debug, Serialize)]
pub struct Tensor {
pub data: Vec<f64>,
pub shape: Vec<usize>,
}
impl Tensor {
pub fn new(data: Vec<f64>, shape: Vec<usize>) -> Self {
let expected: usize = shape.iter().product();
assert_eq!(
data.len(),
expected,
"data length {} does not match shape {:?} (expected {})",
data.len(),
shape,
expected
);
Self { data, shape }
}
pub fn filled(value: f64, shape: Vec<usize>) -> Self {
let n: usize = shape.iter().product();
Self {
data: vec![value; n],
shape,
}
}
}
impl Bounded for Tensor {
fn in_bounds(value: &Self, low: &Self, high: &Self) -> bool {
assert_eq!(value.shape, low.shape, "Tensor shape mismatch");
assert_eq!(value.shape, high.shape, "Tensor shape mismatch");
value
.data
.iter()
.zip(low.data.iter().zip(high.data.iter()))
.all(|(v, (lo, hi))| v >= lo && v <= hi)
}
fn sample_uniform<R: Rng>(rng: &mut R, low: &Self, high: &Self) -> Self {
assert_eq!(low.shape, high.shape, "Tensor shape mismatch");
let data: Vec<f64> = low
.data
.iter()
.zip(high.data.iter())
.map(|(lo, hi)| rng.gen_range(*lo..=*hi))
.collect();
Tensor {
data,
shape: low.shape.clone(),
}
}
fn clamp(value: Self, low: &Self, high: &Self) -> Self {
assert_eq!(value.shape, low.shape, "Tensor shape mismatch");
let data: Vec<f64> = value
.data
.iter()
.zip(low.data.iter().zip(high.data.iter()))
.map(|(v, (lo, hi))| v.clamp(*lo, *hi))
.collect();
Tensor {
data,
shape: value.shape,
}
}
}