border-py-gym-env 0.0.5

Reinforcement learning library
//! Wrapper of gym environments implemented in Python.
#![allow(clippy::float_cmp)]
use crate::{AtariWrapper, PyGymEnvConfig};
use anyhow::Result;
use border_core::{record::Record, Act, Env, Info, Obs, Step};
use log::{info, trace};
use pyo3::types::{IntoPyDict, PyTuple};
use pyo3::{types::PyModule, PyObject, Python, ToPyObject};
use std::marker::PhantomData;
use std::{fmt::Debug, time::Duration};
use serde::{Serialize, de::DeserializeOwned};

/// Information given at every step of the interaction with the environment.
///
/// Currently, it is empty and used to match the type signature.
pub struct PyGymInfo {}

impl Info for PyGymInfo {}

/// Convert [PyObject] to [PyGymEnv]::Obs with a preprocessing.
pub trait PyGymEnvObsFilter<O: Obs> {
    /// Configuration.
    type Config: Clone + Default + Serialize + DeserializeOwned;

    /// Build filter.
    fn build(config: &Self::Config) -> Result<Self> where Self: Sized;

    /// Convert PyObject into observation with filtering.
    fn filt(&mut self, obs: PyObject) -> (O, Record);

    /// Called when resetting the environment.
    ///
    /// This method is useful for stateful filters.
    fn reset(&mut self, obs: PyObject) -> O {
        let (obs, _) = self.filt(obs);
        obs
    }

    /// Returns default configuration.
    fn default_config() -> Self::Config {
        Self::Config::default()
    }
}

/// Convert [PyGymEnv]::Act to [PyObject] with a preprocessing.
///
/// This trait should support vectorized environments.
pub trait PyGymEnvActFilter<A: Act> {
    /// Configuration.
    type Config: Clone + Default + Serialize + DeserializeOwned;

    /// Build filter.
    fn build(config: &Self::Config) -> Result<Self> where Self: Sized;

    /// Filter action and convert it to PyObject.
    ///
    /// For vectorized environments, `act` should have actions for all environments in
    /// the vectorized environment. The return values will be a `PyList` object, each
    /// element is an action of the corresponding environment.
    fn filt(&mut self, act: A) -> (PyObject, Record);

    /// Called when resetting the environment.
    ///
    /// This method is useful for stateful filters.
    /// This method support vectorized environment
    fn reset(&mut self, _is_done: &Option<&Vec<i8>>) {}

    /// Returns default configuration.
    fn default_config() -> Self::Config {
        Self::Config::default()
    }
}

/// An environment in [OpenAI gym](https://github.com/openai/gym).
#[derive(Debug)]
pub struct PyGymEnv<O, A, OF, AF>
where
    O: Obs,
    A: Act,
    OF: PyGymEnvObsFilter<O>,
    AF: PyGymEnvActFilter<A>,
{
    render: bool,
    env: PyObject,
    #[allow(dead_code)]
    action_space: i64,
    #[allow(dead_code)]
    observation_space: Vec<usize>,
    count_steps: usize,
    max_steps: Option<usize>,
    obs_filter: OF,
    act_filter: AF,
    wait_in_render: Duration,
    pybullet: bool,
    pybullet_state: Option<PyObject>,
    phantom: PhantomData<(O, A)>,
}

