use crate::{
core::{
actually_used_field::ActuallyUsedField,
bounds::FieldBounds,
circuits::{
arithmetic::{fast_divide::FastDivide, float_exp::Exp},
boolean::utils::{neg_abs, sign_bit},
traits::arithmetic_circuit::ArithmeticCircuit,
},
expressions::expr::EvalFailure,
global_value::value::FieldValue,
},
traits::{GreaterEqual, Select},
types::DOUBLE_PRECISION_MANTISSA,
utils::{number::Number, used_field::UsedField},
};
const LIMIT: usize = 162326183972299341;
#[derive(Clone, Debug)]
pub struct Sigmoid {
precision_in: usize,
precision_out: usize,
}
impl Sigmoid {
#[allow(unused)]
pub const fn new(precision_in: usize) -> Self {
if precision_in > DOUBLE_PRECISION_MANTISSA {
panic!("precision_in must be at most 52");
}
Sigmoid {
precision_in,
precision_out: DOUBLE_PRECISION_MANTISSA,
}
}
pub fn sigmoid<F: ActuallyUsedField>(&self, x: FieldValue<F>) -> FieldValue<F> {
let bounds = x.bounds();
let (min, max) = bounds.min_and_max(true);
let limit = F::from((LIMIT >> (self.precision_out - self.precision_in)) as u64);
if min.is_ge_zero() && (min - limit).is_ge_zero() {
FieldValue::from(F::power_of_two(self.precision_out))
} else if max.is_lt_zero() && (max + limit).is_le_zero() {
FieldValue::from(F::ZERO)
} else {
let sign = sign_bit(x);
let neg_abs_x = neg_abs(x);
let (min_abs, max_abs) = neg_abs_x.bounds().min_and_max(true);
let neg_abs_x = if (min_abs + limit).is_ge_zero() {
neg_abs_x
} else {
(neg_abs_x.lt(-limit))
.select(FieldValue::from(-limit), neg_abs_x)
.with_bounds((-limit, (-limit).max(max_abs, true)))
};
let exp_neg_abs_x = Exp::new(self.precision_in).exp(neg_abs_x);
let one = FieldValue::<F>::from(Number::power_of_two(self.precision_out));
let sigmo_abs_x = FastDivide::new(0, 0).inv_approx(
(one + exp_neg_abs_x) >> 1,
self.precision_out,
self.precision_out,
);
(FieldValue::<F>::from(sign) * one + sigmo_abs_x
- sign.select(2 * sigmo_abs_x, FieldValue::<F>::from(0)))
.with_bounds(Self::sigmoid_bounds(self))
}
}
fn sigmoid_public<F: UsedField>(&self, x: F) -> F {
let x_float = f64::from(x.to_signed_number()) * 2f64.powi(-(self.precision_in as i32));
let res_float = if x_float >= 0.0 {
1.0 / (1.0 + (-x_float).exp())
} else {
let exp_x = x_float.exp();
exp_x / (exp_x + 1.0)
};
F::from(res_float)
}
fn sigmoid_bounds<F: UsedField>(&self) -> FieldBounds<F> {
FieldBounds::new(F::ZERO, F::power_of_two(self.precision_out))
}
}
impl<F: UsedField> ArithmeticCircuit<F> for Sigmoid {
fn eval(&self, x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
if x.len() != 1 {
panic!("Sigmoid requires one input")
}
Ok(vec![Self::sigmoid_public(self, x[0])])
}
fn eval_gap(&self, _x: &[F]) -> F {
F::from(5)
}
fn bounds(&self, _bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
vec![self.sigmoid_bounds()]
}
fn run(&self, vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
where
F: ActuallyUsedField,
{
if vals.len() != 1 {
panic!("Sigmoid requires one input")
}
vec![Self::sigmoid(self, vals[0])]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
core::circuits::traits::arithmetic_circuit::tests::TestedArithmeticCircuit,
utils::field::ScalarField,
};
use rand::Rng;
impl TestedArithmeticCircuit<ScalarField> for Sigmoid {
fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self {
let mut precision = 52;
while rng.gen_bool(0.5) {
precision -= 1;
}
Self::new(precision as usize)
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
1
}
fn input_precisions(&self) -> Vec<usize> {
vec![self.precision_in]
}
}
#[test]
fn tested_sigmoid() {
Sigmoid::test(1, 16)
}
}