entity-gym-rs 0.8.0

Rust bindings for the entity-gym library
Documentation
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use atomicbox::AtomicOptionBox;
use crossbeam::channel::{bounded, Receiver, Sender};
use crossbeam::sync::{Parker, Unparker};
use ragged_buffer::ragged_buffer::RaggedBuffer;
use std::thread;

use super::{Action, ActionMask, ActionSpace, ActionType, Environment, ObsSpace, Observation};

pub struct VecEnv {
    inner: Arc<VecEnvInner>,
    tasks: Vec<Sender<Task>>,
    wait_on_obs: Parker,

    pub num_feats: Vec<usize>,
    pub obs_space: ObsSpace,
    pub action_space: Vec<(ActionType, ActionSpace)>,
    pub num_envs: usize,

    total_send: u64,
    total_wait: u64,
    total_collect: u64,
}

enum Task {
    Exit,
    Reset,
    RawBatchAct(Arc<Vec<Option<RaggedBuffer<i64>>>>),
}

struct VecEnvInner {
    obs: Vec<AtomicOptionBox<Observation>>,
    completed: AtomicUsize,
    wake_obs: Unparker,
}

impl VecEnv {
    #[allow(clippy::mutex_atomic)]
    pub fn new<T: Environment + Send + 'static>(
        create_env: Arc<dyn Fn(u64) -> T + Send + Sync>,
        num_envs: usize,
        threads: usize,
        // Used to offset seeding when running multiple sets of environments in different processes.
        first_env_index: u64,
    ) -> VecEnv {
        let parker = Parker::new();
        let unparker = parker.unparker().clone();
        let inner = Arc::new(VecEnvInner {
            obs: (0..num_envs).map(|_| AtomicOptionBox::none()).collect(),
            completed: AtomicUsize::new(0),
            wake_obs: unparker,
        });
        let mut senders = Vec::new();
        for i in 0..threads {
            let (task_tx, task_rx) = bounded(num_envs);
            let inner = inner.clone();
            let create_env = create_env.clone();
            thread::spawn(move || {
                inner.worker(task_rx, create_env, i, threads, num_envs, first_env_index)
            });
            senders.push(task_tx)
        }
        let env = create_env(99999);
        VecEnv {
            inner,
            tasks: senders,
            wait_on_obs: parker,

            num_feats: env
                .obs_space()
                .entities
                .iter()
                .map(|(_, e)| e.features.len())
                .collect(),
            obs_space: env.obs_space(),
            action_space: env.action_space(),
            num_envs,

            total_send: 0,
            total_wait: 0,
            total_collect: 0,
        }
    }

    pub fn reset(&mut self) -> Vec<Box<Observation>> {
        self.inner.completed.store(0, Ordering::SeqCst);
        for task in &mut self.tasks {
            task.send(Task::Reset).unwrap();
        }
        self.wait_on_obs.park();
        self.inner
            .obs
            .iter()
            .map(|obs| obs.take(Ordering::SeqCst).unwrap())
            .collect()
    }

    pub fn act(&mut self, actions: Vec<Option<RaggedBuffer<i64>>>) -> Vec<Box<Observation>> {
        //println!();
        let start_time = std::time::Instant::now();
        self.inner.completed.store(0, Ordering::SeqCst);
        let actions = Arc::new(actions);
        for task in &self.tasks {
            task.send(Task::RawBatchAct(actions.clone())).unwrap();
        }
        let send_ns = start_time.elapsed().as_nanos();
        //println!("Sending actions: {} ns", send_ns);
        drop(actions);
        self.total_send += send_ns as u64;
        let start_time = std::time::Instant::now();
        self.wait_on_obs.park();
        let wait_ns = start_time.elapsed().as_nanos();
        //println!("Await obs:       {} ns", wait_ns);
        self.total_wait += wait_ns as u64;
        let start_time = std::time::Instant::now();
        let obss = self
            .inner
            .obs
            .iter()
            .map(|obs| obs.take(Ordering::SeqCst).unwrap())
            .collect();
        let collect_ns = start_time.elapsed().as_nanos();
        //println!("Collecting obs:  {} ns", collect_ns);
        self.total_collect += collect_ns as u64;
        obss
    }
}

