1use super::{
3 FeatureSpace, FiniteSpace, LogElementSpace, NonEmptySpace, ParameterizedDistributionSpace,
4 ReprSpace, Space, SubsetOrd,
5};
6use crate::logging::{LogError, LogValue, StatsLogger};
7use crate::torch::distributions::Categorical;
8use crate::utils::distributions::ArrayDistribution;
9use ndarray::{s, ArrayBase, DataMut, Ix2};
10use num_traits::{Float, One, Zero};
11use rand::distributions::Distribution;
12use rand::Rng;
13use serde::{Deserialize, Serialize};
14use std::cmp::Ordering;
15use std::fmt;
16use tch::{Device, Kind, Tensor};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
20pub struct IndexSpace {
21 pub size: usize,
22}
23
24impl IndexSpace {
25 #[must_use]
26 #[inline]
27 pub const fn new(size: usize) -> Self {
28 Self { size }
29 }
30}
31
32impl fmt::Display for IndexSpace {
33 #[inline]
34 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
35 write!(f, "IndexSpace({})", self.size)
36 }
37}
38
39impl Space for IndexSpace {
40 type Element = usize;
41
42 #[inline]
43 fn contains(&self, value: &Self::Element) -> bool {
44 value < &self.size
45 }
46}
47
48impl SubsetOrd for IndexSpace {
49 #[inline]
50 fn subset_cmp(&self, other: &Self) -> Option<Ordering> {
51 self.size.partial_cmp(&other.size)
52 }
53}
54
55impl NonEmptySpace for IndexSpace {
56 #[inline]
57 fn some_element(&self) -> <Self as Space>::Element {
58 assert_ne!(self.size, 0, "space is empty");
59 0
60 }
61}
62
63impl Distribution<<Self as Space>::Element> for IndexSpace {
64 #[inline]
65 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> <Self as Space>::Element {
66 rng.gen_range(0..self.size)
67 }
68}
69
70impl FiniteSpace for IndexSpace {
71 #[inline]
72 fn size(&self) -> usize {
73 self.size
74 }
75
76 #[inline]
77 fn to_index(&self, element: &Self::Element) -> usize {
78 *element
79 }
80
81 #[inline]
82 fn from_index(&self, index: usize) -> Option<Self::Element> {
83 if index >= self.size {
84 None
85 } else {
86 Some(index)
87 }
88 }
89
90 #[inline]
91 fn from_index_unchecked(&self, index: usize) -> Option<Self::Element> {
92 Some(index)
93 }
94}
95
96impl FeatureSpace for IndexSpace {
98 #[inline]
99 fn num_features(&self) -> usize {
100 self.size
101 }
102
103 #[inline]
104 fn features_out<'a, F: Float>(
105 &self,
106 element: &Self::Element,
107 out: &'a mut [F],
108 zeroed: bool,
109 ) -> &'a mut [F] {
110 let (out, rest) = out.split_at_mut(self.size);
111 if !zeroed {
112 out.fill(F::zero());
113 }
114 out[self.to_index(element)] = F::one();
115 rest
116 }
117
118 #[inline]
119 fn batch_features_out<'a, I, A>(&self, elements: I, out: &mut ArrayBase<A, Ix2>, zeroed: bool)
120 where
121 I: IntoIterator<Item = &'a Self::Element>,
122 Self::Element: 'a,
123 A: DataMut,
124 A::Elem: Float,
125 {
126 if !zeroed {
127 out.slice_mut(s![.., 0..self.num_features()])
128 .fill(Zero::zero());
129 }
130
131 let mut rows = out.rows_mut().into_iter();
133 for element in elements {
134 let mut row = rows.next().expect("fewer rows than elements");
135 row[self.to_index(element)] = One::one();
136 }
137 }
138}
139
140impl ReprSpace<Tensor> for IndexSpace {
142 #[inline]
143 fn repr(&self, element: &Self::Element) -> Tensor {
144 Tensor::scalar_tensor(self.to_index(element) as i64, (Kind::Int64, Device::Cpu))
145 }
146
147 #[inline]
148 fn batch_repr<'a, I>(&self, elements: I) -> Tensor
149 where
150 I: IntoIterator<Item = &'a Self::Element>,
151 Self::Element: 'a,
152 {
153 let indices: Vec<_> = elements
154 .into_iter()
155 .map(|elem| self.to_index(elem) as i64)
156 .collect();
157 Tensor::of_slice(&indices)
158 }
159}
160
161impl ParameterizedDistributionSpace<Tensor> for IndexSpace {
162 type Distribution = Categorical;
163
164 #[inline]
165 fn num_distribution_params(&self) -> usize {
166 self.size
167 }
168
169 #[inline]
170 fn sample_element(&self, params: &Tensor) -> Self::Element {
171 self.from_index(
172 self.distribution(params)
173 .sample()
174 .int64_value(&[])
175 .try_into()
176 .unwrap(),
177 )
178 .unwrap()
179 }
180
181 #[inline]
182 fn distribution(&self, params: &Tensor) -> Self::Distribution {
183 Self::Distribution::new(params)
184 }
185}
186
187impl LogElementSpace for IndexSpace {
189 #[inline]
190 fn log_element<L: StatsLogger + ?Sized>(
191 &self,
192 name: &'static str,
193 element: &Self::Element,
194 logger: &mut L,
195 ) -> Result<(), LogError> {
196 let log_value = LogValue::Index {
197 value: self.to_index(element),
198 size: self.size,
199 };
200 logger.log(name.into(), log_value)
201 }
202}
203
204impl<T: FiniteSpace + ?Sized> From<&T> for IndexSpace {
205 #[inline]
206 fn from(space: &T) -> Self {
207 Self { size: space.size() }
208 }
209}
210
211#[cfg(test)]
212mod space {
213 use super::super::testing;
214 use super::*;
215 use rstest::rstest;
216
217 #[rstest]
218 fn contains_zero(#[values(1, 5)] size: usize) {
219 let space = IndexSpace::new(size);
220 assert!(space.contains(&0));
221 }
222
223 #[rstest]
224 fn not_contains_too_large(#[values(1, 5)] size: usize) {
225 let space = IndexSpace::new(size);
226 assert!(!space.contains(&100));
227 }
228
229 #[rstest]
230 fn contains_samples(#[values(1, 5)] size: usize) {
231 let space = IndexSpace::new(size);
232 testing::check_contains_samples(&space, 100);
233 }
234}
235
236#[cfg(test)]
237mod subset_ord {
238 use super::super::SubsetOrd;
239 use super::*;
240 use std::cmp::Ordering;
241
242 #[test]
243 fn same_eq() {
244 assert_eq!(IndexSpace::new(2), IndexSpace::new(2));
245 assert_eq!(
246 IndexSpace::new(2).subset_cmp(&IndexSpace::new(2)),
247 Some(Ordering::Equal)
248 );
249 }
250
251 #[test]
252 fn different_not_eq() {
253 assert!(IndexSpace::new(2) != IndexSpace::new(1));
254 assert_ne!(
255 IndexSpace::new(2).subset_cmp(&IndexSpace::new(1)),
256 Some(Ordering::Equal)
257 );
258 }
259
260 #[test]
261 fn same_subset_of() {
262 assert!(IndexSpace::new(2).subset_of(&IndexSpace::new(2)));
263 }
264
265 #[test]
266 fn smaller_strict_subset_of() {
267 assert!(IndexSpace::new(1).strict_subset_of(&IndexSpace::new(2)));
268 }
269
270 #[test]
271 fn larger_not_subset_of() {
272 assert!(!IndexSpace::new(3).subset_of(&IndexSpace::new(1)));
273 }
274}
275
276#[cfg(test)]
277mod finite_space {
278 use super::super::testing;
279 use super::*;
280 use rstest::rstest;
281
282 #[rstest]
283 fn from_to_index_iter_size(#[values(1, 5)] size: usize) {
284 let space = IndexSpace::new(size);
285 testing::check_from_to_index_iter_size(&space);
286 }
287
288 #[rstest]
289 fn from_index_sampled(#[values(1, 5)] size: usize) {
290 let space = IndexSpace::new(size);
291 testing::check_from_index_sampled(&space, 100);
292 }
293
294 #[rstest]
295 fn from_index_invalid(#[values(1, 5)] size: usize) {
296 let space = IndexSpace::new(size);
297 testing::check_from_index_invalid(&space);
298 }
299}
300
301#[cfg(test)]
302mod feature_space {
303 use super::*;
304
305 #[test]
306 fn num_features() {
307 assert_eq!(IndexSpace::new(3).num_features(), 3);
308 }
309
310 features_tests!(f, IndexSpace::new(3), 1, [0.0, 1.0, 0.0]);
311 batch_features_tests!(
312 b,
313 IndexSpace::new(3),
314 [2, 0, 1],
315 [[0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
316 );
317}
318
319#[cfg(test)]
320mod repr_space_tensor {
321 use super::*;
322
323 #[test]
324 fn repr() {
325 let space = IndexSpace::new(3);
326 assert_eq!(
327 space.repr(&0),
328 Tensor::scalar_tensor(0, (Kind::Int64, Device::Cpu))
329 );
330 assert_eq!(
331 space.repr(&1),
332 Tensor::scalar_tensor(1, (Kind::Int64, Device::Cpu))
333 );
334 assert_eq!(
335 space.repr(&2),
336 Tensor::scalar_tensor(2, (Kind::Int64, Device::Cpu))
337 );
338 }
339
340 #[test]
341 fn batch_repr() {
342 let space = IndexSpace::new(3);
343 let elements = [0, 1, 2, 1];
344 let actual = space.batch_repr(&elements);
345 let expected = Tensor::of_slice(&[0_i64, 1, 2, 1]);
346 assert_eq!(actual, expected);
347 }
348}
349
350#[cfg(test)]
351mod parameterized_sample_space_tensor {
352 use super::*;
353 use std::ops::RangeInclusive;
354
355 #[test]
356 fn num_sample_params() {
357 let space = IndexSpace::new(3);
358 assert_eq!(3, space.num_distribution_params());
359 }
360
361 #[test]
362 fn sample_element_deterministic() {
363 let space = IndexSpace::new(3);
364 let params = Tensor::of_slice(&[f32::NEG_INFINITY, 0.0, f32::NEG_INFINITY]);
365 for _ in 0..10 {
366 assert_eq!(1, space.sample_element(¶ms));
367 }
368 }
369
370 #[test]
371 fn sample_element_two_of_three() {
372 let space = IndexSpace::new(3);
373 let params = Tensor::of_slice(&[f32::NEG_INFINITY, 0.0, 0.0]);
374 for _ in 0..10 {
375 assert!(0 != space.sample_element(¶ms));
376 }
377 }
378
379 #[allow(clippy::cast_possible_truncation)]
380 #[allow(clippy::cast_sign_loss)] fn bernoulli_confidence_interval(p: f64, n: u64) -> RangeInclusive<u64> {
382 let z = 4.4;
385 let nf = n as f64;
386 let stddev = (p * (1.0 - p) * nf).sqrt();
387 let lower_bound = nf * p - z * stddev;
388 let upper_bound = nf * p + z * stddev;
389 (lower_bound.round() as u64)..=(upper_bound.round() as u64)
390 }
391
392 #[test]
393 fn sample_element_check_distribution() {
394 let space = IndexSpace::new(3);
395 let params = Tensor::of_slice(&[-1.0, 0.0, 1.0]);
396 let probs = [0.090, 0.245, 0.665];
398 let n = 5000;
399
400 let mut one_count = 0;
401 let mut two_count = 0;
402 let mut three_count = 0;
403 for _ in 0..n {
404 match space.sample_element(¶ms) {
405 0 => one_count += 1,
406 1 => two_count += 1,
407 2 => three_count += 1,
408 _ => panic!(),
409 }
410 }
411 let one_interval = bernoulli_confidence_interval(probs[0], n);
413 let two_interval = bernoulli_confidence_interval(probs[1], n);
414 let three_interval = bernoulli_confidence_interval(probs[2], n);
415 assert!(one_interval.contains(&one_count));
416 assert!(two_interval.contains(&two_count));
417 assert!(three_interval.contains(&three_count));
418 }
419}