impl<O, A, OF, AF> PyGymEnv<O, A, OF, AF>
where
    O: Obs,
    A: Act,
    OF: PyGymEnvObsFilter<O>,
    AF: PyGymEnvActFilter<A>,
{
    // /// Constructs an environment.
    // ///
    // /// `name` is the name of the environment, which is implemented in OpenAI gym.
    // pub fn new(
    //     name: &str,
    //     obs_filter: OF,
    //     act_filter: AF,
    //     atari_wrapper: Option<AtariWrapper>,
    // ) -> PyResult<Self> {
    //     let gil = Python::acquire_gil();
    //     let py = gil.python();

    //     // sys.argv is used by pyglet library, which is responsible for rendering.
    //     // Depending on the python interpreter, however, sys.argv can be empty.
    //     // For that case, sys argv is set here.
    //     // See https://github.com/PyO3/pyo3/issues/1241#issuecomment-715952517
    //     let locals = [("sys", py.import("sys")?)].into_py_dict(py);
    //     let _ = py.eval("sys.argv.insert(0, 'PyGymEnv')", None, Some(&locals))?;

    //     // import pybullet-gym if it exists
    //     if py.import("pybulletgym").is_ok() {}

    //     let env = if let Some(mode) = atari_wrapper {
    //         let mode = match mode {
    //             AtariWrapper::Train => true,
    //             AtariWrapper::Eval => false,
    //         };
    //         let gym = py.import("atari_wrappers")?;
    //         let env = gym.call("make_env_single_proc", (name, true, mode), None)?;
    //         env.call_method("seed", (42,), None)?;
    //         env
    //     } else {
    //         let gym = py.import("gym")?;
    //         let env = gym.call("make", (name,), None)?;
    //         env.call_method("seed", (42,), None)?;
    //         env
    //     };

    //     // TODO: consider removing action_space and observation_space.
    //     // Act/obs types are specified by type parameters.
    //     let action_space = env.getattr("action_space")?;
    //     let action_space = if let Ok(val) = action_space.getattr("n") {
    //         val.extract()?
    //     } else {
    //         let action_space: Vec<i64> = action_space.getattr("shape")?.extract()?;
    //         action_space[0]
    //     };
    //     let observation_space = env.getattr("observation_space")?;
    //     let observation_space = observation_space.getattr("shape")?.extract()?;

    //     Ok(PyGymEnv {
    //         render: false,
    //         env: env.into(),
    //         action_space,
    //         observation_space,
    //         // TODO: consider remove RefCell, raw value instead
    //         count_steps: RefCell::new(0),
    //         max_steps: None,
    //         obs_filter,
    //         act_filter,
    //         wait_in_render: Duration::from_millis(0),
    //         pybullet: false,
    //         pybullet_state: None,
    //         phantom: PhantomData,
    //     })
    // }

    /// Set rendering mode.
    ///
    /// If `true`, it renders the state at every step.
    pub fn set_render(&mut self, render: bool) {
        self.render = render;
        if self.pybullet {
            pyo3::Python::with_gil(|py| {
                self.env.call_method0(py, "render").unwrap();
            });
        }
    }

    /// Set the maximum number of steps in the environment.
    pub fn max_steps(mut self, v: Option<usize>) -> Self {
        self.max_steps = v;
        self
    }

    /// Set time for sleep in rendering.
    pub fn set_wait_in_render(&mut self, d: Duration) {
        self.wait_in_render = d;
    }

    /// Get the number of available actions of atari environments
    pub fn get_num_actions_atari(&self) -> i64 {
        pyo3::Python::with_gil(|py| {
            let act_space = self.env.getattr(py, "action_space").unwrap();
            act_space.getattr(py, "n").unwrap().extract(py).unwrap()
        })
    }
}

