Skip to main content

lgp/core/
registers.rs

1use core::slice::Iter;
2use std::{ops::Index, slice::SliceIndex};
3
4use itertools::Itertools;
5use rand::seq::SliceRandom;
6use serde::{Deserialize, Deserializer, Serialize};
7
8use crate::utils::random::generator;
9
10use super::engines::reset_engine::{Reset, ResetEngine};
11
12fn deserialize_vec_with_null<'de, D>(deserializer: D) -> Result<Vec<f64>, D::Error>
13where
14    D: Deserializer<'de>,
15{
16    let vec_opt: Option<Vec<Option<f64>>> = Deserialize::deserialize(deserializer)?;
17    Ok(vec_opt
18        .unwrap_or_default()
19        .into_iter()
20        .map(|x| x.unwrap_or(f64::NAN))
21        .collect())
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Registers {
26    #[serde(deserialize_with = "deserialize_vec_with_null")]
27    data: Vec<f64>,
28    n_actions: usize,
29}
30
31pub enum ArgmaxResult {
32    MaxValues(Vec<usize>),
33    Overflow,
34}
35
36pub enum ActionRegister {
37    Value(usize),
38    Overflow,
39}
40
41impl ArgmaxResult {
42    pub fn one(&self) -> ActionRegister {
43        match self {
44            ArgmaxResult::MaxValues(indices) if indices.len() == 1 => {
45                ActionRegister::Value(indices[0])
46            }
47            _ => ActionRegister::Overflow,
48        }
49    }
50
51    pub fn any(&self) -> ActionRegister {
52        match self {
53            ArgmaxResult::MaxValues(indices) if !indices.is_empty() => {
54                ActionRegister::Value(indices.choose(&mut generator()).copied().unwrap())
55            }
56            _ => ActionRegister::Overflow,
57        }
58    }
59}
60
61pub enum ArgmaxInput {
62    All,
63    ActionRegisters,
64}
65
66impl Reset<Registers> for ResetEngine {
67    fn reset(item: &mut Registers) {
68        for value in item.data.as_mut_slice() {
69            *value = 0.
70        }
71    }
72}
73
74impl Registers {
75    pub fn new(n_actions: usize, n_working_registers: usize) -> Self {
76        let data = vec![0.; n_actions + n_working_registers];
77
78        Registers { data, n_actions }
79    }
80
81    pub fn argmax(&self, range: ArgmaxInput) -> ArgmaxResult {
82        let range_to_use = match range {
83            ArgmaxInput::All => 0..(self.data.len()),
84            ArgmaxInput::ActionRegisters => 0..(self.n_actions),
85        };
86
87        let sliced_data = &self.data[range_to_use];
88        let max_value = sliced_data
89            .iter()
90            .copied()
91            .reduce(f64::max)
92            .expect("Sliced values to not be of cardinality 0.");
93
94        if max_value.is_infinite() || max_value.is_nan() {
95            return ArgmaxResult::Overflow;
96        }
97
98        let max_indices = sliced_data
99            .iter()
100            .copied()
101            .enumerate()
102            .filter(|(_, v)| v == &max_value)
103            .map(|(i, _)| i)
104            .collect_vec();
105
106        ArgmaxResult::MaxValues(max_indices)
107    }
108
109    pub fn len(&self) -> usize {
110        let Registers { data, .. } = self;
111        data.len()
112    }
113
114    pub fn is_empty(&self) -> bool {
115        self.data.is_empty()
116    }
117
118    pub fn update(&mut self, index: usize, value: f64) {
119        let Registers { data, .. } = self;
120        data[index] = value;
121    }
122
123    pub fn get(&self, index: usize) -> &f64 {
124        let Registers { data, .. } = self;
125        data.get(index).unwrap()
126    }
127
128    pub fn iter(&self) -> Iter<'_, f64> {
129        self.data.iter()
130    }
131}
132
133impl<Idx> Index<Idx> for Registers
134where
135    Idx: SliceIndex<[f64]>,
136{
137    type Output = Idx::Output;
138
139    fn index(&self, index: Idx) -> &Self::Output {
140        &self.data[index]
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use crate::core::registers::Registers;
147
148    #[test]
149    fn given_registers_when_indexed_with_range_then_slice_is_returned() {
150        let mut registers = Registers::new(9, 1);
151        registers.update(0, 1.);
152
153        let slice = &registers[0..2];
154
155        assert_eq!(slice, &[1., 0.]);
156    }
157}