entity_gym_rs/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2//! # EntityGym for Rust
3//!
4//! [EntityGym](https://github.com/entity-neural-network/entity-gym) is a Python library that defines a novel entity-based abstraction for reinforcement learning environments which enables highly ergonomic and efficient training of deep reinforcement learning agents.
5//! This crate provides bindings that allows Rust programs to implement the entity-gym API and run neural network agents trained with [enn-trainer](https://github.com/entity-neural-network/enn-trainer).
6
7/// High level API for interacting with neural network agents.
8pub mod agent;
9mod examples;
10/// Low-level and highly API that mirrors the entity-gym Python API. Not intended for direct use.
11pub mod low_level;
12
13#[cfg(feature = "python")]
14mod python {
15    use std::sync::Arc;
16
17    use crate::examples::multisnake::MultiSnake;
18
19    use self::py_vec_env::PyVecEnv;
20
21    pub use super::low_level::*;
22    use pyo3::prelude::*;
23
24    #[pyfunction(
25        board_size = "10",
26        first_env_index = "0",
27        num_snakes = "1",
28        max_snake_length = "10",
29        max_steps = "100"
30    )]
31    fn multisnake(
32        num_envs: usize,
33        threads: usize,
34        board_size: usize,
35        first_env_index: u64,
36        num_snakes: usize,
37        max_snake_length: usize,
38        max_steps: usize,
39    ) -> PyVecEnv {
40        PyVecEnv {
41            env: VecEnv::new(
42                Arc::new(move |i| {
43                    MultiSnake::new(board_size, num_snakes, max_snake_length, max_steps, i)
44                }),
45                num_envs,
46                threads,
47                first_env_index,
48            ),
49        }
50    }
51
52    #[pymodule]
53    fn entity_gym_rs(_py: Python, m: &PyModule) -> PyResult<()> {
54        m.add_class::<py_vec_env::VecObs>()?;
55        m.add_class::<py_vec_env::PyVecEnv>()?;
56        m.add_function(wrap_pyfunction!(multisnake, m)?)?;
57        Ok(())
58    }
59}