pub mod algorithms;
pub mod logging;
pub mod templates;
use burn_tensor::TensorKind;
use burn_tensor::backend::Backend;
use relayrl_types::prelude::tensor::relayrl::BackendMatcher;
use std::path::PathBuf;
pub use algorithms::DDPG::{
DDPGAlgorithm, DDPGKernelTrait, DDPGParams, IDDPGAlgorithm, IDDPGParams, MADDPGAlgorithm,
MADDPGParams, MultiagentDDPGKernelTrait,
};
pub use algorithms::PPO::{
IPPOAlgorithm, IPPOParams, MAPPOAlgorithm, MAPPOParams, MultiagentPPOKernelTrait, PPOAlgorithm,
PPOKernelTrait, PPOParams,
};
pub use algorithms::REINFORCE::{
IREINFORCEAlgorithm, IREINFORCEParams, MAREINFORCEAlgorithm, MAREINFORCEParams,
MultiagentReinforceKernelTrait, REINFORCEKernelTrait, REINFORCEParams, ReinforceAlgorithm,
};
pub use algorithms::TD3::{
ITD3Algorithm, ITD3Params, MATD3Algorithm, MATD3Params, MultiagentTD3KernelTrait, TD3Algorithm,
TD3KernelTrait, TD3Params,
};
pub use templates::base_algorithm::{
AlgorithmError, AlgorithmTrait, MultiagentKernelTrait, StepKernelTrait, TrajectoryData,
WeightProvider,
};
#[derive(Clone, Debug)]
pub struct TrainerArgs {
pub env_dir: PathBuf,
pub save_model_path: PathBuf,
pub obs_dim: usize,
pub act_dim: usize,
pub buffer_size: usize,
}
#[cfg(all(
any(feature = "tch-model", feature = "onnx-model"),
any(feature = "ndarray-backend", feature = "tch-backend")
))]
pub fn acquire_model_module<B: Backend + BackendMatcher<Backend = B>>(
model_name: &str,
layer_specs: Vec<(usize, usize, Vec<f32>, Vec<f32>)>,
input_dtype: relayrl_types::data::tensor::DType,
output_dtype: relayrl_types::data::tensor::DType,
input_shape: Vec<usize>,
output_shape: Vec<usize>,
device: Option<relayrl_types::data::tensor::DeviceType>,
) -> Option<relayrl_types::model::ModelModule<B>> {
use relayrl_types::data::tensor::SupportedTensorBackend;
use relayrl_types::model::{ModelFileType, ModelMetadata, ModelModule};
if layer_specs.is_empty() {
return None;
}
match B::get_supported_backend() {
#[cfg(all(feature = "ndarray-backend", feature = "onnx-model"))]
SupportedTensorBackend::NdArray => {
use crate::algorithms::onnx_builder::build_onnx_mlp_bytes;
let onnx_bytes = build_onnx_mlp_bytes(&layer_specs);
if onnx_bytes.is_empty() {
return None;
}
let model_file = format!("{}.onnx", model_name);
let metadata = ModelMetadata {
model_file,
model_type: ModelFileType::Onnx,
input_dtype,
output_dtype,
input_shape,
output_shape,
default_device: device,
};
ModelModule::from_onnx_bytes(onnx_bytes, metadata).ok()
}
#[cfg(all(feature = "tch-backend", feature = "tch-model"))]
SupportedTensorBackend::Tch => {
use crate::algorithms::pt_builder::build_pt_mlp_temp;
let (pt_bytes, _temp_path) = build_pt_mlp_temp(&layer_specs).ok()?;
if pt_bytes.is_empty() {
return None;
}
let model_file = format!("{}.pt", model_name);
let metadata = ModelMetadata {
model_file,
model_type: ModelFileType::Pt,
input_dtype,
output_dtype,
input_shape,
output_shape,
default_device: device,
};
ModelModule::from_pt_bytes(pt_bytes, metadata).ok()
}
_ => None,
}
}
pub enum PpoTrainerSpec {
PPO {
args: TrainerArgs,
hyperparams: Option<PPOParams>,
},
IPPO {
args: TrainerArgs,
hyperparams: Option<IPPOParams>,
},
}
impl PpoTrainerSpec {
pub fn ppo(args: TrainerArgs, hyperparams: Option<PPOParams>) -> Self {
Self::PPO { args, hyperparams }
}
pub fn ippo(args: TrainerArgs, hyperparams: Option<IPPOParams>) -> Self {
Self::IPPO { args, hyperparams }
}
}
pub enum ReinforceTrainerSpec {
REINFORCE {
args: TrainerArgs,
hyperparams: Option<REINFORCEParams>,
},
IREINFORCE {
args: TrainerArgs,
hyperparams: Option<IREINFORCEParams>,
},
}
impl ReinforceTrainerSpec {
pub fn reinforce(args: TrainerArgs, hyperparams: Option<REINFORCEParams>) -> Self {
Self::REINFORCE { args, hyperparams }
}
pub fn ireinforce(args: TrainerArgs, hyperparams: Option<IREINFORCEParams>) -> Self {
Self::IREINFORCE { args, hyperparams }
}
}
pub enum MultiagentTrainerSpec<K> {
MADDPG {
args: TrainerArgs,
hyperparams: Option<MADDPGParams>,
kernel: K,
},
MAPPO {
args: TrainerArgs,
hyperparams: Option<MAPPOParams>,
kernel: K,
},
MAREINFORCE {
args: TrainerArgs,
hyperparams: Option<MAREINFORCEParams>,
kernel: K,
},
MATD3 {
args: TrainerArgs,
hyperparams: Option<MATD3Params>,
kernel: K,
},
}
impl<K> MultiagentTrainerSpec<K> {
pub fn maddpg(args: TrainerArgs, hyperparams: Option<MADDPGParams>, kernel: K) -> Self {
Self::MADDPG {
args,
hyperparams,
kernel,
}
}
pub fn mappo(args: TrainerArgs, hyperparams: Option<MAPPOParams>, kernel: K) -> Self {
Self::MAPPO {
args,
hyperparams,
kernel,
}
}
pub fn mareinforce(
args: TrainerArgs,
hyperparams: Option<MAREINFORCEParams>,
kernel: K,
) -> Self {
Self::MAREINFORCE {
args,
hyperparams,
kernel,
}
}
pub fn matd3(args: TrainerArgs, hyperparams: Option<MATD3Params>, kernel: K) -> Self {
Self::MATD3 {
args,
hyperparams,
kernel,
}
}
}
#[allow(clippy::large_enum_variant)]
pub enum PpoTrainer<
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK>,
> {
PPO(PPOAlgorithm<B, InK, OutK, K>),
IPPO(IPPOAlgorithm<B, InK, OutK, K>),
}
impl<B, InK, OutK, K> PpoTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: PPOKernelTrait<B, InK, OutK> + Default,
{
pub fn new(spec: PpoTrainerSpec, kernel: K) -> Result<Self, AlgorithmError> {
let trainer = match spec {
PpoTrainerSpec::PPO { args, hyperparams } => {
Self::PPO(PPOAlgorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?)
}
PpoTrainerSpec::IPPO { args, hyperparams } => {
Self::IPPO(IPPOAlgorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?)
}
};
Ok(trainer)
}
pub fn ppo(
args: TrainerArgs,
hyperparams: Option<PPOParams>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(PpoTrainerSpec::ppo(args, hyperparams), kernel)
}
pub fn ippo(
args: TrainerArgs,
hyperparams: Option<IPPOParams>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(PpoTrainerSpec::ippo(args, hyperparams), kernel)
}
pub fn reset_epoch(&mut self) {
match self {
Self::PPO(algorithm) => algorithm.reset_epoch(),
Self::IPPO(algorithm) => algorithm.reset_epoch(),
}
}
}
#[cfg(feature = "ndarray-backend")]
impl<B, InK, OutK, K> PpoTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: PPOKernelTrait<B, InK, OutK> + WeightProvider + Default,
{
pub fn acquire_model_module(&self) -> Option<relayrl_types::model::ModelModule<B>> {
match self {
Self::PPO(algorithm) => algorithm.acquire_model_module(),
Self::IPPO(algorithm) => algorithm.acquire_model_module(),
}
}
}
#[allow(clippy::large_enum_variant)]
pub enum ReinforceTrainer<
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK>,
> {
REINFORCE(ReinforceAlgorithm<B, InK, OutK, K>),
IREINFORCE(IREINFORCEAlgorithm<B, InK, OutK, K>),
}
impl<B, InK, OutK, K> ReinforceTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK>
+ REINFORCEKernelTrait<B, InK, OutK>
+ WeightProvider
+ Default,
{
pub fn new(spec: ReinforceTrainerSpec, kernel: K) -> Result<Self, AlgorithmError> {
let trainer = match spec {
ReinforceTrainerSpec::REINFORCE { args, hyperparams } => {
Self::REINFORCE(ReinforceAlgorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?)
}
ReinforceTrainerSpec::IREINFORCE { args, hyperparams } => {
Self::IREINFORCE(IREINFORCEAlgorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?)
}
};
Ok(trainer)
}
pub fn reinforce(
args: TrainerArgs,
hyperparams: Option<REINFORCEParams>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(ReinforceTrainerSpec::reinforce(args, hyperparams), kernel)
}
pub fn ireinforce(
args: TrainerArgs,
hyperparams: Option<IREINFORCEParams>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(ReinforceTrainerSpec::ireinforce(args, hyperparams), kernel)
}
pub fn reset_epoch(&mut self) {}
pub fn acquire_model_module(&self) -> Option<relayrl_types::model::ModelModule<B>> {
match self {
Self::REINFORCE(algorithm) => algorithm.acquire_model_module(),
Self::IREINFORCE(algorithm) => algorithm.acquire_model_module(),
}
}
}
#[allow(clippy::large_enum_variant)]
pub enum MultiagentTrainer<
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: MultiagentPPOKernelTrait<B, InK, OutK>
+ MultiagentReinforceKernelTrait<B, InK, OutK>
+ MultiagentDDPGKernelTrait<B, InK, OutK>
+ MultiagentTD3KernelTrait<B, InK, OutK>,
> {
MADDPG {
trainer: MADDPGAlgorithm<B, InK, OutK, K>,
},
MAPPO {
trainer: MAPPOAlgorithm<B, InK, OutK, K>,
},
MAREINFORCE {
trainer: MAREINFORCEAlgorithm<B, InK, OutK, K>,
},
MATD3 {
trainer: MATD3Algorithm<B, InK, OutK, K>,
},
}
impl<B, InK, OutK, K> MultiagentTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: MultiagentPPOKernelTrait<B, InK, OutK>
+ MultiagentReinforceKernelTrait<B, InK, OutK>
+ MultiagentDDPGKernelTrait<B, InK, OutK>
+ MultiagentTD3KernelTrait<B, InK, OutK>
+ Default,
{
pub fn new(spec: MultiagentTrainerSpec<K>) -> Result<Self, AlgorithmError> {
let trainer = match spec {
MultiagentTrainerSpec::MADDPG {
args,
hyperparams,
kernel,
} => Self::MADDPG {
trainer: MADDPGAlgorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?,
},
MultiagentTrainerSpec::MAPPO {
args,
hyperparams,
kernel,
} => Self::MAPPO {
trainer: MAPPOAlgorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?,
},
MultiagentTrainerSpec::MAREINFORCE {
args,
hyperparams,
kernel,
} => Self::MAREINFORCE {
trainer: MAREINFORCEAlgorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?,
},
MultiagentTrainerSpec::MATD3 {
args,
hyperparams,
kernel,
} => Self::MATD3 {
trainer: MATD3Algorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?,
},
};
Ok(trainer)
}
pub fn mappo(
args: TrainerArgs,
hyperparams: Option<MAPPOParams>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(MultiagentTrainerSpec::mappo(args, hyperparams, kernel))
}
pub fn mareinforce(
args: TrainerArgs,
hyperparams: Option<MAREINFORCEParams>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(MultiagentTrainerSpec::mareinforce(
args,
hyperparams,
kernel,
))
}
pub fn maddpg(
args: TrainerArgs,
hyperparams: Option<MADDPGParams>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(MultiagentTrainerSpec::maddpg(args, hyperparams, kernel))
}
pub fn matd3(
args: TrainerArgs,
hyperparams: Option<MATD3Params>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(MultiagentTrainerSpec::matd3(args, hyperparams, kernel))
}
pub fn reset_epoch(&mut self) {}
}
#[cfg(feature = "ndarray-backend")]
impl<B, InK, OutK, K> MultiagentTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: MultiagentPPOKernelTrait<B, InK, OutK>
+ MultiagentReinforceKernelTrait<B, InK, OutK>
+ MultiagentDDPGKernelTrait<B, InK, OutK>
+ MultiagentTD3KernelTrait<B, InK, OutK>
+ WeightProvider
+ Default,
{
pub fn acquire_model_module(&self) -> Option<relayrl_types::model::ModelModule<B>> {
match self {
Self::MADDPG { trainer } => trainer.acquire_model_module(),
Self::MAPPO { trainer } => trainer.acquire_model_module(),
Self::MAREINFORCE { trainer } => trainer.acquire_model_module(),
Self::MATD3 { trainer } => trainer.acquire_model_module(),
}
}
}
pub struct RelayRLTrainer;
impl RelayRLTrainer {
pub fn ppo<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<PPOParams>,
kernel: K,
) -> Result<PpoTrainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: PPOKernelTrait<B, InK, OutK> + Default,
{
PpoTrainer::<B, InK, OutK, K>::ppo(args, hyperparams, kernel)
}
pub fn ippo<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<IPPOParams>,
kernel: K,
) -> Result<PpoTrainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: PPOKernelTrait<B, InK, OutK> + Default,
{
PpoTrainer::<B, InK, OutK, K>::ippo(args, hyperparams, kernel)
}
pub fn reinforce<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<REINFORCEParams>,
kernel: K,
) -> Result<ReinforceTrainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK>
+ REINFORCEKernelTrait<B, InK, OutK>
+ WeightProvider
+ Default,
{
ReinforceTrainer::<B, InK, OutK, K>::reinforce(args, hyperparams, kernel)
}
pub fn ireinforce<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<IREINFORCEParams>,
kernel: K,
) -> Result<ReinforceTrainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK>
+ REINFORCEKernelTrait<B, InK, OutK>
+ WeightProvider
+ Default,
{
ReinforceTrainer::<B, InK, OutK, K>::ireinforce(args, hyperparams, kernel)
}
pub fn mappo<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<MAPPOParams>,
kernel: K,
) -> Result<MultiagentTrainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: MultiagentPPOKernelTrait<B, InK, OutK>
+ MultiagentReinforceKernelTrait<B, InK, OutK>
+ MultiagentDDPGKernelTrait<B, InK, OutK>
+ MultiagentTD3KernelTrait<B, InK, OutK>
+ Default,
{
MultiagentTrainer::<B, InK, OutK, K>::mappo(args, hyperparams, kernel)
}
pub fn mareinforce<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<MAREINFORCEParams>,
kernel: K,
) -> Result<MultiagentTrainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: MultiagentPPOKernelTrait<B, InK, OutK>
+ MultiagentReinforceKernelTrait<B, InK, OutK>
+ MultiagentDDPGKernelTrait<B, InK, OutK>
+ MultiagentTD3KernelTrait<B, InK, OutK>
+ Default,
{
MultiagentTrainer::<B, InK, OutK, K>::mareinforce(args, hyperparams, kernel)
}
}
impl<B, InK, OutK, K, T> AlgorithmTrait<T> for PpoTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: PPOKernelTrait<B, InK, OutK> + WeightProvider + Default,
T: TrajectoryData,
{
fn save(&self, filename: &str) {
match self {
Self::PPO(algorithm) => AlgorithmTrait::<T>::save(algorithm, filename),
Self::IPPO(algorithm) => AlgorithmTrait::<T>::save(algorithm, filename),
}
}
async fn receive_trajectory(&mut self, trajectory: T) -> Result<bool, AlgorithmError> {
match self {
Self::PPO(algorithm) => {
AlgorithmTrait::<T>::receive_trajectory(algorithm, trajectory).await
}
Self::IPPO(algorithm) => {
AlgorithmTrait::<T>::receive_trajectory(algorithm, trajectory).await
}
}
}
fn train_model(&mut self) {
match self {
Self::PPO(algorithm) => AlgorithmTrait::<T>::train_model(algorithm),
Self::IPPO(algorithm) => AlgorithmTrait::<T>::train_model(algorithm),
}
}
fn log_epoch(&mut self) {
match self {
Self::PPO(algorithm) => AlgorithmTrait::<T>::log_epoch(algorithm),
Self::IPPO(algorithm) => AlgorithmTrait::<T>::log_epoch(algorithm),
}
}
#[cfg(all(
any(feature = "tch-model", feature = "onnx-model"),
any(feature = "ndarray-backend", feature = "tch-backend")
))]
fn acquire_model<B2: Backend + BackendMatcher<Backend = B2>>(
&self,
) -> Option<relayrl_types::model::ModelModule<B2>>
where
B: 'static,
B2: 'static,
{
match self {
Self::PPO(algorithm) => AlgorithmTrait::<T>::acquire_model::<B2>(algorithm),
Self::IPPO(algorithm) => AlgorithmTrait::<T>::acquire_model::<B2>(algorithm),
}
}
}
impl<B, InK, OutK, K, T> AlgorithmTrait<T> for ReinforceTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK>
+ REINFORCEKernelTrait<B, InK, OutK>
+ WeightProvider
+ Default,
T: TrajectoryData,
{
fn save(&self, filename: &str) {
match self {
Self::REINFORCE(algorithm) => AlgorithmTrait::<T>::save(algorithm, filename),
Self::IREINFORCE(algorithm) => AlgorithmTrait::<T>::save(algorithm, filename),
}
}
async fn receive_trajectory(&mut self, trajectory: T) -> Result<bool, AlgorithmError> {
match self {
Self::REINFORCE(algorithm) => {
AlgorithmTrait::<T>::receive_trajectory(algorithm, trajectory).await
}
Self::IREINFORCE(algorithm) => {
AlgorithmTrait::<T>::receive_trajectory(algorithm, trajectory).await
}
}
}
fn train_model(&mut self) {
match self {
Self::REINFORCE(algorithm) => AlgorithmTrait::<T>::train_model(algorithm),
Self::IREINFORCE(algorithm) => AlgorithmTrait::<T>::train_model(algorithm),
}
}
fn log_epoch(&mut self) {
match self {
Self::REINFORCE(algorithm) => AlgorithmTrait::<T>::log_epoch(algorithm),
Self::IREINFORCE(algorithm) => AlgorithmTrait::<T>::log_epoch(algorithm),
}
}
#[cfg(all(
any(feature = "tch-model", feature = "onnx-model"),
any(feature = "ndarray-backend", feature = "tch-backend")
))]
fn acquire_model<B2: Backend + BackendMatcher<Backend = B2>>(
&self,
) -> Option<relayrl_types::model::ModelModule<B2>>
where
B: 'static,
B2: 'static,
{
match self {
Self::REINFORCE(algorithm) => AlgorithmTrait::<T>::acquire_model::<B2>(algorithm),
Self::IREINFORCE(algorithm) => AlgorithmTrait::<T>::acquire_model::<B2>(algorithm),
}
}
}
impl<B, InK, OutK, K, T> AlgorithmTrait<T> for MultiagentTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: MultiagentPPOKernelTrait<B, InK, OutK>
+ MultiagentReinforceKernelTrait<B, InK, OutK>
+ MultiagentDDPGKernelTrait<B, InK, OutK>
+ MultiagentTD3KernelTrait<B, InK, OutK>
+ WeightProvider
+ Default,
T: TrajectoryData,
{
fn save(&self, filename: &str) {
match self {
Self::MADDPG { trainer } => AlgorithmTrait::<T>::save(trainer, filename),
Self::MAPPO { trainer } => AlgorithmTrait::<T>::save(trainer, filename),
Self::MAREINFORCE { trainer } => AlgorithmTrait::<T>::save(trainer, filename),
Self::MATD3 { trainer } => AlgorithmTrait::<T>::save(trainer, filename),
}
}
async fn receive_trajectory(&mut self, trajectory: T) -> Result<bool, AlgorithmError> {
match self {
Self::MADDPG { trainer } => {
AlgorithmTrait::<T>::receive_trajectory(trainer, trajectory).await
}
Self::MAPPO { trainer } => {
AlgorithmTrait::<T>::receive_trajectory(trainer, trajectory).await
}
Self::MAREINFORCE { trainer } => {
AlgorithmTrait::<T>::receive_trajectory(trainer, trajectory).await
}
Self::MATD3 { trainer } => {
AlgorithmTrait::<T>::receive_trajectory(trainer, trajectory).await
}
}
}
fn train_model(&mut self) {
match self {
Self::MADDPG { trainer } => AlgorithmTrait::<T>::train_model(trainer),
Self::MAPPO { trainer } => AlgorithmTrait::<T>::train_model(trainer),
Self::MAREINFORCE { trainer } => AlgorithmTrait::<T>::train_model(trainer),
Self::MATD3 { trainer } => AlgorithmTrait::<T>::train_model(trainer),
}
}
fn log_epoch(&mut self) {
match self {
Self::MADDPG { trainer } => AlgorithmTrait::<T>::log_epoch(trainer),
Self::MAPPO { trainer } => AlgorithmTrait::<T>::log_epoch(trainer),
Self::MAREINFORCE { trainer } => AlgorithmTrait::<T>::log_epoch(trainer),
Self::MATD3 { trainer } => AlgorithmTrait::<T>::log_epoch(trainer),
}
}
#[cfg(all(
any(feature = "tch-model", feature = "onnx-model"),
any(feature = "ndarray-backend", feature = "tch-backend")
))]
fn acquire_model<B2: Backend + BackendMatcher<Backend = B2>>(
&self,
) -> Option<relayrl_types::model::ModelModule<B2>>
where
B: 'static,
B2: 'static,
{
match self {
Self::MADDPG { trainer } => AlgorithmTrait::<T>::acquire_model::<B2>(trainer),
Self::MAPPO { trainer } => AlgorithmTrait::<T>::acquire_model::<B2>(trainer),
Self::MAREINFORCE { trainer } => AlgorithmTrait::<T>::acquire_model::<B2>(trainer),
Self::MATD3 { trainer } => AlgorithmTrait::<T>::acquire_model::<B2>(trainer),
}
}
}
pub enum DdpgTrainerSpec {
DDPG {
args: TrainerArgs,
hyperparams: Option<DDPGParams>,
},
IDDPG {
args: TrainerArgs,
hyperparams: Option<IDDPGParams>,
},
}
impl DdpgTrainerSpec {
pub fn ddpg(args: TrainerArgs, hyperparams: Option<DDPGParams>) -> Self {
Self::DDPG { args, hyperparams }
}
pub fn iddpg(args: TrainerArgs, hyperparams: Option<IDDPGParams>) -> Self {
Self::IDDPG { args, hyperparams }
}
}
#[allow(clippy::large_enum_variant)]
pub enum DdpgTrainer<
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK>,
> {
DDPG(DDPGAlgorithm<B, InK, OutK, K>),
IDDPG(IDDPGAlgorithm<B, InK, OutK, K>),
}
impl<B, InK, OutK, K> DdpgTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: DDPGKernelTrait<B, InK, OutK> + Default,
{
pub fn new(spec: DdpgTrainerSpec, kernel: K) -> Result<Self, AlgorithmError> {
let trainer = match spec {
DdpgTrainerSpec::DDPG { args, hyperparams } => {
Self::DDPG(DDPGAlgorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?)
}
DdpgTrainerSpec::IDDPG { args, hyperparams } => {
Self::IDDPG(IDDPGAlgorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?)
}
};
Ok(trainer)
}
pub fn ddpg(
args: TrainerArgs,
hyperparams: Option<DDPGParams>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(DdpgTrainerSpec::ddpg(args, hyperparams), kernel)
}
pub fn iddpg(
args: TrainerArgs,
hyperparams: Option<IDDPGParams>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(DdpgTrainerSpec::iddpg(args, hyperparams), kernel)
}
pub fn reset_epoch(&mut self) {
match self {
Self::DDPG(a) => a.reset_epoch(),
Self::IDDPG(a) => a.reset_epoch(),
}
}
}
impl<B, InK, OutK, K, T> AlgorithmTrait<T> for DdpgTrainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: DDPGKernelTrait<B, InK, OutK> + WeightProvider + Default,
T: TrajectoryData,
{
fn save(&self, filename: &str) {
match self {
Self::DDPG(a) => AlgorithmTrait::<T>::save(a, filename),
Self::IDDPG(a) => AlgorithmTrait::<T>::save(a, filename),
}
}
async fn receive_trajectory(&mut self, trajectory: T) -> Result<bool, AlgorithmError> {
match self {
Self::DDPG(a) => AlgorithmTrait::<T>::receive_trajectory(a, trajectory).await,
Self::IDDPG(a) => AlgorithmTrait::<T>::receive_trajectory(a, trajectory).await,
}
}
fn train_model(&mut self) {
match self {
Self::DDPG(a) => AlgorithmTrait::<T>::train_model(a),
Self::IDDPG(a) => AlgorithmTrait::<T>::train_model(a),
}
}
fn log_epoch(&mut self) {
match self {
Self::DDPG(a) => AlgorithmTrait::<T>::log_epoch(a),
Self::IDDPG(a) => AlgorithmTrait::<T>::log_epoch(a),
}
}
#[cfg(all(
any(feature = "tch-model", feature = "onnx-model"),
any(feature = "ndarray-backend", feature = "tch-backend")
))]
fn acquire_model<B2: Backend + BackendMatcher<Backend = B2>>(
&self,
) -> Option<relayrl_types::model::ModelModule<B2>>
where
B: 'static,
B2: 'static,
{
match self {
Self::DDPG(a) => AlgorithmTrait::<T>::acquire_model::<B2>(a),
Self::IDDPG(a) => AlgorithmTrait::<T>::acquire_model::<B2>(a),
}
}
}
pub enum Td3TrainerSpec {
TD3 {
args: TrainerArgs,
hyperparams: Option<TD3Params>,
},
ITD3 {
args: TrainerArgs,
hyperparams: Option<ITD3Params>,
},
}
impl Td3TrainerSpec {
pub fn td3(args: TrainerArgs, hyperparams: Option<TD3Params>) -> Self {
Self::TD3 { args, hyperparams }
}
pub fn itd3(args: TrainerArgs, hyperparams: Option<ITD3Params>) -> Self {
Self::ITD3 { args, hyperparams }
}
}
#[allow(clippy::large_enum_variant)]
pub enum Td3Trainer<
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: StepKernelTrait<B, InK, OutK>,
> {
TD3(TD3Algorithm<B, InK, OutK, K>),
ITD3(ITD3Algorithm<B, InK, OutK, K>),
}
impl<B, InK, OutK, K> Td3Trainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: TD3KernelTrait<B, InK, OutK> + Default,
{
pub fn new(spec: Td3TrainerSpec, kernel: K) -> Result<Self, AlgorithmError> {
let trainer = match spec {
Td3TrainerSpec::TD3 { args, hyperparams } => {
Self::TD3(TD3Algorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?)
}
Td3TrainerSpec::ITD3 { args, hyperparams } => {
Self::ITD3(ITD3Algorithm::<B, InK, OutK, K>::new(
hyperparams,
&args.env_dir,
&args.save_model_path,
args.obs_dim,
args.act_dim,
args.buffer_size,
kernel,
)?)
}
};
Ok(trainer)
}
pub fn td3(
args: TrainerArgs,
hyperparams: Option<TD3Params>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(Td3TrainerSpec::td3(args, hyperparams), kernel)
}
pub fn itd3(
args: TrainerArgs,
hyperparams: Option<ITD3Params>,
kernel: K,
) -> Result<Self, AlgorithmError> {
Self::new(Td3TrainerSpec::itd3(args, hyperparams), kernel)
}
pub fn reset_epoch(&mut self) {
match self {
Self::TD3(a) => a.reset_epoch(),
Self::ITD3(a) => a.reset_epoch(),
}
}
}
impl<B, InK, OutK, K, T> AlgorithmTrait<T> for Td3Trainer<B, InK, OutK, K>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: TD3KernelTrait<B, InK, OutK> + WeightProvider + Default,
T: TrajectoryData,
{
fn save(&self, filename: &str) {
match self {
Self::TD3(a) => AlgorithmTrait::<T>::save(a, filename),
Self::ITD3(a) => AlgorithmTrait::<T>::save(a, filename),
}
}
async fn receive_trajectory(&mut self, trajectory: T) -> Result<bool, AlgorithmError> {
match self {
Self::TD3(a) => AlgorithmTrait::<T>::receive_trajectory(a, trajectory).await,
Self::ITD3(a) => AlgorithmTrait::<T>::receive_trajectory(a, trajectory).await,
}
}
fn train_model(&mut self) {
match self {
Self::TD3(a) => AlgorithmTrait::<T>::train_model(a),
Self::ITD3(a) => AlgorithmTrait::<T>::train_model(a),
}
}
fn log_epoch(&mut self) {
match self {
Self::TD3(a) => AlgorithmTrait::<T>::log_epoch(a),
Self::ITD3(a) => AlgorithmTrait::<T>::log_epoch(a),
}
}
#[cfg(all(
any(feature = "tch-model", feature = "onnx-model"),
any(feature = "ndarray-backend", feature = "tch-backend")
))]
fn acquire_model<B2: Backend + BackendMatcher<Backend = B2>>(
&self,
) -> Option<relayrl_types::model::ModelModule<B2>>
where
B: 'static,
B2: 'static,
{
match self {
Self::TD3(a) => AlgorithmTrait::<T>::acquire_model::<B2>(a),
Self::ITD3(a) => AlgorithmTrait::<T>::acquire_model::<B2>(a),
}
}
}
impl RelayRLTrainer {
pub fn ddpg<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<DDPGParams>,
kernel: K,
) -> Result<DdpgTrainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: DDPGKernelTrait<B, InK, OutK> + Default,
{
DdpgTrainer::<B, InK, OutK, K>::ddpg(args, hyperparams, kernel)
}
pub fn iddpg<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<IDDPGParams>,
kernel: K,
) -> Result<DdpgTrainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: DDPGKernelTrait<B, InK, OutK> + Default,
{
DdpgTrainer::<B, InK, OutK, K>::iddpg(args, hyperparams, kernel)
}
pub fn td3<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<TD3Params>,
kernel: K,
) -> Result<Td3Trainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: TD3KernelTrait<B, InK, OutK> + Default,
{
Td3Trainer::<B, InK, OutK, K>::td3(args, hyperparams, kernel)
}
pub fn itd3<B, InK, OutK, K>(
args: TrainerArgs,
hyperparams: Option<ITD3Params>,
kernel: K,
) -> Result<Td3Trainer<B, InK, OutK, K>, AlgorithmError>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
K: TD3KernelTrait<B, InK, OutK> + Default,
{
Td3Trainer::<B, InK, OutK, K>::itd3(args, hyperparams, kernel)
}
}