use rlevo_core::base::{Observation, TensorConversionError, TensorConvertible};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LunarLanderObservation {
pub values: [f32; 8],
}
impl LunarLanderObservation {
pub fn new(values: [f32; 8]) -> Self {
Self { values }
}
pub fn x(&self) -> f32 {
self.values[0]
}
pub fn y(&self) -> f32 {
self.values[1]
}
pub fn vx(&self) -> f32 {
self.values[2]
}
pub fn vy(&self) -> f32 {
self.values[3]
}
pub fn angle(&self) -> f32 {
self.values[4]
}
pub fn angular_vel(&self) -> f32 {
self.values[5]
}
pub fn leg1_contact(&self) -> f32 {
self.values[6]
}
pub fn leg2_contact(&self) -> f32 {
self.values[7]
}
pub fn is_finite(&self) -> bool {
self.values.iter().all(|v| v.is_finite())
}
}
impl Default for LunarLanderObservation {
fn default() -> Self {
Self { values: [0.0; 8] }
}
}
impl Observation<1> for LunarLanderObservation {
fn shape() -> [usize; 1] {
[8]
}
}
impl<B: burn::tensor::backend::Backend> TensorConvertible<1, B> for LunarLanderObservation {
fn to_tensor(
&self,
device: &<B as burn::tensor::backend::BackendTypes>::Device,
) -> burn::tensor::Tensor<B, 1> {
burn::tensor::Tensor::from_floats(self.values, device)
}
fn from_tensor(tensor: burn::tensor::Tensor<B, 1>) -> Result<Self, TensorConversionError> {
let dims = tensor.dims();
if dims.as_slice() != [8] {
return Err(TensorConversionError {
message: format!("expected shape [8], got {dims:?}"),
});
}
let v = tensor
.into_data()
.into_vec::<f32>()
.map_err(|e| TensorConversionError {
message: e.to_string(),
})?;
let mut values = [0.0_f32; 8];
values.copy_from_slice(&v);
Ok(Self { values })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shape() {
assert_eq!(LunarLanderObservation::shape(), [8]);
}
#[test]
fn test_default_is_finite() {
assert!(LunarLanderObservation::default().is_finite());
}
#[test]
fn round_trips_through_tensor() {
use burn::backend::Flex;
type TestBackend = Flex;
let device = Default::default();
let obs = LunarLanderObservation::new([0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 1.0, 0.0]);
let tensor =
<LunarLanderObservation as TensorConvertible<1, TestBackend>>::to_tensor(&obs, &device);
let round_tripped =
<LunarLanderObservation as TensorConvertible<1, TestBackend>>::from_tensor(tensor)
.unwrap();
assert_eq!(round_tripped, obs);
}
#[test]
fn from_tensor_rejects_wrong_shape() {
use burn::backend::Flex;
type TestBackend = Flex;
let device = Default::default();
let tensor = burn::tensor::Tensor::<TestBackend, 1>::from_floats([0.0, 1.0, 2.0], &device);
let err = <LunarLanderObservation as TensorConvertible<1, TestBackend>>::from_tensor(tensor)
.unwrap_err();
assert!(err.message.contains("expected shape [8]"));
}
}