relearn/spaces/
boolean.rs

1//! `BooleanSpace` definition
2use super::{
3    FeatureSpace, FiniteSpace, LogElementSpace, NonEmptySpace, ParameterizedDistributionSpace,
4    ReprSpace, Space, SubsetOrd,
5};
6use crate::logging::{LogError, LogValue, StatsLogger};
7use crate::torch::distributions::Bernoulli;
8use crate::utils::distributions::ArrayDistribution;
9use crate::utils::num_array::{BuildFromArray1D, NumArray1D};
10use num_traits::Float;
11use rand::distributions::Distribution;
12use rand::Rng;
13use serde::{Deserialize, Serialize};
14use std::cmp::Ordering;
15use std::fmt;
16use tch::{Device, Kind, Tensor};
17
18/// The space `{false, true}`.
19#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
20pub struct BooleanSpace;
21
22impl BooleanSpace {
23    #[must_use]
24    #[inline]
25    pub const fn new() -> Self {
26        BooleanSpace
27    }
28}
29
30impl fmt::Display for BooleanSpace {
31    #[inline]
32    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
33        write!(f, "BooleanSpace")
34    }
35}
36
37impl Space for BooleanSpace {
38    type Element = bool;
39
40    #[inline]
41    fn contains(&self, _value: &Self::Element) -> bool {
42        true
43    }
44}
45
46impl SubsetOrd for BooleanSpace {
47    #[inline]
48    fn subset_cmp(&self, _other: &Self) -> Option<Ordering> {
49        Some(Ordering::Equal)
50    }
51}
52
53impl FiniteSpace for BooleanSpace {
54    #[inline]
55    fn size(&self) -> usize {
56        2
57    }
58
59    #[inline]
60    fn to_index(&self, element: &Self::Element) -> usize {
61        (*element).into()
62    }
63
64    #[inline]
65    fn from_index(&self, index: usize) -> Option<Self::Element> {
66        match index {
67            0 => Some(false),
68            1 => Some(true),
69            _ => None,
70        }
71    }
72
73    #[inline]
74    fn from_index_unchecked(&self, index: usize) -> Option<Self::Element> {
75        Some(index != 0)
76    }
77}
78
79impl NonEmptySpace for BooleanSpace {
80    #[inline]
81    fn some_element(&self) -> Self::Element {
82        false
83    }
84}
85
86/// Represent elements as a Boolean valued tensor.
87impl ReprSpace<Tensor> for BooleanSpace {
88    #[inline]
89    fn repr(&self, element: &Self::Element) -> Tensor {
90        Tensor::scalar_tensor(i64::from(*element), (Kind::Bool, Device::Cpu))
91    }
92
93    #[inline]
94    fn batch_repr<'a, I>(&self, elements: I) -> Tensor
95    where
96        I: IntoIterator<Item = &'a Self::Element>,
97        I::IntoIter: ExactSizeIterator + Clone,
98        Self::Element: 'a,
99    {
100        let elements: Vec<_> = elements.into_iter().copied().collect();
101        Tensor::of_slice(&elements)
102    }
103}
104
105impl ParameterizedDistributionSpace<Tensor> for BooleanSpace {
106    type Distribution = Bernoulli;
107
108    #[inline]
109    fn num_distribution_params(&self) -> usize {
110        1
111    }
112
113    #[inline]
114    fn sample_element(&self, params: &Tensor) -> Self::Element {
115        self.distribution(params).sample().into()
116    }
117
118    #[inline]
119    fn distribution(&self, params: &Tensor) -> Self::Distribution {
120        Self::Distribution::new(params.squeeze_dim(-1))
121    }
122}
123
124/// Features are `[0.0]` for `false` and `[1.0]` for `true`
125impl FeatureSpace for BooleanSpace {
126    #[inline]
127    fn num_features(&self) -> usize {
128        1
129    }
130
131    #[inline]
132    fn features_out<'a, F: Float>(
133        &self,
134        element: &Self::Element,
135        out: &'a mut [F],
136        _zeroed: bool,
137    ) -> &'a mut [F] {
138        out[0] = if *element { F::one() } else { F::zero() };
139        &mut out[1..]
140    }
141
142    #[inline]
143    fn features<T>(&self, element: &Self::Element) -> T
144    where
145        T: BuildFromArray1D,
146        <T::Array as NumArray1D>::Elem: Float,
147    {
148        if *element {
149            T::Array::ones(1).into()
150        } else {
151            T::Array::zeros(1).into()
152        }
153    }
154}
155
156impl Distribution<<Self as Space>::Element> for BooleanSpace {
157    #[inline]
158    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> <Self as Space>::Element {
159        rng.gen()
160    }
161}
162
163impl LogElementSpace for BooleanSpace {
164    #[inline]
165    fn log_element<L: StatsLogger + ?Sized>(
166        &self,
167        name: &'static str,
168        element: &Self::Element,
169        logger: &mut L,
170    ) -> Result<(), LogError> {
171        logger.log(name.into(), LogValue::Scalar(u8::from(*element).into()))
172    }
173}
174
175#[cfg(test)]
176mod space {
177    use super::super::testing;
178    use super::*;
179
180    #[test]
181    fn contains_false() {
182        let space = BooleanSpace;
183        assert!(space.contains(&false));
184    }
185
186    #[test]
187    fn contains_true() {
188        let space = BooleanSpace;
189        assert!(space.contains(&true));
190    }
191
192    #[test]
193    fn contains_samples() {
194        let space = BooleanSpace;
195        testing::check_contains_samples(&space, 10);
196    }
197}
198
199#[cfg(test)]
200mod subset_ord {
201    use super::*;
202
203    #[test]
204    fn eq() {
205        assert_eq!(BooleanSpace, BooleanSpace);
206    }
207
208    #[test]
209    fn cmp_equal() {
210        assert_eq!(
211            BooleanSpace.subset_cmp(&BooleanSpace),
212            Some(Ordering::Equal)
213        );
214    }
215
216    #[test]
217    fn not_less() {
218        assert!(!BooleanSpace.strict_subset_of(&BooleanSpace));
219    }
220}
221
222#[cfg(test)]
223mod finite_space {
224    use super::super::testing;
225    use super::*;
226
227    #[test]
228    fn from_to_index_iter_size() {
229        let space = BooleanSpace;
230        testing::check_from_to_index_iter_size(&space);
231    }
232
233    #[test]
234    fn from_to_index_random() {
235        let space = BooleanSpace;
236        testing::check_from_to_index_random(&space, 10);
237    }
238
239    #[test]
240    fn from_index_sampled() {
241        let space = BooleanSpace;
242        testing::check_from_index_sampled(&space, 10);
243    }
244
245    #[test]
246    fn from_index_invalid() {
247        let space = BooleanSpace;
248        testing::check_from_index_invalid(&space);
249    }
250}
251
252#[cfg(test)]
253mod feature_space {
254    use super::*;
255
256    #[test]
257    fn num_features() {
258        assert_eq!(BooleanSpace.num_features(), 1);
259    }
260
261    features_tests!(false_, BooleanSpace, false, [0.0]);
262    features_tests!(true_, BooleanSpace, true, [1.0]);
263    batch_features_tests!(
264        batch,
265        BooleanSpace,
266        [false, true, true, false],
267        [[0.0], [1.0], [1.0], [0.0]]
268    );
269}
270
271#[cfg(test)]
272mod parameterized_sample_space_tensor {
273    use super::*;
274    use std::iter;
275
276    #[test]
277    fn num_sample_params() {
278        assert_eq!(1, BooleanSpace.num_distribution_params());
279    }
280
281    #[test]
282    fn sample_element_deterministic() {
283        let space = BooleanSpace;
284        let params = Tensor::of_slice(&[f32::INFINITY]);
285        for _ in 0..10 {
286            assert!(space.sample_element(&params));
287        }
288    }
289
290    #[test]
291    fn sample_element_check_distribution() {
292        let space = BooleanSpace;
293        // logit = 1.0; p ~= 0.731
294        let params = Tensor::of_slice(&[1.0f32]);
295        let p = 0.731;
296        let n = 5000;
297        let count: u64 = iter::repeat_with(|| if space.sample_element(&params) { 1 } else { 0 })
298            .take(n)
299            .sum();
300        // Check that the counts are within a confidence interval
301        // Using Wald method <https://en.wikipedia.org/wiki/Binomial_distribution#Wald_method>
302        // Quantile for error rate of 1e-5
303        let z = 4.4;
304        let nf = n as f64;
305        let stddev = (p * (1.0 - p) * nf).sqrt();
306        let lower_bound = nf * p - z * stddev; // ~717
307        let upper_bound = nf * p + z * stddev; // ~745
308        assert!(lower_bound <= count as f64);
309        assert!(upper_bound >= count as f64);
310    }
311}