relayrl_algorithms 0.3.0

A collection of Multi-Agent Deep Reinforcement Learning Algorithms (IPPO, MAPPO, etc.)
Documentation
//! This module defines a trait that must be implemented by any learning algorithm
//! (such as DQN, PPO, etc.) that is integrated with the RelayRL framework. The trait
//! specifies the required functionality for saving models, receiving trajectories,
//! training the model, and logging training epochs.

use burn_tensor::backend::Backend;
use burn_tensor::{Float, Int, TensorKind};
use relayrl_types::prelude::records::{ArrowTrajectory, CsvTrajectory};
use relayrl_types::prelude::tensor::burn::Tensor;
use relayrl_types::prelude::tensor::relayrl::{BackendMatcher, TensorData, TensorError};
use relayrl_types::prelude::trajectory::RelayRLTrajectory;
use std::collections::HashMap;
use thiserror::Error;

#[derive(Clone, Debug, Error)]
pub enum AlgorithmError {
    #[error("Initialization failed: {0}")]
    InitializationError(String),
    #[error("Insertion of trajectory failed: {0}")]
    TrajectoryInsertionError(String),
    #[error("Buffer sampling failed: {0}")]
    BufferSamplingError(String),
}

#[allow(clippy::large_enum_variant)]
pub enum TrajectoryType {
    RelayRL(RelayRLTrajectory),
    Csv(CsvTrajectory),
    Arrow(ArrowTrajectory),
}

pub trait TrajectoryData {
    fn into_relayrl(self) -> Option<RelayRLTrajectory>;
}

impl TrajectoryData for RelayRLTrajectory {
    fn into_relayrl(self) -> Option<RelayRLTrajectory> {
        Some(self)
    }
}

impl TrajectoryData for CsvTrajectory {
    fn into_relayrl(self) -> Option<RelayRLTrajectory> {
        self.trajectory
    }
}

impl TrajectoryData for ArrowTrajectory {
    fn into_relayrl(self) -> Option<RelayRLTrajectory> {
        self.trajectory
    }
}

/// The `AlgorithmTrait` defines the interface that every algorithm implementation must fulfill.
///
/// # Associated Types
///
/// * `Action`: Represents the type of action that the algorithm produces. This type must implement
///   the [`RelayRLActionTrait`].
///
/// * `Trajectory`: Represents the type of trajectory (a sequence of actions) that the algorithm uses
///   for training. This type must implement [`RelayRLTrajectoryTrait`] with its `Action` type matching `Self::Action`.
///
/// # Required Methods
///
/// * `save(&self, filename: &str)`:
///   Save the current model to the specified file. This allows persistence of model state.
/// * `receive_trajectory(&self, trajectory: Self::Trajectory)`:
///   Process a received trajectory for training. This method is called when new experience data
///   is available.
///
/// * `train_model(&self)`:
///   Trigger the training process of the model. The implementation should update the model based
///   on the accumulated trajectories or experiences.
///
/// * `log_epoch(&self)`:
///   Log the training status or results for the current epoch. This may include metrics such as loss,
///   reward averages, etc.
pub trait AlgorithmTrait<T: TrajectoryData> {
    /// Saves the current model to a file specified by `filename`.
    ///
    /// # Arguments
    ///
    /// * `filename` - The path where the model should be saved.
    fn save(&self, filename: &str);

    /// Receives a trajectory of actions and incorporates it into the training process.
    ///
    /// # Arguments
    ///
    /// * `trajectory` - A trajectory containing a sequence of actions experienced by the agent.
    #[allow(async_fn_in_trait)]
    async fn receive_trajectory(&mut self, trajectory: T) -> Result<bool, AlgorithmError>;

    /// Triggers the training process of the model.
    ///
    /// This function should implement the logic to update the model based on received trajectories.
    fn train_model(&mut self);

    /// Logs the training progress for the current epoch.
    ///
    /// This method can be used to print or store metrics such as loss, accuracy, rewards, etc.
    fn log_epoch(&mut self);

    /// Acquires the trained model as a ModelModule for inference or export.
    ///
    /// Returns `None` if no model has been trained yet, if weight export is not supported,
    /// or if the required feature flags are not enabled.
    ///
    /// # Type Parameters
    ///
    /// * `B` - The Burn backend type (e.g., NdArray or LibTorch)
    #[cfg(all(
        any(feature = "tch-model", feature = "onnx-model"),
        any(feature = "ndarray-backend", feature = "tch-backend")
    ))]
    fn acquire_model<B: Backend + BackendMatcher<Backend = B>>(
        &self,
    ) -> Option<relayrl_types::model::ModelModule<B>>;
}

pub enum ForwardOutput<B: Backend + BackendMatcher, const OUT_D: usize> {
    Discrete {
        probs: Tensor<B, OUT_D, Float>,
        logits: Tensor<B, OUT_D, Float>,
        logp_a: Option<Tensor<B, OUT_D, Float>>,
    },
    Continuous {
        mean: Tensor<B, OUT_D, Float>,
        std: Tensor<B, 2, Float>,
        logp_a: Option<Tensor<B, OUT_D, Float>>,
    },
}

pub enum StepAction<B: Backend + BackendMatcher> {
    Discrete(Tensor<B, 2, Int>),
    Continuous(Tensor<B, 2, Float>),
}

pub trait ForwardKernelTrait<B: Backend + BackendMatcher, InK: TensorKind<B>, OutK: TensorKind<B>> {
    fn forward<const IN_D: usize, const OUT_D: usize>(
        &self,
        obs: Tensor<B, IN_D, InK>,
        mask: Tensor<B, OUT_D, OutK>,
        act: Option<Tensor<B, OUT_D, OutK>>,
    ) -> ForwardOutput<B, OUT_D>;
}

pub trait StepKernelTrait<B: Backend + BackendMatcher, InK: TensorKind<B>, OutK: TensorKind<B>> {
    fn step<const IN_D: usize, const OUT_D: usize>(
        &self,
        obs: Tensor<B, IN_D, InK>,
        mask: Tensor<B, OUT_D, OutK>,
    ) -> Result<(StepAction<B>, HashMap<String, TensorData>), TensorError>;

    fn get_input_dim(&self) -> usize;
    fn get_output_dim(&self) -> usize;
}

/// Trait for extracting per-layer weight specs from a trained policy network.
///
/// Each tuple is `(in_dim, out_dim, flat_weights, flat_biases)` in Burn's row-major
/// `[in, out]` layout, layers ordered input→output. This lets the training side hand
/// weights to the ONNX builder without any filesystem I/O.
pub trait WeightProvider {
    fn get_pi_layer_specs(&self) -> Option<Vec<(usize, usize, Vec<f32>, Vec<f32>)>>;
}

/// Base trait for kernels used in **multi-agent** algorithms.
///
/// Inherits [`StepKernelTrait`] so multi-agent kernels can be used for inference
/// during the environment step loop, just like independent-algorithm kernels.
/// Algorithm-specific training methods (e.g. `train_epoch`) are defined on
/// per-algorithm sub-traits (`MultiagentPPOKernelTrait`, etc.).
pub trait MultiagentKernelTrait<
    B: Backend + BackendMatcher,
    InK: TensorKind<B>,
    OutK: TensorKind<B>,
>: StepKernelTrait<B, InK, OutK>
{
    /// Notify the kernel that a new agent slot has been registered.
    ///
    /// Called once per distinct agent key encountered during training.
    fn register_agent(&mut self);
}