reinforcex 0.0.4

Deep Reinforcement Learning Framework
use super::experience::Experience;
use crate::misc::bounded_vec_deque::BoundedVecDeque;
use crate::misc::cumsum;
use crate::misc::random_access_queue::RandomAccessQueue;
use crate::prob_distributions::BaseDistribution;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tch::Tensor;
use ulid::Ulid;

#[derive(Clone)]
pub struct ReplayBuffer {
    experiences: Arc<Mutex<RandomAccessQueue<Arc<Experience>>>>,
    last_n_experiences_by_episode: Arc<Mutex<HashMap<Ulid, BoundedVecDeque<Arc<Experience>>>>>,
    n_steps: usize,
}

impl ReplayBuffer {
    pub fn new(capacity: usize, n_steps: usize) -> Self {
        assert!(capacity > 0);
        assert!(n_steps > 0);
        Self {
            experiences: Arc::new(Mutex::new(RandomAccessQueue::new(capacity))),
            last_n_experiences_by_episode: Arc::new(Mutex::new(HashMap::new())),
            n_steps,
        }
    }

    pub fn append(
        &self,
        agent_id: Ulid,
        episode_id: Ulid,
        state: Tensor,
        action: Option<Tensor>,
        action_distrib: Option<Box<dyn BaseDistribution>>,
        reward: f64,
        is_episode_terminal: bool,
        gamma: f64,
    ) -> Arc<Experience> {
        let experience = Arc::new(Experience::new(
            agent_id,
            episode_id,
            state,
            action,
            action_distrib,
            reward,
            is_episode_terminal,
            Mutex::new(None),
            Mutex::new(None),
        ));

        let mut last_n_experiences_by_episode = self.last_n_experiences_by_episode.lock().unwrap();

        let last_n_experiences = last_n_experiences_by_episode
            .entry(episode_id)
            .or_insert_with(|| BoundedVecDeque::new(self.n_steps));

        if let Some(exp) = last_n_experiences.push_back(experience.clone()) {
            *exp.n_step_discounted_reward.lock().unwrap() = Some(
                cumsum::cumsum_rev(
                    last_n_experiences
                        .clone_deque()
                        .into_iter()
                        .map(|e| e.reward)
                        .collect::<Vec<f64>>()
                        .as_ref(),
                    &vec![gamma; last_n_experiences.len()],
                )[0],
            );
            *exp.n_step_after_experience.lock().unwrap() = Some(experience.clone());
            self.experiences.lock().unwrap().append(exp);
        }

        if is_episode_terminal {
            if let Some(last_n_experiences) = last_n_experiences_by_episode.remove(&episode_id) {
                let mut rewards = last_n_experiences
                    .clone_deque()
                    .into_iter()
                    .skip(1)
                    .map(|e| e.reward)
                    .collect::<Vec<f64>>();
                rewards.push(0.0);
                for (exp, &q) in last_n_experiences
                    .clone_deque()
                    .into_iter()
                    .zip(cumsum::cumsum_rev(&rewards, &vec![gamma; rewards.len()]).iter())
                {
                    if exp.is_episode_terminal {
                        continue;
                    }
                    *exp.n_step_discounted_reward.lock().unwrap() = Some(q);
                    *exp.n_step_after_experience.lock().unwrap() = Some(experience.clone());
                    self.experiences.lock().unwrap().append(exp);
                }
            }
        }

        experience
    }

    pub fn sample(&self, num_experiences: usize, replacement: bool) -> Vec<Arc<Experience>> {
        if replacement {
            self.experiences
                .lock()
                .unwrap()
                .sample_with_replacement(num_experiences)
                .into_iter()
                .cloned()
                .collect()
        } else {
            self.experiences
                .lock()
                .unwrap()
                .sample_without_replacement(num_experiences)
                .into_iter()
                .cloned()
                .collect()
        }
    }

    pub fn len(&self) -> usize {
        self.experiences.lock().unwrap().len()
    }

    pub fn clear(&self) {
        self.experiences.lock().unwrap().clear();
        self.last_n_experiences_by_episode.lock().unwrap().clear();
    }

