moose 0.2.2

Encrypted learning and data processing framework
Documentation
//! Support for comparison operators

use super::*;
use crate::computation::EqualOp;
use crate::error::Result;
use crate::execution::Session;
use crate::Const;

pub(crate) trait TreeReduceMul<S: Session, T, O> {
    fn reduce_mul(&self, sess: &S, x: &[T]) -> O;
}

impl<S: Session, T: Clone> TreeReduceMul<S, T, T> for ReplicatedPlacement
where
    ReplicatedPlacement: PlacementMul<S, T, T, T>,
{
    fn reduce_mul(&self, sess: &S, x: &[T]) -> T {
        let elementwise_mul =
            |rep: &ReplicatedPlacement, sess: &S, x: &T, y: &T| -> T { rep.mul(sess, x, y) };
        self.tree_reduce(sess, x, elementwise_mul)
    }
}

impl EqualOp {
    pub(crate) fn rep_kernel<S: Session, RepRingT, RepBitT, RepBitArrayT>(
        sess: &S,
        rep: &ReplicatedPlacement,
        x: RepRingT,
        y: RepRingT,
    ) -> Result<RepBitT>
    where
        ReplicatedPlacement: PlacementBitDecompose<S, RepRingT, RepBitArrayT>,
        ReplicatedPlacement: PlacementSub<S, RepRingT, RepRingT, RepRingT>,
        ReplicatedPlacement: PlacementEqualZero<S, RepBitArrayT, RepBitT>,
    {
        let z = rep.sub(sess, &x, &y);
        let bits = rep.bit_decompose(sess, &z);
        Ok(rep.equal_zero(sess, &bits))
    }

    pub(crate) fn rep_ring_kernel<S: Session, RepRingT, RepBitT>(
        sess: &S,
        rep: &ReplicatedPlacement,
        x: RepRingT,
        y: RepRingT,
    ) -> Result<RepRingT>
    where
        ReplicatedPlacement: PlacementEqual<S, RepRingT, RepRingT, RepBitT>,
        ReplicatedPlacement: PlacementRingInject<S, RepBitT, RepRingT>,
    {
        let b = rep.equal(sess, &x, &y);
        Ok(rep.ring_inject(sess, 0, &b))
    }
}

impl EqualZeroOp {
    pub(crate) fn bitdec_bit_kernel<S: Session, RepBitArrayT, RepBitT, MirBitT, N: Const>(
        sess: &S,
        rep: &ReplicatedPlacement,
        x: RepBitArrayT,
    ) -> Result<RepBitT>
    where
        RepBitArrayT: BitArray<Len = N>,
        ReplicatedPlacement: PlacementIndex<S, RepBitArrayT, RepBitT>,
        ReplicatedPlacement: ShapeFill<S, RepBitT, Result = MirBitT>,
        ReplicatedPlacement: PlacementXor<S, MirBitT, RepBitT, RepBitT>,
        ReplicatedPlacement: TreeReduceMul<S, RepBitT, RepBitT>,
    {
        let vx: Vec<_> = (0..N::VALUE).map(|i| rep.index(sess, i, &x)).collect();

        let ones = rep.shape_fill(sess, 1u8, &vx[0]);
        let v_not: Vec<_> = vx.iter().map(|vi| rep.xor(sess, &ones, vi)).collect();

        Ok(rep.reduce_mul(sess, &v_not))
    }

    pub(crate) fn bitdec_ring_kernel<S: Session, RepBitArrayT, RepRingT, RepBitT>(
        sess: &S,
        rep: &ReplicatedPlacement,
        x: RepBitArrayT,
    ) -> Result<RepRingT>
    where
        ReplicatedPlacement: PlacementEqualZero<S, RepBitArrayT, RepBitT>,
        ReplicatedPlacement: PlacementRingInject<S, RepBitT, RepRingT>,
    {
        let r_bit = rep.equal_zero(sess, &x);
        Ok(rep.ring_inject(sess, 0, &r_bit))
    }
}

