1use std::cell::RefCell;
10
11use crate::{batcher::ScratchPadView, inferer::Inferer, prelude::InfererWrapper};
12use anyhow::{bail, Result};
13use perchance::PerchanceContext;
14use rand::thread_rng;
15use rand_distr::{Distribution, StandardNormal};
16
17pub trait NoiseGenerator {
20 fn generate(&self, count: usize, out: &mut [f32]);
21}
22
23pub struct ConstantGenerator {
25 value: f32,
26}
27
28impl ConstantGenerator {
29 pub fn for_value(value: f32) -> Self {
31 Self { value }
32 }
33
34 pub fn zeros() -> Self {
36 Self::for_value(1.0)
37 }
38
39 pub fn ones() -> Self {
41 Self::for_value(1.0)
42 }
43}
44
45impl NoiseGenerator for ConstantGenerator {
46 fn generate(&self, _count: usize, out: &mut [f32]) {
47 for o in out {
48 *o = self.value;
49 }
50 }
51}
52
53pub struct LowQualityNoiseGenerator {
58 ctx: RefCell<PerchanceContext>,
59}
60
61impl LowQualityNoiseGenerator {
62 pub fn new(seed: u128) -> Self {
64 Self {
65 ctx: RefCell::new(PerchanceContext::new(seed)),
66 }
67 }
68}
69
70impl Default for LowQualityNoiseGenerator {
71 fn default() -> Self {
72 Self {
73 ctx: RefCell::new(PerchanceContext::new(perchance::gen_time_seed())),
74 }
75 }
76}
77
78impl NoiseGenerator for LowQualityNoiseGenerator {
79 fn generate(&self, _count: usize, out: &mut [f32]) {
81 let mut ctx = self.ctx.borrow_mut();
82 for o in out {
83 *o = ctx.normal_f32();
84 }
85 }
86}
87
88pub struct HighQualityNoiseGenerator {
94 normal_distribution: StandardNormal,
95}
96
97impl Default for HighQualityNoiseGenerator {
98 fn default() -> Self {
99 Self {
100 normal_distribution: StandardNormal,
101 }
102 }
103}
104
105impl NoiseGenerator for HighQualityNoiseGenerator {
106 fn generate(&self, _count: usize, out: &mut [f32]) {
108 let mut rng = thread_rng();
109 for o in out {
110 *o = self.normal_distribution.sample(&mut rng);
111 }
112 }
113}
114
115struct EpsilonInjectorState<NG: NoiseGenerator> {
116 count: usize,
117 index: usize,
118 generator: NG,
119
120 inputs: Vec<(String, Vec<usize>)>,
121}
122pub struct EpsilonInjector<T: Inferer, NG: NoiseGenerator = HighQualityNoiseGenerator> {
129 inner: T,
130
131 state: EpsilonInjectorState<NG>,
132}
133
134impl<T> EpsilonInjector<T, HighQualityNoiseGenerator>
135where
136 T: Inferer,
137{
138 pub fn wrap(inferer: T, key: &str) -> Result<EpsilonInjector<T, HighQualityNoiseGenerator>> {
146 Self::with_generator(inferer, HighQualityNoiseGenerator::default(), key)
147 }
148}
149
150impl<T, NG> EpsilonInjector<T, NG>
151where
152 T: Inferer,
153 NG: NoiseGenerator,
154{
155 pub fn with_generator(inferer: T, generator: NG, key: &str) -> Result<Self> {
161 let inputs = inferer.input_shapes();
162
163 let (index, count) = match inputs.iter().enumerate().find(|(_, (k, _))| k == key) {
164 Some((index, (_, shape))) => (index, shape.iter().product()),
165 None => bail!("model has no input key {:?}", key),
166 };
167
168 let inputs = inputs
169 .iter()
170 .filter(|(k, _)| *k != key)
171 .map(|(k, v)| (k.to_owned(), v.to_owned()))
172 .collect::<Vec<_>>();
173
174 Ok(Self {
175 inner: inferer,
176 state: EpsilonInjectorState {
177 index,
178 count,
179 generator,
180 inputs,
181 },
182 })
183 }
184}
185
186impl<T, NG> Inferer for EpsilonInjector<T, NG>
187where
188 T: Inferer,
189 NG: NoiseGenerator,
190{
191 fn select_batch_size(&self, max_count: usize) -> usize {
192 self.inner.select_batch_size(max_count)
193 }
194
195 fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
196 let total_count = self.state.count * batch.len();
197 let output = batch.input_slot_mut(self.state.index);
198 self.state.generator.generate(total_count, output);
199
200 self.inner.infer_raw(batch)
201 }
202
203 fn input_shapes(&self) -> &[(String, Vec<usize>)] {
204 &self.state.inputs
205 }
206
207 fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
208 self.inner.raw_input_shapes()
209 }
210
211 fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
212 self.inner.raw_output_shapes()
213 }
214
215 fn begin_agent(&self, id: u64) {
216 self.inner.begin_agent(id);
217 }
218
219 fn end_agent(&self, id: u64) {
220 self.inner.end_agent(id);
221 }
222}
223
224pub struct EpsilonInjectorWrapper<Inner: InfererWrapper, NG: NoiseGenerator> {
225 inner: Inner,
226 state: EpsilonInjectorState<NG>,
227}
228
229impl<Inner: InfererWrapper> EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerator> {
230 pub fn wrap(
238 inner: Inner,
239 inferer: &dyn Inferer,
240 key: &str,
241 ) -> Result<EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerator>> {
242 Self::with_generator(inner, inferer, HighQualityNoiseGenerator::default(), key)
243 }
244}
245
246impl<Inner, NG> EpsilonInjectorWrapper<Inner, NG>
247where
248 Inner: InfererWrapper,
249 NG: NoiseGenerator,
250{
251 pub fn with_generator(
257 inner: Inner,
258 inferer: &dyn Inferer,
259 generator: NG,
260 key: &str,
261 ) -> Result<Self> {
262 let inputs = inner.input_shapes(inferer);
263
264 let (index, count) = match inputs.iter().enumerate().find(|(_, (k, _))| k == key) {
265 Some((index, (_, shape))) => (index, shape.iter().product()),
266 None => bail!("model has no input key {:?}", key),
267 };
268
269 let inputs = inputs
270 .iter()
271 .filter(|(k, _)| *k != key)
272 .map(|(k, v)| (k.to_owned(), v.to_owned()))
273 .collect::<Vec<_>>();
274
275 Ok(Self {
276 inner,
277 state: EpsilonInjectorState {
278 index,
279 count,
280 generator,
281 inputs,
282 },
283 })
284 }
285}
286
287impl<Inner, NG> InfererWrapper for EpsilonInjectorWrapper<Inner, NG>
288where
289 Inner: InfererWrapper,
290 NG: NoiseGenerator,
291{
292 fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
293 self.inner.invoke(inferer, batch)?;
294 let total_count = self.state.count * batch.len();
295 let output = batch.input_slot_mut(self.state.index);
296 self.state.generator.generate(total_count, output);
297
298 self.inner.invoke(inferer, batch)
299 }
300
301 fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
302 self.state.inputs.as_ref()
303 }
304
305 fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
306 self.inner.output_shapes(inferer)
307 }
308
309 fn begin_agent(&self, inferer: &dyn Inferer, id: u64) {
310 self.inner.begin_agent(inferer, id);
311 }
312
313 fn end_agent(&self, inferer: &dyn Inferer, id: u64) {
314 self.inner.end_agent(inferer, id);
315 }
316}