relayrl_algorithms 0.3.0

A collection of Multi-Agent Deep Reinforcement Learning Algorithms (IPPO, MAPPO, etc.)
Documentation
use crate::templates::base_replay_buffer::{
    Batch, BatchKey, BufferSample, GenericReplayBuffer, ReplayBufferError, SampleScalars,
};
use async_trait::async_trait;
use relayrl_types::prelude::tensor::relayrl::TensorData;
use relayrl_types::prelude::trajectory::RelayRLTrajectory;
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;

struct Buffers {
    observations: Vec<Option<TensorData>>,
    actions: Vec<Option<TensorData>>,
    next_observations: Vec<Option<TensorData>>,
    rewards: Vec<f32>,
    dones: Vec<f32>,
    pointer: usize,
    current_size: usize,
}

struct BufferMetadata {
    buffer_size: usize,
    batch_size: usize,
}

pub struct DDPGReplayBuffer {
    buffers: Arc<Mutex<Buffers>>,
    metadata: Arc<BufferMetadata>,
}

impl Default for DDPGReplayBuffer {
    fn default() -> Self {
        Self::new(1_000_000, 128)
    }
}

impl DDPGReplayBuffer {
    pub fn new(buffer_size: usize, batch_size: usize) -> Self {
        let capacity = buffer_size;
        let buffers = Buffers {
            observations: Vec::with_capacity(capacity),
            actions: Vec::with_capacity(capacity),
            next_observations: Vec::with_capacity(capacity),
            rewards: Vec::with_capacity(capacity),
            dones: Vec::with_capacity(capacity),
            pointer: 0,
            current_size: 0,
        };
        Self {
            buffers: Arc::new(Mutex::new(buffers)),
            metadata: Arc::new(BufferMetadata {
                buffer_size,
                batch_size,
            }),
        }
    }

    pub fn batch_size(&self) -> usize {
        self.metadata.batch_size
    }

    pub fn current_size(&self) -> usize {
        // Approximate; accurate value requires lock
        0 // actual check is done inside sample_buffer
    }
}

#[async_trait]
impl GenericReplayBuffer for DDPGReplayBuffer {
    async fn insert_trajectory(
        &self,
        trajectory: RelayRLTrajectory,
    ) -> Result<Box<dyn Any>, ReplayBufferError> {
        let mut buffers = self.buffers.lock().await;
        let capacity = self.metadata.buffer_size;
        let actions = &trajectory.actions;

        let mut episode_return = 0.0f32;
        let mut episode_length = 0i32;

        for (i, action) in actions.iter().enumerate() {
            episode_length += 1;
            let rew = action.get_rew();
            episode_return += rew;

            let obs = action.get_obs().cloned();
            let act = action.get_act().cloned();
            let done = if action.get_done() { 1.0f32 } else { 0.0f32 };

            // next_obs: use the next step's obs, or current obs for terminal transitions
            let next_obs = if action.get_done() || i + 1 >= actions.len() {
                action.get_obs().cloned()
            } else {
                actions[i + 1].get_obs().cloned()
            };

            let ptr = buffers.pointer;

            if ptr < buffers.observations.len() {
                buffers.observations[ptr] = obs;
                buffers.actions[ptr] = act;
                buffers.next_observations[ptr] = next_obs;
                buffers.rewards[ptr] = rew;
                buffers.dones[ptr] = done;
            } else {
                buffers.observations.push(obs);
                buffers.actions.push(act);
                buffers.next_observations.push(next_obs);
                buffers.rewards.push(rew);
                buffers.dones.push(done);
            }

            buffers.pointer = (ptr + 1) % capacity;
            buffers.current_size = (buffers.current_size + 1).min(capacity);
        }

        Ok(Box::new((episode_return, episode_length)))
    }

    /// Randomly samples `batch_size` transitions. Does NOT clear the buffer.
    async fn sample_buffer(&self) -> Result<Batch, ReplayBufferError> {
        let buffers = self.buffers.lock().await;
        let current_size = buffers.current_size;
        let batch_size = self.metadata.batch_size;

        if current_size < batch_size {
            return Err(ReplayBufferError::BufferSamplingError(format!(
                "DDPG replay buffer has {current_size} transitions, need {batch_size}"
            )));
        }

        use rand::seq::SliceRandom;
        let mut rng = rand::rng();
        let mut indices: Vec<usize> = (0..current_size).collect();
        indices.shuffle(&mut rng);
        indices.truncate(batch_size);

        let obs: Vec<TensorData> = indices
            .iter()
            .filter_map(|&i| buffers.observations[i].clone())
            .collect();
        let act: Vec<TensorData> = indices
            .iter()
            .filter_map(|&i| buffers.actions[i].clone())
            .collect();
        let next_obs: Vec<TensorData> = indices
            .iter()
            .filter_map(|&i| buffers.next_observations[i].clone())
            .collect();
        let rew: Vec<f32> = indices.iter().map(|&i| buffers.rewards[i]).collect();
        let done: Vec<f32> = indices.iter().map(|&i| buffers.dones[i]).collect();

        let mut batch: HashMap<BatchKey, BufferSample> = HashMap::new();
        batch.insert(BatchKey::Obs, BufferSample::Tensors(obs.into_boxed_slice()));
        batch.insert(BatchKey::Act, BufferSample::Tensors(act.into_boxed_slice()));
        batch.insert(
            BatchKey::Custom("NextObs".to_string()),
            BufferSample::Tensors(next_obs.into_boxed_slice()),
        );
        batch.insert(
            BatchKey::Custom("Rew".to_string()),
            BufferSample::Scalars(SampleScalars::F32(rew.into_boxed_slice())),
        );
        batch.insert(
            BatchKey::Custom("Done".to_string()),
            BufferSample::Scalars(SampleScalars::F32(done.into_boxed_slice())),
        );

        Ok(batch)
    }
}

#[cfg(test)]
mod tests {
    use super::DDPGReplayBuffer;
    use crate::templates::base_replay_buffer::GenericReplayBuffer;

    #[test]
    fn buffer_defaults_are_sane() {
        let buf = DDPGReplayBuffer::new(1000, 64);
        assert_eq!(buf.batch_size(), 64);
    }

    #[tokio::test]
    async fn sample_fails_when_underfull() {
        let buf = DDPGReplayBuffer::new(1000, 64);
        let result = buf.sample_buffer().await;
        assert!(result.is_err());
    }
}