reinforcex 0.0.4

Deep Reinforcement Learning Framework
use super::experience::Experience;
use crate::misc::bounded_vec_deque::BoundedVecDeque;
use crate::prob_distributions::BaseDistribution;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tch::Tensor;
use ulid::Ulid;

pub struct OnPolicyBuffer {
    experiences_by_episode: HashMap<Ulid, BoundedVecDeque<Arc<Experience>>>,
}

impl OnPolicyBuffer {
    pub fn new() -> Self {
        Self {
            experiences_by_episode: HashMap::new(),
        }
    }

    pub fn append(
        &mut self,
        agent_id: Ulid,
        episode_id: Ulid,
        state: Tensor,
        action: Option<Tensor>,
        action_distrib: Option<Box<dyn BaseDistribution>>,
        reward: f64,
        is_episode_terminal: bool,
    ) -> Arc<Experience> {
        let experience = Arc::new(Experience::new(
            agent_id,
            episode_id,
            state,
            action,
            action_distrib,
            reward,
            is_episode_terminal,
            Mutex::new(None),
            Mutex::new(None),
        ));

        self.experiences_by_episode
            .entry(episode_id)
            .or_insert_with(|| BoundedVecDeque::new(1e9 as usize))
            .push_back(experience.clone());

        experience
    }

    pub fn flush(&mut self) -> Vec<Vec<Arc<Experience>>> {
        self.experiences_by_episode
            .drain()
            .map(|(_k, v)| v.to_vec())
            .collect()
    }
}