use std::cell::RefCell;
use crate::{batcher::ScratchPadView, inferer::Inferer, prelude::InfererWrapper};
use anyhow::{bail, Result};
use perchance::PerchanceContext;
use rand::thread_rng;
use rand_distr::{Distribution, StandardNormal};
pub trait NoiseGenerator {
fn generate(&self, count: usize, out: &mut [f32]);
}
pub struct ConstantGenerator {
value: f32,
}
impl ConstantGenerator {
pub fn for_value(value: f32) -> Self {
Self { value }
}
pub fn zeros() -> Self {
Self::for_value(1.0)
}
pub fn ones() -> Self {
Self::for_value(1.0)
}
}
impl NoiseGenerator for ConstantGenerator {
fn generate(&self, _count: usize, out: &mut [f32]) {
for o in out {
*o = self.value;
}
}
}
pub struct LowQualityNoiseGenerator {
ctx: RefCell<PerchanceContext>,
}
impl LowQualityNoiseGenerator {
pub fn new(seed: u128) -> Self {
Self {
ctx: RefCell::new(PerchanceContext::new(seed)),
}
}
}
impl Default for LowQualityNoiseGenerator {
fn default() -> Self {
Self {
ctx: RefCell::new(PerchanceContext::new(perchance::gen_time_seed())),
}
}
}
impl NoiseGenerator for LowQualityNoiseGenerator {
fn generate(&self, _count: usize, out: &mut [f32]) {
let mut ctx = self.ctx.borrow_mut();
for o in out {
*o = ctx.normal_f32();
}
}
}
pub struct HighQualityNoiseGenerator {
normal_distribution: StandardNormal,
}
impl Default for HighQualityNoiseGenerator {
fn default() -> Self {
Self {
normal_distribution: StandardNormal,
}
}
}
impl NoiseGenerator for HighQualityNoiseGenerator {
fn generate(&self, _count: usize, out: &mut [f32]) {
let mut rng = thread_rng();
for o in out {
*o = self.normal_distribution.sample(&mut rng);
}
}
}
struct EpsilonInjectorState<NG: NoiseGenerator> {
count: usize,
index: usize,
generator: NG,
inputs: Vec<(String, Vec<usize>)>,
}
pub struct EpsilonInjector<T: Inferer, NG: NoiseGenerator = HighQualityNoiseGenerator> {
inner: T,
state: EpsilonInjectorState<NG>,
}
impl<T> EpsilonInjector<T, HighQualityNoiseGenerator>
where
T: Inferer,
{
pub fn wrap(inferer: T, key: &str) -> Result<EpsilonInjector<T, HighQualityNoiseGenerator>> {
Self::with_generator(inferer, HighQualityNoiseGenerator::default(), key)
}
}
impl<T, NG> EpsilonInjector<T, NG>
where
T: Inferer,
NG: NoiseGenerator,
{
pub fn with_generator(inferer: T, generator: NG, key: &str) -> Result<Self> {
let inputs = inferer.input_shapes();
let (index, count) = match inputs.iter().enumerate().find(|(_, (k, _))| k == key) {
Some((index, (_, shape))) => (index, shape.iter().product()),
None => bail!("model has no input key {:?}", key),
};
let inputs = inputs
.iter()
.filter(|(k, _)| *k != key)
.map(|(k, v)| (k.to_owned(), v.to_owned()))
.collect::<Vec<_>>();
Ok(Self {
inner: inferer,
state: EpsilonInjectorState {
index,
count,
generator,
inputs,
},
})
}
}
impl<T, NG> Inferer for EpsilonInjector<T, NG>
where
T: Inferer,
NG: NoiseGenerator,
{
fn select_batch_size(&self, max_count: usize) -> usize {
self.inner.select_batch_size(max_count)
}
fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
let total_count = self.state.count * batch.len();
let output = batch.input_slot_mut(self.state.index);
self.state.generator.generate(total_count, output);
self.inner.infer_raw(batch)
}
fn input_shapes(&self) -> &[(String, Vec<usize>)] {
&self.state.inputs
}
fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
self.inner.raw_input_shapes()
}
fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
self.inner.raw_output_shapes()
}
fn begin_agent(&self, id: u64) {
self.inner.begin_agent(id);
}
fn end_agent(&self, id: u64) {
self.inner.end_agent(id);
}
}
pub struct EpsilonInjectorWrapper<Inner: InfererWrapper, NG: NoiseGenerator> {
inner: Inner,
state: EpsilonInjectorState<NG>,
}
impl<Inner: InfererWrapper> EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerator> {
pub fn wrap(
inner: Inner,
inferer: &dyn Inferer,
key: &str,
) -> Result<EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerator>> {
Self::with_generator(inner, inferer, HighQualityNoiseGenerator::default(), key)
}
}
impl<Inner, NG> EpsilonInjectorWrapper<Inner, NG>
where
Inner: InfererWrapper,
NG: NoiseGenerator,
{
pub fn with_generator(
inner: Inner,
inferer: &dyn Inferer,
generator: NG,
key: &str,
) -> Result<Self> {
let inputs = inner.input_shapes(inferer);
let (index, count) = match inputs.iter().enumerate().find(|(_, (k, _))| k == key) {
Some((index, (_, shape))) => (index, shape.iter().product()),
None => bail!("model has no input key {:?}", key),
};
let inputs = inputs
.iter()
.filter(|(k, _)| *k != key)
.map(|(k, v)| (k.to_owned(), v.to_owned()))
.collect::<Vec<_>>();
Ok(Self {
inner,
state: EpsilonInjectorState {
index,
count,
generator,
inputs,
},
})
}
}
impl<Inner, NG> InfererWrapper for EpsilonInjectorWrapper<Inner, NG>
where
Inner: InfererWrapper,
NG: NoiseGenerator,
{
fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
self.inner.invoke(inferer, batch)?;
let total_count = self.state.count * batch.len();
let output = batch.input_slot_mut(self.state.index);
self.state.generator.generate(total_count, output);
self.inner.invoke(inferer, batch)
}
fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
self.state.inputs.as_ref()
}
fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
self.inner.output_shapes(inferer)
}
fn begin_agent(&self, inferer: &dyn Inferer, id: u64) {
self.inner.begin_agent(inferer, id);
}
fn end_agent(&self, inferer: &dyn Inferer, id: u64) {
self.inner.end_agent(inferer, id);
}
}