use burn_tensor::backend::Backend;
use burn_tensor::{Float, Int, TensorKind};
use relayrl_types::prelude::records::{ArrowTrajectory, CsvTrajectory};
use relayrl_types::prelude::tensor::burn::Tensor;
use relayrl_types::prelude::tensor::relayrl::{BackendMatcher, TensorData, TensorError};
use relayrl_types::prelude::trajectory::RelayRLTrajectory;
use std::collections::HashMap;
use thiserror::Error;
#[derive(Clone, Debug, Error)]
pub enum AlgorithmError {
#[error("Initialization failed: {0}")]
InitializationError(String),
#[error("Insertion of trajectory failed: {0}")]
TrajectoryInsertionError(String),
#[error("Buffer sampling failed: {0}")]
BufferSamplingError(String),
}
#[allow(clippy::large_enum_variant)]
pub enum TrajectoryType {
RelayRL(RelayRLTrajectory),
Csv(CsvTrajectory),
Arrow(ArrowTrajectory),
}
pub trait TrajectoryData {
fn into_relayrl(self) -> Option<RelayRLTrajectory>;
}
impl TrajectoryData for RelayRLTrajectory {
fn into_relayrl(self) -> Option<RelayRLTrajectory> {
Some(self)
}
}
impl TrajectoryData for CsvTrajectory {
fn into_relayrl(self) -> Option<RelayRLTrajectory> {
self.trajectory
}
}
impl TrajectoryData for ArrowTrajectory {
fn into_relayrl(self) -> Option<RelayRLTrajectory> {
self.trajectory
}
}
pub trait AlgorithmTrait<T: TrajectoryData> {
fn save(&self, filename: &str);
#[allow(async_fn_in_trait)]
async fn receive_trajectory(&mut self, trajectory: T) -> Result<bool, AlgorithmError>;
fn train_model(&mut self);
fn log_epoch(&mut self);
}
pub enum ForwardOutput<B: Backend + BackendMatcher, const OUT_D: usize> {
Discrete {
probs: Tensor<B, OUT_D, Float>,
logits: Tensor<B, OUT_D, Float>,
logp_a: Option<Tensor<B, OUT_D, Float>>,
},
Continuous {
mean: Tensor<B, OUT_D, Float>,
std: Tensor<B, 2, Float>,
logp_a: Option<Tensor<B, OUT_D, Float>>,
},
}
pub enum StepAction<B: Backend + BackendMatcher> {
Discrete(Tensor<B, 2, Int>),
Continuous(Tensor<B, 2, Float>),
}
pub trait ForwardKernelTrait<B: Backend + BackendMatcher, InK: TensorKind<B>, OutK: TensorKind<B>> {
fn forward<const IN_D: usize, const OUT_D: usize>(
&self,
obs: Tensor<B, IN_D, InK>,
mask: Tensor<B, OUT_D, OutK>,
act: Option<Tensor<B, OUT_D, OutK>>,
) -> ForwardOutput<B, OUT_D>;
}
pub trait StepKernelTrait<B: Backend + BackendMatcher, InK: TensorKind<B>, OutK: TensorKind<B>> {
fn step<const IN_D: usize, const OUT_D: usize>(
&self,
obs: Tensor<B, IN_D, InK>,
mask: Tensor<B, OUT_D, OutK>,
) -> Result<(StepAction<B>, HashMap<String, TensorData>), TensorError>;
fn get_input_dim(&self) -> usize;
fn get_output_dim(&self) -> usize;
}
pub trait WeightProvider {
fn get_pi_layer_specs(&self) -> Option<Vec<(usize, usize, Vec<f32>, Vec<f32>)>>;
}
pub trait TrainableKernelTrait {
fn train_pi_step(
&mut self,
obs: &[TensorData],
act: &[TensorData],
mask: &[TensorData],
adv: &[f32],
logp_old: &[TensorData],
) -> (f32, HashMap<String, f32>);
fn train_vf_step(&mut self, obs: &[TensorData], mask: &[TensorData], ret: &[f32]) -> f32;
}