use super::*;
use crate::fixedpoint::{FixedpointTensor, PolynomialEval};
use lazy_static::lazy_static;
use num::bigint::BigInt;
use num::rational::Ratio;
use num_traits::{FromPrimitive, One, ToPrimitive};
use std::convert::TryInto;
impl Pow2Op {
pub(crate) fn rep_rep_kernel<S: Session, RepRingT, RepBitT, RepBitArrayT, RepShapeT, N: Const>(
sess: &S,
rep: &ReplicatedPlacement,
x: RepFixedTensor<RepRingT>,
) -> Result<RepFixedTensor<RepRingT>>
where
RepRingT: Ring<BitLength = N>,
RepRingT: Clone,
ReplicatedPlacement: PlacementBitDecompose<S, RepRingT, RepBitArrayT>,
ReplicatedPlacement: PlacementIndex<S, RepBitArrayT, RepBitT>,
ReplicatedPlacement: PlacementRingInject<S, RepBitT, RepRingT>,
ReplicatedPlacement: PlacementMux<S, RepRingT, RepRingT, RepRingT, RepRingT>,
ReplicatedPlacement: PlacementNeg<S, RepRingT, RepRingT>,
ReplicatedPlacement: PlacementShape<S, RepRingT, RepShapeT>,
ReplicatedPlacement: PlacementFill<S, RepShapeT, RepRingT>,
ReplicatedPlacement: PlacementSub<S, RepRingT, RepRingT, RepRingT>,
ReplicatedPlacement: PlacementAdd<S, RepRingT, RepRingT, RepRingT>,
ReplicatedPlacement: PlacementShl<S, RepRingT, RepRingT>,
ReplicatedPlacement: Pow2FromBits<S, RepRingT>,
ReplicatedPlacement: ExpFromParts<S, RepRingT, RepRingT>,
ReplicatedPlacement: PlacementDiv<
S,
RepFixedTensor<RepRingT>,
RepFixedTensor<RepRingT>,
RepFixedTensor<RepRingT>,
>,
{
let integral_precision = x.integral_precision as usize;
let fractional_precision = x.fractional_precision as usize;
let bit_len_prec = fractional_precision + integral_precision;
let x_bits = rep.bit_decompose(sess, &x.tensor);
let msb_bit = rep.index(sess, RepRingT::BitLength::VALUE - 1, &x_bits);
let msb = rep.ring_inject(sess, 0, &msb_bit);
let abs_x = rep.mux(sess, &msb, &rep.neg(sess, &x.tensor), &x.tensor);
let absolute_bits = rep.bit_decompose(sess, &abs_x);
let x_bits_vec: Vec<_> = (0..RepRingT::BitLength::VALUE)
.map(|i| rep.index(sess, i, &absolute_bits))
.collect();
let higher_bits: Vec<_> = (fractional_precision..RepRingT::BitLength::VALUE)
.map(|i| rep.ring_inject(sess, 0, &x_bits_vec[i]))
.collect();
let x_shape = rep.shape(sess, &x.tensor);
let zero = rep.fill(sess, 0_u8.into(), &x_shape);
let higher_composed = higher_bits
.clone()
.into_iter()
.enumerate()
.fold(zero, |acc, (i, item)| {
rep.add(sess, &acc, &rep.shl(sess, fractional_precision + i, &item))
});
let fractional_part = rep.sub(sess, &abs_x, &higher_composed);
let d = rep.pow2_from_bits(sess, &higher_bits[0..integral_precision]);
let g = rep.exp_from_parts(
sess,
&d,
&fractional_part,
fractional_precision as u32,
bit_len_prec as u32,
);
let one = rep.fill(sess, 1_u8.into(), &x_shape);
let one_fixed = RepFixedTensor {
tensor: rep.shl(sess, x.fractional_precision as usize, &one),
integral_precision: x.integral_precision,
fractional_precision: x.fractional_precision,
};
let g_fixed = RepFixedTensor {
tensor: g.clone(),
integral_precision: x.integral_precision,
fractional_precision: x.fractional_precision,
};
let inverse = rep.div(sess, &one_fixed, &g_fixed);
let switch = rep.mux(sess, &msb, &inverse.tensor, &g);
Ok(RepFixedTensor {
tensor: switch,
integral_precision: x.integral_precision,
fractional_precision: x.fractional_precision,
})
}
}
pub(crate) trait Pow2FromBits<S: Session, RepRingT> {
fn pow2_from_bits(&self, sess: &S, x: &[RepRingT]) -> RepRingT;
}
impl<S: Session, RepRingT> Pow2FromBits<S, RepRingT> for ReplicatedPlacement
where
ReplicatedShape: KnownType<S>,
ReplicatedPlacement: PlacementMul<S, RepRingT, RepRingT, RepRingT>,
ReplicatedPlacement: PlacementShl<S, RepRingT, RepRingT>,
ReplicatedPlacement: PlacementFill<S, m!(ReplicatedShape), RepRingT>,
ReplicatedPlacement: PlacementShape<S, RepRingT, m!(ReplicatedShape)>,
ReplicatedPlacement: PlacementSub<S, RepRingT, RepRingT, RepRingT>,
ReplicatedPlacement: PlacementAdd<S, RepRingT, RepRingT, RepRingT>,
{
fn pow2_from_bits(&self, sess: &S, x: &[RepRingT]) -> RepRingT {
let rep = self;
let ones = rep.fill(sess, 1_u8.into(), &rep.shape(sess, &x[0]));
let selectors: Vec<_> = x
.iter()
.enumerate()
.map(|(i, bit)| {
let pos = rep.shl(sess, 1 << i, bit);
let neg = rep.sub(sess, &ones, bit);
rep.add(sess, &pos, &neg)
})
.collect();
selectors.iter().fold(ones, |acc, y| rep.mul(sess, &acc, y))
}
}
lazy_static! {
static ref P_1045: Vec<f64> = (0..100).map(|coefficient_index| {
let ln2i: Ratio<BigInt> = Ratio::from_float(2_f64.ln().powf(coefficient_index as f64)).unwrap();
let fact = (1..coefficient_index + 1).fold(One::one(), |acc: BigInt, i| {
let i_th: BigInt = FromPrimitive::from_usize(i).unwrap();
acc * i_th
});
let coefficient_value: Ratio<BigInt> = ln2i / fact;
coefficient_value.to_f64().unwrap()
}).collect();
}
pub(crate) trait ExpFromParts<S: Session, T, O> {
fn exp_from_parts(&self, sess: &S, e_int: &T, e_frac: &T, f: u32, k: u32) -> O;
}
impl<S: Session, RepRingT> ExpFromParts<S, RepRingT, RepRingT> for ReplicatedPlacement
where
RepRingT: Clone,
RepFixedTensor<RepRingT>: CanonicalType,
<RepFixedTensor<RepRingT> as CanonicalType>::Type: KnownType<S>,
m!(c!(RepFixedTensor<RepRingT>)): TryInto<RepFixedTensor<RepRingT>>,
m!(c!(RepFixedTensor<RepRingT>)): From<RepFixedTensor<RepRingT>>,
ReplicatedPlacement: PolynomialEval<S, m!(c!(RepFixedTensor<RepRingT>))>,
ReplicatedPlacement: PlacementShl<S, RepRingT, RepRingT>,
ReplicatedPlacement: PlacementTruncPr<S, RepRingT, RepRingT>,
ReplicatedPlacement: PlacementMul<S, RepRingT, RepRingT, RepRingT>,
{
fn exp_from_parts(
&self,
sess: &S,
e_int: &RepRingT,
e_frac: &RepRingT,
f: u32,
k: u32,
) -> RepRingT {
let amount = k - 2 - f;
let x = RepFixedTensor {
tensor: self.shl(sess, amount as usize, e_frac),
integral_precision: 2,
fractional_precision: k - 2,
};
let e_approx = self
.polynomial_eval(sess, P_1045.to_vec(), x.into())
.try_into()
.ok()
.unwrap();
let e_prod = self.mul(sess, e_int, &e_approx.tensor);
self.trunc_pr(sess, amount, &e_prod)
}
}
impl ExpOp {
pub(crate) fn rep_rep_kernel<S: Session, RepFixedT, MirFixedT>(
sess: &S,
rep: &ReplicatedPlacement,
x: RepFixedT,
) -> Result<RepFixedT>
where
RepFixedT: FixedpointTensor,
ReplicatedPlacement: ShapeFill<S, RepFixedT, Result = MirFixedT>,
ReplicatedPlacement: PlacementMul<S, MirFixedT, RepFixedT, RepFixedT>,
ReplicatedPlacement: PlacementTruncPr<S, RepFixedT, RepFixedT>,
ReplicatedPlacement: PlacementPow2<S, RepFixedT, RepFixedT>,
{
let log2e = rep.shape_fill(
sess,
1.0_f64
.exp()
.log2()
.as_fixedpoint(x.fractional_precision() as usize),
&x,
);
let shifted_exponent = rep.mul(sess, &log2e, &x);
let exponent = rep.trunc_pr(sess, x.fractional_precision(), &shifted_exponent);
Ok(rep.pow2(sess, &exponent))
}
}
#[cfg(feature = "sync_execute")]
#[cfg(test)]
mod tests {
use super::*;
use crate::prelude::*;
use ndarray::prelude::*;
#[test]
fn test_pow2_from_bits() {
let alice = HostPlacement::from("alice");
let rep = ReplicatedPlacement::from(["alice", "bob", "carole"]);
let x: HostRing64Tensor = alice.from_raw(array![[0u64], [1], [1], [1]]);
let target: HostRing64Tensor = alice.from_raw(array![16384u64]);
let sess = SyncSession::default();
let x_shared = rep.share(&sess, &x);
let x0 = rep.index_axis(&sess, 0, 0, &x_shared);
let x1 = rep.index_axis(&sess, 0, 1, &x_shared);
let x2 = rep.index_axis(&sess, 0, 2, &x_shared);
let x3 = rep.index_axis(&sess, 0, 3, &x_shared);
let x_vec = vec![x0, x1, x2, x3];
let pow2_shared = rep.pow2_from_bits(&sess, &x_vec);
assert_eq!(target, alice.reveal(&sess, &pow2_shared));
}
}