impl VecEnvInner {
    fn worker<T: Environment>(
        &self,
        rx: Receiver<Task>,
        create_env: Arc<dyn Fn(u64) -> T>,
        thread_id: usize,
        nthread: usize,
        total_envs: usize,
        seed_offset: u64,
    ) {
        let local_envs = total_envs / nthread
            + if thread_id < total_envs % nthread {
                1
            } else {
                0
            };
        assert!(
            local_envs > 0,
            "No environments for thread {} ({}/{} nthread={})",
            thread_id,
            local_envs,
            total_envs,
            nthread,
        );
        let mut agents_per_env = None;
        let mut envs = vec![];
        let env_offset = thread_id * (total_envs / nthread) + total_envs % nthread;
        let mut env_count = 0;
        while env_count < local_envs {
            let env = create_env(envs.len() as u64 + seed_offset);
            match &mut agents_per_env {
                None => agents_per_env = Some(env.agents()),
                Some(n) => assert_eq!(env.agents(), *n),
            }
            env_count += env.agents();
            envs.push(env);
        }
        let agents_per_env = agents_per_env.unwrap();
        assert!(local_envs % agents_per_env == 0);
        let mut action_masks = vec![];
        loop {
            let task = rx.recv().unwrap();
            match task {
                Task::Exit => break,
                Task::Reset => {
                    action_masks.clear();
                    for (i, env) in envs.iter_mut().enumerate() {
                        let env_id = i * agents_per_env + env_offset;
                        let mut obs = env.reset();
                        let mut done = obs[0].done;
                        while done {
                            let mut onew = env.reset();
                            done = onew[0].done;
                            for i in 0..obs.len() {
                                onew[i].reward = obs[i].reward;
                                onew[i].done = obs[i].done;
                                onew[i].metrics.extend(obs[i].metrics.clone());
                            }
                            obs = onew;
                        }
                        for (j, obs) in obs.into_iter().enumerate() {
                            action_masks.push(obs.actions.clone());
                            self.obs[env_id + j].store(Some(obs), Ordering::SeqCst);
                        }
                    }
                    if self
                        .completed
                        .fetch_add(envs.len() * agents_per_env, Ordering::SeqCst)
                        == total_envs - envs.len() * agents_per_env
                    {
                        self.wake_obs.unpark();
                    }
                }
                Task::RawBatchAct(ragged_actions) => {
                    let mut new_action_masks = Vec::with_capacity(action_masks.len());
                    for (i, env) in envs.iter_mut().enumerate() {
                        let mut actions = Vec::with_capacity(agents_per_env);
                        for agent in 0..agents_per_env {
                            let env_id = env_offset + i * agents_per_env + agent;
                            let action = ragged_actions
                                .iter()
                                .enumerate()
                                .map(|(idx_act_type, a)| match a {
                                    Some(a) => {
                                        let subarray = a.subarrays[env_id].clone();
                                        match &action_masks[i * agents_per_env + agent]
                                            [idx_act_type]
                                        {
                                            Some(ActionMask::DenseCategorical {
                                                actors, ..
                                            }) => Some(Action::Categorical {
                                                actors: actors.clone(),
                                                action: a.data[subarray]
                                                    .iter()
                                                    .map(|x| *x as usize)
                                                    .collect(),
                                            }),
                                            Some(ActionMask::SelectEntity { actors, actees }) => {
                                                Some(Action::SelectEntity {
                                                    actors: actors.clone(),
                                                    actees: a.data[subarray]
                                                        .iter()
                                                        .map(|x| actees[*x as usize])
                                                        .collect(),
                                                })
                                            }
                                            None => None,
                                        }
                                    }
                                    None => None,
                                })
                                .collect::<Vec<_>>();
                            actions.push(action);
                        }
                        let mut obs = env.act(&actions);
                        let mut done = obs[0].done;
                        while done {
                            let mut onew = env.reset();
                            done = onew[0].done;
                            for i in 0..obs.len() {
                                onew[i].reward = obs[i].reward;
                                onew[i].done = obs[i].done;
                                onew[i].metrics.extend(obs[i].metrics.clone());
                            }
                            obs = onew;
                        }
                        let env_id = env_offset + i * agents_per_env;
                        for (j, obs) in obs.into_iter().enumerate() {
                            new_action_masks.push(obs.actions.clone());
                            self.obs[env_id + j].store(Some(obs), Ordering::SeqCst);
                        }
                    }
                    if self
                        .completed
                        .fetch_add(envs.len() * agents_per_env, Ordering::SeqCst)
                        == total_envs - envs.len() * agents_per_env
                    {
                        self.wake_obs.unpark();
                    }
                    action_masks = new_action_masks;
                }
            }
        }
    }
}

impl Drop for VecEnv {
    fn drop(&mut self) {
        println!("Total send: {} ms", self.total_send / 1_000_000);
        println!("Total wait: {} ms", self.total_wait / 1_000_000);
        println!("Total collect: {} ms", self.total_collect / 1_000_000);
        // TODO: bounded channel, there could still be tasks in the queue
        // TODO: also only need to send num_threads exits
        for tx in &self.tasks {
            tx.send(Task::Exit).unwrap();
        }
    }
}