relearn/spaces/
boolean.rs1use super::{
3 FeatureSpace, FiniteSpace, LogElementSpace, NonEmptySpace, ParameterizedDistributionSpace,
4 ReprSpace, Space, SubsetOrd,
5};
6use crate::logging::{LogError, LogValue, StatsLogger};
7use crate::torch::distributions::Bernoulli;
8use crate::utils::distributions::ArrayDistribution;
9use crate::utils::num_array::{BuildFromArray1D, NumArray1D};
10use num_traits::Float;
11use rand::distributions::Distribution;
12use rand::Rng;
13use serde::{Deserialize, Serialize};
14use std::cmp::Ordering;
15use std::fmt;
16use tch::{Device, Kind, Tensor};
17
18#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
20pub struct BooleanSpace;
21
22impl BooleanSpace {
23 #[must_use]
24 #[inline]
25 pub const fn new() -> Self {
26 BooleanSpace
27 }
28}
29
30impl fmt::Display for BooleanSpace {
31 #[inline]
32 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
33 write!(f, "BooleanSpace")
34 }
35}
36
37impl Space for BooleanSpace {
38 type Element = bool;
39
40 #[inline]
41 fn contains(&self, _value: &Self::Element) -> bool {
42 true
43 }
44}
45
46impl SubsetOrd for BooleanSpace {
47 #[inline]
48 fn subset_cmp(&self, _other: &Self) -> Option<Ordering> {
49 Some(Ordering::Equal)
50 }
51}
52
53impl FiniteSpace for BooleanSpace {
54 #[inline]
55 fn size(&self) -> usize {
56 2
57 }
58
59 #[inline]
60 fn to_index(&self, element: &Self::Element) -> usize {
61 (*element).into()
62 }
63
64 #[inline]
65 fn from_index(&self, index: usize) -> Option<Self::Element> {
66 match index {
67 0 => Some(false),
68 1 => Some(true),
69 _ => None,
70 }
71 }
72
73 #[inline]
74 fn from_index_unchecked(&self, index: usize) -> Option<Self::Element> {
75 Some(index != 0)
76 }
77}
78
79impl NonEmptySpace for BooleanSpace {
80 #[inline]
81 fn some_element(&self) -> Self::Element {
82 false
83 }
84}
85
86impl ReprSpace<Tensor> for BooleanSpace {
88 #[inline]
89 fn repr(&self, element: &Self::Element) -> Tensor {
90 Tensor::scalar_tensor(i64::from(*element), (Kind::Bool, Device::Cpu))
91 }
92
93 #[inline]
94 fn batch_repr<'a, I>(&self, elements: I) -> Tensor
95 where
96 I: IntoIterator<Item = &'a Self::Element>,
97 I::IntoIter: ExactSizeIterator + Clone,
98 Self::Element: 'a,
99 {
100 let elements: Vec<_> = elements.into_iter().copied().collect();
101 Tensor::of_slice(&elements)
102 }
103}
104
105impl ParameterizedDistributionSpace<Tensor> for BooleanSpace {
106 type Distribution = Bernoulli;
107
108 #[inline]
109 fn num_distribution_params(&self) -> usize {
110 1
111 }
112
113 #[inline]
114 fn sample_element(&self, params: &Tensor) -> Self::Element {
115 self.distribution(params).sample().into()
116 }
117
118 #[inline]
119 fn distribution(&self, params: &Tensor) -> Self::Distribution {
120 Self::Distribution::new(params.squeeze_dim(-1))
121 }
122}
123
124impl FeatureSpace for BooleanSpace {
126 #[inline]
127 fn num_features(&self) -> usize {
128 1
129 }
130
131 #[inline]
132 fn features_out<'a, F: Float>(
133 &self,
134 element: &Self::Element,
135 out: &'a mut [F],
136 _zeroed: bool,
137 ) -> &'a mut [F] {
138 out[0] = if *element { F::one() } else { F::zero() };
139 &mut out[1..]
140 }
141
142 #[inline]
143 fn features<T>(&self, element: &Self::Element) -> T
144 where
145 T: BuildFromArray1D,
146 <T::Array as NumArray1D>::Elem: Float,
147 {
148 if *element {
149 T::Array::ones(1).into()
150 } else {
151 T::Array::zeros(1).into()
152 }
153 }
154}
155
156impl Distribution<<Self as Space>::Element> for BooleanSpace {
157 #[inline]
158 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> <Self as Space>::Element {
159 rng.gen()
160 }
161}
162
163impl LogElementSpace for BooleanSpace {
164 #[inline]
165 fn log_element<L: StatsLogger + ?Sized>(
166 &self,
167 name: &'static str,
168 element: &Self::Element,
169 logger: &mut L,
170 ) -> Result<(), LogError> {
171 logger.log(name.into(), LogValue::Scalar(u8::from(*element).into()))
172 }
173}
174
175#[cfg(test)]
176mod space {
177 use super::super::testing;
178 use super::*;
179
180 #[test]
181 fn contains_false() {
182 let space = BooleanSpace;
183 assert!(space.contains(&false));
184 }
185
186 #[test]
187 fn contains_true() {
188 let space = BooleanSpace;
189 assert!(space.contains(&true));
190 }
191
192 #[test]
193 fn contains_samples() {
194 let space = BooleanSpace;
195 testing::check_contains_samples(&space, 10);
196 }
197}
198
199#[cfg(test)]
200mod subset_ord {
201 use super::*;
202
203 #[test]
204 fn eq() {
205 assert_eq!(BooleanSpace, BooleanSpace);
206 }
207
208 #[test]
209 fn cmp_equal() {
210 assert_eq!(
211 BooleanSpace.subset_cmp(&BooleanSpace),
212 Some(Ordering::Equal)
213 );
214 }
215
216 #[test]
217 fn not_less() {
218 assert!(!BooleanSpace.strict_subset_of(&BooleanSpace));
219 }
220}
221
222#[cfg(test)]
223mod finite_space {
224 use super::super::testing;
225 use super::*;
226
227 #[test]
228 fn from_to_index_iter_size() {
229 let space = BooleanSpace;
230 testing::check_from_to_index_iter_size(&space);
231 }
232
233 #[test]
234 fn from_to_index_random() {
235 let space = BooleanSpace;
236 testing::check_from_to_index_random(&space, 10);
237 }
238
239 #[test]
240 fn from_index_sampled() {
241 let space = BooleanSpace;
242 testing::check_from_index_sampled(&space, 10);
243 }
244
245 #[test]
246 fn from_index_invalid() {
247 let space = BooleanSpace;
248 testing::check_from_index_invalid(&space);
249 }
250}
251
252#[cfg(test)]
253mod feature_space {
254 use super::*;
255
256 #[test]
257 fn num_features() {
258 assert_eq!(BooleanSpace.num_features(), 1);
259 }
260
261 features_tests!(false_, BooleanSpace, false, [0.0]);
262 features_tests!(true_, BooleanSpace, true, [1.0]);
263 batch_features_tests!(
264 batch,
265 BooleanSpace,
266 [false, true, true, false],
267 [[0.0], [1.0], [1.0], [0.0]]
268 );
269}
270
271#[cfg(test)]
272mod parameterized_sample_space_tensor {
273 use super::*;
274 use std::iter;
275
276 #[test]
277 fn num_sample_params() {
278 assert_eq!(1, BooleanSpace.num_distribution_params());
279 }
280
281 #[test]
282 fn sample_element_deterministic() {
283 let space = BooleanSpace;
284 let params = Tensor::of_slice(&[f32::INFINITY]);
285 for _ in 0..10 {
286 assert!(space.sample_element(¶ms));
287 }
288 }
289
290 #[test]
291 fn sample_element_check_distribution() {
292 let space = BooleanSpace;
293 let params = Tensor::of_slice(&[1.0f32]);
295 let p = 0.731;
296 let n = 5000;
297 let count: u64 = iter::repeat_with(|| if space.sample_element(¶ms) { 1 } else { 0 })
298 .take(n)
299 .sum();
300 let z = 4.4;
304 let nf = n as f64;
305 let stddev = (p * (1.0 - p) * nf).sqrt();
306 let lower_bound = nf * p - z * stddev; let upper_bound = nf * p + z * stddev; assert!(lower_bound <= count as f64);
309 assert!(upper_bound >= count as f64);
310 }
311}