impl LessOp {
    pub(crate) fn rep_kernel<S: Session, RepRingT, RepBitT>(
        sess: &S,
        rep: &ReplicatedPlacement,
        x: RepRingT,
        y: RepRingT,
    ) -> Result<RepBitT>
    where
        ReplicatedPlacement: PlacementSub<S, RepRingT, RepRingT, RepRingT>,
        ReplicatedPlacement: PlacementMsb<S, RepRingT, RepBitT>,
    {
        let z = rep.sub(sess, &x, &y);
        Ok(rep.msb(sess, &z))
    }

    pub(crate) fn rep_mir_kernel<S: Session, RepRingT, MirRingT, RepBitT>(
        sess: &S,
        rep: &ReplicatedPlacement,
        x: RepRingT,
        y: MirRingT,
    ) -> Result<RepBitT>
    where
        ReplicatedPlacement: PlacementSub<S, RepRingT, MirRingT, RepRingT>,
        ReplicatedPlacement: PlacementMsb<S, RepRingT, RepBitT>,
    {
        let z = rep.sub(sess, &x, &y);
        Ok(rep.msb(sess, &z))
    }

    pub(crate) fn mir_rep_kernel<S: Session, RepRingT, MirRingT, RepBitT>(
        sess: &S,
        rep: &ReplicatedPlacement,
        x: MirRingT,
        y: RepRingT,
    ) -> Result<RepBitT>
    where
        ReplicatedPlacement: PlacementSub<S, MirRingT, RepRingT, RepRingT>,
        ReplicatedPlacement: PlacementMsb<S, RepRingT, RepBitT>,
    {
        let z = rep.sub(sess, &x, &y);
        Ok(rep.msb(sess, &z))
    }
}

impl GreaterOp {
    pub(crate) fn rep_kernel<S: Session, RepRingT, RepBitT>(
        sess: &S,
        rep: &ReplicatedPlacement,
        x: RepRingT,
        y: RepRingT,
    ) -> Result<RepBitT>
    where
        ReplicatedPlacement: PlacementSub<S, RepRingT, RepRingT, RepRingT>,
        ReplicatedPlacement: PlacementMsb<S, RepRingT, RepBitT>,
    {
        let z = rep.sub(sess, &y, &x);
        Ok(rep.msb(sess, &z))
    }

    pub(crate) fn rep_mir_kernel<S: Session, RepRingT, MirRingT, RepBitT>(
        sess: &S,
        rep: &ReplicatedPlacement,
        x: RepRingT,
        y: MirRingT,
    ) -> Result<RepBitT>
    where
        ReplicatedPlacement: PlacementSub<S, MirRingT, RepRingT, RepRingT>,
        ReplicatedPlacement: PlacementMsb<S, RepRingT, RepBitT>,
    {
        let z = rep.sub(sess, &y, &x);
        Ok(rep.msb(sess, &z))
    }

    pub(crate) fn mir_rep_kernel<S: Session, RepRingT, MirRingT, RepBitT>(
        sess: &S,
        rep: &ReplicatedPlacement,
        x: MirRingT,
        y: RepRingT,
    ) -> Result<RepBitT>
    where
        ReplicatedPlacement: PlacementSub<S, RepRingT, MirRingT, RepRingT>,
        ReplicatedPlacement: PlacementMsb<S, RepRingT, RepBitT>,
    {
        let z = rep.sub(sess, &y, &x);
        Ok(rep.msb(sess, &z))
    }
}

#[cfg(feature = "sync_execute")]
#[cfg(test)]
mod tests {
    use crate::prelude::*;
    use ndarray::prelude::*;

    #[test]
    fn test_equal() {
        let alice = HostPlacement::from("alice");
        let bob = HostPlacement::from("bob");
        let rep = ReplicatedPlacement::from(["alice", "bob", "carole"]);

        let sess = SyncSession::default();

        let x: HostRing64Tensor = alice.from_raw(array![1024u64, 5, 4]);
        let y: HostRing64Tensor = bob.from_raw(array![1024u64, 4, 5]);

        let x_shared = rep.share(&sess, &x);
        let y_shared = rep.share(&sess, &y);

        let res: ReplicatedBitTensor = rep.equal(&sess, &x_shared, &y_shared);

        let opened_result = alice.reveal(&sess, &res);
        assert_eq!(opened_result, alice.from_raw(array![1, 0, 0]));
    }
}