use crate::computation::*;
use crate::error::Result;
#[cfg(feature = "compile")]
use crate::execution::symbolic::Symbolic;
use crate::execution::Session;
use crate::kernels::*;
use crate::types::*;
use crate::{BitArray, Const, Ring, TensorLike, N128, N64};
use ndarray::prelude::*;
use ndarray::LinalgScalar;
use ndarray::Slice;
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use std::cmp::Reverse;
use std::convert::TryFrom;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::num::Wrapping;
mod bitarray;
mod fixedpoint;
mod ops;
mod prim;
pub use bitarray::*;
pub use fixedpoint::Convert;
pub use prim::*;
pub type ArcArrayD<A> = ArcArray<A, IxDyn>;
#[derive(Serialize, Deserialize, Hash, PartialEq, Eq, Clone, Debug)]
pub struct HostPlacement {
pub owner: Role,
}
impl From<&str> for HostPlacement {
fn from(role: &str) -> Self {
HostPlacement {
owner: Role::from(role),
}
}
}
impl<S: Session> PlacementPlace<S, HostUnit> for HostPlacement {
fn place(&self, _sess: &S, x: HostUnit) -> HostUnit {
match x.placement() {
Ok(place) if &place == self => x,
_ => HostUnit(self.clone()),
}
}
}
#[derive(Serialize, Deserialize, PartialEq, Clone, Debug)]
pub struct HostString(pub String, pub HostPlacement);
impl Placed for HostString {
type Placement = HostPlacement;
fn placement(&self) -> Result<Self::Placement> {
Ok(self.1.clone())
}
}
impl<S: Session> PlacementPlace<S, HostString> for HostPlacement {
fn place(&self, _sess: &S, string: HostString) -> HostString {
match string.placement() {
Ok(place) if self == &place => string,
_ => {
HostString(string.0, self.clone())
}
}
}
}
#[derive(Serialize, Deserialize, PartialEq, Hash, Clone, Debug)]
pub struct RawShape(pub Vec<usize>);
impl RawShape {
fn extend_singletons(self, mut axis: Vec<usize>) -> Self {
let ax = axis.pop();
match ax {
Some(ax) => {
let (left, right) = self.0.split_at(ax);
RawShape::extend_singletons(RawShape([left, right].join(&1usize)), axis)
}
None => self,
}
}
fn slice(self, begin: usize, end: usize) -> Self {
let slc = &self.0[begin..end];
RawShape(slc.to_vec())
}
fn squeeze(mut self, axis: Option<usize>) -> Self {
match axis {
Some(axis) => {
let removed_axis = self.0.remove(axis);
match removed_axis {
1 => self,
_ => panic!(
"The axis selected has a value of {:?}. Cannot select an axis to squeeze out
which has size not equal to one", removed_axis
),
}
}
None => RawShape(self.0.into_iter().filter(|x| *x != 1).collect::<Vec<_>>()),
}
}
}
#[derive(Serialize, Deserialize, PartialEq, Clone, Debug)]
pub struct HostShape(pub RawShape, pub HostPlacement);
impl Placed for HostShape {
type Placement = HostPlacement;
fn placement(&self) -> Result<Self::Placement> {
Ok(self.1.clone())
}
}
impl<S: Session> PlacementPlace<S, HostShape> for HostPlacement {
fn place(&self, _sess: &S, shape: HostShape) -> HostShape {
match shape.placement() {
Ok(place) if self == &place => shape,
_ => {
HostShape(shape.0, self.clone())
}
}
}
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Debug)]
pub struct SliceInfoElem {
pub start: isize,
pub end: Option<isize>,
pub step: Option<isize>,
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Debug)]
pub struct SliceInfo(pub Vec<SliceInfoElem>);
impl From<SliceInfo> for ndarray::SliceInfo<Vec<ndarray::SliceInfoElem>, IxDyn, IxDyn> {
fn from(s: SliceInfo) -> ndarray::SliceInfo<Vec<ndarray::SliceInfoElem>, IxDyn, IxDyn> {
let all_slices: Vec<ndarray::SliceInfoElem> = s
.0
.iter()
.map(|x| ndarray::SliceInfoElem::from(Slice::new(x.start, x.end, x.step.unwrap_or(1))))
.collect();
ndarray::SliceInfo::<Vec<ndarray::SliceInfoElem>, IxDyn, IxDyn>::try_from(all_slices)
.unwrap()
}
}
#[derive(Serialize, Deserialize, Hash, Clone, Debug, PartialEq)]
pub struct HostTensor<T>(pub ArcArrayD<T>, pub HostPlacement);
impl<T> Placed for HostTensor<T> {
type Placement = HostPlacement;
fn placement(&self) -> Result<Self::Placement> {
Ok(self.1.clone())
}
}
impl<S: Session, T> PlacementPlace<S, HostTensor<T>> for HostPlacement {
fn place(&self, _sess: &S, x: HostTensor<T>) -> HostTensor<T> {
match x.placement() {
Ok(place) if &place == self => x,
_ => HostTensor(x.0, self.clone()),
}
}
}
impl<T> HostTensor<T>
where
T: LinalgScalar,
{
pub(crate) fn place(plc: &HostPlacement, x: ArcArrayD<T>) -> HostTensor<T> {
HostTensor::<T>(x, plc.clone())
}
pub(crate) fn reshape(self, newshape: HostShape) -> Self {
HostTensor::<T>(self.0.into_shape(newshape.0 .0).unwrap(), self.1) }
pub(crate) fn shape(&self) -> HostShape {
HostShape(RawShape(self.0.shape().into()), self.1.clone())
}
}
#[derive(Serialize, Deserialize, Hash, Clone, PartialEq)]
pub struct HostBitTensor(pub BitArrayRepr, pub HostPlacement);
impl std::fmt::Debug for HostBitTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.data.as_bitslice().fmt(f)
}
}
impl TensorLike for HostBitTensor {
type Scalar = u8;
}
impl Placed for HostBitTensor {
type Placement = HostPlacement;
fn placement(&self) -> Result<Self::Placement> {
Ok(self.1.clone())
}
}
impl<S: Session> PlacementPlace<S, HostBitTensor> for HostPlacement {
fn place(&self, _sess: &S, x: HostBitTensor) -> HostBitTensor {
match x.placement() {
Ok(place) if &place == self => x,
_ => {
HostBitTensor(x.0, self.clone())
}
}
}
}
impl HostBitTensor {
pub(crate) fn place(plc: &HostPlacement, x: BitArrayRepr) -> HostBitTensor {
HostBitTensor(x, plc.clone())
}
fn reshape(self, newshape: HostShape) -> Self {
let arr = BitArrayRepr {
data: self.0.data.clone(),
dim: std::sync::Arc::new(IxDyn(&newshape.0 .0)),
};
HostBitTensor(arr, self.1)
}
fn expand_dims(self, mut axis: Vec<usize>) -> Self {
let plc = self.1.clone();
axis.sort_by_key(|ax| Reverse(*ax));
let newshape = self.shape().0.extend_singletons(axis);
self.reshape(HostShape(newshape, plc))
}
fn shape(&self) -> HostShape {
HostShape(RawShape(self.0.shape().into()), self.1.clone())
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct HostBitArray<HostBitTensorT, N>(pub HostBitTensorT, pub PhantomData<N>);
impl<HostBitT: CanonicalType, N> CanonicalType for HostBitArray<HostBitT, N> {
type Type = HostBitArray<<HostBitT as CanonicalType>::Type, N>;
}
impl<HostBitT, N: Const> BitArray for HostBitArray<HostBitT, N> {
type Len = N;
}
#[cfg(feature = "compile")]
impl<HostBitT: Placed, N: Const> BitArray for Symbolic<HostBitArray<HostBitT, N>> {
type Len = N;
}
impl<HostBitTensorT: Placed, N> Placed for HostBitArray<HostBitTensorT, N> {
type Placement = HostBitTensorT::Placement;
fn placement(&self) -> Result<Self::Placement> {
self.0.placement()
}
}
#[cfg(feature = "compile")]
impl<HostBitTensorT, N: Const> PartiallySymbolicType for HostBitArray<HostBitTensorT, N>
where
HostBitTensorT: SymbolicType,
{
type Type = HostBitArray<<HostBitTensorT as SymbolicType>::Type, N>;
}
#[cfg(feature = "compile")]
impl<HostBitT, N> From<HostBitArray<HostBitT, N>> for Symbolic<HostBitArray<HostBitT, N>>
where
HostBitT: Placed<Placement = HostPlacement>,
{
fn from(x: HostBitArray<HostBitT, N>) -> Self {
Symbolic::Concrete(x)
}
}
#[cfg(feature = "compile")]
impl<HostBitT, N> TryFrom<Symbolic<HostBitArray<HostBitT, N>>> for HostBitArray<HostBitT, N>
where
HostBitT: Placed<Placement = HostPlacement>,
{
type Error = crate::error::Error;
fn try_from(v: Symbolic<HostBitArray<HostBitT, N>>) -> crate::error::Result<Self> {
match v {
Symbolic::Concrete(x) => Ok(x),
_ => Err(crate::error::Error::Unexpected(None)), }
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct AbstractHostAesKey<HostBitArrayT>(pub(crate) HostBitArrayT);
impl<HostBitArrayT: Placed<Placement = HostPlacement>> Placed
for AbstractHostAesKey<HostBitArrayT>
{
type Placement = HostBitArrayT::Placement;
fn placement(&self) -> Result<Self::Placement> {
self.0.placement()
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct HostFixedAesTensor<HostBitArrayT> {
pub tensor: HostBitArrayT,
pub integral_precision: u32,
pub fractional_precision: u32,
}
impl<HostBitArrayT: Placed> Placed for HostFixedAesTensor<HostBitArrayT>
where
<HostBitArrayT as Placed>::Placement: Into<Placement>,
{
type Placement = Placement;
fn placement(&self) -> Result<Self::Placement> {
Ok(self.tensor.placement()?.into())
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct HostFixedTensor<HostRingT> {
pub tensor: HostRingT,
pub fractional_precision: u32,
pub integral_precision: u32,
}
impl<RingT: Placed> Placed for HostFixedTensor<RingT> {
type Placement = RingT::Placement;
fn placement(&self) -> Result<Self::Placement> {
self.tensor.placement()
}
}
impl<S: Session, RingT> PlacementPlace<S, HostFixedTensor<RingT>> for HostPlacement
where
HostFixedTensor<RingT>: Placed<Placement = HostPlacement>,
HostPlacement: PlacementPlace<S, RingT>,
{
fn place(&self, sess: &S, x: HostFixedTensor<RingT>) -> HostFixedTensor<RingT> {
match x.placement() {
Ok(place) if self == &place => x,
_ => {
HostFixedTensor {
tensor: self.place(sess, x.tensor),
integral_precision: x.integral_precision,
fractional_precision: x.fractional_precision,
}
}
}
}
}
#[derive(Serialize, Deserialize, Hash, Clone, Debug, PartialEq)]
pub struct HostRingTensor<T>(pub ArcArrayD<Wrapping<T>>, pub HostPlacement);
impl Ring for HostRing64Tensor {
type BitLength = N64;
}
impl Ring for HostRing128Tensor {
type BitLength = N128;
}
impl<T> Placed for HostRingTensor<T> {
type Placement = HostPlacement;
fn placement(&self) -> Result<Self::Placement> {
Ok(self.1.clone())
}
}
impl<T> TensorLike for HostRingTensor<T> {
type Scalar = T;
}
impl<S: Session, T> PlacementPlace<S, HostRingTensor<T>> for HostPlacement
where
HostRingTensor<T>: Placed<Placement = HostPlacement>,
{
fn place(&self, _sess: &S, x: HostRingTensor<T>) -> HostRingTensor<T> {
match x.placement() {
Ok(place) if &place == self => x,
_ => {
HostRingTensor(x.0, self.clone())
}
}
}
}
impl<T> HostRingTensor<T> {
pub(crate) fn place(plc: &HostPlacement, x: ArcArrayD<Wrapping<T>>) -> HostRingTensor<T> {
HostRingTensor::<T>(x, plc.clone())
}
fn reshape(self, newshape: HostShape) -> Self {
HostRingTensor::<T>(self.0.into_shape(newshape.0 .0).unwrap(), self.1) }
fn expand_dims(self, mut axis: Vec<usize>) -> Self {
let plc = self.1.clone();
axis.sort_by_key(|ax| Reverse(*ax));
let newshape = self.shape().0.extend_singletons(axis);
self.reshape(HostShape(newshape, plc))
}
}
#[cfg(test)]
impl<T: Clone> HostRingTensor<T> {
pub(crate) fn from_raw_plc<D: ndarray::Dimension, P: Into<HostPlacement>>(
raw_tensor: Array<T, D>,
plc: P,
) -> HostRingTensor<T> {
let tensor = raw_tensor.mapv(Wrapping).into_dyn();
HostRingTensor(tensor.into_shared(), plc.into())
}
}
impl<T> HostRingTensor<T> {
fn shape(&self) -> HostShape {
HostShape(RawShape(self.0.shape().into()), self.1.clone())
}
}
impl From<&HostRingTensor<u64>> for ArrayD<i64> {
fn from(r: &HostRingTensor<u64>) -> ArrayD<i64> {
r.0.mapv(|element| element.0 as i64)
}
}
impl From<&HostRingTensor<u128>> for ArrayD<i128> {
fn from(r: &HostRingTensor<u128>) -> ArrayD<i128> {
r.0.mapv(|element| element.0 as i128)
}
}
impl<T> HostRingTensor<T>
where
Wrapping<T>: Clone + num_traits::Zero,
{
fn sum(self, axis: Option<usize>) -> Result<HostRingTensor<T>> {
if let Some(i) = axis {
Ok(HostRingTensor(
self.0.sum_axis(Axis(i)).into_shared(),
self.1,
))
} else {
let out = Array::from_elem([], self.0.sum()).into_dyn();
Ok(HostRingTensor(out.into_shared(), self.1))
}
}
}
pub trait FromRaw<T, O> {
#![allow(clippy::wrong_self_convention)]
fn from_raw(&self, raw: T) -> O;
}
impl<T: Clone, O> FromRaw<&[T], O> for HostPlacement
where
HostPlacement: FromRaw<Array1<T>, O>,
{
fn from_raw(&self, raw: &[T]) -> O {
self.from_raw(Array::from_vec(raw.to_vec()))
}
}
impl<T, O> FromRaw<Vec<T>, O> for HostPlacement
where
HostPlacement: FromRaw<Array1<T>, O>,
{
fn from_raw(&self, raw: Vec<T>) -> O {
self.from_raw(Array::from_vec(raw))
}
}
impl<T: Clone, D: ndarray::Dimension> FromRaw<Array<T, D>, HostTensor<T>> for HostPlacement {
fn from_raw(&self, raw: Array<T, D>) -> HostTensor<T> {
HostTensor(raw.into_dyn().into_shared(), self.clone())
}
}
impl<T: Clone, D: ndarray::Dimension> FromRaw<Array<T, D>, HostRingTensor<T>> for HostPlacement {
fn from_raw(&self, raw: Array<T, D>) -> HostRingTensor<T> {
HostRingTensor(raw.mapv(Wrapping).into_dyn().into_shared(), self.clone())
}
}
impl<D: ndarray::Dimension> FromRaw<Array<u8, D>, HostBitTensor> for HostPlacement {
fn from_raw(&self, raw: Array<u8, D>) -> HostBitTensor {
let raw = raw.into_dyn();
let data = raw
.as_standard_layout()
.as_slice()
.unwrap()
.iter()
.map(|&ai| ai != 0)
.collect();
HostBitTensor(BitArrayRepr::from_raw(data, raw.dim()), self.clone())
}
}
impl<T: Clone, D: ndarray::Dimension, N: Const> FromRaw<Array<T, D>, HostBitArray<HostBitTensor, N>>
for HostPlacement
where
HostPlacement: FromRaw<Array<T, D>, HostBitTensor>,
{
fn from_raw(&self, raw: Array<T, D>) -> HostBitArray<HostBitTensor, N> {
assert_eq!(raw.shape()[0], N::VALUE);
let raw_bits: HostBitTensor = self.from_raw(raw);
HostBitArray(raw_bits, PhantomData)
}
}
impl FromRaw<RawShape, HostShape> for HostPlacement {
fn from_raw(&self, raw: RawShape) -> HostShape {
HostShape(raw, self.clone())
}
}
impl FromRaw<RawSeed, HostSeed> for HostPlacement {
fn from_raw(&self, raw: RawSeed) -> HostSeed {
HostSeed(raw, self.clone())
}
}
#[cfg(feature = "sync_execute")]
#[cfg(test)]
mod tests {
use super::*;
use crate::prelude::*;
use rstest::rstest;
#[test]
fn test_host_shape_op() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x: HostRing64Tensor = plc.from_raw(array![1024, 5, 4]);
let shape = plc.shape(&sess, &x).0;
let expected = RawShape(vec![3]);
assert_eq!(expected, shape);
}
#[test]
fn dot_prod_f32() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x: HostFloat32Tensor = plc.from_raw(array![[1.0, -2.0], [3.0, -4.0]]);
let z = plc.dot(&sess, &x, &x);
let expected: HostFloat32Tensor = plc.from_raw(array![[-5.0, 6.0], [-9.0, 10.0]]);
assert_eq!(z, expected);
}
#[test]
fn test_inverse() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x: HostFloat32Tensor = plc.from_raw(array![[1.0, 2.0], [3.0, 4.0]]);
let x_inv = plc.inverse(&sess, &x);
let expected: HostFloat32Tensor = plc.from_raw(array![[-2.0, 1.0], [1.5, -0.5]]);
assert_eq!(x_inv, expected);
}
#[test]
fn test_shape_slice() {
let x_shape = RawShape(vec![1, 2, 3]);
let x_slice = x_shape.slice(1, 3);
assert_eq!(x_slice, RawShape(vec![2, 3]))
}
#[test]
fn test_tensor_slice() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x: HostRing64Tensor = plc.from_raw(array![[1, 2], [3, 4]]);
let slice = SliceInfo(vec![
SliceInfoElem {
start: 1,
end: None,
step: None,
},
SliceInfoElem {
start: 0,
end: None,
step: None,
},
]);
let y = plc.slice(&sess, slice, &x);
let expected: HostRing64Tensor = plc.from_raw(array![[3, 4]]);
assert_eq!(y, expected);
}
#[test]
fn test_tensor_slice_neg_indicies() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x: HostRing64Tensor = plc.from_raw(array![[1, 2], [3, 4]]);
let slice = SliceInfo(vec![
SliceInfoElem {
start: -1,
end: None,
step: Some(2),
},
SliceInfoElem {
start: -1,
end: None,
step: Some(2),
},
]);
let y = plc.slice(&sess, slice, &x);
let expected: HostRing64Tensor = plc.from_raw(array![[4]]);
assert_eq!(y, expected);
}
#[test]
#[should_panic]
fn test_tensor_slice_index_out_of_range() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x: HostRing64Tensor = plc.from_raw(array![[1, 2], [3, 4]]);
let slice = SliceInfo(vec![
SliceInfoElem {
start: -1,
end: None,
step: Some(2),
},
SliceInfoElem {
start: -1,
end: None,
step: Some(2),
},
SliceInfoElem {
start: -1,
end: None,
step: Some(2),
},
]);
let _y = plc.slice(&sess, slice, &x);
}
#[test]
fn test_diag() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x: HostFloat64Tensor = plc.from_raw(array![[1.0, 2.0], [3.0, 4.0]]);
let y = plc.diag(&sess, &x);
let expected: HostFloat64Tensor = plc.from_raw(array![1.0, 4.0]);
assert_eq!(y, expected);
}
#[test]
fn test_bit_diag() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x: HostBitTensor = plc.from_raw(array![[1, 1], [1, 0]]);
let y = plc.diag(&sess, &x);
let expected: HostBitTensor = plc.from_raw(array![1, 0]);
assert_eq!(y, expected);
}
#[test]
fn test_index() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x: HostRing64Tensor = plc.from_raw(array![[[1_u64, 2], [3, 4]], [[4, 5], [6, 7]]]);
let y = plc.index_axis(&sess, 0, 1, &x);
let expected: HostRing64Tensor = plc.from_raw(array![[4, 5], [6, 7]]);
assert_eq!(y, expected);
let x: HostRing64Tensor = plc.from_raw(array![0_u64, 1, 2, 3]);
let y = plc.index_axis(&sess, 0, 1, &x);
let expected: HostRing64Tensor = plc.from_raw(array![1]);
let y = plc.expand_dims(&sess, [0].to_vec(), &y);
assert_eq!(y, expected);
}
#[rstest]
#[case(
array![[[0, 1], [0, 0]], [[1, 1], [0, 0]]].into_dyn(),
0,
1,
array![[1, 1], [0, 0]].into_dyn(),
)]
#[case(
array![[[0, 1], [0, 0]], [[1, 1], [0, 0]]].into_dyn(),
0,
0,
array![[0, 1], [0, 0]].into_dyn(),
)]
fn test_index_bit(
#[case] x: ArrayD<u8>,
#[case] axis: usize,
#[case] index: usize,
#[case] expected: ArrayD<u8>,
) {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x: HostBitTensor = plc.from_raw(x);
let y = plc.index_axis(&sess, axis, index, &x);
let expected: HostBitTensor = plc.from_raw(expected);
assert_eq!(y, expected);
}
#[test]
fn test_transpose() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x: HostFloat32Tensor = plc.from_raw(array![[1.0, 2.0], [3.0, 4.0]]);
let y = plc.transpose(&sess, &x);
let expected: HostFloat32Tensor = plc.from_raw(array![[1.0, 3.0], [2.0, 4.0]]);
assert_eq!(y, expected);
}
#[test]
fn test_concatenate() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let a: HostFloat32Tensor = plc.from_raw(array![[[1.0, 2.0], [3.0, 4.0]]]);
let b: HostFloat32Tensor = plc.from_raw(array![[[1.0, 2.0], [3.0, 4.0]]]);
let conc = plc.concatenate(&sess, 0, &vec![a, b]);
let expected: HostFloat32Tensor =
plc.from_raw(array![[[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]]]);
assert_eq!(conc, expected)
}
#[test]
fn test_atleast_2d() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let a: HostFloat32Tensor = plc.from_raw(array![[1.0, 2.0], [3.0, 4.0]]);
let a_exp = a.clone();
let b: HostFloat32Tensor = plc.from_raw(array![1.0, 2.0, 3.0, 4.0]);
let b_exp: HostFloat32Tensor = plc.from_raw(array![[1.0, 2.0, 3.0, 4.0]]);
let c: HostFloat32Tensor = plc.from_raw(array![1.0, 2.0, 3.0, 4.0]);
let c_exp: HostFloat32Tensor = plc.from_raw(array![[1.0], [2.0], [3.0], [4.0]]);
let d: HostFloat32Tensor = plc.from_raw(Array::from_elem([], 1.0));
let d_exp: HostFloat32Tensor = plc.from_raw(array![[1.0]]);
let ax = plc.at_least_2d(&sess, true, &a);
let bx = plc.at_least_2d(&sess, false, &b);
let cx = plc.at_least_2d(&sess, true, &c);
let dx = plc.at_least_2d(&sess, true, &d);
assert_eq!(ax, a_exp);
assert_eq!(bx, b_exp);
assert_eq!(cx, c_exp);
assert_eq!(dx, d_exp);
}
#[test]
fn test_add_broadcasting() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x_1: HostFloat32Tensor = plc.from_raw(array![2.0]);
let y_1: HostFloat32Tensor = plc.from_raw(array![1.0, 2.0]);
let z_1 = plc.add(&sess, &x_1, &y_1);
let z_1_exp: HostFloat32Tensor = plc.from_raw(array![3.0, 4.0]);
assert_eq!(z_1, z_1_exp);
let x_2: HostFloat32Tensor = plc.from_raw(array![1.0, 2.0]);
let y_2: HostFloat32Tensor = plc.from_raw(array![2.0]);
let z_2 = plc.add(&sess, &x_2, &y_2);
let z_2_exp: HostFloat32Tensor = plc.from_raw(array![3.0, 4.0]);
assert_eq!(z_2, z_2_exp);
}
#[test]
fn test_sub_broadcasting() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x_1: HostFloat32Tensor = plc.from_raw(array![2.0]);
let y_1: HostFloat32Tensor = plc.from_raw(array![1.0, 2.0]);
let z_1 = plc.sub(&sess, &x_1, &y_1);
let z_1_exp: HostFloat32Tensor = plc.from_raw(array![1.0, 0.0]);
assert_eq!(z_1, z_1_exp);
let x_2: HostFloat32Tensor = plc.from_raw(array![1.0, 2.0]);
let y_2: HostFloat32Tensor = plc.from_raw(array![2.0]);
let z_2 = plc.sub(&sess, &x_2, &y_2);
let z_2_exp: HostFloat32Tensor = plc.from_raw(array![-1.0, 0.0]);
assert_eq!(z_2, z_2_exp);
}
#[test]
fn test_mul_broadcasting() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x_1: HostFloat32Tensor = plc.from_raw(array![2.0]);
let y_1: HostFloat32Tensor = plc.from_raw(array![1.0, 2.0]);
let z_1 = plc.mul(&sess, &x_1, &y_1);
let z_1_exp: HostFloat32Tensor = plc.from_raw(array![2.0, 4.0]);
assert_eq!(z_1, z_1_exp);
let x_2: HostFloat32Tensor = plc.from_raw(array![1.0, 2.0]);
let y_2: HostFloat32Tensor = plc.from_raw(array![2.0]);
let z_2 = plc.mul(&sess, &x_2, &y_2);
let z_2_exp: HostFloat32Tensor = plc.from_raw(array![2.0, 4.0]);
assert_eq!(z_2, z_2_exp);
}
#[test]
fn test_div_broadcasting() {
let sess = SyncSession::default();
let plc = HostPlacement::from("host");
let x_1: HostFloat32Tensor = plc.from_raw(array![1.0]);
let y_1: HostFloat32Tensor = plc.from_raw(array![2.0, 4.0]);
let z_1 = plc.div(&sess, &x_1, &y_1);
let z_1_exp: HostFloat32Tensor = plc.from_raw(array![0.5, 0.25]);
assert_eq!(z_1, z_1_exp);
let x_2: HostFloat32Tensor = plc.from_raw(array![2.0, 4.0]);
let y_2: HostFloat32Tensor = plc.from_raw(array![2.0]);
let z_2 = plc.div(&sess, &x_2, &y_2);
let z_2_exp: HostFloat32Tensor = plc.from_raw(array![1.0, 2.0]);
assert_eq!(z_2, z_2_exp);
}
#[test]
fn test_kernel_sqrt() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostFloat64Tensor = plc.from_raw(array![[4.0, 9.0], [16.0, 25.0]]);
let sqrt = plc.sqrt(&sess, &x);
let expected: HostFloat64Tensor = plc.from_raw(array![[2.0, 3.0], [4.0, 5.0]]);
assert_eq!(expected, sqrt)
}
#[rstest]
#[case(None)]
#[case(Some(2))]
fn test_kernel_squeeze(#[case] axis: Option<usize>) {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostFloat64Tensor = plc.from_raw(array![[1.0, 2.0], [3.0, 4.0]]);
let x_expanded = plc.expand_dims(&sess, vec![2], &x);
let x_squeezed = plc.squeeze(&sess, axis, &x_expanded);
let shape = plc.shape(&sess, &x_squeezed);
let expected = RawShape(vec![2, 2]);
assert_eq!(expected, shape.0)
}
#[test]
fn test_kernel_transpose() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostFloat64Tensor = plc.from_raw(array![[1.0, 2.0], [3.0, 4.0]]);
let t = plc.transpose(&sess, &x);
let expected: HostFloat64Tensor = plc.from_raw(array![[1.0, 3.0], [2.0, 4.0]]);
assert_eq!(expected, t);
}
#[test]
fn test_kernel_concatenate() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostFloat64Tensor = plc.from_raw(array![[1.0, 2.0], [3.0, 4.0]]);
let y: HostFloat64Tensor = plc.from_raw(array![[5.0, 6.0], [7.0, 8.0]]);
let c = plc.concatenate(&sess, 0, &[x, y]);
let expected: HostFloat64Tensor =
plc.from_raw(array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]);
assert_eq!(expected, c);
}
#[test]
fn bit_sample() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let shape: HostShape = plc.from_raw(RawShape(vec![5]));
let seed: HostSeed = plc.from_raw(RawSeed([0u8; 16]));
let r: HostBitTensor = plc.sample_uniform_seeded(&sess, &shape, &seed);
let expected: HostBitTensor = plc.from_raw(array![0, 1, 1, 0, 0]);
assert_eq!(r, expected);
}
#[test]
fn bit_fill() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let shape: HostShape = plc.from_raw(RawShape(vec![2]));
let x: HostBitTensor = plc.fill(&sess, 1_u8.into(), &shape);
let expected: HostBitTensor = plc.from_raw(array![1, 1]);
assert_eq!(x, expected);
}
#[test]
fn bit_xor() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let shape: HostShape = plc.from_raw(RawShape(vec![5]));
let zero = plc.fill(&sess, 0_u8.into(), &shape);
let one = plc.fill(&sess, 1_u8.into(), &shape);
assert_eq!(&plc.xor(&sess, &zero, &one), &one);
assert_eq!(&plc.xor(&sess, &one, &zero), &one);
assert_eq!(&plc.xor(&sess, &one, &one), &zero);
assert_eq!(&plc.xor(&sess, &zero, &zero), &zero);
}
#[test]
fn bit_or() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let shape: HostShape = plc.from_raw(RawShape(vec![5]));
let zero: HostBitTensor = plc.fill(&sess, 0_u8.into(), &shape);
let one: HostBitTensor = plc.fill(&sess, 1_u8.into(), &shape);
assert_eq!(&plc.or(&sess, &zero, &one), &one);
assert_eq!(&plc.or(&sess, &one, &zero), &one);
assert_eq!(&plc.or(&sess, &one, &one), &one);
assert_eq!(&plc.or(&sess, &zero, &zero), &zero);
}
#[test]
fn bit_and() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let shape: HostShape = plc.from_raw(RawShape(vec![5]));
let zero: HostBitTensor = plc.fill(&sess, 0_u8.into(), &shape);
let one: HostBitTensor = plc.fill(&sess, 1_u8.into(), &shape);
assert_eq!(&plc.and(&sess, &zero, &one), &zero);
assert_eq!(&plc.and(&sess, &one, &zero), &zero);
assert_eq!(&plc.and(&sess, &one, &one), &one);
assert_eq!(&plc.and(&sess, &zero, &zero), &zero);
}
#[test]
fn ring_matrix_vector_prod() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostRing64Tensor = plc.from_raw(array![[1, 2], [3, 4]]);
let y: HostRing64Tensor = plc.from_raw(array![1, 1]);
let z = plc.dot(&sess, &x, &y);
let expected: HostRing64Tensor = plc.from_raw(array![3, 7]);
assert_eq!(expected, z);
}
#[test]
fn ring_matrix_matrix_prod() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostRing64Tensor = plc.from_raw(array![[1, 2], [3, 4]]);
let y: HostRing64Tensor = plc.from_raw(array![[1, 0], [0, 1]]);
let z = plc.dot(&sess, &x, &y);
let expected: HostRing64Tensor = plc.from_raw(array![[1, 2], [3, 4]]);
assert_eq!(expected, z);
}
#[test]
fn ring_vector_prod() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostRing64Tensor = plc.from_raw(array![1, 2]);
let y: HostRing64Tensor = plc.from_raw(array![1, 1]);
let z = plc.dot(&sess, &x, &y);
let expected: HostRing64Tensor = plc.from_raw(Array::from_elem([], 3));
assert_eq!(expected, z);
}
#[test]
fn ring_sample() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let shape = plc.from_raw(RawShape(vec![5]));
let seed = plc.from_raw(RawSeed([0u8; 16]));
let r64: HostRing64Tensor = plc.sample_uniform_seeded(&sess, &shape, &seed);
assert_eq!(
r64,
plc.from_raw(array![
4263935709876578662,
3326810793440857224,
17325099178452873543,
15208531650305571673,
9619880027406922172
])
);
let r128: HostRing128Tensor = plc.sample_uniform_seeded(&sess, &shape, &seed);
assert_eq!(
r128,
plc.from_raw(array![
78655730786844307471556614669614075016,
319591670596555766473793801091584867161,
177455464885365520564027128957528354027,
72628979995024532377123578937486303732,
299726520301515014350190124791858941972
])
);
let r64_bits: HostRing64Tensor = plc.sample_bits_seeded(&sess, &shape, &seed);
assert_eq!(r64_bits, plc.from_raw(array![0, 1, 1, 0, 0]));
let r128_bits: HostRing128Tensor = plc.sample_bits_seeded(&sess, &shape, &seed);
assert_eq!(r128_bits, plc.from_raw(array![0, 1, 1, 0, 0]));
}
#[test]
fn ring_fill() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let shape = plc.from_raw(RawShape(vec![2]));
let r: HostRing64Tensor = plc.fill(&sess, 1_u64.into(), &shape);
assert_eq!(r, plc.from_raw(array![1, 1]));
}
#[test]
fn ring_sum_with_axis() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostRing64Tensor = plc.from_raw(array![[1, 2], [3, 4]]);
let out = plc.sum(&sess, Some(0), &x);
assert_eq!(out, plc.from_raw(array![4, 6]))
}
#[test]
fn ring_sum_without_axis() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostRing64Tensor = plc.from_raw(array![[1, 2], [3, 4]]);
let out = plc.sum(&sess, None, &x);
assert_eq!(out, plc.from_raw(Array::from_elem([], 10_u64)))
}
#[test]
fn ring_add_n() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostRing64Tensor = plc.from_raw(array![[1, 4], [9, 16], [25, 36]]);
let expected: HostRing64Tensor = plc.from_raw(array![[1, 4], [9, 16], [25, 36]]);
let out = plc.add_n(&sess, &[x]);
assert_eq!(out, expected);
let x: HostRing64Tensor = plc.from_raw(array![[1, 4], [9, 16], [25, 36]]);
let y: HostRing64Tensor = plc.from_raw(array![[1, 3], [6, 10], [15, 21]]);
let z: HostRing64Tensor = plc.from_raw(array![[1, 36], [1225, 41616], [1413721, 48024900]]);
let expected: HostRing64Tensor =
plc.from_raw(array![[3, 43], [1240, 41642], [1413761, 48024957]]);
let out = plc.add_n(&sess, &[x, y, z]);
assert_eq!(out, expected);
let w: HostRing128Tensor = plc.from_raw(array![[6, 3, 10], [5, 16, 8], [4, 2, 1]]);
let x: HostRing128Tensor = plc.from_raw(array![[40, 20, 10], [5, 16, 8], [4, 2, 1]]);
let y: HostRing128Tensor = plc.from_raw(array![[42, 21, 64], [32, 16, 8], [4, 2, 1]]);
let z: HostRing128Tensor = plc.from_raw(array![[256, 128, 64], [32, 16, 8], [4, 2, 1]]);
let expected: HostRing128Tensor =
plc.from_raw(array![[344, 172, 148], [74, 64, 32], [16, 8, 4]]);
let out = plc.add_n(&sess, &[w, x, y, z]);
assert_eq!(out, expected);
}
#[test]
fn bit_extract() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let shape = plc.from_raw(RawShape(vec![5]));
let value: HostRing64Tensor = plc.fill(&sess, 7_u64.into(), &shape);
let r0 = plc.bit_extract(&sess, 0, &value);
let r0_expected: HostBitTensor = plc.from_raw(array![1, 1, 1, 1, 1]);
assert_eq!(r0, r0_expected);
let r1 = plc.bit_extract(&sess, 1, &value);
let r1_expected: HostBitTensor = plc.from_raw(array![1, 1, 1, 1, 1]);
assert_eq!(r1, r1_expected);
let r2 = plc.bit_extract(&sess, 2, &value);
let r2_expected: HostBitTensor = plc.from_raw(array![1, 1, 1, 1, 1]);
assert_eq!(r2, r2_expected);
let r3 = plc.bit_extract(&sess, 3, &value);
let r3_expected: HostBitTensor = plc.from_raw(array![0, 0, 0, 0, 0]);
assert_eq!(r3, r3_expected);
}
#[test]
fn bit_decompose1() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostRing64Tensor = plc.from_raw(array![[[1, 2], [3, 4]], [[4, 5], [6, 7]]]);
let x_bits: HostBitTensor = plc.bit_decompose(&sess, &x);
let targets: Vec<_> = (0..64).map(|i| plc.bit_extract(&sess, i, &x)).collect();
for (i, target) in targets.iter().enumerate() {
let sliced = plc.index_axis(&sess, 0, i, &x_bits);
assert_eq!(&sliced, target);
}
}
#[test]
fn bit_decompose2() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostRing64Tensor = plc.from_raw(array![6743216615002642708]);
let x_bits: HostBitTensor = plc.bit_decompose(&sess, &x);
let expected: HostBitTensor = plc.from_raw(
array![
0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1,
1, 0, 1, 1, 1, 0, 1, 0
]
.into_shape((64, 1))
.unwrap(),
);
assert_eq!(x_bits, expected);
}
#[test]
fn bit_shl_dim() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostRing64Tensor = plc.from_raw(array![6743216615002642708]);
let x_bits: HostBitTensor = plc.bit_decompose(&sess, &x);
let x_bits = plc.shl_dim(&sess, 6, 64, &x_bits);
let expected: HostBitTensor = plc.from_raw(
array![
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1,
0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0,
1, 0, 1, 0, 0, 1, 1, 0,
]
.into_shape((64, 1))
.unwrap(),
);
assert_eq!(x_bits, expected);
}
#[test]
fn bit_transpose() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostRing64Tensor = plc.from_raw(array![6743216615002642708]);
let x_bits: HostBitTensor = plc.bit_decompose(&sess, &x);
let x_bits = plc.transpose(&sess, &x_bits);
let expected: HostBitTensor = plc.from_raw(
array![
0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1,
1, 0, 1, 1, 1, 0, 1, 0
]
.into_shape((64, 1))
.unwrap()
.reversed_axes(),
);
assert_eq!(x_bits, expected);
}
#[test]
fn bit_transpose2() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostBitTensor =
plc.from_raw(array![[1, 1, 1], [0, 0, 0]].into_shape((2, 3)).unwrap());
let xt = plc.transpose(&sess, &x);
let expected: HostBitTensor =
plc.from_raw(array![[1, 0], [1, 0], [1, 0]].into_shape((3, 2)).unwrap());
assert_eq!(xt, expected);
}
#[test]
fn bit_transpose3() {
let plc = HostPlacement::from("host");
let sess = SyncSession::default();
let x: HostBitTensor = plc.from_raw(
array![[[1, 1, 1], [0, 0, 0]], [[0, 1, 0], [1, 0, 0]]]
.into_shape((2, 2, 3))
.unwrap(),
);
let xt = plc.transpose(&sess, &x);
let expected: HostBitTensor = plc.from_raw(
array![[[1, 1, 1], [0, 0, 0]], [[0, 1, 0], [1, 0, 0]]]
.into_shape((2, 2, 3))
.unwrap()
.reversed_axes(),
);
assert_eq!(xt, expected);
}
#[test]
fn test_host_mul() {
let plc = HostPlacement::from("host");
let x: HostRing128Tensor = plc.from_raw(array![340282366920938463463374415046855271599]);
let sess = SyncSession::default();
let y = plc.mul(&sess, &x, &x);
let expected: HostRing128Tensor = plc.from_raw(array![37011954726876357358499180449]);
assert_eq!(y, expected);
}
}