impl<O, A, OF, AF> Env for PyGymEnv<O, A, OF, AF>
where
    O: Obs,
    A: Act + Debug,
    OF: PyGymEnvObsFilter<O>,
    AF: PyGymEnvActFilter<A>,
{
    type Obs = O;
    type Act = A;
    type Info = PyGymInfo;
    type Config = PyGymEnvConfig<O, A, OF, AF>;

    /// Currently it supports non-vectorized environment.
    fn step_with_reset(&mut self, a: &Self::Act) -> (Step<Self>, Record)
    where
            Self: Sized
    {
        let (step, record) = self.step(a);
        assert_eq!(step.is_done.len(), 1);
        let step = if step.is_done[0] == 1 {
            let init_obs = self.reset(None).unwrap();
            Step {
                act: step.act,
                obs: step.obs,
                reward: step.reward,
                is_done: step.is_done,
                info: step.info,
                init_obs
            }
        } else {
            step
        };

        (step, record)
    }

    /// Resets the environment, the obs/act filters and returns the observation tensor.
    ///
    /// In this environment, the length of `is_done` is assumed to be 1.
    fn reset(&mut self, is_done: Option<&Vec<i8>>) -> Result<O> {
        trace!("PyGymEnv::reset()");

        // Reset the action filter, required for stateful filters.
        self.act_filter.reset(&is_done);

        // Reset the environment
        let reset = match is_done {
            None => true,
            // when reset() is called in border_core::util::sample()
            Some(v) => {
                debug_assert_eq!(v.len(), 1);
                v[0] != 0
            }
        };

        if !reset {
            Ok(O::dummy(1))
        } else {
            pyo3::Python::with_gil(|py| {
                let obs = self.env.call_method0(py, "reset")?;
                if self.pybullet && self.render {
                    let floor: &PyModule =
                        self.pybullet_state.as_ref().unwrap().extract(py).unwrap();
                    // floor.call1("add_floor", (&self.env,)).unwrap();
                    floor.getattr("add_floor")?.call1((&self.env,)).unwrap();
                }
                Ok(self.obs_filter.reset(obs))
            })
        }
    }

    /// Runs a step of the environment's dynamics.
    ///
    /// It returns [`Step`] and [`Record`] objects.
    /// The [`Record`] is composed of [`Record`]s constructed in [`PyGymEnvObsFilter`] and
    /// [`PyGymEnvActFilter`].
    fn step(&mut self, a: &A) -> (Step<Self>, Record) {
        trace!("PyGymEnv::step()");

        pyo3::Python::with_gil(|py| {
            if self.render {
                if !self.pybullet {
                    let _ = self.env.call_method0(py, "render");
                } else {
                    let cam: &PyModule = self.pybullet_state.as_ref().unwrap().extract(py).unwrap();
                    // cam.call1("update_camera_pos", (&self.env,)).unwrap();
                    cam.getattr("update_camera_pos").unwrap().call1((&self.env,)).unwrap();
                }
                std::thread::sleep(self.wait_in_render);
            }

            let (a_py, record_a) = self.act_filter.filt(a.clone());
            let ret = self.env.call_method(py, "step", (a_py,), None).unwrap();
            let step: &PyTuple = ret.extract(py).unwrap();
            let obs = step.get_item(0).to_owned();
            let (obs, record_o) = self.obs_filter.filt(obs.to_object(py));
            let reward: Vec<f32> = vec![step.get_item(1).extract().unwrap()];
            let mut is_done: Vec<i8> = vec![if step.get_item(2).extract().unwrap() {
                1
            } else {
                0
            }];

            // let c = *self.count_steps.borrow();
            self.count_steps += 1; //.replace(c + 1);
            if let Some(max_steps) = self.max_steps {
                if self.count_steps >= max_steps {
                    is_done[0] = 1;
                    self.count_steps = 0;
                }
            };

            (
                Step::<Self>::new(obs, a.clone(), reward, is_done, PyGymInfo {}, O::dummy(1)),
                record_o.merge(record_a),
            )
        })
    }

    /// Constructs [PyGymEnv].
    fn build(config: &Self::Config, seed: i64) -> Result<Self> {
        let gil = Python::acquire_gil();
        let py = gil.python();

        // sys.argv is used by pyglet library, which is responsible for rendering.
        // Depending on the python interpreter, however, sys.argv can be empty.
        // For that case, sys argv is set here.
        // See https://github.com/PyO3/pyo3/issues/1241#issuecomment-715952517
        let locals = [("sys", py.import("sys")?)].into_py_dict(py);
        let _ = py.eval("sys.argv.insert(0, 'PyGymEnv')", None, Some(&locals))?;
        let path = py.eval("sys.path", None, Some(&locals)).unwrap();
        let ver = py.eval("sys.version", None, Some(&locals)).unwrap();
        info!("Initialize PyGymEnv");
        info!("{}", path);
        info!("Python version = {}", ver);

        // import pybullet-gym if it exists
        if py.import("pybulletgym").is_ok() {}

        let name = config.name.as_str();
        let env = if let Some(mode) = config.atari_wrapper.as_ref() {
            let mode = match mode {
                AtariWrapper::Train => true,
                AtariWrapper::Eval => false,
            };
            let gym = py.import("atari_wrappers")?;
            let env = gym.getattr("make_env_single_proc")?.call((name, true, mode), None)?;
            env.call_method("seed", (seed,), None)?;
            env
        } else if !config.pybullet {
            let gym = py.import("f32_wrapper")?;
            let env = gym.getattr("make_f32")?.call((name,), None)?;
            env.call_method("seed", (seed,), None)?;
            env
        } else {
            let gym = py.import("gym")?;
            let env = gym.getattr("make")?.call((name,), None)?;
            env.call_method("seed", (seed,), None)?;
            env
        };

        // TODO: consider removing action_space and observation_space.
        // Act/obs types are specified by type parameters.
        let action_space = env.getattr("action_space")?;
        let action_space = if let Ok(val) = action_space.getattr("n") {
            val.extract()?
        } else {
            let action_space: Vec<i64> = action_space.getattr("shape")?.extract()?;
            action_space[0]
        };
        let observation_space = env.getattr("observation_space")?;
        let observation_space = observation_space.getattr("shape")?.extract()?;

        let pybullet_state = if !config.pybullet {
            None
        } else {
            let pybullet_state = Python::with_gil(|py| {
                PyModule::from_code(
                    py,
                    r#"
_torsoId = None
_floor = False

def add_floor(env):
    global _floor
    if not _floor:
        p = env.env._p
        import pybullet_data
        p.setAdditionalSearchPath(pybullet_data.getDataPath())
        p.loadURDF("plane.urdf")
        _floor = True
        env.env.stateId = p.saveState()

def get_torso_id(p):
    global _torsoId
    if _torsoId is None:
        torsoId = -1
        for i in range(p.getNumBodies()):
            print(p.getBodyInfo(i))
            if p.getBodyInfo(i)[0].decode() == "torso":
                torsoId = i
                print("found torso")
        _torsoId = torsoId
    
    return _torsoId

def update_camera_pos(env):
    p = env.env._p
    torsoId = get_torso_id(env.env._p)
    if torsoId >= 0:
        distance = 5
        yaw = 0
        humanPos, humanOrn = p.getBasePositionAndOrientation(torsoId)
        p.resetDebugVisualizerCamera(distance, yaw, -20, humanPos)

            "#,
                    "pybullet_state.py",
                    "pybullet_state",
                )
                .unwrap()
                .to_object(py)
            });
            Some(pybullet_state)
        };

        Ok(PyGymEnv {
            env: env.into(),
            action_space,
            observation_space,
            obs_filter: OF::build(&config.obs_filter_config.as_ref().unwrap())?,
            act_filter: AF::build(&config.act_filter_config.as_ref().unwrap())?,
            render: false,
            count_steps: 0,
            wait_in_render: Duration::from_millis(0),
            max_steps: config.max_steps,
            pybullet: config.pybullet,
            pybullet_state,
            phantom: PhantomData,
        })
    }
}