relayrl_algorithms 0.2.0

A collection of Multi-Agent Deep Reinforcement Learning Algorithms (IPPO, MAPPO, etc.)
Documentation
use crate::algorithms::{compute_normed_advantages, discounted_cumsum, scalar_stats};
use crate::templates::base_replay_buffer::{
    Batch, BatchKey, BufferSample, BufferTensors, GenericReplayBuffer, ReplayBufferError,
    SampleScalars,
};
use async_trait::async_trait;
use relayrl_types::prelude::action::RelayRLData;
use relayrl_types::prelude::tensor::relayrl::TensorData;
use relayrl_types::prelude::trajectory::RelayRLTrajectory;
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::Mutex;

struct Buffers {
    observations: BufferTensors,
    actions: BufferTensors,
    masks: BufferTensors,
    rewards: Vec<f32>,
    advantages: Vec<f32>,
    returns: Vec<f32>,
    logprobs: BufferTensors,
    values: Option<Vec<f32>>,
}

struct BufferMetadata {
    gamma: f32,
    lambda: f32,
    with_vf_baseline: bool,
    buffer_size: usize,
    buffer_pointer: AtomicUsize,
    buffer_path_start_idx: AtomicUsize,
}

pub struct ReinforceReplayBuffer {
    buffers: Arc<Mutex<Buffers>>,
    metadata: Arc<BufferMetadata>,
}

impl Default for ReinforceReplayBuffer {
    fn default() -> Self {
        Self::new(1_000_000, 0.98, 0.97, false)
    }
}

impl ReinforceReplayBuffer {
    pub fn new(buffer_size: usize, gamma: f32, lambda: f32, with_vf_baseline: bool) -> Self {
        let buffers = Buffers {
            observations: Vec::with_capacity(buffer_size),
            actions: Vec::with_capacity(buffer_size),
            masks: Vec::with_capacity(buffer_size),
            rewards: Vec::with_capacity(buffer_size),
            advantages: Vec::with_capacity(buffer_size),
            returns: Vec::with_capacity(buffer_size),
            logprobs: Vec::with_capacity(buffer_size),
            values: with_vf_baseline.then(|| Vec::with_capacity(buffer_size)),
        };

        Self {
            buffers: Arc::new(Mutex::new(buffers)),
            metadata: Arc::new(BufferMetadata {
                gamma,
                lambda,
                with_vf_baseline,
                buffer_size,
                buffer_pointer: AtomicUsize::new(0),
                buffer_path_start_idx: AtomicUsize::new(0),
            }),
        }
    }

    fn tensor_scalar_f32(data: &TensorData) -> f32 {
        let values: &[f32] = bytemuck::cast_slice(&data.data);
        values.first().copied().unwrap_or(0.0)
    }

    fn finish_path(&self, buffers: &mut Buffers, final_value: Option<f32>) {
        let final_value = final_value.unwrap_or(0.0);
        let start = self.metadata.buffer_path_start_idx.load(Ordering::SeqCst);
        let end = self.metadata.buffer_pointer.load(Ordering::SeqCst);
        if start >= end {
            return;
        }
        let slice = start..end;

        if self.metadata.with_vf_baseline {
            let mut rewards = buffers.rewards[slice.clone()].to_vec();
            let mut values = buffers
                .values
                .as_ref()
                .map(|v| v[slice.clone()].to_vec())
                .unwrap_or_default();
            rewards.push(final_value);
            values.push(final_value);

            let deltas: Vec<f32> = (0..rewards.len() - 1)
                .map(|i| rewards[i] + self.metadata.gamma * values[i + 1] - values[i])
                .collect();
            let advantages = discounted_cumsum(&deltas, self.metadata.gamma * self.metadata.lambda);
            buffers.advantages[slice.clone()].copy_from_slice(&advantages);
            buffers.returns[slice.clone()].copy_from_slice(&discounted_cumsum(
                &buffers.rewards[slice.clone()],
                self.metadata.gamma,
            ));
        } else {
            let rewards = &buffers.rewards[slice.clone()];
            let advantages = discounted_cumsum(rewards, self.metadata.gamma);
            let returns = discounted_cumsum(rewards, self.metadata.gamma);
            buffers.advantages[slice.clone()].copy_from_slice(&advantages);
            buffers.returns[slice.clone()].copy_from_slice(&returns);
        }

        self.metadata
            .buffer_path_start_idx
            .store(end, Ordering::SeqCst);
    }
}

