pub mod traits {
pub use burn_tensor::{Tensor, TensorKind, backend::Backend};
use std::any::Any;
pub use thiserror::Error;
pub use uuid::Uuid;
#[derive(Debug, Error, Clone)]
pub enum EnvironmentError {
#[error("Environment error: {0}")]
EnvironmentError(String),
#[error("Observation building error: {0}")]
ObservationBuildingError(String),
#[error("Training performance return error: {0}")]
TrainingPerformanceReturnError(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum EnvironmentKind {
Scalar,
Vector,
Other(String),
Unknown,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum EnvDType {
NdArray(EnvNdArrayDType),
Tch(EnvTchDType),
}
impl std::fmt::Display for EnvDType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EnvDType::NdArray(ndarray) => write!(f, "NdArray({})", ndarray),
EnvDType::Tch(tch) => write!(f, "Tch({})", tch),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum EnvTchDType {
F16,
Bf16,
F32,
F64,
I8,
I16,
I32,
I64,
U8,
Bool,
}
impl std::fmt::Display for EnvTchDType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EnvTchDType::F16 => write!(f, "F16"),
EnvTchDType::Bf16 => write!(f, "Bf16"),
EnvTchDType::F32 => write!(f, "F32"),
EnvTchDType::F64 => write!(f, "F64"),
EnvTchDType::I8 => write!(f, "I8"),
EnvTchDType::I16 => write!(f, "I16"),
EnvTchDType::I32 => write!(f, "I32"),
EnvTchDType::I64 => write!(f, "I64"),
EnvTchDType::U8 => write!(f, "U8"),
EnvTchDType::Bool => write!(f, "Bool"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum EnvNdArrayDType {
F16,
F32,
F64,
I8,
I16,
I32,
I64,
Bool,
}
impl std::fmt::Display for EnvNdArrayDType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EnvNdArrayDType::F16 => write!(f, "F16"),
EnvNdArrayDType::F32 => write!(f, "F32"),
EnvNdArrayDType::F64 => write!(f, "F64"),
EnvNdArrayDType::I8 => write!(f, "I8"),
EnvNdArrayDType::I16 => write!(f, "I16"),
EnvNdArrayDType::I32 => write!(f, "I32"),
EnvNdArrayDType::I64 => write!(f, "I64"),
EnvNdArrayDType::Bool => write!(f, "Bool"),
}
}
}
pub type EnvironmentUuid = Uuid;
pub type EnvInfo = Vec<(String, String)>;
#[derive(Debug, Clone)]
pub struct ScalarEnvReset<B: Backend, const D_IN: usize, KInput: TensorKind<B>> {
pub observation: Tensor<B, D_IN, KInput>,
pub info: Option<EnvInfo>,
}
#[derive(Debug, Clone)]
pub struct ScalarEnvStep<B: Backend, const D_IN: usize, KInput: TensorKind<B>> {
pub observation: Tensor<B, D_IN, KInput>,
pub reward: f32,
pub terminated: bool,
pub truncated: bool,
pub info: Option<EnvInfo>,
}
#[derive(Debug, Clone)]
pub struct VectorEnvReset<B: Backend, const D_IN: usize, KInput: TensorKind<B>> {
pub env_id: EnvironmentUuid,
pub observation: Tensor<B, D_IN, KInput>,
pub info: Option<EnvInfo>,
}
#[derive(Debug, Clone)]
pub struct VectorEnvStep<B: Backend, const D_IN: usize, KInput: TensorKind<B>> {
pub env_id: EnvironmentUuid,
pub observation: Tensor<B, D_IN, KInput>,
pub reward: f32,
pub terminated: bool,
pub truncated: bool,
pub info: Option<EnvInfo>,
}
pub type DynVectorEnv<B, const D_IN: usize, const D_OUT: usize, KInput, KOutput> =
dyn VectorEnvironment<B, D_IN, D_OUT, KInput, KOutput>;
pub trait DynScalarEnvironment<
B: Backend,
const D_IN: usize,
const D_OUT: usize,
KInput: TensorKind<B>,
KOutput: TensorKind<B>,
>: ScalarEnvironment<B, D_IN, D_OUT, KInput, KOutput> + Send + Sync
{
fn clone_box(&self) -> Box<dyn DynScalarEnvironment<B, D_IN, D_OUT, KInput, KOutput>>;
}
impl<B, const D_IN: usize, const D_OUT: usize, KInput, KOutput, T>
DynScalarEnvironment<B, D_IN, D_OUT, KInput, KOutput> for T
where
B: Backend,
KInput: TensorKind<B>,
KOutput: TensorKind<B>,
T: ScalarEnvironment<B, D_IN, D_OUT, KInput, KOutput> + Clone + Send + Sync + 'static,
{
fn clone_box(&self) -> Box<dyn DynScalarEnvironment<B, D_IN, D_OUT, KInput, KOutput>> {
Box::new(self.clone())
}
}
impl<B, const D_IN: usize, const D_OUT: usize, KInput, KOutput> Clone
for Box<dyn DynScalarEnvironment<B, D_IN, D_OUT, KInput, KOutput>>
where
B: Backend,
KInput: TensorKind<B>,
KOutput: TensorKind<B>,
{
fn clone(&self) -> Self {
self.clone_box()
}
}
pub enum EnvironmentHandle<
B: Backend,
const D_IN: usize,
const D_OUT: usize,
KInput: TensorKind<B>,
KOutput: TensorKind<B>,
> {
Scalar(Box<dyn DynScalarEnvironment<B, D_IN, D_OUT, KInput, KOutput>>),
Vector(Box<DynVectorEnv<B, D_IN, D_OUT, KInput, KOutput>>),
}
pub trait ScalarEnvironment<
B: Backend,
const D_IN: usize,
const D_OUT: usize,
KindIn: TensorKind<B>,
KindOut: TensorKind<B>,
>: Environment<B, D_IN, D_OUT, KindIn, KindOut> + Send + Sync
{
fn step(
&self,
action: Tensor<B, D_OUT, KindOut>,
) -> Result<ScalarEnvStep<B, D_IN, KindIn>, EnvironmentError>;
fn reset(&self) -> Result<ScalarEnvReset<B, D_IN, KindIn>, EnvironmentError>;
}
pub trait VectorEnvironment<
B: Backend,
const D_IN: usize,
const D_OUT: usize,
KindIn: TensorKind<B>,
KindOut: TensorKind<B>,
>: Environment<B, D_IN, D_OUT, KindIn, KindOut> + Send + Sync
{
fn init_num_envs(&self, num_envs: usize) -> Result<Vec<EnvironmentUuid>, EnvironmentError>;
fn step(
&self,
actions: &[(EnvironmentUuid, Tensor<B, D_OUT, KindOut>)],
) -> Result<Vec<VectorEnvStep<B, D_IN, KindIn>>, EnvironmentError>;
fn reset(
&self,
env_ids: &[EnvironmentUuid],
) -> Result<Vec<VectorEnvReset<B, D_IN, KindIn>>, EnvironmentError>;
}
pub trait Environment<
B: Backend,
const D_IN: usize,
const D_OUT: usize,
KindIn: TensorKind<B>,
KindOut: TensorKind<B>,
>: Send + Sync
{
fn run_environment(&self) -> Result<(), EnvironmentError>;
fn build_observation(&self) -> Result<Box<dyn Any>, EnvironmentError>;
fn observation_dtype(&self) -> EnvDType;
fn action_dtype(&self) -> EnvDType;
fn kind(&self) -> EnvironmentKind;
fn into_handle(self: Box<Self>) -> EnvironmentHandle<B, D_IN, D_OUT, KindIn, KindOut>;
}
pub trait TrainingPerformanceReturnFn {
fn calculate_performance_return(&self) -> Result<Box<dyn Any>, EnvironmentError>;
}
}
pub use traits::*;