relearn/spaces/
index.rs

1//! `IndexSpace` definition
2use super::{
3    FeatureSpace, FiniteSpace, LogElementSpace, NonEmptySpace, ParameterizedDistributionSpace,
4    ReprSpace, Space, SubsetOrd,
5};
6use crate::logging::{LogError, LogValue, StatsLogger};
7use crate::torch::distributions::Categorical;
8use crate::utils::distributions::ArrayDistribution;
9use ndarray::{s, ArrayBase, DataMut, Ix2};
10use num_traits::{Float, One, Zero};
11use rand::distributions::Distribution;
12use rand::Rng;
13use serde::{Deserialize, Serialize};
14use std::cmp::Ordering;
15use std::fmt;
16use tch::{Device, Kind, Tensor};
17
18/// An index space; consists of the integers `0` to `size - 1`
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
20pub struct IndexSpace {
21    pub size: usize,
22}
23
24impl IndexSpace {
25    #[must_use]
26    #[inline]
27    pub const fn new(size: usize) -> Self {
28        Self { size }
29    }
30}
31
32impl fmt::Display for IndexSpace {
33    #[inline]
34    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
35        write!(f, "IndexSpace({})", self.size)
36    }
37}
38
39impl Space for IndexSpace {
40    type Element = usize;
41
42    #[inline]
43    fn contains(&self, value: &Self::Element) -> bool {
44        value < &self.size
45    }
46}
47
48impl SubsetOrd for IndexSpace {
49    #[inline]
50    fn subset_cmp(&self, other: &Self) -> Option<Ordering> {
51        self.size.partial_cmp(&other.size)
52    }
53}
54
55impl NonEmptySpace for IndexSpace {
56    #[inline]
57    fn some_element(&self) -> <Self as Space>::Element {
58        assert_ne!(self.size, 0, "space is empty");
59        0
60    }
61}
62
63impl Distribution<<Self as Space>::Element> for IndexSpace {
64    #[inline]
65    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> <Self as Space>::Element {
66        rng.gen_range(0..self.size)
67    }
68}
69
70impl FiniteSpace for IndexSpace {
71    #[inline]
72    fn size(&self) -> usize {
73        self.size
74    }
75
76    #[inline]
77    fn to_index(&self, element: &Self::Element) -> usize {
78        *element
79    }
80
81    #[inline]
82    fn from_index(&self, index: usize) -> Option<Self::Element> {
83        if index >= self.size {
84            None
85        } else {
86            Some(index)
87        }
88    }
89
90    #[inline]
91    fn from_index_unchecked(&self, index: usize) -> Option<Self::Element> {
92        Some(index)
93    }
94}
95
96/// Features are one-hot vectors
97impl FeatureSpace for IndexSpace {
98    #[inline]
99    fn num_features(&self) -> usize {
100        self.size
101    }
102
103    #[inline]
104    fn features_out<'a, F: Float>(
105        &self,
106        element: &Self::Element,
107        out: &'a mut [F],
108        zeroed: bool,
109    ) -> &'a mut [F] {
110        let (out, rest) = out.split_at_mut(self.size);
111        if !zeroed {
112            out.fill(F::zero());
113        }
114        out[self.to_index(element)] = F::one();
115        rest
116    }
117
118    #[inline]
119    fn batch_features_out<'a, I, A>(&self, elements: I, out: &mut ArrayBase<A, Ix2>, zeroed: bool)
120    where
121        I: IntoIterator<Item = &'a Self::Element>,
122        Self::Element: 'a,
123        A: DataMut,
124        A::Elem: Float,
125    {
126        if !zeroed {
127            out.slice_mut(s![.., 0..self.num_features()])
128                .fill(Zero::zero());
129        }
130
131        // Don't zip rows so that we can check whether there are too few rows.
132        let mut rows = out.rows_mut().into_iter();
133        for element in elements {
134            let mut row = rows.next().expect("fewer rows than elements");
135            row[self.to_index(element)] = One::one();
136        }
137    }
138}
139
140/// Represents elements as integer tensors.
141impl ReprSpace<Tensor> for IndexSpace {
142    #[inline]
143    fn repr(&self, element: &Self::Element) -> Tensor {
144        Tensor::scalar_tensor(self.to_index(element) as i64, (Kind::Int64, Device::Cpu))
145    }
146
147    #[inline]
148    fn batch_repr<'a, I>(&self, elements: I) -> Tensor
149    where
150        I: IntoIterator<Item = &'a Self::Element>,
151        Self::Element: 'a,
152    {
153        let indices: Vec<_> = elements
154            .into_iter()
155            .map(|elem| self.to_index(elem) as i64)
156            .collect();
157        Tensor::of_slice(&indices)
158    }
159}
160
161impl ParameterizedDistributionSpace<Tensor> for IndexSpace {
162    type Distribution = Categorical;
163
164    #[inline]
165    fn num_distribution_params(&self) -> usize {
166        self.size
167    }
168
169    #[inline]
170    fn sample_element(&self, params: &Tensor) -> Self::Element {
171        self.from_index(
172            self.distribution(params)
173                .sample()
174                .int64_value(&[])
175                .try_into()
176                .unwrap(),
177        )
178        .unwrap()
179    }
180
181    #[inline]
182    fn distribution(&self, params: &Tensor) -> Self::Distribution {
183        Self::Distribution::new(params)
184    }
185}
186
187/// Log the index as a sample from `0..N`
188impl LogElementSpace for IndexSpace {
189    #[inline]
190    fn log_element<L: StatsLogger + ?Sized>(
191        &self,
192        name: &'static str,
193        element: &Self::Element,
194        logger: &mut L,
195    ) -> Result<(), LogError> {
196        let log_value = LogValue::Index {
197            value: self.to_index(element),
198            size: self.size,
199        };
200        logger.log(name.into(), log_value)
201    }
202}
203
204impl<T: FiniteSpace + ?Sized> From<&T> for IndexSpace {
205    #[inline]
206    fn from(space: &T) -> Self {
207        Self { size: space.size() }
208    }
209}
210
211#[cfg(test)]
212mod space {
213    use super::super::testing;
214    use super::*;
215    use rstest::rstest;
216
217    #[rstest]
218    fn contains_zero(#[values(1, 5)] size: usize) {
219        let space = IndexSpace::new(size);
220        assert!(space.contains(&0));
221    }
222
223    #[rstest]
224    fn not_contains_too_large(#[values(1, 5)] size: usize) {
225        let space = IndexSpace::new(size);
226        assert!(!space.contains(&100));
227    }
228
229    #[rstest]
230    fn contains_samples(#[values(1, 5)] size: usize) {
231        let space = IndexSpace::new(size);
232        testing::check_contains_samples(&space, 100);
233    }
234}
235
236#[cfg(test)]
237mod subset_ord {
238    use super::super::SubsetOrd;
239    use super::*;
240    use std::cmp::Ordering;
241
242    #[test]
243    fn same_eq() {
244        assert_eq!(IndexSpace::new(2), IndexSpace::new(2));
245        assert_eq!(
246            IndexSpace::new(2).subset_cmp(&IndexSpace::new(2)),
247            Some(Ordering::Equal)
248        );
249    }
250
251    #[test]
252    fn different_not_eq() {
253        assert!(IndexSpace::new(2) != IndexSpace::new(1));
254        assert_ne!(
255            IndexSpace::new(2).subset_cmp(&IndexSpace::new(1)),
256            Some(Ordering::Equal)
257        );
258    }
259
260    #[test]
261    fn same_subset_of() {
262        assert!(IndexSpace::new(2).subset_of(&IndexSpace::new(2)));
263    }
264
265    #[test]
266    fn smaller_strict_subset_of() {
267        assert!(IndexSpace::new(1).strict_subset_of(&IndexSpace::new(2)));
268    }
269
270    #[test]
271    fn larger_not_subset_of() {
272        assert!(!IndexSpace::new(3).subset_of(&IndexSpace::new(1)));
273    }
274}
275
276#[cfg(test)]
277mod finite_space {
278    use super::super::testing;
279    use super::*;
280    use rstest::rstest;
281
282    #[rstest]
283    fn from_to_index_iter_size(#[values(1, 5)] size: usize) {
284        let space = IndexSpace::new(size);
285        testing::check_from_to_index_iter_size(&space);
286    }
287
288    #[rstest]
289    fn from_index_sampled(#[values(1, 5)] size: usize) {
290        let space = IndexSpace::new(size);
291        testing::check_from_index_sampled(&space, 100);
292    }
293
294    #[rstest]
295    fn from_index_invalid(#[values(1, 5)] size: usize) {
296        let space = IndexSpace::new(size);
297        testing::check_from_index_invalid(&space);
298    }
299}
300
301#[cfg(test)]
302mod feature_space {
303    use super::*;
304
305    #[test]
306    fn num_features() {
307        assert_eq!(IndexSpace::new(3).num_features(), 3);
308    }
309
310    features_tests!(f, IndexSpace::new(3), 1, [0.0, 1.0, 0.0]);
311    batch_features_tests!(
312        b,
313        IndexSpace::new(3),
314        [2, 0, 1],
315        [[0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
316    );
317}
318
319#[cfg(test)]
320mod repr_space_tensor {
321    use super::*;
322
323    #[test]
324    fn repr() {
325        let space = IndexSpace::new(3);
326        assert_eq!(
327            space.repr(&0),
328            Tensor::scalar_tensor(0, (Kind::Int64, Device::Cpu))
329        );
330        assert_eq!(
331            space.repr(&1),
332            Tensor::scalar_tensor(1, (Kind::Int64, Device::Cpu))
333        );
334        assert_eq!(
335            space.repr(&2),
336            Tensor::scalar_tensor(2, (Kind::Int64, Device::Cpu))
337        );
338    }
339
340    #[test]
341    fn batch_repr() {
342        let space = IndexSpace::new(3);
343        let elements = [0, 1, 2, 1];
344        let actual = space.batch_repr(&elements);
345        let expected = Tensor::of_slice(&[0_i64, 1, 2, 1]);
346        assert_eq!(actual, expected);
347    }
348}
349
350#[cfg(test)]
351mod parameterized_sample_space_tensor {
352    use super::*;
353    use std::ops::RangeInclusive;
354
355    #[test]
356    fn num_sample_params() {
357        let space = IndexSpace::new(3);
358        assert_eq!(3, space.num_distribution_params());
359    }
360
361    #[test]
362    fn sample_element_deterministic() {
363        let space = IndexSpace::new(3);
364        let params = Tensor::of_slice(&[f32::NEG_INFINITY, 0.0, f32::NEG_INFINITY]);
365        for _ in 0..10 {
366            assert_eq!(1, space.sample_element(&params));
367        }
368    }
369
370    #[test]
371    fn sample_element_two_of_three() {
372        let space = IndexSpace::new(3);
373        let params = Tensor::of_slice(&[f32::NEG_INFINITY, 0.0, 0.0]);
374        for _ in 0..10 {
375            assert!(0 != space.sample_element(&params));
376        }
377    }
378
379    #[allow(clippy::cast_possible_truncation)]
380    #[allow(clippy::cast_sign_loss)] // negative f64 casts to 0.0 as desired
381    fn bernoulli_confidence_interval(p: f64, n: u64) -> RangeInclusive<u64> {
382        // Using Wald method <https://en.wikipedia.org/wiki/Binomial_distribution#Wald_method>
383        // Quantile for error rate of 1e-5
384        let z = 4.4;
385        let nf = n as f64;
386        let stddev = (p * (1.0 - p) * nf).sqrt();
387        let lower_bound = nf * p - z * stddev;
388        let upper_bound = nf * p + z * stddev;
389        (lower_bound.round() as u64)..=(upper_bound.round() as u64)
390    }
391
392    #[test]
393    fn sample_element_check_distribution() {
394        let space = IndexSpace::new(3);
395        let params = Tensor::of_slice(&[-1.0, 0.0, 1.0]);
396        // Corresponding approximate probabilities
397        let probs = [0.090, 0.245, 0.665];
398        let n = 5000;
399
400        let mut one_count = 0;
401        let mut two_count = 0;
402        let mut three_count = 0;
403        for _ in 0..n {
404            match space.sample_element(&params) {
405                0 => one_count += 1,
406                1 => two_count += 1,
407                2 => three_count += 1,
408                _ => panic!(),
409            }
410        }
411        // Check that the counts are within their expected intervals
412        let one_interval = bernoulli_confidence_interval(probs[0], n);
413        let two_interval = bernoulli_confidence_interval(probs[1], n);
414        let three_interval = bernoulli_confidence_interval(probs[2], n);
415        assert!(one_interval.contains(&one_count));
416        assert!(two_interval.contains(&two_count));
417        assert!(three_interval.contains(&three_count));
418    }
419}