    pub fn get_n_steps(&self) -> usize {
        self.n_steps
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use rayon::prelude::*;
    use std::{sync::Arc, thread::sleep, time::Duration};
    use tch::Tensor;

    #[test]
    fn test_replay_buffer_new() {
        let buffer = ReplayBuffer::new(100, 5);
        assert_eq!(buffer.len(), 0);
    }

    #[test]
    fn test_replay_buffer_append_and_len() {
        let buffer = ReplayBuffer::new(100, 1);
        let state = Tensor::from_slice(&[1.0]);
        let episode_id = Ulid::new();
        buffer.append(
            episode_id,
            episode_id,
            state.shallow_clone(),
            None,
            None,
            1.0,
            false,
            1.0,
        );
        assert_eq!(buffer.len(), 0);
        buffer.append(
            episode_id,
            episode_id,
            state.shallow_clone(),
            None,
            None,
            1.0,
            false,
            1.0,
        );
        assert_eq!(buffer.len(), 1);
        buffer.append(
            episode_id,
            episode_id,
            state.shallow_clone(),
            None,
            None,
            1.0,
            false,
            1.0,
        );
        assert_eq!(buffer.len(), 2);
    }

    #[test]
    fn test_is_episode_terminal() {
        let buffer = ReplayBuffer::new(100, 1);
        let state = Tensor::from_slice(&[1.0]);
        let episode_id = Ulid::new();
        buffer.append(
            episode_id,
            episode_id,
            state.shallow_clone(),
            None,
            None,
            1.0,
            true,
            1.0,
        );
    }

    #[test]
    fn test_replay_buffer_sample() {
        let buffer = ReplayBuffer::new(100, 5);
        let episode_id = Ulid::new();
        for i in 0..10 {
            let state = Tensor::from_slice(&[i as f64]);
            buffer.append(
                episode_id, episode_id, state, None, None, i as f64, false, 1.0,
            );
        }
        let samples = buffer.sample(3, false);
        assert_eq!(samples.len(), 3);
    }

    #[test]
    fn test_replay_buffer_terminal_state() {
        let buffer = ReplayBuffer::new(100, 5);
        let episode_id = Ulid::new();
        for i in 0..5 {
            let state = Tensor::from_slice(&[i as f64]);
            buffer.append(
                episode_id,
                episode_id,
                state,
                None,
                None,
                i as f64,
                i == 4,
                1.0,
            );
        }
        let last_n_experiences_by_episode = buffer.last_n_experiences_by_episode.lock().unwrap();
        assert_eq!(
            last_n_experiences_by_episode
                .get(&episode_id)
                .map(|v| v.len())
                .unwrap_or(0),
            0
        );
    }

    #[test]
    fn test_q_value_and_next_experience_update() {
        let buffer = ReplayBuffer::new(100, 2);
        let state1 = Tensor::from_slice(&[0.0]);
        let state2 = Tensor::from_slice(&[1.0]);
        let state3 = Tensor::from_slice(&[2.0]);
        let state4 = Tensor::from_slice(&[3.0]);
        let state5 = Tensor::from_slice(&[4.0]);
        let state6 = Tensor::from_slice(&[5.0]);
        let state7 = Tensor::from_slice(&[6.0]);
        let state8 = Tensor::from_slice(&[7.0]);
        let state9 = Tensor::from_slice(&[8.0]);

        let episode1_id = Ulid::new();
        let episode2_id = Ulid::new();

        buffer.append(
            episode1_id,
            episode1_id,
            state1,
            None,
            None,
            0.0,
            false,
            0.9,
        );
        buffer.append(
            episode1_id,
            episode1_id,
            state2,
            None,
            None,
            2.0,
            false,
            0.9,
        );
        buffer.append(episode1_id, episode1_id, state3, None, None, 3.0, true, 0.9);
        buffer.append(
            episode2_id,
            episode2_id,
            state4,
            None,
            None,
            0.0,
            false,
            0.9,
        );
        buffer.append(
            episode2_id,
            episode2_id,
            state5,
            None,
            None,
            0.0,
            false,
            0.9,
        );
        buffer.append(
            episode2_id,
            episode2_id,
            state6,
            None,
            None,
            0.0,
            false,
            0.9,
        );
        buffer.append(
            episode2_id,
            episode2_id,
            state7,
            None,
            None,
            0.0,
            false,
            0.9,
        );
        buffer.append(
            episode2_id,
            episode2_id,
            state8,
            None,
            None,
            0.0,
            false,
            0.9,
        );
        buffer.append(episode2_id, episode2_id, state9, None, None, 5.0, true, 0.9);

        for experience in buffer.sample(7, false) {
            let n_step_discounted_reward = *experience.n_step_discounted_reward.lock().unwrap();
            let n_step_after_experience = experience.n_step_after_experience.lock().unwrap();
            let expected_q_value;
            let state_val = experience.state.double_value(&[]);
            if state_val == 0.0 {
                expected_q_value = 2.0 + 0.9 * 3.0;
                assert!((n_step_discounted_reward.unwrap() - expected_q_value).abs() < 1e-6);
                assert!(n_step_after_experience.is_some());
            } else if state_val == 1.0 {
                expected_q_value = 3.0;
                assert!((n_step_discounted_reward.unwrap() - expected_q_value).abs() < 1e-6);
                assert!(n_step_after_experience.is_some());
            } else if state_val == 3.0 {
                expected_q_value = 0.0;
                assert!((n_step_discounted_reward.unwrap() - expected_q_value).abs() < 1e-6);
                assert!(n_step_after_experience.is_some());
            } else if state_val == 4.0 {
                expected_q_value = 0.0;
                assert!((n_step_discounted_reward.unwrap() - expected_q_value).abs() < 1e-6);
                assert!(n_step_after_experience.is_some());
            } else if state_val == 5.0 {
                expected_q_value = 0.0;
                assert!((n_step_discounted_reward.unwrap() - expected_q_value).abs() < 1e-6);
                assert!(n_step_after_experience.is_some());
            } else if state_val == 6.0 {
                expected_q_value = 0.9 * 5.0;
                assert!((n_step_discounted_reward.unwrap() - expected_q_value).abs() < 1e-6);
                assert!(n_step_after_experience.is_some());
            } else if state_val == 7.0 {
                expected_q_value = 5.0;
                assert!((n_step_discounted_reward.unwrap() - expected_q_value).abs() < 1e-6);
                assert!(n_step_after_experience.is_some());
            } else {
                panic!("Unexpected state")
            }
        }
    }

    #[test]
    fn test_concurrent_append_and_sample_with_threads() {
        let buffer = Arc::new(ReplayBuffer::new(200, 3));
        let n_threads = 10;

        (0..n_threads).into_par_iter().for_each(|i| {
            let episode_id = Ulid::new();
            for j in 1..100 {
                let state = Tensor::from_slice(&[i as f64, j as f64]);
                buffer.append(episode_id, episode_id, state, None, None, 1.0, false, 0.99);
                sleep(Duration::from_millis(1));

                if j % 10 == 0 {
                    let samples = buffer.sample(3, true);
                    assert!(samples.len() == 3);
                }
            }
        });

        assert!(buffer.len() > 0);
        let samples = buffer.sample(5, true);
        assert_eq!(samples.len(), 5);
    }
}