1use core::slice::Iter;
2use std::{ops::Index, slice::SliceIndex};
3
4use itertools::Itertools;
5use rand::seq::SliceRandom;
6use serde::{Deserialize, Deserializer, Serialize};
7
8use crate::utils::random::generator;
9
10use super::engines::reset_engine::{Reset, ResetEngine};
11
12fn deserialize_vec_with_null<'de, D>(deserializer: D) -> Result<Vec<f64>, D::Error>
13where
14 D: Deserializer<'de>,
15{
16 let vec_opt: Option<Vec<Option<f64>>> = Deserialize::deserialize(deserializer)?;
17 Ok(vec_opt
18 .unwrap_or_default()
19 .into_iter()
20 .map(|x| x.unwrap_or(f64::NAN))
21 .collect())
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Registers {
26 #[serde(deserialize_with = "deserialize_vec_with_null")]
27 data: Vec<f64>,
28 n_actions: usize,
29}
30
31pub enum ArgmaxResult {
32 MaxValues(Vec<usize>),
33 Overflow,
34}
35
36pub enum ActionRegister {
37 Value(usize),
38 Overflow,
39}
40
41impl ArgmaxResult {
42 pub fn one(&self) -> ActionRegister {
43 match self {
44 ArgmaxResult::MaxValues(indices) if indices.len() == 1 => {
45 ActionRegister::Value(indices[0])
46 }
47 _ => ActionRegister::Overflow,
48 }
49 }
50
51 pub fn any(&self) -> ActionRegister {
52 match self {
53 ArgmaxResult::MaxValues(indices) if !indices.is_empty() => {
54 ActionRegister::Value(indices.choose(&mut generator()).copied().unwrap())
55 }
56 _ => ActionRegister::Overflow,
57 }
58 }
59}
60
61pub enum ArgmaxInput {
62 All,
63 ActionRegisters,
64}
65
66impl Reset<Registers> for ResetEngine {
67 fn reset(item: &mut Registers) {
68 for value in item.data.as_mut_slice() {
69 *value = 0.
70 }
71 }
72}
73
74impl Registers {
75 pub fn new(n_actions: usize, n_working_registers: usize) -> Self {
76 let data = vec![0.; n_actions + n_working_registers];
77
78 Registers { data, n_actions }
79 }
80
81 pub fn argmax(&self, range: ArgmaxInput) -> ArgmaxResult {
82 let range_to_use = match range {
83 ArgmaxInput::All => 0..(self.data.len()),
84 ArgmaxInput::ActionRegisters => 0..(self.n_actions),
85 };
86
87 let sliced_data = &self.data[range_to_use];
88 let max_value = sliced_data
89 .iter()
90 .copied()
91 .reduce(f64::max)
92 .expect("Sliced values to not be of cardinality 0.");
93
94 if max_value.is_infinite() || max_value.is_nan() {
95 return ArgmaxResult::Overflow;
96 }
97
98 let max_indices = sliced_data
99 .iter()
100 .copied()
101 .enumerate()
102 .filter(|(_, v)| v == &max_value)
103 .map(|(i, _)| i)
104 .collect_vec();
105
106 ArgmaxResult::MaxValues(max_indices)
107 }
108
109 pub fn len(&self) -> usize {
110 let Registers { data, .. } = self;
111 data.len()
112 }
113
114 pub fn is_empty(&self) -> bool {
115 self.data.is_empty()
116 }
117
118 pub fn update(&mut self, index: usize, value: f64) {
119 let Registers { data, .. } = self;
120 data[index] = value;
121 }
122
123 pub fn get(&self, index: usize) -> &f64 {
124 let Registers { data, .. } = self;
125 data.get(index).unwrap()
126 }
127
128 pub fn iter(&self) -> Iter<'_, f64> {
129 self.data.iter()
130 }
131}
132
133impl<Idx> Index<Idx> for Registers
134where
135 Idx: SliceIndex<[f64]>,
136{
137 type Output = Idx::Output;
138
139 fn index(&self, index: Idx) -> &Self::Output {
140 &self.data[index]
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use crate::core::registers::Registers;
147
148 #[test]
149 fn given_registers_when_indexed_with_range_then_slice_is_returned() {
150 let mut registers = Registers::new(9, 1);
151 registers.update(0, 1.);
152
153 let slice = ®isters[0..2];
154
155 assert_eq!(slice, &[1., 0.]);
156 }
157}