cervo_core/
epsilon.rs

1// Author: Tom Solberg <tom.solberg@embark-studios.com>
2// Copyright © 2022, Embark Studios AB, all rights reserved.
3// Created: 11 May 2022
4
5/*!
6Utilities for filling noise inputs for an inference model.
7*/
8
9use std::cell::RefCell;
10
11use crate::{batcher::ScratchPadView, inferer::Inferer, prelude::InfererWrapper};
12use anyhow::{bail, Result};
13use perchance::PerchanceContext;
14use rand::thread_rng;
15use rand_distr::{Distribution, StandardNormal};
16
17/// `NoiseGenerators` are consumed by the [`EpsilonInjector`] by generating noise sampled for a standard normal
18/// distribution. Custom noise-generators can be implemented and passed via [`EpsilonInjector::with_generator`].
19pub trait NoiseGenerator {
20    fn generate(&self, count: usize, out: &mut [f32]);
21}
22
23/// A non-noisy noise generator, primarily intended for debugging or testing purposes.
24pub struct ConstantGenerator {
25    value: f32,
26}
27
28impl ConstantGenerator {
29    /// Will generate the provided `value` when called.
30    pub fn for_value(value: f32) -> Self {
31        Self { value }
32    }
33
34    /// Convenience function for a constant generator for zeros.
35    pub fn zeros() -> Self {
36        Self::for_value(1.0)
37    }
38
39    /// Convenience function for a constant generator for ones.
40    pub fn ones() -> Self {
41        Self::for_value(1.0)
42    }
43}
44
45impl NoiseGenerator for ConstantGenerator {
46    fn generate(&self, _count: usize, out: &mut [f32]) {
47        for o in out {
48            *o = self.value;
49        }
50    }
51}
52
53/// A low quality noise generator which is about twice as fast as the built-in [`HighQualityNoiseGenerator`]. This uses
54/// an XORSHIFT algorithm internally which isn't cryptographically secure.
55///
56/// The default implementation will seed the generator from the current time.
57pub struct LowQualityNoiseGenerator {
58    ctx: RefCell<PerchanceContext>,
59}
60
61impl LowQualityNoiseGenerator {
62    /// Create a new LQNG with the provided fixed seed.
63    pub fn new(seed: u128) -> Self {
64        Self {
65            ctx: RefCell::new(PerchanceContext::new(seed)),
66        }
67    }
68}
69
70impl Default for LowQualityNoiseGenerator {
71    fn default() -> Self {
72        Self {
73            ctx: RefCell::new(PerchanceContext::new(perchance::gen_time_seed())),
74        }
75    }
76}
77
78impl NoiseGenerator for LowQualityNoiseGenerator {
79    /// Generate `count` random values.
80    fn generate(&self, _count: usize, out: &mut [f32]) {
81        let mut ctx = self.ctx.borrow_mut();
82        for o in out {
83            *o = ctx.normal_f32();
84        }
85    }
86}
87
88/// A high quality noise generator which is measurably slower than the LQGN, but still fast enough for most real-time
89/// use-cases.
90///
91/// This implementation uses [`rand::thread_rng`] internally as the entropy source, and uses the optimized
92/// `StandardNormal` distribution for sampling.
93pub struct HighQualityNoiseGenerator {
94    normal_distribution: StandardNormal,
95}
96
97impl Default for HighQualityNoiseGenerator {
98    fn default() -> Self {
99        Self {
100            normal_distribution: StandardNormal,
101        }
102    }
103}
104
105impl NoiseGenerator for HighQualityNoiseGenerator {
106    /// Generate `count` random values.
107    fn generate(&self, _count: usize, out: &mut [f32]) {
108        let mut rng = thread_rng();
109        for o in out {
110            *o = self.normal_distribution.sample(&mut rng);
111        }
112    }
113}
114
115struct EpsilonInjectorState<NG: NoiseGenerator> {
116    count: usize,
117    index: usize,
118    generator: NG,
119
120    inputs: Vec<(String, Vec<usize>)>,
121}
122/// The [`EpsilonInjector`] wraps an inferer to add noise values as one of the input data points. This is useful for
123/// continuous action policies where you might have trained your agent to follow a stochastic policy trained with the
124/// reparametrization trick.
125///
126/// Note that it's fully possible to pass an epsilon directly in your observation, and this is purely a convenience
127/// wrapper.
128pub struct EpsilonInjector<T: Inferer, NG: NoiseGenerator = HighQualityNoiseGenerator> {
129    inner: T,
130
131    state: EpsilonInjectorState<NG>,
132}
133
134impl<T> EpsilonInjector<T, HighQualityNoiseGenerator>
135where
136    T: Inferer,
137{
138    /// Wraps the provided `inferer` to automatically generate noise for the input named by `key`.
139    ///
140    /// This function will use [`HighQualityNoiseGenerator`] as the noise source.
141    ///
142    /// # Errors
143    ///
144    /// Will return an error if the provided key doesn't match an input on the model.
145    pub fn wrap(inferer: T, key: &str) -> Result<EpsilonInjector<T, HighQualityNoiseGenerator>> {
146        Self::with_generator(inferer, HighQualityNoiseGenerator::default(), key)
147    }
148}
149
150impl<T, NG> EpsilonInjector<T, NG>
151where
152    T: Inferer,
153    NG: NoiseGenerator,
154{
155    /// Create a new injector for the provided `key`, using the custom `generator` as the noise source.
156    ///
157    /// # Errors
158    ///
159    /// Will return an error if the provided key doesn't match an input on the model.
160    pub fn with_generator(inferer: T, generator: NG, key: &str) -> Result<Self> {
161        let inputs = inferer.input_shapes();
162
163        let (index, count) = match inputs.iter().enumerate().find(|(_, (k, _))| k == key) {
164            Some((index, (_, shape))) => (index, shape.iter().product()),
165            None => bail!("model has no input key {:?}", key),
166        };
167
168        let inputs = inputs
169            .iter()
170            .filter(|(k, _)| *k != key)
171            .map(|(k, v)| (k.to_owned(), v.to_owned()))
172            .collect::<Vec<_>>();
173
174        Ok(Self {
175            inner: inferer,
176            state: EpsilonInjectorState {
177                index,
178                count,
179                generator,
180                inputs,
181            },
182        })
183    }
184}
185
186impl<T, NG> Inferer for EpsilonInjector<T, NG>
187where
188    T: Inferer,
189    NG: NoiseGenerator,
190{
191    fn select_batch_size(&self, max_count: usize) -> usize {
192        self.inner.select_batch_size(max_count)
193    }
194
195    fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
196        let total_count = self.state.count * batch.len();
197        let output = batch.input_slot_mut(self.state.index);
198        self.state.generator.generate(total_count, output);
199
200        self.inner.infer_raw(batch)
201    }
202
203    fn input_shapes(&self) -> &[(String, Vec<usize>)] {
204        &self.state.inputs
205    }
206
207    fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
208        self.inner.raw_input_shapes()
209    }
210
211    fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
212        self.inner.raw_output_shapes()
213    }
214
215    fn begin_agent(&self, id: u64) {
216        self.inner.begin_agent(id);
217    }
218
219    fn end_agent(&self, id: u64) {
220        self.inner.end_agent(id);
221    }
222}
223
224pub struct EpsilonInjectorWrapper<Inner: InfererWrapper, NG: NoiseGenerator> {
225    inner: Inner,
226    state: EpsilonInjectorState<NG>,
227}
228
229impl<Inner: InfererWrapper> EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerator> {
230    /// Wraps the provided `inferer` to automatically generate noise for the input named by `key`.
231    ///
232    /// This function will use [`HighQualityNoiseGenerator`] as the noise source.
233    ///
234    /// # Errors
235    ///
236    /// Will return an error if the provided key doesn't match an input on the model.
237    pub fn wrap(
238        inner: Inner,
239        inferer: &dyn Inferer,
240        key: &str,
241    ) -> Result<EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerator>> {
242        Self::with_generator(inner, inferer, HighQualityNoiseGenerator::default(), key)
243    }
244}
245
246impl<Inner, NG> EpsilonInjectorWrapper<Inner, NG>
247where
248    Inner: InfererWrapper,
249    NG: NoiseGenerator,
250{
251    /// Create a new injector for the provided `key`, using the custom `generator` as the noise source.
252    ///
253    /// # Errors
254    ///
255    /// Will return an error if the provided key doesn't match an input on the model.
256    pub fn with_generator(
257        inner: Inner,
258        inferer: &dyn Inferer,
259        generator: NG,
260        key: &str,
261    ) -> Result<Self> {
262        let inputs = inner.input_shapes(inferer);
263
264        let (index, count) = match inputs.iter().enumerate().find(|(_, (k, _))| k == key) {
265            Some((index, (_, shape))) => (index, shape.iter().product()),
266            None => bail!("model has no input key {:?}", key),
267        };
268
269        let inputs = inputs
270            .iter()
271            .filter(|(k, _)| *k != key)
272            .map(|(k, v)| (k.to_owned(), v.to_owned()))
273            .collect::<Vec<_>>();
274
275        Ok(Self {
276            inner,
277            state: EpsilonInjectorState {
278                index,
279                count,
280                generator,
281                inputs,
282            },
283        })
284    }
285}
286
287impl<Inner, NG> InfererWrapper for EpsilonInjectorWrapper<Inner, NG>
288where
289    Inner: InfererWrapper,
290    NG: NoiseGenerator,
291{
292    fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
293        self.inner.invoke(inferer, batch)?;
294        let total_count = self.state.count * batch.len();
295        let output = batch.input_slot_mut(self.state.index);
296        self.state.generator.generate(total_count, output);
297
298        self.inner.invoke(inferer, batch)
299    }
300
301    fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
302        self.state.inputs.as_ref()
303    }
304
305    fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
306        self.inner.output_shapes(inferer)
307    }
308
309    fn begin_agent(&self, inferer: &dyn Inferer, id: u64) {
310        self.inner.begin_agent(inferer, id);
311    }
312
313    fn end_agent(&self, inferer: &dyn Inferer, id: u64) {
314        self.inner.end_agent(inferer, id);
315    }
316}