moose 0.2.2

Encrypted learning and data processing framework
Documentation
use super::*;
use crate::execution::RuntimeSession;

pub trait Convert<T> {
    type Scale: num_traits::One + Clone;
    fn encode(x: &T, scaling_factor: Self::Scale) -> Self;
    fn decode(x: &Self, scaling_factor: Self::Scale) -> T;
}

impl Convert<HostFloat64Tensor> for HostRing64Tensor {
    type Scale = u64;
    fn encode(x: &HostFloat64Tensor, scaling_factor: Self::Scale) -> HostRing64Tensor {
        let x_upshifted = &x.0 * (scaling_factor as f64);
        let x_converted = x_upshifted.mapv(|el| Wrapping((el as i64) as u64));
        HostRingTensor(x_converted.into_shared(), x.1.clone())
    }
    fn decode(x: &Self, scaling_factor: Self::Scale) -> HostFloat64Tensor {
        let x_upshifted: ArrayD<i64> = ArrayD::from(x);
        let x_converted = x_upshifted.mapv(|el| el as f64);
        let x_converted = x_converted / scaling_factor as f64;
        HostTensor(x_converted.into_shared(), x.1.clone())
    }
}

impl Convert<HostFloat64Tensor> for HostRing128Tensor {
    type Scale = u128;
    fn encode(x: &HostFloat64Tensor, scaling_factor: Self::Scale) -> HostRing128Tensor {
        let x_upshifted = &x.0 * (scaling_factor as f64);
        let x_converted = x_upshifted.mapv(|el| Wrapping((el as i128) as u128));
        HostRingTensor(x_converted.into_shared(), x.1.clone())
    }
    fn decode(x: &Self, scaling_factor: Self::Scale) -> HostFloat64Tensor {
        let x_upshifted: ArrayD<i128> = ArrayD::from(x);
        let x_converted = x_upshifted.mapv(|el| el as f64);
        let x_converted = x_converted / scaling_factor as f64;
        HostTensor(x_converted.into_shared(), x.1.clone())
    }
}

impl<T> HostRingTensor<T>
where
    Wrapping<T>: Clone + num_traits::Zero + std::ops::Mul<Wrapping<T>, Output = Wrapping<T>>,
    HostRingTensor<T>: Convert<HostFloat64Tensor>,
{
    fn mul(self, other: HostRingTensor<T>) -> HostRingTensor<T> {
        HostRingTensor(self.0 * other.0, self.1)
    }

    fn compute_mean_weight(&self, axis: &Option<usize>) -> Result<HostFloat64Tensor> {
        let shape: &[usize] = self.0.shape();
        if let Some(ax) = axis {
            let dim_len = shape[*ax] as f64;
            Ok(HostTensor(
                Array::from_elem([], 1.0 / dim_len).into_shared().into_dyn(),
                self.1.clone(),
            ))
        } else {
            let dim_prod: usize = std::iter::Product::product(shape.iter());
            let prod_inv = 1.0 / dim_prod as f64;
            Ok(HostTensor(
                Array::from_elem([], prod_inv).into_shared().into_dyn(),
                self.1.clone(),
            ))
        }
    }

    fn fixedpoint_mean(
        x: Self,
        axis: Option<usize>,
        scaling_factor: <HostRingTensor<T> as Convert<HostFloat64Tensor>>::Scale,
    ) -> Result<HostRingTensor<T>> {
        let mean_weight = x.compute_mean_weight(&axis)?;
        let encoded_weight = HostRingTensor::<T>::encode(&mean_weight, scaling_factor);
        let operand_sum = x.sum(axis)?;
        Ok(operand_sum.mul(encoded_weight))
    }
}

impl RingFixedpointMeanOp {
    pub(crate) fn ring64_kernel<S: RuntimeSession>(
        sess: &S,
        plc: &HostPlacement,
        axis: Option<u32>,
        scaling_base: u64,
        scaling_exp: u32,
        x: HostRing64Tensor,
    ) -> Result<HostRing64Tensor>
    where
        HostPlacement: PlacementPlace<S, HostRing64Tensor>,
    {
        let x = plc.place(sess, x);
        let scaling_factor = u64::pow(scaling_base, scaling_exp);
        let axis = axis.map(|a| a as usize);
        HostRing64Tensor::fixedpoint_mean(x, axis, scaling_factor)
    }

    pub(crate) fn ring128_kernel<S: RuntimeSession>(
        sess: &S,
        plc: &HostPlacement,
        axis: Option<u32>,
        scaling_base: u64,
        scaling_exp: u32,
        x: HostRing128Tensor,
    ) -> Result<HostRing128Tensor>
    where
        HostPlacement: PlacementPlace<S, HostRing128Tensor>,
    {
        let x = plc.place(sess, x);
        let scaling_factor = u128::pow(scaling_base as u128, scaling_exp);
        let axis = axis.map(|a| a as usize);
        HostRing128Tensor::fixedpoint_mean(x, axis, scaling_factor)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn fixedpoint_mean_with_axis() {
        let plc = HostPlacement::from("alice");

        let x: HostFloat64Tensor = plc.from_raw(array![[1., 2.], [3., 4.]]);
        let encoding_factor = 2u64.pow(16);
        let decoding_factor = 2u64.pow(32);
        let x_encoded = HostRing64Tensor::encode(&x, encoding_factor);
        let mean_encoded =
            HostRing64Tensor::fixedpoint_mean(x_encoded, Some(0), encoding_factor).unwrap();
        let mean = HostRing64Tensor::decode(&mean_encoded, decoding_factor);
        assert_eq!(mean, plc.from_raw(array![2., 3.]));
    }

    #[test]
    fn fixedpoint_mean_no_axis() {
        let plc = HostPlacement::from("alice");

        let x: HostFloat64Tensor = plc.from_raw(array![[1., 2.], [3., 4.]]);
        let encoding_factor = 2u64.pow(16);
        let decoding_factor = 2u64.pow(32);
        let x_encoded = HostRing64Tensor::encode(&x, encoding_factor);
        let mean_encoded =
            HostRing64Tensor::fixedpoint_mean(x_encoded, None, encoding_factor).unwrap();
        let mean = HostRing64Tensor::decode(&mean_encoded, decoding_factor);
        assert_eq!(mean, plc.from_raw(Array::from_elem([], 2.5)));
    }
}