1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

use std::iter;
use serde_json::Value;
use serde_json::value::from_value;
use super::error::GymResult;
use rand::{thread_rng, Rng};

fn tuples(sizes: &[u64]) -> Vec<Vec<u64>> {
	match sizes.len() {
		0 => vec![],
		1 => (0..sizes[0]).map(|x| vec![x]).collect(),
		_ => {
			let (&head, tail) = sizes.split_first().unwrap();
			(0..head).flat_map(|x| iter::repeat(x).zip(tuples(tail))
				.map(|(h, mut t)| {
					t.insert(0, h);
					t
				})
			).collect()
		}
	}
}

#[derive(Debug, Clone)]
pub enum Space {
	DISCRETE{n: u64},
	BOX{shape: Vec<u64>, high: Vec<f64>, low: Vec<f64>},
	TUPLE{spaces: Vec<Box<Space>>}
}

impl Space {

	pub(crate) fn from_json(info: &Value) -> GymResult<Space> {
		match info["name"].as_str().unwrap() {
			"Discrete" => {
				let n = info["n"].as_u64().unwrap();
				Ok(Space::DISCRETE{n: n})
			},
			"Box" => {
				let shape = from_value(info["shape"].clone())?;
				let high = from_value(info["high"].clone())?;
				let low = from_value(info["low"].clone())?;

				Ok(Space::BOX{shape: shape, high: high, low: low})
			},
			"Tuple" => panic!("Parsing for Tuple spaces is not yet implemented"),
			e @ _ => panic!("Unrecognized space name: {}", e)
		}
	}

	pub fn sample(&self) -> Vec<f64> {
		let mut rng = thread_rng();
		match *self {
			Space::DISCRETE{n} => {
				vec![(rng.gen::<u64>()%n) as f64]
			},
			Space::BOX{ref shape, ref high, ref low} => {
				let mut ret = Vec::with_capacity(shape.iter().map(|x| *x as usize).product());
				let mut index = 0;

				for _ in tuples(shape) {
					ret.push(rng.gen_range(low[index], high[index]));
					index += 1
				}
				ret
			},
			Space::TUPLE{ref spaces} => {
				let mut ret = Vec::new();
				for space in spaces {
					ret.extend(space.sample());
				}
				ret
			}
		}
	}
}