#[async_trait]
impl GenericReplayBuffer for ReinforceReplayBuffer {
    async fn insert_trajectory(
        &self,
        trajectory: RelayRLTrajectory,
    ) -> Result<Box<dyn Any>, ReplayBufferError> {
        let mut buffers = self.buffers.lock().await;
        let mut episode_return = 0.0f32;
        let mut episode_length = 0i32;

        for action in &trajectory.actions {
            episode_length += 1;
            let reward = action.get_rew();
            episode_return += reward;

            buffers.observations.push(action.get_obs().cloned());
            buffers.actions.push(action.get_act().cloned());
            buffers.masks.push(action.get_mask().cloned());
            buffers.logprobs.push(None);
            buffers.rewards.push(reward);
            buffers.advantages.push(0.0);
            buffers.returns.push(0.0);

            if let Some(map) = action.get_data() {
                if let Some(RelayRLData::Tensor(logp)) = map.get("logp_a")
                    && let Some(slot) = buffers.logprobs.last_mut()
                {
                    *slot = Some(logp.clone());
                }
                if self.metadata.with_vf_baseline {
                    let value = match map.get("val") {
                        Some(RelayRLData::Tensor(val)) => Self::tensor_scalar_f32(val),
                        _ => 0.0,
                    };
                    if let Some(values) = buffers.values.as_mut() {
                        values.push(value);
                    }
                }
            } else if self.metadata.with_vf_baseline
                && let Some(values) = buffers.values.as_mut()
            {
                values.push(0.0);
            }

            let next = self.metadata.buffer_pointer.load(Ordering::SeqCst) + 1;
            self.metadata.buffer_pointer.store(next, Ordering::SeqCst);

            if action.get_done() {
                self.finish_path(&mut buffers, Some(reward));
            }
        }

        Ok(Box::new((episode_return, episode_length)))
    }

    async fn sample_buffer(&self) -> Result<Batch, ReplayBufferError> {
        let mut buffers = self.buffers.lock().await;
        let capacity = self.metadata.buffer_pointer.load(Ordering::SeqCst);
        if capacity == 0 {
            return Err(ReplayBufferError::BufferSamplingError(
                "Replay buffer is empty".to_string(),
            ));
        }
        if capacity > self.metadata.buffer_size {
            return Err(ReplayBufferError::BufferSamplingError(
                "Replay buffer capacity exceeded".to_string(),
            ));
        }

        let adv_raw = &buffers.advantages[..capacity];
        let (adv_mean, adv_std) = scalar_stats(adv_raw);
        let adv_norm = compute_normed_advantages(adv_raw, adv_mean, adv_std.max(1e-8));

        let obs: Vec<TensorData> = buffers.observations[..capacity]
            .iter()
            .filter_map(|x| x.clone())
            .collect();
        let act: Vec<TensorData> = buffers.actions[..capacity]
            .iter()
            .filter_map(|x| x.clone())
            .collect();
        let mask: Vec<TensorData> = buffers.masks[..capacity]
            .iter()
            .filter_map(|x| x.clone())
            .collect();
        let logp: Vec<TensorData> = buffers.logprobs[..capacity]
            .iter()
            .filter_map(|x| x.clone())
            .collect();
        let ret: Vec<f32> = buffers.returns[..capacity].to_vec();

        self.metadata.buffer_pointer.store(0, Ordering::SeqCst);
        self.metadata
            .buffer_path_start_idx
            .store(0, Ordering::SeqCst);
        buffers.observations.clear();
        buffers.actions.clear();
        buffers.masks.clear();
        buffers.rewards.clear();
        buffers.advantages.clear();
        buffers.returns.clear();
        buffers.logprobs.clear();
        if let Some(values) = buffers.values.as_mut() {
            values.clear();
        }

        let mut batch: HashMap<BatchKey, BufferSample> = HashMap::new();
        batch.insert(BatchKey::Obs, BufferSample::Tensors(obs.into_boxed_slice()));
        batch.insert(BatchKey::Act, BufferSample::Tensors(act.into_boxed_slice()));
        batch.insert(
            BatchKey::Mask,
            BufferSample::Tensors(mask.into_boxed_slice()),
        );
        batch.insert(
            BatchKey::Custom("Adv".to_string()),
            BufferSample::Scalars(SampleScalars::F32(adv_norm.into_boxed_slice())),
        );
        batch.insert(
            BatchKey::Custom("Ret".to_string()),
            BufferSample::Scalars(SampleScalars::F32(ret.into_boxed_slice())),
        );
        batch.insert(
            BatchKey::Custom("LogP".to_string()),
            BufferSample::Tensors(logp.into_boxed_slice()),
        );

        Ok(batch)
    }
}