pub mod algorithms;
pub mod logging;
pub mod templates;
use burn_tensor::TensorKind;
use burn_tensor::backend::Backend;
use relayrl_types::prelude::tensor::relayrl::BackendMatcher;
use std::path::PathBuf;
pub use algorithms::PPO::{
IPPOAlgorithm, IPPOParams, MAPPOAlgorithm, MAPPOParams, PPOAlgorithm, PPOKernelTrait, PPOParams,
};
pub use algorithms::REINFORCE::{
IREINFORCEAlgorithm, IREINFORCEParams, MAREINFORCEAlgorithm, MAREINFORCEParams,
REINFORCEParams, ReinforceAlgorithm,
};
pub use templates::base_algorithm::{
AlgorithmError, AlgorithmTrait, StepKernelTrait, TrainableKernelTrait, TrajectoryData,
WeightProvider,
};
#[derive(Clone, Debug)]
pub struct TrainerArgs {
pub env_dir: PathBuf,
pub save_model_path: PathBuf,
pub obs_dim: usize,
pub act_dim: usize,
pub buffer_size: usize,
}
pub enum PpoTrainerSpec {
PPO {
args: TrainerArgs,
hyperparams: Option<PPOParams>,
},
IPPO {
args: TrainerArgs,
hyperparams: Option<IPPOParams>,
},
}
impl PpoTrainerSpec {
pub fn ppo(args: TrainerArgs, hyperparams: Option<PPOParams>) -> Self {
Self::PPO { args, hyperparams }
}
pub fn ippo(args: TrainerArgs, hyperparams: Option<IPPOParams>) -> Self {
Self::IPPO { args, hyperparams }
}
}
pub enum ReinforceTrainerSpec {
REINFORCE {
args: TrainerArgs,
hyperparams: Option<REINFORCEParams>,
},
IREINFORCE {
args: TrainerArgs,
hyperparams: Option<IREINFORCEParams>,
},
}
impl ReinforceTrainerSpec {
pub fn reinforce(args: TrainerArgs, hyperparams: Option<REINFORCEParams>) -> Self {
Self::REINFORCE { args, hyperparams }
}
pub fn ireinforce(args: TrainerArgs, hyperparams: Option<IREINFORCEParams>) -> Self {
Self::IREINFORCE { args, hyperparams }
}
}
pub enum MultiagentTrainerSpec {
MAPPO {
args: TrainerArgs,
hyperparams: Option<MAPPOParams>,
},
MAREINFORCE {
args: TrainerArgs,
hyperparams: Option<MAREINFORCEParams>,
},
}
impl MultiagentTrainerSpec {
pub fn mappo(args: TrainerArgs, hyperparams: Option<MAPPOParams>) -> Self {
Self::MAPPO { args, hyperparams }
}
pub fn mareinforce(args: TrainerArgs, hyperparams: Option<MAREINFORCEParams>) -> Self {
Self::MAREINFORCE { args, hyperparams }
}
}
#[allow(clippy::large_enum_variant)]
pub enum PpoTrainer<
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK>,
> {
PPO(PPOAlgorithm<B, InK, OutK, K>),
IPPO(IPPOAlgorithm<B, InK, OutK, K>),
}
impl<B, InK, OutK, K> PpoTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: PPOKernelTrait<B, InK, OutK> + Default,
{
pub fn new(spec: PpoTrainerSpec, kernel: K) -> Result<Self, AlgorithmError> {
let trainer = match spec {
PpoTrainerSpec::PPO { args, hyperparams } => {
Self::PPO(PPOAlgorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?)
}
PpoTrainerSpec::IPPO { args, hyperparams } => {
Self::IPPO(IPPOAlgorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?)
}
};
Ok(trainer)
}
pub fn ppo(
args: TrainerArgs,
hyperparams: Option<PPOParams>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(PpoTrainerSpec::ppo(args, hyperparams), kernel)
}
pub fn ippo(
args: TrainerArgs,
hyperparams: Option<IPPOParams>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(PpoTrainerSpec::ippo(args, hyperparams), kernel)
}
pub fn reset_epoch(&mut self) {
match self {
Self::PPO(algorithm) => algorithm.reset_epoch(),
Self::IPPO(algorithm) => algorithm.reset_epoch(),
}
}
}
#[cfg(feature = "ndarray-backend")]
impl<B, InK, OutK, K> PpoTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: PPOKernelTrait<B, InK, OutK> + WeightProvider + Default,
{
pub fn acquire_model_module(&self) -> Option<relayrl_types::model::ModelModule<B>> {
match self {
Self::PPO(algorithm) => algorithm.acquire_model_module(),
Self::IPPO(algorithm) => algorithm.acquire_model_module(),
}
}
}
#[allow(clippy::large_enum_variant)]
pub enum ReinforceTrainer<
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK>,
> {
REINFORCE(ReinforceAlgorithm<B, InK, OutK, K>),
IREINFORCE(IREINFORCEAlgorithm<B, InK, OutK, K>),
}
impl<B, InK, OutK, K> ReinforceTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK> + TrainableKernelTrait + Default,
{
pub fn new(spec: ReinforceTrainerSpec, kernel: K) -> Result<Self, AlgorithmError> {
let trainer = match spec {
ReinforceTrainerSpec::REINFORCE { args, hyperparams } => {
Self::REINFORCE(ReinforceAlgorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?)
}
ReinforceTrainerSpec::IREINFORCE { args, hyperparams } => {
Self::IREINFORCE(IREINFORCEAlgorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?)
}
};
Ok(trainer)
}
pub fn reinforce(
args: TrainerArgs,
hyperparams: Option<REINFORCEParams>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(ReinforceTrainerSpec::reinforce(args, hyperparams), kernel)
}
pub fn ireinforce(
args: TrainerArgs,
hyperparams: Option<IREINFORCEParams>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(ReinforceTrainerSpec::ireinforce(args, hyperparams), kernel)
}
pub fn reset_epoch(&mut self) {}
pub fn acquire_model_module(&self) -> Option<relayrl_types::model::ModelModule<B>>
where
B: BackendMatcher<Backend = B>,
{
None
}
}
#[allow(clippy::large_enum_variant)]
pub enum MultiagentTrainer<B: Backend + BackendMatcher, InK: TensorKind<B>, OutK: TensorKind<B>> {
MAPPO {
trainer: MAPPOAlgorithm<B, InK, OutK>,
},
MAREINFORCE {
trainer: MAREINFORCEAlgorithm<B, InK, OutK>,
},
}
impl<B, InK, OutK> MultiagentTrainer<B, InK, OutK>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
{
pub fn new(spec: MultiagentTrainerSpec) -> Result<Self, AlgorithmError> {
let trainer = match spec {
MultiagentTrainerSpec::MAPPO { args, hyperparams } => Self::MAPPO {
trainer: MAPPOAlgorithm::<B, InK, OutK>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
)?,
},
MultiagentTrainerSpec::MAREINFORCE { args, hyperparams } => Self::MAREINFORCE {
trainer: MAREINFORCEAlgorithm::<B, InK, OutK>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
)?,
},
};
Ok(trainer)
}
pub fn mappo(
args: TrainerArgs,
hyperparams: Option<MAPPOParams>,
) -> Result<Self, AlgorithmError> {
Self::new(MultiagentTrainerSpec::mappo(args, hyperparams))
}
pub fn mareinforce(
args: TrainerArgs,
hyperparams: Option<MAREINFORCEParams>,
) -> Result<Self, AlgorithmError> {
Self::new(MultiagentTrainerSpec::mareinforce(args, hyperparams))
}
pub fn reset_epoch(&mut self) {}
}
#[cfg(feature = "ndarray-backend")]
impl<B, InK, OutK> MultiagentTrainer<B, InK, OutK>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
{
pub fn acquire_model_module(&self) -> Option<relayrl_types::model::ModelModule<B>> {
match self {
Self::MAPPO { trainer } => trainer.acquire_model_module(),
Self::MAREINFORCE { .. } => None,
}
}
}
pub struct RelayRLTrainer;
impl RelayRLTrainer {
pub fn ppo<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<PPOParams>,
kernel: K,
) -> Result<PpoTrainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: PPOKernelTrait<B, InK, OutK> + Default,
{
PpoTrainer::<B, InK, OutK, K>::ppo(args, hyperparams, kernel)
}
pub fn ippo<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<IPPOParams>,
kernel: K,
) -> Result<PpoTrainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: PPOKernelTrait<B, InK, OutK> + Default,
{
PpoTrainer::<B, InK, OutK, K>::ippo(args, hyperparams, kernel)
}
pub fn reinforce<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<REINFORCEParams>,
kernel: K,
) -> Result<ReinforceTrainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK> + TrainableKernelTrait + Default,
{
ReinforceTrainer::<B, InK, OutK, K>::reinforce(args, hyperparams, kernel)
}
pub fn ireinforce<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<IREINFORCEParams>,
kernel: K,
) -> Result<ReinforceTrainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK> + TrainableKernelTrait + Default,
{
ReinforceTrainer::<B, InK, OutK, K>::ireinforce(args, hyperparams, kernel)
}
pub fn mappo<B, InK, OutK>(
args: TrainerArgs,
hyperparams: Option<MAPPOParams>,
) -> Result<MultiagentTrainer<B, InK, OutK>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
{
MultiagentTrainer::<B, InK, OutK>::mappo(args, hyperparams)
}
pub fn mareinforce<B, InK, OutK>(
args: TrainerArgs,
hyperparams: Option<MAREINFORCEParams>,
) -> Result<MultiagentTrainer<B, InK, OutK>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
{
MultiagentTrainer::<B, InK, OutK>::mareinforce(args, hyperparams)
}
}
impl<B, InK, OutK, K, T> AlgorithmTrait<T> for PpoTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: PPOKernelTrait<B, InK, OutK> + Default,
T: TrajectoryData,
{
fn save(&self, filename: &str) {
match self {
Self::PPO(algorithm) => AlgorithmTrait::<T>::save(algorithm, filename),
Self::IPPO(algorithm) => AlgorithmTrait::<T>::save(algorithm, filename),
}
}
async fn receive_trajectory(&mut self, trajectory: T) -> Result<bool, AlgorithmError> {
match self {
Self::PPO(algorithm) => {
AlgorithmTrait::<T>::receive_trajectory(algorithm, trajectory).await
}
Self::IPPO(algorithm) => {
AlgorithmTrait::<T>::receive_trajectory(algorithm, trajectory).await
}
}
}
fn train_model(&mut self) {
match self {
Self::PPO(algorithm) => AlgorithmTrait::<T>::train_model(algorithm),
Self::IPPO(algorithm) => AlgorithmTrait::<T>::train_model(algorithm),
}
}
fn log_epoch(&mut self) {
match self {
Self::PPO(algorithm) => AlgorithmTrait::<T>::log_epoch(algorithm),
Self::IPPO(algorithm) => AlgorithmTrait::<T>::log_epoch(algorithm),
}
}
}
impl<B, InK, OutK, K, T> AlgorithmTrait<T> for ReinforceTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK> + TrainableKernelTrait + Default,
T: TrajectoryData,
{
fn save(&self, filename: &str) {
match self {
Self::REINFORCE(algorithm) => AlgorithmTrait::<T>::save(algorithm, filename),
Self::IREINFORCE(algorithm) => AlgorithmTrait::<T>::save(algorithm, filename),
}
}
async fn receive_trajectory(&mut self, trajectory: T) -> Result<bool, AlgorithmError> {
match self {
Self::REINFORCE(algorithm) => {
AlgorithmTrait::<T>::receive_trajectory(algorithm, trajectory).await
}
Self::IREINFORCE(algorithm) => {
AlgorithmTrait::<T>::receive_trajectory(algorithm, trajectory).await
}
}
}
fn train_model(&mut self) {
match self {
Self::REINFORCE(algorithm) => AlgorithmTrait::<T>::train_model(algorithm),
Self::IREINFORCE(algorithm) => AlgorithmTrait::<T>::train_model(algorithm),
}
}
fn log_epoch(&mut self) {
match self {
Self::REINFORCE(algorithm) => AlgorithmTrait::<T>::log_epoch(algorithm),
Self::IREINFORCE(algorithm) => AlgorithmTrait::<T>::log_epoch(algorithm),
}
}
}
impl<B, InK, OutK, T> AlgorithmTrait<T> for MultiagentTrainer<B, InK, OutK>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
T: TrajectoryData,
{
fn save(&self, filename: &str) {
match self {
Self::MAPPO { trainer } => AlgorithmTrait::<T>::save(trainer, filename),
Self::MAREINFORCE { trainer } => AlgorithmTrait::<T>::save(trainer, filename),
}
}
async fn receive_trajectory(&mut self, trajectory: T) -> Result<bool, AlgorithmError> {
match self {
Self::MAPPO { trainer } => {
AlgorithmTrait::<T>::receive_trajectory(trainer, trajectory).await
}
Self::MAREINFORCE { trainer } => {
AlgorithmTrait::<T>::receive_trajectory(trainer, trajectory).await
}
}
}
fn train_model(&mut self) {
match self {
Self::MAPPO { trainer } => AlgorithmTrait::<T>::train_model(trainer),
Self::MAREINFORCE { trainer } => AlgorithmTrait::<T>::train_model(trainer),
}
}
fn log_epoch(&mut self) {
match self {
Self::MAPPO { trainer } => AlgorithmTrait::<T>::log_epoch(trainer),
Self::MAREINFORCE { trainer } => AlgorithmTrait::<T>::log_epoch(trainer),
}
}
}