relayrl_algorithms 0.3.0

A collection of Multi-Agent Deep Reinforcement Learning Algorithms (IPPO, MAPPO, etc.)
Documentation
use crate::templates::base_replay_buffer::{Batch, GenericReplayBuffer, ReplayBufferError};

use async_trait::async_trait;
use relayrl_types::prelude::trajectory::RelayRLTrajectory;

use std::any::Any;

type SharedReplayBuffer = super::super::replay_buffer::DDPGReplayBuffer;

pub struct MultiagentDDPGReplayBuffer {
    inner: SharedReplayBuffer,
}

pub type MADDPGReplayBuffer = MultiagentDDPGReplayBuffer;

impl Default for MultiagentDDPGReplayBuffer {
    fn default() -> Self {
        Self::new(1_000_000, 128)
    }
}

impl MultiagentDDPGReplayBuffer {
    pub fn new(buffer_size: usize, batch_size: usize) -> Self {
        Self {
            inner: SharedReplayBuffer::new(buffer_size, batch_size),
        }
    }
}

#[async_trait]
impl GenericReplayBuffer for MultiagentDDPGReplayBuffer {
    async fn insert_trajectory(
        &self,
        trajectory: RelayRLTrajectory,
    ) -> Result<Box<dyn Any>, ReplayBufferError> {
        self.inner.insert_trajectory(trajectory).await
    }

    async fn sample_buffer(&self) -> Result<Batch, ReplayBufferError> {
        self.inner.sample_buffer().await
    }
}